import torch from torch import nn from typing import Optional, Tuple, List, Union, Any from transformers import Qwen2VLForConditionalGeneration import logging import warnings from PIL import Image from transformers.image_utils import load_image logger = logging.getLogger(__name__) def load_images(images, lazy_load: bool = True): # Disable PIL DecompositionBomb threshold for reading large images. pil_max_px = Image.MAX_IMAGE_PIXELS Image.MAX_IMAGE_PIXELS = None images_batch = [] for image in images: if isinstance(image, Image.Image): images_batch.append(image) else: pil_image = load_image(image) if lazy_load: images_batch.append(pil_image) else: # avoid Too many open files error images_batch.append(pil_image.copy()) pil_image.close() Image.MAX_IMAGE_PIXELS = pil_max_px return images_batch def formatting_prompts_func( query: str, doc: str, query_type: str = 'text', doc_type: str = 'text', prefix_str: str = '', ) -> str: """ Format prompts for different combinations of query and content types. Args: query: Query text or image path doc: Content text or image path query_type: Whether query is an image doc_type: Whether content is an image prefix_str: Optional prefix string to add """ # Format query part if query_type == 'image': query_part = "**Query**:\n<|vision_start|><|image_pad|><|vision_end|>" else: query_part = f"**Query**:\n{query}" # Format content part if doc_type == 'image': doc_part = "**Document**:\n<|vision_start|><|image_pad|><|vision_end|>" else: doc_part = f"**Document**:\n{doc}" # Combine parts prompt = doc_part + '\n' + query_part # Add prefix if provided if prefix_str: prompt = prefix_str + '\n' + prompt return prompt class JinaVLForRanking(Qwen2VLForConditionalGeneration): def __init__(self, config): super().__init__(config) self.padding_side = "left" self.num_labels = 1 # config.num_labels # hack the lm_head to do nothing, since we only want the hidden states self.lm_head = nn.Identity() # copy the idea from `Qwen2ForRewardModel` to have a MLP layer to get the final score self.score = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(), nn.Linear(config.hidden_size, self.num_labels), ) # Initialize weights and apply final processing self.post_init() self.score_token_id = 100 def forward(self, *args, **kwargs) -> torch.Tensor: # Delete output_hidden_states from kwargs kwargs.pop("output_hidden_states", None) kwargs.pop("use_cache", None) assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()" outputs = super().forward( *args, use_cache=False, output_hidden_states=True, **kwargs, ) # get the hidden states of the last layer hidden_states = outputs.hidden_states[-1] # IMPORTANT: the padding token must be on the left side # get the hidden states of the last token and apply the linear layer pooled_logits = self.score(hidden_states[:, -1]) return pooled_logits.squeeze(-1) @torch.no_grad() def compute_score( self, pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 8, max_length: int = 8192, max_query_length: int = 512, max_doc_length: Optional[int] = None, query_type: str = 'text', doc_type: str = 'text', show_progress: bool = False, ) -> List[float]: if not hasattr(self, "_processor"): from transformers import AutoProcessor self._processor = AutoProcessor.from_pretrained(self.name_or_path, trust_remote_code=True) assert isinstance(pairs, list) if isinstance(pairs[0], str): pairs = [pairs] max_length = max_length or self.config.max_length if max_doc_length is None: max_doc_length = max(max_length - max_query_length, max_query_length) if max_doc_length < max_query_length: warnings.warn( f"max_doc_length={max_doc_length} should be greater than max_query_length={max_query_length}" ) assert ( max_doc_length + max_query_length <= max_length ), f"max_doc_length ({max_doc_length}) + max_query_length ({max_query_length}) should be less than max_length ({max_length})" max_length = max_length - 1 all_scores = [] device = next(self.parameters()).device batch_iter = range(0, len(pairs), batch_size) if show_progress: from tqdm import trange batch_iter = trange(0, len(pairs), batch_size, desc="Computing scores") for start_index in batch_iter: mini_batch = pairs[start_index : start_index + batch_size] batch_inputs = [] for q, d in mini_batch: # TEMP FIX: Truncate long documents if doc_type == 'text': tokens = self._processor.tokenizer(d, truncation=True, max_length=max_doc_length) if len(tokens['input_ids']) >= max_doc_length: d = self._processor.tokenizer.decode(tokens['input_ids']) batch_inputs.append( formatting_prompts_func( q, d, query_type=query_type, doc_type=doc_type ) ) batch_images = None if doc_type == 'image': batch_images = load_images([d for (q, d) in mini_batch]) elif query_type == 'image': batch_images = load_images([q for (q, d) in mini_batch]) batch = self._processor( text=batch_inputs, images=batch_images, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) # append the reward token to the input_ids and attention_mask batch_size = batch["input_ids"].size(0) batch["input_ids"] = torch.cat( [ batch["input_ids"], torch.full((batch_size, 1), self.score_token_id, device=batch["input_ids"].device), ], dim=1, ) batch["attention_mask"] = torch.cat( [ batch["attention_mask"], torch.ones((batch_size, 1), device=batch["attention_mask"].device), ], dim=1, ) # move the batch to the correct device batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} scores = self.forward(**batch).view(-1).cpu().float().numpy().tolist() all_scores.extend(scores) if len(all_scores) == 1: return all_scores[0] return all_scores