Update modeling_n2_eye.py
Browse files- modeling_n2_eye.py +10 -2
modeling_n2_eye.py
CHANGED
@@ -5,8 +5,11 @@ from transformers import (
|
|
5 |
AutoModelForCausalLM,
|
6 |
CLIPVisionModel,
|
7 |
PreTrainedModel,
|
8 |
-
PretrainedConfig
|
|
|
|
|
9 |
)
|
|
|
10 |
from typing import Optional
|
11 |
|
12 |
|
@@ -209,4 +212,9 @@ class MultimodalLFM2Model(PreTrainedModel):
|
|
209 |
projection_state_dict = torch.load(projection_path, map_location="cpu")
|
210 |
model.vision_projection.load_state_dict(projection_state_dict)
|
211 |
|
212 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
5 |
AutoModelForCausalLM,
|
6 |
CLIPVisionModel,
|
7 |
PreTrainedModel,
|
8 |
+
PretrainedConfig,
|
9 |
+
AutoConfig,
|
10 |
+
AutoModel
|
11 |
)
|
12 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
13 |
from typing import Optional
|
14 |
|
15 |
|
|
|
212 |
projection_state_dict = torch.load(projection_path, map_location="cpu")
|
213 |
model.vision_projection.load_state_dict(projection_state_dict)
|
214 |
|
215 |
+
return model
|
216 |
+
|
217 |
+
|
218 |
+
# Register the model with transformers
|
219 |
+
AutoConfig.register("multimodal_lfm2", MultimodalLFM2Config)
|
220 |
+
AutoModelForCausalLM.register(MultimodalLFM2Config, MultimodalLFM2Model)
|