mrfakename commited on
Commit
ac2566f
·
verified ·
1 Parent(s): e2c50c8

Update modeling_moonshot.py

Browse files
Files changed (1) hide show
  1. modeling_moonshot.py +36 -12
modeling_moonshot.py CHANGED
@@ -427,21 +427,27 @@ class Attention(nn.Module):
427
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
428
  ):
429
  # Standard scaled dot-product attention
430
- batch_size, q_length, num_heads, head_dim = query_states.shape
431
- _, kv_length, num_kv_heads, _ = key_states.shape
432
-
433
- # Handle grouped-query attention by repeating k/v heads if necessary
434
- if num_kv_heads != num_heads:
435
- # Each query head uses the corresponding key-value head (num_heads // num_kv_heads) times
436
- key_states = repeat_kv(key_states, self.num_key_value_groups)
437
- value_states = repeat_kv(value_states, self.num_key_value_groups)
438
-
439
- # Prepare for attention computation (batch_size, num_heads, seq_length, head_dim)
440
  query_states = query_states.transpose(1, 2)
 
 
441
  key_states = key_states.transpose(1, 2)
442
  value_states = value_states.transpose(1, 2)
 
 
 
 
 
 
 
443
 
444
- # (batch_size, num_heads, query_length, key_length)
 
 
 
445
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
446
 
447
  if softmax_scale is None:
@@ -449,15 +455,33 @@ class Attention(nn.Module):
449
  attn_weights = attn_weights * softmax_scale
450
 
451
  if attention_mask is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  attn_weights = attn_weights + attention_mask
453
 
454
  # Apply softmax and dropout
455
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
456
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
457
 
458
- # Context vectors
459
  attn_output = torch.matmul(attn_weights, value_states)
 
 
460
  attn_output = attn_output.transpose(1, 2).contiguous()
 
461
 
462
  return attn_output
463
 
 
427
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
428
  ):
429
  # Standard scaled dot-product attention
430
+ batch_size, q_len, num_heads, head_dim = query_states.shape
431
+ bsz, kv_seq_len, num_kv_heads, _ = key_states.shape
432
+
433
+ # Transpose query states for matmul: (batch_size, num_heads, q_len, head_dim)
 
 
 
 
 
 
434
  query_states = query_states.transpose(1, 2)
435
+
436
+ # Transpose key/value states for repeat_kv: (batch_size, num_kv_heads, kv_seq_len, head_dim)
437
  key_states = key_states.transpose(1, 2)
438
  value_states = value_states.transpose(1, 2)
439
+
440
+ # Handle grouped-query attention by repeating k/v heads if necessary
441
+ # repeat_kv expects (batch, num_key_value_heads, slen, head_dim)
442
+ # repeat_kv outputs (batch, num_attention_heads, slen, head_dim)
443
+ if self.num_key_value_groups > 1:
444
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
445
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
446
 
447
+ # key_states is now (batch_size, num_heads, kv_seq_len, head_dim)
448
+ # value_states is now (batch_size, num_heads, kv_seq_len, head_dim)
449
+
450
+ # Attention score calculation: (batch_size, num_heads, q_len, kv_seq_len)
451
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
452
 
453
  if softmax_scale is None:
 
455
  attn_weights = attn_weights * softmax_scale
456
 
457
  if attention_mask is not None:
458
+ # The attention mask passed from _flash_attention_forward is the padding_mask
459
+ # which is (batch_size, seq_len). We need the causal mask prepared in the main forward pass.
460
+ # This part needs adjustment depending on how the causal mask is passed.
461
+ # For now, assuming the correct mask is passed somehow.
462
+ # If attention_mask is the padding mask, it needs expanding and causal masking added.
463
+ # This standard attention path currently doesn't receive the full causal mask.
464
+
465
+ # Let's log a warning for now as this mask handling is likely incorrect
466
+ # compared to the original Llama attention or FlashAttention's causal=True
467
+ logger.warning_once(
468
+ "Standard attention mask handling might be incomplete. "
469
+ "Ensure the correct causal mask is being used if not using Flash Attention."
470
+ )
471
+ # Assuming attention_mask is already the correct shape [bsz, 1, q_len, kv_seq_len]
472
+ # If it's the padding mask [bsz, kv_seq_len], it needs expansion + causal.
473
  attn_weights = attn_weights + attention_mask
474
 
475
  # Apply softmax and dropout
476
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
477
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training)
478
 
479
+ # Context vectors: (batch_size, num_heads, q_len, head_dim)
480
  attn_output = torch.matmul(attn_weights, value_states)
481
+
482
+ # Reshape to original format: (batch_size, num_heads, q_len, head_dim) -> (batch_size, q_len, hidden_size)
483
  attn_output = attn_output.transpose(1, 2).contiguous()
484
+ # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) # This reshape happens outside this function
485
 
486
  return attn_output
487