Fix: Respect `is_causal=False` config in forward to enable bidirectional attention

#37
by Bool1020 - opened
Files changed (1) hide show
  1. modeling_qwen.py +4 -4
modeling_qwen.py CHANGED
@@ -350,7 +350,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
350
  past_key_value: Optional[Cache] = None,
351
  output_attentions: bool = False,
352
  use_cache: bool = False,
353
- is_causal: bool = True,
354
  **kwargs,
355
  ):
356
  if "padding_mask" in kwargs:
@@ -646,7 +646,7 @@ class Qwen2SdpaAttention(Qwen2Attention):
646
  past_key_value: Optional[Cache] = None,
647
  output_attentions: bool = False,
648
  use_cache: bool = False,
649
- is_causal: bool = False,
650
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
651
  if output_attentions:
652
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -965,7 +965,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
965
  output_hidden_states: Optional[bool] = None,
966
  return_dict: Optional[bool] = None,
967
  labels: Optional[torch.LongTensor] = None,
968
- is_causal: Optional[bool] = True,
969
  ) -> Union[Tuple, BaseModelOutputWithPast]:
970
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
971
  output_hidden_states = (
@@ -1160,7 +1160,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1160
  output_attentions: Optional[bool] = None,
1161
  output_hidden_states: Optional[bool] = None,
1162
  return_dict: Optional[bool] = None,
1163
- is_causal: Optional[bool] = True,
1164
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1165
  r"""
1166
  Args:
 
350
  past_key_value: Optional[Cache] = None,
351
  output_attentions: bool = False,
352
  use_cache: bool = False,
353
+ is_causal: bool = False,
354
  **kwargs,
355
  ):
356
  if "padding_mask" in kwargs:
 
646
  past_key_value: Optional[Cache] = None,
647
  output_attentions: bool = False,
648
  use_cache: bool = False,
649
+ is_causal: bool = True,
650
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
651
  if output_attentions:
652
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
 
965
  output_hidden_states: Optional[bool] = None,
966
  return_dict: Optional[bool] = None,
967
  labels: Optional[torch.LongTensor] = None,
968
+ is_causal: Optional[bool] = False,
969
  ) -> Union[Tuple, BaseModelOutputWithPast]:
970
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
971
  output_hidden_states = (
 
1160
  output_attentions: Optional[bool] = None,
1161
  output_hidden_states: Optional[bool] = None,
1162
  return_dict: Optional[bool] = None,
1163
+ is_causal: Optional[bool] = False,
1164
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1165
  r"""
1166
  Args: