renll commited on
Commit
5aeaea6
·
verified ·
1 Parent(s): 571031a

fix sliding window merging

Browse files
Files changed (1) hide show
  1. modeling_phi4flash.py +7 -7
modeling_phi4flash.py CHANGED
@@ -573,7 +573,7 @@ class SambaYFlashAttention2(SambaYAttention):
573
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
575
 
576
- use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] is not None
577
 
578
  if past_key_value is not None:
579
 
@@ -710,8 +710,8 @@ class SambaYFlashAttention2(SambaYAttention):
710
  softmax_scale=softmax_scale,
711
  causal=causal,
712
  window_size=(
713
- self.config.layer_types[self.layer_idx] -1,
714
- self.config.layer_types[self.layer_idx] -1,
715
  ),
716
  )
717
 
@@ -735,8 +735,8 @@ class SambaYFlashAttention2(SambaYAttention):
735
  softmax_scale=softmax_scale,
736
  causal=causal,
737
  window_size=(
738
- self.config.layer_types[self.layer_idx] -1,
739
- self.config.layer_types[self.layer_idx] -1,
740
  ),
741
  )
742
 
@@ -1085,9 +1085,9 @@ class SambaYDecoderLayer(nn.Module):
1085
  residual = residual.to(torch.float32)
1086
  self_attn_weights = None
1087
  else:
1088
- if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] is not None and attention_mask is not None: # efficient SDPA and no padding
1089
  if past_key_value is not None and cache_position[0] > 0: # when decoding
1090
- attention_mask = attention_mask[:, -self.config.layer_types[self.layer_idx]:]
1091
  #hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
1092
  # Self Attention
1093
  attn_outputs, self_attn_weights, yoco_key_values = self.attn(
 
573
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
575
 
576
+ use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] == "sliding_attention"
577
 
578
  if past_key_value is not None:
579
 
 
710
  softmax_scale=softmax_scale,
711
  causal=causal,
712
  window_size=(
713
+ self.config.sliding_window -1,
714
+ self.config.sliding_window -1,
715
  ),
716
  )
717
 
 
735
  softmax_scale=softmax_scale,
736
  causal=causal,
737
  window_size=(
738
+ self.config.sliding_window -1,
739
+ self.config.sliding_window -1,
740
  ),
741
  )
742
 
 
1085
  residual = residual.to(torch.float32)
1086
  self_attn_weights = None
1087
  else:
1088
+ if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] == "sliding_attention" and attention_mask is not None: # efficient SDPA and no padding
1089
  if past_key_value is not None and cache_position[0] > 0: # when decoding
1090
+ attention_mask = attention_mask[:, -self.config.sliding_window:]
1091
  #hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
1092
  # Self Attention
1093
  attn_outputs, self_attn_weights, yoco_key_values = self.attn(