Update modeling_moonshot.py
Browse files- modeling_moonshot.py +36 -12
modeling_moonshot.py
CHANGED
|
@@ -427,21 +427,27 @@ class Attention(nn.Module):
|
|
| 427 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
| 428 |
):
|
| 429 |
# Standard scaled dot-product attention
|
| 430 |
-
batch_size,
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
#
|
| 434 |
-
if num_kv_heads != num_heads:
|
| 435 |
-
# Each query head uses the corresponding key-value head (num_heads // num_kv_heads) times
|
| 436 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 437 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 438 |
-
|
| 439 |
-
# Prepare for attention computation (batch_size, num_heads, seq_length, head_dim)
|
| 440 |
query_states = query_states.transpose(1, 2)
|
|
|
|
|
|
|
| 441 |
key_states = key_states.transpose(1, 2)
|
| 442 |
value_states = value_states.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
-
# (batch_size, num_heads,
|
|
|
|
|
|
|
|
|
|
| 445 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
| 446 |
|
| 447 |
if softmax_scale is None:
|
|
@@ -449,15 +455,33 @@ class Attention(nn.Module):
|
|
| 449 |
attn_weights = attn_weights * softmax_scale
|
| 450 |
|
| 451 |
if attention_mask is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
attn_weights = attn_weights + attention_mask
|
| 453 |
|
| 454 |
# Apply softmax and dropout
|
| 455 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 456 |
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
| 457 |
|
| 458 |
-
# Context vectors
|
| 459 |
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
|
|
|
| 460 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
| 461 |
|
| 462 |
return attn_output
|
| 463 |
|
|
|
|
| 427 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
| 428 |
):
|
| 429 |
# Standard scaled dot-product attention
|
| 430 |
+
batch_size, q_len, num_heads, head_dim = query_states.shape
|
| 431 |
+
bsz, kv_seq_len, num_kv_heads, _ = key_states.shape
|
| 432 |
+
|
| 433 |
+
# Transpose query states for matmul: (batch_size, num_heads, q_len, head_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
query_states = query_states.transpose(1, 2)
|
| 435 |
+
|
| 436 |
+
# Transpose key/value states for repeat_kv: (batch_size, num_kv_heads, kv_seq_len, head_dim)
|
| 437 |
key_states = key_states.transpose(1, 2)
|
| 438 |
value_states = value_states.transpose(1, 2)
|
| 439 |
+
|
| 440 |
+
# Handle grouped-query attention by repeating k/v heads if necessary
|
| 441 |
+
# repeat_kv expects (batch, num_key_value_heads, slen, head_dim)
|
| 442 |
+
# repeat_kv outputs (batch, num_attention_heads, slen, head_dim)
|
| 443 |
+
if self.num_key_value_groups > 1:
|
| 444 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 445 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 446 |
|
| 447 |
+
# key_states is now (batch_size, num_heads, kv_seq_len, head_dim)
|
| 448 |
+
# value_states is now (batch_size, num_heads, kv_seq_len, head_dim)
|
| 449 |
+
|
| 450 |
+
# Attention score calculation: (batch_size, num_heads, q_len, kv_seq_len)
|
| 451 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
| 452 |
|
| 453 |
if softmax_scale is None:
|
|
|
|
| 455 |
attn_weights = attn_weights * softmax_scale
|
| 456 |
|
| 457 |
if attention_mask is not None:
|
| 458 |
+
# The attention mask passed from _flash_attention_forward is the padding_mask
|
| 459 |
+
# which is (batch_size, seq_len). We need the causal mask prepared in the main forward pass.
|
| 460 |
+
# This part needs adjustment depending on how the causal mask is passed.
|
| 461 |
+
# For now, assuming the correct mask is passed somehow.
|
| 462 |
+
# If attention_mask is the padding mask, it needs expanding and causal masking added.
|
| 463 |
+
# This standard attention path currently doesn't receive the full causal mask.
|
| 464 |
+
|
| 465 |
+
# Let's log a warning for now as this mask handling is likely incorrect
|
| 466 |
+
# compared to the original Llama attention or FlashAttention's causal=True
|
| 467 |
+
logger.warning_once(
|
| 468 |
+
"Standard attention mask handling might be incomplete. "
|
| 469 |
+
"Ensure the correct causal mask is being used if not using Flash Attention."
|
| 470 |
+
)
|
| 471 |
+
# Assuming attention_mask is already the correct shape [bsz, 1, q_len, kv_seq_len]
|
| 472 |
+
# If it's the padding mask [bsz, kv_seq_len], it needs expansion + causal.
|
| 473 |
attn_weights = attn_weights + attention_mask
|
| 474 |
|
| 475 |
# Apply softmax and dropout
|
| 476 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 477 |
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
| 478 |
|
| 479 |
+
# Context vectors: (batch_size, num_heads, q_len, head_dim)
|
| 480 |
attn_output = torch.matmul(attn_weights, value_states)
|
| 481 |
+
|
| 482 |
+
# Reshape to original format: (batch_size, num_heads, q_len, head_dim) -> (batch_size, q_len, hidden_size)
|
| 483 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 484 |
+
# attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) # This reshape happens outside this function
|
| 485 |
|
| 486 |
return attn_output
|
| 487 |
|