import os import torch import torch.nn as nn from transformers import ( AutoModelForCausalLM, CLIPVisionModel, PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel ) from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING from typing import Optional class MultimodalLFM2Config(PretrainedConfig): model_type = "multimodal_lfm2" def __init__( self, lfm2_model_name="LiquidAI/LFM2-1.2B", clip_model_name="openai/clip-vit-base-patch32", vision_projection_dim=512, **kwargs ): super().__init__(**kwargs) self.lfm2_model_name = lfm2_model_name self.clip_model_name = clip_model_name self.vision_projection_dim = vision_projection_dim class MultimodalLFM2Model(PreTrainedModel): config_class = MultimodalLFM2Config def __init__(self, config): super().__init__(config) # --- Language Model --- self.language_model = AutoModelForCausalLM.from_pretrained( config.lfm2_model_name, torch_dtype=torch.bfloat16, trust_remote_code=True ) # --- Vision Encoder --- self.vision_encoder = CLIPVisionModel.from_pretrained(config.clip_model_name) for param in self.vision_encoder.parameters(): param.requires_grad = False # --- Projection Layer --- self.language_hidden_size = self.language_model.config.hidden_size self.vision_hidden_size = self.vision_encoder.config.hidden_size self.vision_projection = nn.Sequential( nn.Linear(self.vision_hidden_size, config.vision_projection_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(config.vision_projection_dim, self.language_hidden_size), nn.LayerNorm(self.language_hidden_size) ) self.image_token_id = None def gradient_checkpointing_enable(self, **kwargs): """Delegates gradient checkpointing to the language model.""" self.language_model.gradient_checkpointing_enable(**kwargs) def _prepare_multimodal_inputs( self, input_ids: torch.Tensor, images: torch.Tensor ) -> torch.Tensor: """ Prepares input embeddings by combining text and image features. """ inputs_embeds = self.language_model.get_input_embeddings()(input_ids) vision_outputs = self.vision_encoder(pixel_values=images) image_features = vision_outputs.last_hidden_state projected_image_features = self.vision_projection(image_features).to(self.language_model.dtype) batch_size = input_ids.shape[0] image_token_mask = (input_ids == self.image_token_id) for i in range(batch_size): image_positions = torch.where(image_token_mask[i])[0] if len(image_positions) > 0: img_feat = projected_image_features[i] # match length if len(image_positions) > img_feat.shape[0]: repeat_times = (len(image_positions) + img_feat.shape[0] - 1) // img_feat.shape[0] img_feat = img_feat.repeat(repeat_times, 1)[:len(image_positions)] elif len(image_positions) < img_feat.shape[0]: img_feat = img_feat[:len(image_positions)] inputs_embeds[i, image_positions] = img_feat return inputs_embeds def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, images: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, **kwargs ): """ Forward pass for training. """ if images is not None and self.image_token_id is not None: inputs_embeds = self._prepare_multimodal_inputs(input_ids, images) final_input_ids = None else: inputs_embeds = None final_input_ids = input_ids return self.language_model( input_ids=final_input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, return_dict=True ) def generate( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, images: Optional[torch.Tensor] = None, **kwargs ): """ Generation method for inference. """ if images is not None and self.image_token_id is not None: inputs_embeds = self._prepare_multimodal_inputs(input_ids, images) final_input_ids = None else: inputs_embeds = None final_input_ids = input_ids return self.language_model.generate( input_ids=final_input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs ) def save_pretrained(self, save_directory, **kwargs): """ Custom save method - saves everything in one directory. """ os.makedirs(save_directory, exist_ok=True) # Save config self.config.save_pretrained(save_directory) # Save language model state dict directly torch.save( self.language_model.state_dict(), os.path.join(save_directory, "language_model.bin") ) # Save language model config self.language_model.config.save_pretrained(save_directory, config_file_name="language_model_config.json") # Save vision projection torch.save( self.vision_projection.state_dict(), os.path.join(save_directory, "vision_projection.bin") ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Custom loading method - works with your current structure. """ config = cls.config_class.from_pretrained(pretrained_model_name_or_path) model = cls(config) # Try to load from pytorch_model.bin (your current structure) main_model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") if os.path.exists(main_model_path): # Load the full model state dict full_state_dict = torch.load(main_model_path, map_location="cpu") # Separate language model and vision projection weights language_state_dict = {} projection_state_dict = {} for key, value in full_state_dict.items(): if key.startswith("language_model."): # Remove the "language_model." prefix new_key = key[len("language_model."):] language_state_dict[new_key] = value elif key.startswith("vision_projection."): # Remove the "vision_projection." prefix new_key = key[len("vision_projection."):] projection_state_dict[new_key] = value # Load the separated state dicts if language_state_dict: model.language_model.load_state_dict(language_state_dict) if projection_state_dict: model.vision_projection.load_state_dict(projection_state_dict) else: # Fallback to separate files language_model_path = os.path.join(pretrained_model_name_or_path, "language_model.bin") if os.path.exists(language_model_path): language_state_dict = torch.load(language_model_path, map_location="cpu") model.language_model.load_state_dict(language_state_dict) projection_path = os.path.join(pretrained_model_name_or_path, "vision_projection.bin") if os.path.exists(projection_path): projection_state_dict = torch.load(projection_path, map_location="cpu") model.vision_projection.load_state_dict(projection_state_dict) return model # Register the model with transformers AutoConfig.register("multimodal_lfm2", MultimodalLFM2Config) AutoModelForCausalLM.register(MultimodalLFM2Config, MultimodalLFM2Model)