import base64 import json import os import math from io import BytesIO from typing import Any, Dict, List, Literal, Optional, Union import requests import torch from PIL import Image from torch import nn from transformers import AutoProcessor, Qwen2VLForConditionalGeneration class Transformer(nn.Module): save_in_root: bool = True def __init__( self, model_name_or_path: str = 'llamaindex/vdr-2b-multi-v1', processor_name_or_path: Optional[str] = None, max_pixels: int = 768 * 28 * 28, min_pixels: int = 1 * 28 * 28, dimension: int = 2048, cache_dir: Optional[str] = None, device: str = 'cuda:0', **kwargs, ) -> None: super(Transformer, self).__init__() self.device = device self.dimension = dimension self.max_pixels = max_pixels self.min_pixels = min_pixels # Initialize model self.model = Qwen2VLForConditionalGeneration.from_pretrained( model_name_or_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map=device, cache_dir=cache_dir, **kwargs ).eval() # Initialize processor self.processor = AutoProcessor.from_pretrained( processor_name_or_path or model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, cache_dir=cache_dir ) self.model.padding_side = "left" self.processor.tokenizer.padding_side = "left" self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>" self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>" def _smart_resize(self, height: int, width: int) -> tuple[int, int]: h_bar = max(28, self._round_by_factor(height, 28)) w_bar = max(28, self._round_by_factor(width, 28)) if h_bar * w_bar > self.max_pixels: beta = math.sqrt((height * width) / self.max_pixels) h_bar = self._floor_by_factor(height / beta, 28) w_bar = self._floor_by_factor(width / beta, 28) elif h_bar * w_bar < self.min_pixels: beta = math.sqrt(self.min_pixels / (height * width)) h_bar = self._ceil_by_factor(height * beta, 28) w_bar = self._ceil_by_factor(width * beta, 28) return w_bar, h_bar @staticmethod def _round_by_factor(number: float, factor: int) -> int: return round(number / factor) * factor @staticmethod def _ceil_by_factor(number: float, factor: int) -> int: return math.ceil(number / factor) * factor @staticmethod def _floor_by_factor(number: float, factor: int) -> int: return math.floor(number / factor) * factor def _resize_image(self, image: Image.Image) -> Image.Image: new_size = self._smart_resize(image.height, image.width) return image.resize(new_size) @staticmethod def _decode_data_image(data_image_str: str) -> Image.Image: header, data = data_image_str.split(',', 1) image_data = base64.b64decode(data) return Image.open(BytesIO(image_data)) def _process_input(self, texts: List[Union[str, Image.Image]]) -> tuple[List[str], List[Image.Image]]: processed_texts = [] processed_images = [] dummy_image = Image.new('RGB', (56, 56)) for sample in texts: if isinstance(sample, str): processed_texts.append(self.query_prompt % sample) processed_images.append(dummy_image) elif isinstance(sample, Image.Image): processed_texts.append(self.document_prompt) processed_images.append(self._resize_image(sample)) return processed_texts, processed_images def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: cache_position = torch.arange(0, features['input_ids'].shape[0]) inputs = self.model.prepare_inputs_for_generation( **features, cache_position=cache_position, use_cache=False ) with torch.no_grad(): output = self.model( **inputs, return_dict=True, output_hidden_states=True ) embeddings = output.hidden_states[-1][:, -1] features['sentence_embedding'] = torch.nn.functional.normalize( embeddings[:, :self.dimension], p=2, dim=-1 ) return features def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]: processed_texts, processed_images = self._process_input(texts) inputs = self.processor( text=processed_texts, images=processed_images, videos=None, padding=padding, return_tensors='pt' ) return {k: v.to(self.device) for k, v in inputs.items()} def save(self, output_path: str, safe_serialization: bool = True) -> None: self.model.save_pretrained(output_path, safe_serialization=safe_serialization) self.processor.save_pretrained(output_path)