fix sliding window merging
Browse files- modeling_phi4flash.py +7 -7
modeling_phi4flash.py
CHANGED
@@ -573,7 +573,7 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
573 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
574 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
575 |
|
576 |
-
use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx]
|
577 |
|
578 |
if past_key_value is not None:
|
579 |
|
@@ -710,8 +710,8 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
710 |
softmax_scale=softmax_scale,
|
711 |
causal=causal,
|
712 |
window_size=(
|
713 |
-
self.config.
|
714 |
-
self.config.
|
715 |
),
|
716 |
)
|
717 |
|
@@ -735,8 +735,8 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
735 |
softmax_scale=softmax_scale,
|
736 |
causal=causal,
|
737 |
window_size=(
|
738 |
-
self.config.
|
739 |
-
self.config.
|
740 |
),
|
741 |
)
|
742 |
|
@@ -1085,9 +1085,9 @@ class SambaYDecoderLayer(nn.Module):
|
|
1085 |
residual = residual.to(torch.float32)
|
1086 |
self_attn_weights = None
|
1087 |
else:
|
1088 |
-
if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx]
|
1089 |
if past_key_value is not None and cache_position[0] > 0: # when decoding
|
1090 |
-
attention_mask = attention_mask[:, -self.config.
|
1091 |
#hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
|
1092 |
# Self Attention
|
1093 |
attn_outputs, self_attn_weights, yoco_key_values = self.attn(
|
|
|
573 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
574 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
575 |
|
576 |
+
use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] == "sliding_attention"
|
577 |
|
578 |
if past_key_value is not None:
|
579 |
|
|
|
710 |
softmax_scale=softmax_scale,
|
711 |
causal=causal,
|
712 |
window_size=(
|
713 |
+
self.config.sliding_window -1,
|
714 |
+
self.config.sliding_window -1,
|
715 |
),
|
716 |
)
|
717 |
|
|
|
735 |
softmax_scale=softmax_scale,
|
736 |
causal=causal,
|
737 |
window_size=(
|
738 |
+
self.config.sliding_window -1,
|
739 |
+
self.config.sliding_window -1,
|
740 |
),
|
741 |
)
|
742 |
|
|
|
1085 |
residual = residual.to(torch.float32)
|
1086 |
self_attn_weights = None
|
1087 |
else:
|
1088 |
+
if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] == "sliding_attention" and attention_mask is not None: # efficient SDPA and no padding
|
1089 |
if past_key_value is not None and cache_position[0] > 0: # when decoding
|
1090 |
+
attention_mask = attention_mask[:, -self.config.sliding_window:]
|
1091 |
#hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
|
1092 |
# Self Attention
|
1093 |
attn_outputs, self_attn_weights, yoco_key_values = self.attn(
|