Fix `AttributeError: 'Phi3ForCausalLM' object has no attribute 'generate'` with transformers>=4.52
#9
by
sylwia-kuros
- opened
- modeling_phi3.py +2 -1
modeling_phi3.py
CHANGED
@@ -26,6 +26,7 @@ from torch import nn
|
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
from transformers.activations import ACT2FN
|
28 |
from transformers.cache_utils import Cache, DynamicCache
|
|
|
29 |
from transformers.modeling_attn_mask_utils import \
|
30 |
_prepare_4d_causal_attention_mask
|
31 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
@@ -1201,7 +1202,7 @@ class Phi3Model(Phi3PreTrainedModel):
|
|
1201 |
)
|
1202 |
|
1203 |
|
1204 |
-
class Phi3ForCausalLM(Phi3PreTrainedModel):
|
1205 |
_tied_weights_keys = ['lm_head.weight']
|
1206 |
|
1207 |
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
|
|
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
from transformers.activations import ACT2FN
|
28 |
from transformers.cache_utils import Cache, DynamicCache
|
29 |
+
from transformers.generation import GenerationMixin
|
30 |
from transformers.modeling_attn_mask_utils import \
|
31 |
_prepare_4d_causal_attention_mask
|
32 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
|
|
1202 |
)
|
1203 |
|
1204 |
|
1205 |
+
class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
1206 |
_tied_weights_keys = ['lm_head.weight']
|
1207 |
|
1208 |
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
|