Update modeling_n2_eye.py
Browse files- modeling_n2_eye.py +37 -12
modeling_n2_eye.py
CHANGED
@@ -167,21 +167,46 @@ class MultimodalLFM2Model(PreTrainedModel):
|
|
167 |
@classmethod
|
168 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
169 |
"""
|
170 |
-
Custom loading method -
|
171 |
"""
|
172 |
config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
|
173 |
model = cls(config)
|
174 |
|
175 |
-
#
|
176 |
-
|
177 |
-
if os.path.exists(
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
return model
|
|
|
167 |
@classmethod
|
168 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
169 |
"""
|
170 |
+
Custom loading method - works with your current structure.
|
171 |
"""
|
172 |
config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
|
173 |
model = cls(config)
|
174 |
|
175 |
+
# Try to load from pytorch_model.bin (your current structure)
|
176 |
+
main_model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
|
177 |
+
if os.path.exists(main_model_path):
|
178 |
+
# Load the full model state dict
|
179 |
+
full_state_dict = torch.load(main_model_path, map_location="cpu")
|
180 |
+
|
181 |
+
# Separate language model and vision projection weights
|
182 |
+
language_state_dict = {}
|
183 |
+
projection_state_dict = {}
|
184 |
+
|
185 |
+
for key, value in full_state_dict.items():
|
186 |
+
if key.startswith("language_model."):
|
187 |
+
# Remove the "language_model." prefix
|
188 |
+
new_key = key[len("language_model."):]
|
189 |
+
language_state_dict[new_key] = value
|
190 |
+
elif key.startswith("vision_projection."):
|
191 |
+
# Remove the "vision_projection." prefix
|
192 |
+
new_key = key[len("vision_projection."):]
|
193 |
+
projection_state_dict[new_key] = value
|
194 |
+
|
195 |
+
# Load the separated state dicts
|
196 |
+
if language_state_dict:
|
197 |
+
model.language_model.load_state_dict(language_state_dict)
|
198 |
+
if projection_state_dict:
|
199 |
+
model.vision_projection.load_state_dict(projection_state_dict)
|
200 |
+
else:
|
201 |
+
# Fallback to separate files
|
202 |
+
language_model_path = os.path.join(pretrained_model_name_or_path, "language_model.bin")
|
203 |
+
if os.path.exists(language_model_path):
|
204 |
+
language_state_dict = torch.load(language_model_path, map_location="cpu")
|
205 |
+
model.language_model.load_state_dict(language_state_dict)
|
206 |
+
|
207 |
+
projection_path = os.path.join(pretrained_model_name_or_path, "vision_projection.bin")
|
208 |
+
if os.path.exists(projection_path):
|
209 |
+
projection_state_dict = torch.load(projection_path, map_location="cpu")
|
210 |
+
model.vision_projection.load_state_dict(projection_state_dict)
|
211 |
|
212 |
return model
|