|
|
|
|
|
|
|
import os |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from functools import partial |
|
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast |
|
|
|
import numpy as np |
|
import torch |
|
from huggingface_hub import snapshot_download |
|
from peft import PeftModel, LoraConfig |
|
from peft.utils.hotswap import hotswap_adapter |
|
from PIL import Image |
|
from torch import nn |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
from transformers import BatchFeature |
|
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor |
|
from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config |
|
import peft |
|
from .custom_lora_module import Linear |
|
|
|
class PromptType(str, Enum): |
|
query = "query" |
|
passage = "passage" |
|
|
|
|
|
class TaskType(str, Enum): |
|
retrieval = "retrieval" |
|
code = "code" |
|
text_matching = "text-matching" |
|
test = "test" |
|
|
|
|
|
PREFIX_DICT = {"query": "Query", "passage": "Passage"} |
|
TRUNCATE_DIMS = [128, 256, 512, 1024] |
|
VECTOR_TYPES = ["single_vector", "multi_vector"] |
|
|
|
|
|
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor): |
|
def __init__(self, *args, **kwargs) -> None: |
|
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs) |
|
self.assistant_prefix_len = 58 |
|
self.text_max_length = 8192 |
|
|
|
def process_images( |
|
self, |
|
images: Union[List[Image.Image], List[List[Image.Image]]], |
|
) -> BatchFeature: |
|
|
|
if isinstance(images[0], list): |
|
images = cast(List[List[Image.Image]], images) |
|
text_doc = [] |
|
for i in range(len(images)): |
|
conversation = [ |
|
{"role": "user", "content": [{"type": "image"}] * len(images[i])} |
|
] |
|
template = self.apply_chat_template( |
|
conversation, add_generation_prompt=False |
|
) |
|
text_doc.append(template[self.assistant_prefix_len :]) |
|
|
|
else: |
|
images = cast(List[Image.Image], images) |
|
text_doc = [ |
|
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n" |
|
] * len(images) |
|
|
|
|
|
batch_doc = self(text=text_doc, images=images, padding="longest", return_tensors="pt") |
|
|
|
offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2] |
|
|
|
pixel_values = torch.split(batch_doc["pixel_values"], offsets.tolist()) |
|
|
|
max_length = max([len(pv) for pv in pixel_values]) |
|
|
|
pixel_values = [ |
|
torch.cat( |
|
[ |
|
pv, |
|
torch.zeros( |
|
(max_length - len(pv), pv.shape[1]), |
|
dtype=pv.dtype, |
|
device=pv.device, |
|
), |
|
] |
|
) |
|
for pv in pixel_values |
|
] |
|
|
|
batch_doc["pixel_values"] = torch.stack(pixel_values) |
|
return batch_doc |
|
|
|
def process_texts( |
|
self, |
|
texts: List[str], |
|
max_length: Optional[int] = None, |
|
prefix: Optional[str] = None, |
|
padding: Optional[str] = None, |
|
) -> BatchFeature: |
|
|
|
max_length = ( |
|
self.text_max_length |
|
if max_length is None |
|
else min(max_length, self.text_max_length) |
|
) |
|
padded_texts: List[str] = [] |
|
|
|
for text in texts: |
|
if prefix: |
|
text = f"{prefix}: {text}" |
|
padded_texts.append(text) |
|
|
|
text_batch = self( |
|
text=padded_texts, |
|
return_tensors="pt", |
|
padding=padding or "longest", |
|
max_length=max_length, |
|
truncation=True, |
|
) |
|
|
|
return text_batch |
|
|
|
|
|
@dataclass |
|
class JinaEmbeddingsV4ModelOutput: |
|
""" |
|
Base class for the Hybrid Model outputs. |
|
Args: |
|
vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM. |
|
single_vec_emb (torch.Tensor, optional): Single-vector embeddings. |
|
multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings. |
|
""" |
|
|
|
vlm_last_hidden_states: Optional[torch.Tensor] = None |
|
single_vec_emb: Optional[torch.Tensor] = None |
|
multi_vec_emb: Optional[torch.Tensor] = None |
|
|
|
|
|
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): |
|
config_class = JinaEmbeddingsV4Config |
|
main_input_name: ClassVar[str] = "doc_input_ids" |
|
|
|
def __init__(self, config: JinaEmbeddingsV4Config): |
|
Qwen2_5_VLForConditionalGeneration.__init__(self, config) |
|
self._init_projection_layers(config) |
|
self.post_init() |
|
self.processor = JinaEmbeddingsV4Processor.from_pretrained( |
|
self.name_or_path, trust_remote_code=True |
|
) |
|
self.single_vector_projector_dim = config.single_vector_projector_dim |
|
self.multi_vector_projector_dim = config.multi_vector_projector_dim |
|
|
|
def get_last_hidden_states( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.Tensor, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
if "pixel_values" in kwargs: |
|
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] |
|
kwargs["pixel_values"] = torch.cat( |
|
[pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0 |
|
) |
|
|
|
position_ids, rope_deltas = super().get_rope_index( |
|
input_ids=input_ids, |
|
image_grid_thw=kwargs.get("image_grid_thw", None), |
|
attention_mask=attention_mask, |
|
) |
|
|
|
kwargs["output_hidden_states"] = True |
|
outputs = super().forward( |
|
input_ids, |
|
attention_mask, |
|
**kwargs, |
|
position_ids=position_ids, |
|
rope_deltas=rope_deltas, |
|
use_cache=False, |
|
) |
|
|
|
hidden_states = outputs.hidden_states |
|
if not hidden_states: |
|
raise ValueError("Hidden states not found in model output") |
|
|
|
return hidden_states[-1] |
|
|
|
def _init_projection_layers(self, config) -> None: |
|
""" |
|
Initializes projection layers. |
|
""" |
|
self.config.single_vector_projector_dim = config.single_vector_projector_dim |
|
self.config.multi_vector_projector_dim = config.multi_vector_projector_dim |
|
|
|
self.single_vector_projector = nn.Linear( |
|
in_features=self.config.hidden_size, |
|
out_features=self.config.single_vector_projector_dim, |
|
) |
|
|
|
self.multi_vector_projector = nn.Linear( |
|
in_features=self.config.hidden_size, |
|
out_features=self.config.multi_vector_projector_dim, |
|
) |
|
|
|
def project_to_single_vector_embeddings( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Project the hidden states to single-vector embeddings. |
|
""" |
|
if self._input_has_image(input_ids[0]): |
|
img_start_positions = torch.where(input_ids == self.config.vision_start_token_id)[1] |
|
img_end_positions = torch.where(input_ids == self.config.vision_end_token_id)[1] |
|
|
|
batch_size, seq_len = input_ids.shape |
|
position_indices = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1) |
|
image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (position_indices <= img_end_positions.unsqueeze(1)) |
|
|
|
masked_hidden_states = hidden_states * image_mask.unsqueeze(-1) |
|
pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(dim=1, keepdim=True) |
|
|
|
else: |
|
pooled_output = torch.sum( |
|
hidden_states * attention_mask.unsqueeze(-1), dim=1 |
|
) / torch.sum(attention_mask, dim=1, keepdim=True) |
|
|
|
single_vec_emb = self.single_vector_projector(pooled_output) |
|
return torch.nn.functional.normalize(single_vec_emb, dim=-1) |
|
|
|
def project_to_multi_vector_embeddings( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Project the hidden states to multi-vector embeddings. |
|
""" |
|
multi_vec_emb = self.multi_vector_projector(hidden_states) |
|
multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1) |
|
return multi_vec_emb * attention_mask.unsqueeze(-1) |
|
|
|
def _input_has_image(self, input_ids): |
|
return self.config.vision_start_token_id in input_ids |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.Tensor, |
|
output_vlm_last_hidden_states: bool = False, |
|
**kwargs, |
|
) -> JinaEmbeddingsV4ModelOutput: |
|
""" |
|
Forward pass through the model. Returns both single-vector and multi-vector embeddings. |
|
Args: |
|
input_ids (torch.Tensor): The input tokens tensor. |
|
attention_mask (torch.Tensor): The attention mask tensor. |
|
Returns: |
|
JinaEmbeddingsV4ModelOutput: |
|
single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim). |
|
multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim). |
|
""" |
|
|
|
hidden_states = self.get_last_hidden_states( |
|
input_ids=input_ids, attention_mask=attention_mask, **kwargs |
|
) |
|
|
|
single_vec_emb = self.project_to_single_vector_embeddings( |
|
hidden_states, attention_mask, input_ids=input_ids |
|
) |
|
multi_vec_emb = self.project_to_multi_vector_embeddings( |
|
hidden_states, attention_mask |
|
) |
|
|
|
return JinaEmbeddingsV4ModelOutput( |
|
vlm_last_hidden_states=( |
|
hidden_states if output_vlm_last_hidden_states else None |
|
), |
|
single_vec_emb=single_vec_emb, |
|
multi_vec_emb=multi_vec_emb, |
|
) |
|
|
|
def _process_batches( |
|
self, |
|
data: List[Union[str, Image.Image]], |
|
processor_fn: Callable, |
|
desc: str, |
|
vector_type: str = "single_vector", |
|
return_numpy: bool = False, |
|
batch_size: int = 32, |
|
truncate_dim: Optional[int] = None, |
|
) -> Union[np.ndarray, List[torch.Tensor]]: |
|
dataloader = DataLoader( |
|
dataset=data, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
collate_fn=processor_fn, |
|
) |
|
results = [] |
|
self.eval() |
|
for batch in tqdm(dataloader, desc=desc): |
|
with torch.no_grad(): |
|
batch = {k: v.to(self.device) for k, v in batch.items()} |
|
with torch.autocast(device_type=torch.device(self.device).type): |
|
embeddings = self(**batch) |
|
if vector_type == "single_vector": |
|
embeddings = embeddings.single_vec_emb |
|
if truncate_dim is not None: |
|
embeddings = embeddings[:, :truncate_dim] |
|
else: |
|
embeddings = embeddings.multi_vec_emb |
|
results.append( |
|
embeddings.cpu() |
|
if return_numpy |
|
else list(torch.unbind(embeddings)) |
|
) |
|
if return_numpy: |
|
return np.concatenate([result.numpy() for result in results], axis=0) |
|
return [item for sublist in results for item in sublist] |
|
|
|
def _validate_encoding_params( |
|
self, |
|
vector_type: Optional[str] = None, |
|
truncate_dim: Optional[int] = None, |
|
prompt_name: Optional[str] = None, |
|
) -> Dict[str, Any]: |
|
encode_kwargs = {} |
|
if prompt_name is not None: |
|
if prompt_name not in PREFIX_DICT: |
|
raise ValueError( |
|
f"Invalid prompt_name: {prompt_name}. Must be one of {list(PREFIX_DICT.keys())}." |
|
) |
|
else: |
|
encode_kwargs["prefix"] = ( |
|
PREFIX_DICT[prompt_name] |
|
if self.task != TaskType.text_matching |
|
else PREFIX_DICT["query"] |
|
) |
|
|
|
vector_type = vector_type or "single_vector" |
|
if vector_type not in VECTOR_TYPES: |
|
raise ValueError( |
|
f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}." |
|
) |
|
else: |
|
encode_kwargs["vector_type"] = vector_type |
|
|
|
truncate_dim = truncate_dim or self.config.truncate_dim |
|
if truncate_dim is not None and truncate_dim not in TRUNCATE_DIMS: |
|
raise ValueError( |
|
f"Invalid truncate_dim: {truncate_dim}. Must be one of {TRUNCATE_DIMS}." |
|
) |
|
else: |
|
encode_kwargs["truncate_dim"] = truncate_dim |
|
|
|
return encode_kwargs |
|
|
|
def encode_texts( |
|
self, |
|
texts: List[str], |
|
max_length: int = 8192, |
|
batch_size: int = 8, |
|
vector_type: Optional[str] = None, |
|
return_numpy: bool = False, |
|
truncate_dim: Optional[int] = None, |
|
prompt_name: Optional[str] = None, |
|
) -> List[torch.Tensor]: |
|
""" |
|
Encodes a list of texts into embeddings. |
|
|
|
Args: |
|
texts: List of text strings to encode |
|
max_length: Maximum token length for text processing |
|
batch_size: Number of texts to process at once |
|
vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector') |
|
return_numpy: Whether to return numpy arrays instead of torch tensors |
|
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024) |
|
prompt_name: Type of text being encoded ('query' or 'passage') |
|
|
|
Returns: |
|
List of text embeddings as tensors or numpy arrays |
|
""" |
|
prompt_name = prompt_name or "query" |
|
encode_kwargs = self._validate_encoding_params( |
|
vector_type, truncate_dim, prompt_name |
|
) |
|
|
|
processor_fn = partial( |
|
self.processor.process_texts, |
|
max_length=max_length, |
|
prefix=encode_kwargs.pop("prefix"), |
|
) |
|
|
|
embeddings = self._process_batches( |
|
data=texts, |
|
processor_fn=processor_fn, |
|
desc="Encoding texts...", |
|
return_numpy=return_numpy, |
|
batch_size=batch_size, |
|
**encode_kwargs, |
|
) |
|
|
|
return embeddings |
|
|
|
def encode_images( |
|
self, |
|
images: List[Image.Image], |
|
batch_size: int = 8, |
|
vector_type: Optional[str] = None, |
|
return_numpy: bool = False, |
|
truncate_dim: Optional[int] = None, |
|
max_pixels: Optional[int] = None, |
|
) -> List[torch.Tensor]: |
|
""" |
|
Encodes a list of images into embeddings. |
|
|
|
Args: |
|
images: List of PIL images to encode |
|
batch_size: Number of images to process at once |
|
vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector') |
|
return_numpy: Whether to return numpy arrays instead of torch tensors |
|
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024) |
|
max_pixels: Maximum number of pixels to process per image |
|
|
|
Returns: |
|
List of image embeddings as tensors or numpy arrays |
|
""" |
|
if max_pixels: |
|
default_max_pixels = self.processor.image_processor.max_pixels |
|
self.processor.image_processor.max_pixels = max_pixels |
|
|
|
encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim) |
|
|
|
embeddings = self._process_batches( |
|
data=images, |
|
processor_fn=self.processor.process_images, |
|
desc="Encoding images...", |
|
batch_size=batch_size, |
|
return_numpy=return_numpy, |
|
**encode_kwargs, |
|
) |
|
|
|
if max_pixels: |
|
self.processor.image_processor.max_pixels = default_max_pixels |
|
|
|
return embeddings |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path, |
|
*args, |
|
**kwargs, |
|
): |
|
""" |
|
Loads a pretrained model and configures it with the appropriate task adapter (`retrieval` by default). |
|
""" |
|
if "torch_dtype" not in kwargs: |
|
kwargs["torch_dtype"] = "auto" |
|
|
|
task_value = kwargs.pop("task", "test") |
|
try: |
|
task = TaskType(task_value) |
|
except ValueError: |
|
valid_tasks = [t.value for t in TaskType] |
|
raise ValueError( |
|
f"Invalid task: {task_value}. Must be one of {valid_tasks}." |
|
) |
|
|
|
base_model = super().from_pretrained( |
|
pretrained_model_name_or_path, *args, **kwargs |
|
) |
|
|
|
|
|
if os.path.isdir(base_model.name_or_path): |
|
adapter_dir = os.path.join(base_model.name_or_path, "adapters") |
|
else: |
|
adapter_cache_path = snapshot_download( |
|
repo_id=base_model.name_or_path, allow_patterns=["adapters/*"] |
|
) |
|
adapter_dir = os.path.join(adapter_cache_path, "adapters") |
|
|
|
base_model.adapter_dir = adapter_dir |
|
base_model.task = task |
|
|
|
lora_config = LoraConfig.from_pretrained(os.path.join(adapter_dir, task.value)) |
|
lora_config._custom_modules = {torch.nn.modules.linear.Linear: Linear} |
|
|
|
peft_model = PeftModel.from_pretrained( |
|
model=base_model, model_id=os.path.join(adapter_dir, task.value), config=lora_config |
|
) |
|
|
|
|
|
def set_task_method(self, task: Union[str, TaskType]): |
|
""" |
|
Set the task adapter for the model. |
|
|
|
Args: |
|
task (Union[str, TaskType]): The task name. Must be one of TaskType values or |
|
one of ['retrieval', 'text-matching', 'code'] |
|
""" |
|
if isinstance(task, str): |
|
try: |
|
task = TaskType(task) |
|
except ValueError: |
|
valid_tasks = [t.value for t in TaskType] |
|
raise ValueError( |
|
f"Invalid task: {task}. Must be one of {valid_tasks}" |
|
) |
|
if self.model.task != task: |
|
adapter_path = os.path.join(self.adapter_dir, task.value) |
|
hotswap_adapter(self, adapter_path, adapter_name="default") |
|
self.model.task = task |
|
|
|
def get_task_method(self): |
|
""" |
|
Get the task adapter for the model. |
|
""" |
|
return self.model.task.value |
|
|
|
|
|
peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model)) |
|
peft_model.get_task = get_task_method.__get__(peft_model, type(peft_model)) |
|
|
|
return peft_model |
|
|