FlameF0X commited on
Commit
62e955e
·
verified ·
1 Parent(s): df8daca

Update modeling_n2_eye.py

Browse files
Files changed (1) hide show
  1. modeling_n2_eye.py +10 -2
modeling_n2_eye.py CHANGED
@@ -5,8 +5,11 @@ from transformers import (
5
  AutoModelForCausalLM,
6
  CLIPVisionModel,
7
  PreTrainedModel,
8
- PretrainedConfig
 
 
9
  )
 
10
  from typing import Optional
11
 
12
 
@@ -209,4 +212,9 @@ class MultimodalLFM2Model(PreTrainedModel):
209
  projection_state_dict = torch.load(projection_path, map_location="cpu")
210
  model.vision_projection.load_state_dict(projection_state_dict)
211
 
212
- return model
 
 
 
 
 
 
5
  AutoModelForCausalLM,
6
  CLIPVisionModel,
7
  PreTrainedModel,
8
+ PretrainedConfig,
9
+ AutoConfig,
10
+ AutoModel
11
  )
12
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
13
  from typing import Optional
14
 
15
 
 
212
  projection_state_dict = torch.load(projection_path, map_location="cpu")
213
  model.vision_projection.load_state_dict(projection_state_dict)
214
 
215
+ return model
216
+
217
+
218
+ # Register the model with transformers
219
+ AutoConfig.register("multimodal_lfm2", MultimodalLFM2Config)
220
+ AutoModelForCausalLM.register(MultimodalLFM2Config, MultimodalLFM2Model)