mrfakename commited on
Commit
e2c50c8
·
verified ·
1 Parent(s): 8cf9e16

Update modeling_moonshot.py

Browse files
Files changed (1) hide show
  1. modeling_moonshot.py +8 -2
modeling_moonshot.py CHANGED
@@ -428,9 +428,15 @@ class Attention(nn.Module):
428
  ):
429
  # Standard scaled dot-product attention
430
  batch_size, q_length, num_heads, head_dim = query_states.shape
 
431
 
432
- # Prepare the query, key, value for attention computation
433
- # (batch_size, num_heads, seq_length, head_dim)
 
 
 
 
 
434
  query_states = query_states.transpose(1, 2)
435
  key_states = key_states.transpose(1, 2)
436
  value_states = value_states.transpose(1, 2)
 
428
  ):
429
  # Standard scaled dot-product attention
430
  batch_size, q_length, num_heads, head_dim = query_states.shape
431
+ _, kv_length, num_kv_heads, _ = key_states.shape
432
 
433
+ # Handle grouped-query attention by repeating k/v heads if necessary
434
+ if num_kv_heads != num_heads:
435
+ # Each query head uses the corresponding key-value head (num_heads // num_kv_heads) times
436
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
437
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
438
+
439
+ # Prepare for attention computation (batch_size, num_heads, seq_length, head_dim)
440
  query_states = query_states.transpose(1, 2)
441
  key_states = key_states.transpose(1, 2)
442
  value_states = value_states.transpose(1, 2)