BAAI
/

BGE-VL-Screenshot / modeling_bge_vl_screenshot.py
JUNJIE99's picture
Upload folder using huggingface_hub
8f2764b verified
import logging
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
import torch
from PIL import Image
from typing import List, Optional, Tuple, Union
logger = logging.getLogger(__name__)
class BGE_VL_Screenshot(Qwen2_5_VLForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
mask = input_ids == self.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0] # (Bs, L, D)
embeddings = hidden_states[:, -1, :]
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
return embeddings
def set_processor(self, model_name_or_path, max_len=3072, eos_token_id=151643, min_image_token=64, max_image_token=2500):
self.max_len = max_len
self.eos_token_id = eos_token_id
self.processor = AutoProcessor.from_pretrained(
model_name_or_path,
padding_side='left',
min_pixels=min_image_token * 28 * 28,
max_pixels=max_image_token * 28 * 28
)
assert self.processor.tokenizer.padding_side == 'left'
def prepare_text_input(self, image=None, text=None, q_or_c=None, task_instruction=None):
assert q_or_c in ["query", "candidate", "q", "c"]
prompt_template = "<|im_start|>system\n{}<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
if "q" in q_or_c:
if task_instruction is None:
system_prompt = "You are a helpful assistant."
task_instruction_example_csr = "Represent the given image with the given query."
print(f"""Warning: For optimal performance, UniSE-MLLM requires the task instruction to be specified in the query. For example, for the composed screenshot retrieval task, you might use a specific instruction like: {task_instruction_example_csr}.""")
else:
system_prompt = task_instruction
if image is None:
user_prompt = text
else:
if text is not None:
user_prompt = f"Query:{text}<|vision_start|><|image_pad|><|vision_end|>"
else:
user_prompt = "<|vision_start|><|image_pad|><|vision_end|>"
text_input = prompt_template.format(system_prompt, user_prompt)
else:
if text is not None:
system_prompt = "Represent the given text."
user_prompt = f"{text}"
if image is not None:
system_prompt = "Represent the given text-rich image, focusing on extracting and interpreting both its rich text content and visual features."
user_prompt = f"<|vision_start|><|image_pad|><|vision_end|>"
text_input = prompt_template.format(system_prompt, user_prompt)
# print(text_input)
# print("\n")
return text_input
def data_process(self, images=None, text=None, q_or_c=None, task_instruction=None):
if images is not None:
_is_list = isinstance(images, list)
elif text is not None:
_is_list = isinstance(text, list)
else:
raise ValueError("images and text cannot be both None.")
assert q_or_c in ["query", "candidate", "q", "c"]
if not _is_list :
text_input = self.prepare_text_input(images, text, q_or_c, task_instruction)
text_input = [text_input]
if images is not None:
images = Image.open(images).convert("RGB")
images = [images]
inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len)
else:
inputs = self.processor(text=text_input, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len)
if inputs.input_ids.size(-1) == self.max_len:
inputs.input_ids[:, -1] = self.eos_token_id
assert (inputs.input_ids[:, -1] == self.eos_token_id).all()
assert (inputs.attention_mask[:, -1] == 1).all()
else:
if text is None:
text = [None] * len(images)
text_input = [self.prepare_text_input(_image, _text, q_or_c, task_instruction) for _image, _text in zip(images, text)]
if images is not None:
images = [Image.open(_image).convert("RGB") for _image in images]
inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len)
else:
inputs = self.processor(text=text_input, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len)
if inputs.input_ids.size(-1) == self.max_len:
inputs.input_ids[:, -1] = self.eos_token_id
assert (inputs.input_ids[:, -1] == self.eos_token_id).all()
assert (inputs.attention_mask[:, -1] == 1).all()
inputs = inputs.to(self.device)
return inputs