from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import os from accelerate import PartialState import PIL from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig, AutoTokenizer, LlamaTokenizerFast from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers import Qwen3ForCausalLM, SiglipImageProcessor from safetensors.torch import load_file from transformers.modeling_outputs import CausalLMOutputWithPast from modeling_siglip import SiglipVisionModel from configuration_siglip import SiglipVisionConfig from configuration_qwen3 import Qwen3Config from abc import ABC, abstractmethod from einops import rearrange IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 class PromptBuilder(ABC): def __init__(self, system_prompt: Optional[str] = None) -> None: # Only some models define a system prompt => let subclasses handle this logic! self.system_prompt = system_prompt @abstractmethod def add_turn(self, role: str, message: str) -> str: ... @abstractmethod def get_potential_prompt(self, user_msg: str) -> None: ... @abstractmethod def get_prompt(self) -> str: ... class Qwen3PromptBuilder(PromptBuilder): def __init__(self, system_prompt: Optional[str] = None) -> None: super().__init__(system_prompt) self.system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" self.bos, self.eos = "", "<|im_end|>" # Get role-specific "wrap" functions self.wrap_human = lambda msg: f"<|im_start|>user\n{msg}<|im_end|>assistant\n" self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}\n" # === `self.prompt` gets built up over multiple turns === self.prompt, self.turn_count = "", 0 def add_turn(self, role: str, message: str) -> str: # assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") message = message.strip() #.replace("", "").strip() # Special Handling for "system" prompt (turn_count == 0) if self.turn_count == 0: sys_message = self.system_prompt + self.wrap_human(message) wrapped_message = sys_message elif (self.turn_count % 2) == 0: human_message = self.wrap_human(message) wrapped_message = human_message else: gpt_message = self.wrap_gpt(message) wrapped_message = gpt_message # Update Prompt self.prompt += wrapped_message # Bump Turn Counter self.turn_count += 1 # Return "wrapped_message" (effective string added to context) return wrapped_message def get_potential_prompt(self, message: str) -> None: # Assumes that it's always the user's (human's) turn! prompt_copy = str(self.prompt) # Special Handling for "system" prompt (turn_count == 0) if self.turn_count == 0: sys_message = self.system_prompt + self.wrap_human(message) prompt_copy += sys_message else: human_message = self.wrap_human(message) prompt_copy += human_message # return prompt_copy.removeprefix(self.bos).rstrip() return prompt_copy.rstrip() def get_prompt(self) -> str: # Remove prefix (if exists) because it gets auto-inserted by tokenizer! # return self.prompt.removeprefix(self.bos).rstrip() return self.prompt.rstrip() class InfiMedConfig(PretrainedConfig): def __init__( self, vision_config=None, llm_config=None, run_dir: str = None, load_precision: str = "bf16", max_length: int = 128, temperature: float = 1.0, **kwargs ): if vision_config is None: vision_config = {} print( 'vision_config is None. Initializing the SiglipVisionConfig with default values.') if llm_config is None: llm_config = {'architectures': ['Qwen3ForCausalLM']} print( 'llm_config is None. Initializing the Qwen3Config config with default values') self.vision_config = SiglipVisionConfig(**vision_config) if llm_config['architectures'][0] == 'Qwen3ForCausalLM': self.llm_config = Qwen3Config(**llm_config) else: raise ValueError('Unsupported architecture: {}'.format( llm_config['architectures'][0])) self.run_dir = run_dir self.load_precision = load_precision self.max_length = max_length self.temperature = temperature super().__init__(**kwargs) class AvgPoolProjector(nn.Module): def __init__( self, layer_num: int = 2, query_num: int = 144, mm_hidden_size: int = 1024, llm_hidden_size: int = 4096, ): super().__init__() self.layer_num = layer_num self.query_num = query_num self.mm_hidden_size = mm_hidden_size self.llm_hidden_size = llm_hidden_size self.build_net() def build_net(self): hw = int(self.query_num ** 0.5) sampler = nn.AdaptiveAvgPool2d((hw, hw)) self.sampler = sampler modules = [nn.Linear(self.mm_hidden_size, self.llm_hidden_size)] for _ in range(1, self.layer_num): modules.append(nn.GELU()) modules.append(nn.Linear(self.llm_hidden_size, self.llm_hidden_size)) self.mlp_projector = nn.Sequential(*modules) print(f"patch size {hw} average pooling layer initialized") def forward(self, visual_feat: torch.Tensor) -> torch.Tensor: batch_size, seq_len, h_dim = visual_feat.shape hw = int(seq_len ** 0.5) shaped_visual_feat = rearrange(visual_feat, "b (h w) d -> b d h w", h=hw, w=hw) pooled_visual_feat = self.sampler(shaped_visual_feat) reshaped_visual_feat = rearrange(pooled_visual_feat, "b d h w -> b (h w) d") output_feat = self.mlp_projector(reshaped_visual_feat) return output_feat class InfiMed(PreTrainedModel): config_class = InfiMedConfig def __init__(self, config: InfiMedConfig, vision_model=None, language_model=None): super().__init__(config) self.run_dir = Path(config.run_dir) if config.run_dir else None self.model_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[config.load_precision] self.distributed_state = PartialState() self.max_new_tokens = config.max_length self.temperature = config.temperature self.top_p = config.top_p self.repetition_penalty = config.repetition_penalty if vision_model is not None: self.vision_model = vision_model else: # self.vision_model = SiglipVisionModel.from_pretrained(config.vision_config._name_or_path, hidden_act = "gelu") self.vision_model = SiglipVisionModel(config.vision_config) if language_model is not None: self.language_model = language_model self.config.llm_config = language_model.config else: if config.llm_config.architectures[0] == 'Qwen3ForCausalLM': # self.language_model = Qwen3ForCausalLM.from_pretrained(config.llm_config._name_or_path, pad_token_id = 151670, bos_token_id = 128245, eos_token_id = 151645, tie_word_embeddings = False) self.language_model = Qwen3ForCausalLM(config.llm_config) else: raise NotImplementedError( f'{config.llm_config.architectures[0]} is not implemented.') self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, use_fast=True) self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "", "<|pad|>"]}) self.tokenizer.pad_token = "<|pad|>" self.tokenizer.bos_token = "" self.offset = 1 if self.tokenizer.encode("\n")[0] == self.tokenizer.bos_token_id else 0 if "finetune" in config.run_dir: self.arch_specifier = "full-align+729-avgpool" else: self.arch_specifier = "no-align+avgpool" if self.arch_specifier.split("+")[-1].split("-")[0] != "avgpool": query_dim = int(self.arch_specifier.split("+")[-1].split("-")[0]) else: query_dim = 144 self.projector = AvgPoolProjector(query_num=query_dim, mm_hidden_size=config.vision_config.hidden_size, llm_hidden_size=config.llm_config.hidden_size) self.vision_backbone_requires_grad = False self.img_context_token_id = 151655 self.image_processor = SiglipImageProcessor.from_pretrained( config._name_or_path, size={"height": 384, "width": 384}, resample=PIL.Image.Resampling.BICUBIC, crop_size={"height": 384, "width": 384}, do_center_crop=True, do_normalize=True, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], do_convert_rgb=True ) @classmethod # load model from .pt file def from_pretrained_ckpt(cls, pretrained_model_name_or_path, *args, **kwargs): config = InfiMedConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) model = cls(config) ckpt_base_path = os.path.join(os.path.dirname(pretrained_model_name_or_path), "checkpoints") if (Path(ckpt_base_path) / SAFE_WEIGHTS_NAME).exists(): state_dict = load_file(Path(ckpt_base_path) / SAFE_WEIGHTS_NAME) elif (Path(ckpt_base_path) / WEIGHTS_NAME).exists(): state_dict = torch.load(Path(ckpt_base_path) / WEIGHTS_NAME, map_location="cpu")["model"] elif (Path(ckpt_base_path) / "latest-checkpoint.pt").exists(): state_dict = torch.load(Path(ckpt_base_path) / "latest-checkpoint.pt", map_location="cpu")["model"] else: raise FileNotFoundError("No model weights found in the directory.") if "vision_backbone" in state_dict: model.vision_model.load_state_dict(state_dict["vision_backbone"]) new_state_dict = {} for key, value in state_dict["llm_backbone"].items(): new_key = key.replace("llm.", "") new_state_dict[new_key] = value model.language_model.load_state_dict(new_state_dict) model.projector.load_state_dict(state_dict["projector"]) model.to("cuda", dtype=torch.bfloat16) model.requires_grad_(False) model.eval() return model def save_checkpoint(self, save_path): os.makedirs(save_path, exist_ok=True) self.save_pretrained(save_path) self.tokenizer.save_pretrained(save_path) self.image_processor.save_pretrained(save_path) def process_messages(self,messages): prompt_builder = Qwen3PromptBuilder() if "image" in messages: processed_prompt = "" + "\n" + messages['prompt'].replace("", '') elif "images" in messages: processed_prompt = "" for i, image in enumerate(messages['images']): processed_prompt += f": " processed_prompt += "\n" + messages['prompt'].replace("", '') msg = prompt_builder.add_turn("user", processed_prompt) msg = msg.strip() if isinstance(self.tokenizer, LlamaTokenizerFast): msg = msg.rstrip() else: pass turn_input_ids, _ = tokenizer_image_token(msg, self.tokenizer) result = [] for x in turn_input_ids: if x == -200: result.extend([self.img_context_token_id] * 729) else: result.append(x) turn_input_ids = result input_ids = torch.tensor(turn_input_ids) input_ids = input_ids[: self.tokenizer.model_max_length] input_ids = input_ids.unsqueeze(0) if "image" in messages: pixel_values = self.image_processor(images=messages["image"], return_tensors="pt")["pixel_values"] else: pixel_values = None input_ids = input_ids.to("cuda") pixel_values = pixel_values.to("cuda") if pixel_values is not None else None return input_ids, pixel_values def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_flags: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> CausalLMOutputWithPast: return_dict = return_dict if return_dict is not None else self.config.use_return_dict vit_embeds = self.extract_feature(pixel_values) input_embeds = self.language_model.get_input_embeddings()(input_ids) vit_batch_size = pixel_values.shape[0] B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) try: input_embeds[selected] = input_embeds[selected] * \ 0.0 + vit_embeds.reshape(-1, C) except Exception as e: vit_embeds = vit_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' f'vit_embeds.shape={vit_embeds.shape}') n_token = selected.sum() input_embeds[selected] = input_embeds[selected] * \ 0.0 + vit_embeds[:n_token] input_embeds = input_embeds.reshape(B, N, C) if attention_mask is None: batch_size = input_embeds.shape[0] max_len = input_embeds.shape[1] attention_mask = torch.zeros((batch_size, max_len), device=input_embeds.device).bool() for index in range(batch_size): if getattr(self.tokenizer, 'tokenizer_padding_side', 'right') == 'left': attention_mask[index, -max_len:] = True else: attention_mask[index, :max_len] = True outputs = self.language_model( inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def extract_feature(self, pixel_values): vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[-2] h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = vit_embeds.reshape( vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.projector(vit_embeds) return vit_embeds @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **generate_kwargs, ) -> torch.LongTensor: assert self.img_context_token_id is not None if pixel_values is not None: if visual_features is not None: vit_embeds = visual_features else: vit_embeds = self.extract_feature(pixel_values) input_embeds = self.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) assert selected.sum() != 0 input_embeds[selected] = vit_embeds.reshape( -1, C).to(input_embeds.device) input_embeds = input_embeds.reshape(B, N, C) else: input_embeds = self.language_model.get_input_embeddings()(input_ids) if attention_mask is None: batch_size = input_embeds.shape[0] max_len = input_embeds.shape[1] attention_mask = torch.zeros((batch_size, max_len), device=input_embeds.device).bool() for index in range(batch_size): if getattr(self.tokenizer, 'tokenizer_padding_side', 'right') == 'left': attention_mask[index, -max_len:] = True else: attention_mask[index, :max_len] = True outputs = self.language_model.generate( # input_ids=origin_input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, generation_config=generation_config, output_hidden_states=output_hidden_states, # return_dict=return_dict, use_cache=True, **generate_kwargs, ) return outputs @torch.no_grad() def generate_output(self,messages): input_ids, pixel_values = self.process_messages(messages) do_sample = False if self.temperature == 0 else True generated_ids = self.generate(pixel_values=pixel_values, input_ids=input_ids, temperature=self.temperature,top_p=self.top_p,repetition_penalty=self.repetition_penalty,max_new_tokens=self.max_new_tokens,do_sample = do_sample) generated_ids_trimmed = generated_ids output_text = self.tokenizer.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0] def generate_outputs(self,messages_list): res = [] for messages in messages_list: result = self.generate_output(messages) res.append(result) return res def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] labels = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) labels.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) for x in insert_separator(prompt_chunks, [IGNORE_INDEX] * (offset + 1)): labels.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids, labels