FlameF0X commited on
Commit
a388ffd
·
verified ·
1 Parent(s): 4f7c22a

Create modeling_n2_eye.py

Browse files
Files changed (1) hide show
  1. modeling_n2_eye.py +220 -0
modeling_n2_eye.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import (
5
+ AutoModelForCausalLM,
6
+ CLIPVisionModel,
7
+ PreTrainedModel,
8
+ PretrainedConfig,
9
+ AutoConfig,
10
+ AutoModel
11
+ )
12
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
13
+ from typing import Optional
14
+
15
+
16
+ class MultimodalLFM2Config(PretrainedConfig):
17
+ model_type = "multimodal_lfm2"
18
+
19
+ def __init__(
20
+ self,
21
+ lfm2_model_name="LiquidAI/LFM2-1.2B",
22
+ clip_model_name="google/siglip2-so400m-patch14-384",
23
+ vision_projection_dim=512,
24
+ **kwargs
25
+ ):
26
+ super().__init__(**kwargs)
27
+ self.lfm2_model_name = lfm2_model_name
28
+ self.clip_model_name = clip_model_name
29
+ self.vision_projection_dim = vision_projection_dim
30
+
31
+
32
+ class MultimodalLFM2Model(PreTrainedModel):
33
+ config_class = MultimodalLFM2Config
34
+
35
+ def __init__(self, config):
36
+ super().__init__(config)
37
+
38
+ # --- Language Model ---
39
+ self.language_model = AutoModelForCausalLM.from_pretrained(
40
+ config.lfm2_model_name,
41
+ torch_dtype=torch.bfloat16,
42
+ trust_remote_code=True
43
+ )
44
+
45
+ # --- Vision Encoder ---
46
+ self.vision_encoder = CLIPVisionModel.from_pretrained(config.clip_model_name)
47
+ for param in self.vision_encoder.parameters():
48
+ param.requires_grad = False
49
+
50
+ # --- Projection Layer ---
51
+ self.language_hidden_size = self.language_model.config.hidden_size
52
+ self.vision_hidden_size = self.vision_encoder.config.hidden_size
53
+ self.vision_projection = nn.Sequential(
54
+ nn.Linear(self.vision_hidden_size, config.vision_projection_dim),
55
+ nn.ReLU(),
56
+ nn.Dropout(0.1),
57
+ nn.Linear(config.vision_projection_dim, self.language_hidden_size),
58
+ nn.LayerNorm(self.language_hidden_size)
59
+ )
60
+ self.image_token_id = None
61
+
62
+ def gradient_checkpointing_enable(self, **kwargs):
63
+ """Delegates gradient checkpointing to the language model."""
64
+ self.language_model.gradient_checkpointing_enable(**kwargs)
65
+
66
+ def _prepare_multimodal_inputs(
67
+ self,
68
+ input_ids: torch.Tensor,
69
+ images: torch.Tensor
70
+ ) -> torch.Tensor:
71
+ """
72
+ Prepares input embeddings by combining text and image features.
73
+ """
74
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
75
+ vision_outputs = self.vision_encoder(pixel_values=images)
76
+ image_features = vision_outputs.last_hidden_state
77
+ projected_image_features = self.vision_projection(image_features).to(self.language_model.dtype)
78
+
79
+ batch_size = input_ids.shape[0]
80
+ image_token_mask = (input_ids == self.image_token_id)
81
+
82
+ for i in range(batch_size):
83
+ image_positions = torch.where(image_token_mask[i])[0]
84
+ if len(image_positions) > 0:
85
+ img_feat = projected_image_features[i]
86
+ # match length
87
+ if len(image_positions) > img_feat.shape[0]:
88
+ repeat_times = (len(image_positions) + img_feat.shape[0] - 1) // img_feat.shape[0]
89
+ img_feat = img_feat.repeat(repeat_times, 1)[:len(image_positions)]
90
+ elif len(image_positions) < img_feat.shape[0]:
91
+ img_feat = img_feat[:len(image_positions)]
92
+ inputs_embeds[i, image_positions] = img_feat
93
+
94
+ return inputs_embeds
95
+
96
+ def forward(
97
+ self,
98
+ input_ids: torch.Tensor,
99
+ attention_mask: torch.Tensor,
100
+ images: Optional[torch.Tensor] = None,
101
+ labels: Optional[torch.Tensor] = None,
102
+ **kwargs
103
+ ):
104
+ """
105
+ Forward pass for training.
106
+ """
107
+ if images is not None and self.image_token_id is not None:
108
+ inputs_embeds = self._prepare_multimodal_inputs(input_ids, images)
109
+ final_input_ids = None
110
+ else:
111
+ inputs_embeds = None
112
+ final_input_ids = input_ids
113
+
114
+ return self.language_model(
115
+ input_ids=final_input_ids,
116
+ inputs_embeds=inputs_embeds,
117
+ attention_mask=attention_mask,
118
+ labels=labels,
119
+ return_dict=True
120
+ )
121
+
122
+ def generate(
123
+ self,
124
+ input_ids: torch.Tensor,
125
+ attention_mask: torch.Tensor,
126
+ images: Optional[torch.Tensor] = None,
127
+ **kwargs
128
+ ):
129
+ """
130
+ Generation method for inference.
131
+ """
132
+ if images is not None and self.image_token_id is not None:
133
+ inputs_embeds = self._prepare_multimodal_inputs(input_ids, images)
134
+ final_input_ids = None
135
+ else:
136
+ inputs_embeds = None
137
+ final_input_ids = input_ids
138
+
139
+ return self.language_model.generate(
140
+ input_ids=final_input_ids,
141
+ inputs_embeds=inputs_embeds,
142
+ attention_mask=attention_mask,
143
+ **kwargs
144
+ )
145
+
146
+ def save_pretrained(self, save_directory, **kwargs):
147
+ """
148
+ Custom save method - saves everything in one directory.
149
+ """
150
+ os.makedirs(save_directory, exist_ok=True)
151
+
152
+ # Save config
153
+ self.config.save_pretrained(save_directory)
154
+
155
+ # Save language model state dict directly
156
+ torch.save(
157
+ self.language_model.state_dict(),
158
+ os.path.join(save_directory, "language_model.bin")
159
+ )
160
+
161
+ # Save language model config
162
+ self.language_model.config.save_pretrained(save_directory, config_file_name="language_model_config.json")
163
+
164
+ # Save vision projection
165
+ torch.save(
166
+ self.vision_projection.state_dict(),
167
+ os.path.join(save_directory, "vision_projection.bin")
168
+ )
169
+
170
+ @classmethod
171
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
172
+ """
173
+ Custom loading method - works with your current structure.
174
+ """
175
+ config = cls.config_class.from_pretrained(pretrained_model_name_or_path)
176
+ model = cls(config)
177
+
178
+ # Try to load from pytorch_model.bin (your current structure)
179
+ main_model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
180
+ if os.path.exists(main_model_path):
181
+ # Load the full model state dict
182
+ full_state_dict = torch.load(main_model_path, map_location="cpu")
183
+
184
+ # Separate language model and vision projection weights
185
+ language_state_dict = {}
186
+ projection_state_dict = {}
187
+
188
+ for key, value in full_state_dict.items():
189
+ if key.startswith("language_model."):
190
+ # Remove the "language_model." prefix
191
+ new_key = key[len("language_model."):]
192
+ language_state_dict[new_key] = value
193
+ elif key.startswith("vision_projection."):
194
+ # Remove the "vision_projection." prefix
195
+ new_key = key[len("vision_projection."):]
196
+ projection_state_dict[new_key] = value
197
+
198
+ # Load the separated state dicts
199
+ if language_state_dict:
200
+ model.language_model.load_state_dict(language_state_dict)
201
+ if projection_state_dict:
202
+ model.vision_projection.load_state_dict(projection_state_dict)
203
+ else:
204
+ # Fallback to separate files
205
+ language_model_path = os.path.join(pretrained_model_name_or_path, "language_model.bin")
206
+ if os.path.exists(language_model_path):
207
+ language_state_dict = torch.load(language_model_path, map_location="cpu")
208
+ model.language_model.load_state_dict(language_state_dict)
209
+
210
+ projection_path = os.path.join(pretrained_model_name_or_path, "vision_projection.bin")
211
+ if os.path.exists(projection_path):
212
+ projection_state_dict = torch.load(projection_path, map_location="cpu")
213
+ model.vision_projection.load_state_dict(projection_state_dict)
214
+
215
+ return model
216
+
217
+
218
+ # Register the model with transformers
219
+ AutoConfig.register("multimodal_lfm2", MultimodalLFM2Config)
220
+ AutoModelForCausalLM.register(MultimodalLFM2Config, MultimodalLFM2Model)