Fix: Respect `is_causal=False` config in forward to enable bidirectional attention
Browse files- modeling_qwen.py +4 -4
modeling_qwen.py
CHANGED
@@ -350,7 +350,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
|
350 |
past_key_value: Optional[Cache] = None,
|
351 |
output_attentions: bool = False,
|
352 |
use_cache: bool = False,
|
353 |
-
is_causal: bool =
|
354 |
**kwargs,
|
355 |
):
|
356 |
if "padding_mask" in kwargs:
|
@@ -646,7 +646,7 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|
646 |
past_key_value: Optional[Cache] = None,
|
647 |
output_attentions: bool = False,
|
648 |
use_cache: bool = False,
|
649 |
-
is_causal: bool =
|
650 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
651 |
if output_attentions:
|
652 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
@@ -965,7 +965,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
965 |
output_hidden_states: Optional[bool] = None,
|
966 |
return_dict: Optional[bool] = None,
|
967 |
labels: Optional[torch.LongTensor] = None,
|
968 |
-
is_causal: Optional[bool] =
|
969 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
970 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
971 |
output_hidden_states = (
|
@@ -1160,7 +1160,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
1160 |
output_attentions: Optional[bool] = None,
|
1161 |
output_hidden_states: Optional[bool] = None,
|
1162 |
return_dict: Optional[bool] = None,
|
1163 |
-
is_causal: Optional[bool] =
|
1164 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1165 |
r"""
|
1166 |
Args:
|
|
|
350 |
past_key_value: Optional[Cache] = None,
|
351 |
output_attentions: bool = False,
|
352 |
use_cache: bool = False,
|
353 |
+
is_causal: bool = False,
|
354 |
**kwargs,
|
355 |
):
|
356 |
if "padding_mask" in kwargs:
|
|
|
646 |
past_key_value: Optional[Cache] = None,
|
647 |
output_attentions: bool = False,
|
648 |
use_cache: bool = False,
|
649 |
+
is_causal: bool = True,
|
650 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
651 |
if output_attentions:
|
652 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
|
|
965 |
output_hidden_states: Optional[bool] = None,
|
966 |
return_dict: Optional[bool] = None,
|
967 |
labels: Optional[torch.LongTensor] = None,
|
968 |
+
is_causal: Optional[bool] = False,
|
969 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
970 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
971 |
output_hidden_states = (
|
|
|
1160 |
output_attentions: Optional[bool] = None,
|
1161 |
output_hidden_states: Optional[bool] = None,
|
1162 |
return_dict: Optional[bool] = None,
|
1163 |
+
is_causal: Optional[bool] = False,
|
1164 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1165 |
r"""
|
1166 |
Args:
|