Update modeling_moonshot.py
Browse files- modeling_moonshot.py +8 -2
modeling_moonshot.py
CHANGED
@@ -428,9 +428,15 @@ class Attention(nn.Module):
|
|
428 |
):
|
429 |
# Standard scaled dot-product attention
|
430 |
batch_size, q_length, num_heads, head_dim = query_states.shape
|
|
|
431 |
|
432 |
-
#
|
433 |
-
|
|
|
|
|
|
|
|
|
|
|
434 |
query_states = query_states.transpose(1, 2)
|
435 |
key_states = key_states.transpose(1, 2)
|
436 |
value_states = value_states.transpose(1, 2)
|
|
|
428 |
):
|
429 |
# Standard scaled dot-product attention
|
430 |
batch_size, q_length, num_heads, head_dim = query_states.shape
|
431 |
+
_, kv_length, num_kv_heads, _ = key_states.shape
|
432 |
|
433 |
+
# Handle grouped-query attention by repeating k/v heads if necessary
|
434 |
+
if num_kv_heads != num_heads:
|
435 |
+
# Each query head uses the corresponding key-value head (num_heads // num_kv_heads) times
|
436 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
437 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
438 |
+
|
439 |
+
# Prepare for attention computation (batch_size, num_heads, seq_length, head_dim)
|
440 |
query_states = query_states.transpose(1, 2)
|
441 |
key_states = key_states.transpose(1, 2)
|
442 |
value_states = value_states.transpose(1, 2)
|