Update modeling_moonshot.py
Browse files- modeling_moonshot.py +36 -12
modeling_moonshot.py
CHANGED
@@ -427,21 +427,27 @@ class Attention(nn.Module):
|
|
427 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
428 |
):
|
429 |
# Standard scaled dot-product attention
|
430 |
-
batch_size,
|
431 |
-
|
432 |
-
|
433 |
-
#
|
434 |
-
if num_kv_heads != num_heads:
|
435 |
-
# Each query head uses the corresponding key-value head (num_heads // num_kv_heads) times
|
436 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
437 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
438 |
-
|
439 |
-
# Prepare for attention computation (batch_size, num_heads, seq_length, head_dim)
|
440 |
query_states = query_states.transpose(1, 2)
|
|
|
|
|
441 |
key_states = key_states.transpose(1, 2)
|
442 |
value_states = value_states.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
|
444 |
-
# (batch_size, num_heads,
|
|
|
|
|
|
|
445 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
446 |
|
447 |
if softmax_scale is None:
|
@@ -449,15 +455,33 @@ class Attention(nn.Module):
|
|
449 |
attn_weights = attn_weights * softmax_scale
|
450 |
|
451 |
if attention_mask is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
attn_weights = attn_weights + attention_mask
|
453 |
|
454 |
# Apply softmax and dropout
|
455 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
456 |
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
457 |
|
458 |
-
# Context vectors
|
459 |
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
|
460 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
461 |
|
462 |
return attn_output
|
463 |
|
|
|
427 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
428 |
):
|
429 |
# Standard scaled dot-product attention
|
430 |
+
batch_size, q_len, num_heads, head_dim = query_states.shape
|
431 |
+
bsz, kv_seq_len, num_kv_heads, _ = key_states.shape
|
432 |
+
|
433 |
+
# Transpose query states for matmul: (batch_size, num_heads, q_len, head_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
query_states = query_states.transpose(1, 2)
|
435 |
+
|
436 |
+
# Transpose key/value states for repeat_kv: (batch_size, num_kv_heads, kv_seq_len, head_dim)
|
437 |
key_states = key_states.transpose(1, 2)
|
438 |
value_states = value_states.transpose(1, 2)
|
439 |
+
|
440 |
+
# Handle grouped-query attention by repeating k/v heads if necessary
|
441 |
+
# repeat_kv expects (batch, num_key_value_heads, slen, head_dim)
|
442 |
+
# repeat_kv outputs (batch, num_attention_heads, slen, head_dim)
|
443 |
+
if self.num_key_value_groups > 1:
|
444 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
445 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
446 |
|
447 |
+
# key_states is now (batch_size, num_heads, kv_seq_len, head_dim)
|
448 |
+
# value_states is now (batch_size, num_heads, kv_seq_len, head_dim)
|
449 |
+
|
450 |
+
# Attention score calculation: (batch_size, num_heads, q_len, kv_seq_len)
|
451 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
452 |
|
453 |
if softmax_scale is None:
|
|
|
455 |
attn_weights = attn_weights * softmax_scale
|
456 |
|
457 |
if attention_mask is not None:
|
458 |
+
# The attention mask passed from _flash_attention_forward is the padding_mask
|
459 |
+
# which is (batch_size, seq_len). We need the causal mask prepared in the main forward pass.
|
460 |
+
# This part needs adjustment depending on how the causal mask is passed.
|
461 |
+
# For now, assuming the correct mask is passed somehow.
|
462 |
+
# If attention_mask is the padding mask, it needs expanding and causal masking added.
|
463 |
+
# This standard attention path currently doesn't receive the full causal mask.
|
464 |
+
|
465 |
+
# Let's log a warning for now as this mask handling is likely incorrect
|
466 |
+
# compared to the original Llama attention or FlashAttention's causal=True
|
467 |
+
logger.warning_once(
|
468 |
+
"Standard attention mask handling might be incomplete. "
|
469 |
+
"Ensure the correct causal mask is being used if not using Flash Attention."
|
470 |
+
)
|
471 |
+
# Assuming attention_mask is already the correct shape [bsz, 1, q_len, kv_seq_len]
|
472 |
+
# If it's the padding mask [bsz, kv_seq_len], it needs expansion + causal.
|
473 |
attn_weights = attn_weights + attention_mask
|
474 |
|
475 |
# Apply softmax and dropout
|
476 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
477 |
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
|
478 |
|
479 |
+
# Context vectors: (batch_size, num_heads, q_len, head_dim)
|
480 |
attn_output = torch.matmul(attn_weights, value_states)
|
481 |
+
|
482 |
+
# Reshape to original format: (batch_size, num_heads, q_len, head_dim) -> (batch_size, q_len, hidden_size)
|
483 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
484 |
+
# attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) # This reshape happens outside this function
|
485 |
|
486 |
return attn_output
|
487 |
|