from typing import Any, Dict, Optional import PIL import torch import PIL import torch from typing import Dict from io import BytesIO from transformers import SiglipImageProcessor from sentence_transformers.models import Transformer as BaseTransformer class MultiModalTransformer(BaseTransformer): def __init__( self, model_name_or_path: str, cache_dir: Optional[str] = None, tokenizer_args: Optional[Dict[str, Any]] = None, **kwargs, ): super().__init__(model_name_or_path, **kwargs) if tokenizer_args is None: tokenizer_args = {} self.processor = SiglipImageProcessor.from_pretrained( model_name_or_path, cache_dir=cache_dir, **tokenizer_args ) def forward( self, features: dict[str, torch.Tensor], **kwargs ) -> dict[str, torch.Tensor]: trans_features = { "input_ids": features["input_ids"], "attention_mask": features["attention_mask"], } if "pixel_values" in features: trans_features["pixel_values"] = features["pixel_values"].to( self.auto_model.dtype ) sentence_embedding = self.auto_model(**trans_features, **kwargs)[ "sentence_embedding" ] features.update({"sentence_embedding": sentence_embedding}) return features def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]: img_start_token = "<|jasper_img_start|>" img_token = "<|jasper_img_token|>" img_end_token = "<|jasper_img_end|>" num_img_tokens = 300 def process_text_item(item): if isinstance(item, str): return item, [] text, images = "", [] for sub_item in item: if sub_item["type"] == "text": text += sub_item["content"] elif sub_item["type"] == "image_bytes": text += img_start_token + img_token * num_img_tokens + img_end_token images.append( PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB") ) elif sub_item["type"] == "image_path": text += img_start_token + img_token * num_img_tokens + img_end_token images.append(PIL.Image.open(sub_item["content"]).convert("RGB")) else: raise ValueError(f"unknown data type {sub_item['type']}") return text, images all_texts, all_images = [], [] for item in texts: text, images = process_text_item(item) all_texts.append(text) all_images.extend(images) ipt = self.tokenizer( all_texts, padding="longest", truncation=True, max_length=self.max_seq_length, return_tensors="pt", ) if all_images: ipt["pixel_values"] = self.processor( images=all_images, return_tensors="pt" )["pixel_values"] return ipt