infgrad's picture
Introduce a custom Sentence Transformer module for smooth multi-modality (#1)
1d33766 verified
from typing import Any, Dict, Optional
import PIL
import torch
import PIL
import torch
from typing import Dict
from io import BytesIO
from transformers import SiglipImageProcessor
from sentence_transformers.models import Transformer as BaseTransformer
class MultiModalTransformer(BaseTransformer):
def __init__(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(model_name_or_path, **kwargs)
if tokenizer_args is None:
tokenizer_args = {}
self.processor = SiglipImageProcessor.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **tokenizer_args
)
def forward(
self, features: dict[str, torch.Tensor], **kwargs
) -> dict[str, torch.Tensor]:
trans_features = {
"input_ids": features["input_ids"],
"attention_mask": features["attention_mask"],
}
if "pixel_values" in features:
trans_features["pixel_values"] = features["pixel_values"].to(
self.auto_model.dtype
)
sentence_embedding = self.auto_model(**trans_features, **kwargs)[
"sentence_embedding"
]
features.update({"sentence_embedding": sentence_embedding})
return features
def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
img_start_token = "<|jasper_img_start|>"
img_token = "<|jasper_img_token|>"
img_end_token = "<|jasper_img_end|>"
num_img_tokens = 300
def process_text_item(item):
if isinstance(item, str):
return item, []
text, images = "", []
for sub_item in item:
if sub_item["type"] == "text":
text += sub_item["content"]
elif sub_item["type"] == "image_bytes":
text += img_start_token + img_token * num_img_tokens + img_end_token
images.append(
PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB")
)
elif sub_item["type"] == "image_path":
text += img_start_token + img_token * num_img_tokens + img_end_token
images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
else:
raise ValueError(f"unknown data type {sub_item['type']}")
return text, images
all_texts, all_images = [], []
for item in texts:
text, images = process_text_item(item)
all_texts.append(text)
all_images.extend(images)
ipt = self.tokenizer(
all_texts,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt",
)
if all_images:
ipt["pixel_values"] = self.processor(
images=all_images, return_tensors="pt"
)["pixel_values"]
return ipt