FlameF0X commited on
Commit
df8daca
·
verified ·
1 Parent(s): 39385a2

Update modeling_n2_eye.py

Browse files
Files changed (1) hide show
  1. modeling_n2_eye.py +37 -12
modeling_n2_eye.py CHANGED
@@ -167,21 +167,46 @@ class MultimodalLFM2Model(PreTrainedModel):
167
  @classmethod
168
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
169
  """
170
- Custom loading method - loads from flat directory structure.
171
  """
172
  config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
173
  model = cls(config)
174
 
175
- # Load language model state dict
176
- language_model_path = os.path.join(pretrained_model_name_or_path, "language_model.bin")
177
- if os.path.exists(language_model_path):
178
- language_state_dict = torch.load(language_model_path, map_location="cpu")
179
- model.language_model.load_state_dict(language_state_dict)
180
-
181
- # Load vision projection
182
- projection_path = os.path.join(pretrained_model_name_or_path, "vision_projection.bin")
183
- if os.path.exists(projection_path):
184
- projection_state_dict = torch.load(projection_path, map_location="cpu")
185
- model.vision_projection.load_state_dict(projection_state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  return model
 
167
  @classmethod
168
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
169
  """
170
+ Custom loading method - works with your current structure.
171
  """
172
  config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
173
  model = cls(config)
174
 
175
+ # Try to load from pytorch_model.bin (your current structure)
176
+ main_model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
177
+ if os.path.exists(main_model_path):
178
+ # Load the full model state dict
179
+ full_state_dict = torch.load(main_model_path, map_location="cpu")
180
+
181
+ # Separate language model and vision projection weights
182
+ language_state_dict = {}
183
+ projection_state_dict = {}
184
+
185
+ for key, value in full_state_dict.items():
186
+ if key.startswith("language_model."):
187
+ # Remove the "language_model." prefix
188
+ new_key = key[len("language_model."):]
189
+ language_state_dict[new_key] = value
190
+ elif key.startswith("vision_projection."):
191
+ # Remove the "vision_projection." prefix
192
+ new_key = key[len("vision_projection."):]
193
+ projection_state_dict[new_key] = value
194
+
195
+ # Load the separated state dicts
196
+ if language_state_dict:
197
+ model.language_model.load_state_dict(language_state_dict)
198
+ if projection_state_dict:
199
+ model.vision_projection.load_state_dict(projection_state_dict)
200
+ else:
201
+ # Fallback to separate files
202
+ language_model_path = os.path.join(pretrained_model_name_or_path, "language_model.bin")
203
+ if os.path.exists(language_model_path):
204
+ language_state_dict = torch.load(language_model_path, map_location="cpu")
205
+ model.language_model.load_state_dict(language_state_dict)
206
+
207
+ projection_path = os.path.join(pretrained_model_name_or_path, "vision_projection.bin")
208
+ if os.path.exists(projection_path):
209
+ projection_state_dict = torch.load(projection_path, map_location="cpu")
210
+ model.vision_projection.load_state_dict(projection_state_dict)
211
 
212
  return model