|
import torch |
|
import math |
|
|
|
from typing import List, Union |
|
from PIL import Image |
|
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.image_utils import ImageInput |
|
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order |
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
|
|
|
from .modeling_vora import VoRAForCausalLM |
|
|
|
|
|
def smart_resize( |
|
height: int, width: int, factor: int = 14, min_pixels: int = 14 * 14, max_pixels: int = 14 * 14 * 160 * 160 |
|
): |
|
"""Rescales the image so that the following conditions are met: |
|
|
|
1. Both dimensions (height and width) are divisible by 'factor'. |
|
|
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
|
|
|
3. The aspect ratio of the image is maintained as closely as possible. |
|
|
|
""" |
|
if height < factor or width < factor: |
|
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") |
|
elif max(height, width) / min(height, width) > 200: |
|
raise ValueError( |
|
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" |
|
) |
|
h_bar = round(height / factor) * factor |
|
w_bar = round(width / factor) * factor |
|
if h_bar * w_bar > max_pixels: |
|
beta = math.sqrt((height * width) / max_pixels) |
|
h_bar = math.floor(height / beta / factor) * factor |
|
w_bar = math.floor(width / beta / factor) * factor |
|
elif h_bar * w_bar < min_pixels: |
|
beta = math.sqrt(min_pixels / (height * width)) |
|
h_bar = math.ceil(height * beta / factor) * factor |
|
w_bar = math.ceil(width * beta / factor) * factor |
|
return h_bar, w_bar |
|
|
|
|
|
class VoRAProcessorKwargs(ProcessingKwargs, total=False): |
|
_defaults = { |
|
"text_kwargs": { |
|
"padding": False, |
|
}, |
|
"images_kwargs": {}, |
|
} |
|
|
|
|
|
class VoRAProcesser(ProcessorMixin): |
|
attributes = ["image_processor", "tokenizer"] |
|
valid_kwargs = [ |
|
"chat_template", |
|
"image_token", |
|
] |
|
image_processor_class = "AutoImageProcessor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__( |
|
self, |
|
image_processor=None, |
|
tokenizer=None, |
|
chat_template=None, |
|
image_token="<image>", |
|
image_token_index = -200, |
|
**kwargs, |
|
): |
|
self.image_token = image_token |
|
self.image_token_index = image_token_index |
|
super().__init__(image_processor, tokenizer, chat_template=chat_template) |
|
|
|
def __call__( |
|
self, |
|
images: ImageInput = None, |
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, |
|
**kwargs: Unpack[VoRAProcessorKwargs], |
|
): |
|
if images is None and text is None: |
|
raise ValueError("You have to specify at least one of `images` or `text`.") |
|
|
|
images, text = _validate_images_text_input_order(images, text) |
|
output_kwargs = self._merge_kwargs( |
|
VoRAProcessorKwargs, |
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
**kwargs, |
|
) |
|
|
|
if images is not None: |
|
images = [[self.anyres_resize(image[0])] for image in images] |
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) |
|
else: |
|
image_inputs = {} |
|
|
|
if isinstance(text, str): |
|
text = [text] |
|
elif not isinstance(text, list) and not isinstance(text[0], str): |
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings") |
|
|
|
input_ids = [self.tokenizer_vision_placeholder(t) for t in text] |
|
attention_mask = [ |
|
[1] * len(input_ids[i]) for i in range(len(input_ids)) |
|
] |
|
text_inputs = dict( |
|
input_ids=torch.as_tensor(input_ids, dtype=torch.int64), |
|
attention_mask=torch.as_tensor(attention_mask, dtype=torch.int64), |
|
) |
|
image_inputs['frames'] = image_inputs.pop('pixel_values') |
|
image_inputs['n_frames'] = [len(_images) for _images in images] |
|
image_inputs['vision_placeholder_index'] = self.image_token_index |
|
return BatchFeature(data={**text_inputs, **image_inputs}) |
|
|
|
def anyres_resize(self, pil_img: Image.Image): |
|
h, w = pil_img.size |
|
h, w = smart_resize(h, w) |
|
image = pil_img.resize((w, h)) |
|
return image |
|
|
|
def tokenizer_vision_placeholder(self, prompt, add_bos=False): |
|
def join_lists(*lists, sep): |
|
result = [] |
|
for i, lst in enumerate(lists): |
|
if i > 0 and sep: |
|
result.extend([sep]) |
|
result.extend(lst) |
|
return result |
|
|
|
prompt_chunks = [self.tokenizer.encode( |
|
chunk) for chunk in prompt.split(self.image_token)] |
|
input_ids = join_lists(*prompt_chunks, sep=self.image_token_index) |
|
if add_bos: |
|
input_ids = [self.tokenizer.bos_token_id] + input_ids |
|
|
|
return input_ids |
|
|
|
|