|
|
|
import torch |
|
from typing import List, Dict, Any, Union |
|
from PIL import Image |
|
from transformers.processing_utils import ProcessorMixin, BatchFeature |
|
from transformers import AutoTokenizer, AutoImageProcessor |
|
|
|
PLACEHOLDER = "<|media_placeholder|>" |
|
|
|
class OpenCUAProcessor(ProcessorMixin): |
|
attributes = ["image_processor", "tokenizer", "image_token_id", "merge_size"] |
|
|
|
def __init__(self, image_processor, tokenizer, image_token_id: int = 151664, merge_size: int = 2, **kwargs): |
|
self.image_processor = image_processor |
|
self.tokenizer = tokenizer |
|
self.image_token_id = image_token_id |
|
self.merge_size = getattr(image_processor, "merge_size", merge_size) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
trust = kwargs.get("trust_remote_code", True) |
|
|
|
try: |
|
from tokenization_opencua import TikTokenV3 |
|
tok = TikTokenV3.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust) |
|
except Exception: |
|
tok = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust) |
|
imgproc = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust) |
|
return cls(imgproc, tok, **kwargs) |
|
|
|
def apply_chat_template(self, messages: List[Dict[str, Any]], **kwargs) -> Union[str, List[int]]: |
|
return self.tokenizer.apply_chat_template(messages, **kwargs) |
|
|
|
|
|
def __call__(self, *args, **kwargs) -> BatchFeature: |
|
|
|
data = {"input_ids": torch.zeros(1, 1, dtype=torch.long)} |
|
return BatchFeature(data=data) |
|
|
|
|
|
def prepare_vllm_inputs(self, messages, images, add_generation_prompt=True): |
|
text = self.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt) |
|
proc = self.image_processor(images=images, return_tensors="pt") |
|
grid = torch.as_tensor(proc["image_grid_thw"]) |
|
merge = getattr(self, "merge_size", 2) |
|
for thw in grid: |
|
num = int((thw[0] * thw[1] * thw[2]) // (merge ** 2)) |
|
text = text.replace(PLACEHOLDER, PLACEHOLDER * num, 1) |
|
return text, images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|