|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Processor class for Srashina2Vision. |
|
""" |
|
|
|
from copy import deepcopy |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from transformers import ( |
|
AutoImageProcessor, |
|
PreTrainedTokenizer, |
|
Qwen2VLImageProcessor, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
) |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.image_transforms import ( |
|
convert_to_rgb, |
|
to_channel_dimension_format, |
|
) |
|
from transformers.image_utils import ( |
|
ChannelDimension, |
|
ImageInput, |
|
VideoInput, |
|
get_image_size, |
|
infer_channel_dimension_format, |
|
is_scaled_image, |
|
make_list_of_images, |
|
to_numpy_array, |
|
) |
|
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
|
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack |
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class GenerationStopper(StoppingCriteria): |
|
def __init__( |
|
self, |
|
stop_str_list: list[str], |
|
tokenizer: PreTrainedTokenizer, |
|
decode_suffix_length: int = 5, |
|
): |
|
self.stop_str_list = stop_str_list |
|
self.tokenizer = deepcopy(tokenizer) |
|
self.decode_suffix_length = decode_suffix_length |
|
self.input_ids_end = None |
|
|
|
def __repr__(self): |
|
return f"Stopping words: {self.stop_str_list}" |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
if self.input_ids_end is None: |
|
length = input_ids.shape[1] |
|
self.input_ids_end = length - 1 if (length - 1) > 0 else 0 |
|
decode_ids = input_ids[0][self.input_ids_end :][-self.decode_suffix_length :] |
|
if len(decode_ids) == 0: |
|
decoded = "" |
|
else: |
|
decoded = self.tokenizer.decode(decode_ids) |
|
|
|
for stop_str in self.stop_str_list: |
|
if stop_str in decoded: |
|
self.input_ids_end = None |
|
return True |
|
return False |
|
|
|
@property |
|
def criteria(self): |
|
return StoppingCriteriaList([self]) |
|
|
|
def format(self, sentence: str): |
|
for w in self.stop_str_list: |
|
if w in sentence[-len(w) :]: |
|
sentence = sentence[: -len(w)] |
|
return sentence |
|
|
|
|
|
class Sarashina2VisionImageProcessor(Qwen2VLImageProcessor): |
|
def _preprocess( |
|
self, |
|
images: Union[ImageInput, VideoInput], |
|
do_resize: bool = None, |
|
resample: Image.Resampling = None, |
|
do_rescale: bool = None, |
|
rescale_factor: float = None, |
|
do_normalize: bool = None, |
|
image_mean: Optional[Union[float, List[float]]] = None, |
|
image_std: Optional[Union[float, List[float]]] = None, |
|
do_convert_rgb: bool = None, |
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, |
|
input_data_format: Optional[Union[str, ChannelDimension]] = None, |
|
): |
|
""" |
|
Preprocess an image or batch of images. Copy of the `preprocess` method from `Qwen2VLImageProcessor`. |
|
|
|
Args: |
|
images (`ImageInput`): |
|
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. |
|
vision_info (`List[Dict]`, *optional*): |
|
Optional list of dictionaries containing additional information about vision inputs. |
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`): |
|
Whether to resize the image. |
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`): |
|
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. |
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): |
|
Whether to rescale the image. |
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): |
|
Scale factor to use if rescaling the image. |
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): |
|
Whether to normalize the image. |
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
|
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. |
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
|
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. |
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
|
Whether to convert the image to RGB. |
|
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): |
|
The channel dimension format for the output image. Can be one of: |
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
|
- Unset: Use the channel dimension format of the input image. |
|
input_data_format (`ChannelDimension` or `str`, *optional*): |
|
The channel dimension format for the input image. Can be one of: |
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. |
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. |
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. |
|
""" |
|
images = make_list_of_images(images) |
|
|
|
if do_convert_rgb: |
|
images = [convert_to_rgb(image) for image in images] |
|
|
|
|
|
images = [to_numpy_array(image) for image in images] |
|
|
|
if do_rescale and is_scaled_image(images[0]): |
|
logger.warning_once( |
|
"It looks like you are trying to rescale already rescaled images. If the input" |
|
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." |
|
) |
|
if input_data_format is None: |
|
|
|
input_data_format = infer_channel_dimension_format(images[0]) |
|
|
|
height, width = get_image_size(images[0], channel_dim=input_data_format) |
|
resized_height, resized_width = height, width |
|
processed_images = [] |
|
for image in images: |
|
if do_rescale: |
|
image = self.rescale( |
|
image, scale=rescale_factor, input_data_format=input_data_format |
|
) |
|
|
|
if do_normalize: |
|
image = self.normalize( |
|
image=image, |
|
mean=image_mean, |
|
std=image_std, |
|
input_data_format=input_data_format, |
|
) |
|
|
|
image = to_channel_dimension_format( |
|
image, data_format, input_channel_dim=input_data_format |
|
) |
|
|
|
if do_resize: |
|
resized_height, resized_width = smart_resize( |
|
height, |
|
width, |
|
factor=self.patch_size * self.merge_size, |
|
min_pixels=self.min_pixels, |
|
max_pixels=self.max_pixels, |
|
) |
|
image = ( |
|
F.interpolate( |
|
torch.from_numpy(image).unsqueeze(0), |
|
size=(resized_height, resized_width), |
|
mode="bicubic", |
|
) |
|
.squeeze(0) |
|
.numpy() |
|
) |
|
|
|
processed_images.append(image) |
|
|
|
patches = np.array(processed_images) |
|
if data_format == ChannelDimension.LAST: |
|
patches = patches.transpose(0, 3, 1, 2) |
|
if patches.shape[0] % self.temporal_patch_size != 0: |
|
repeats = np.repeat(patches[-1][np.newaxis], self.temporal_patch_size - 1, axis=0) |
|
patches = np.concatenate([patches, repeats], axis=0) |
|
channel = patches.shape[1] |
|
grid_t = patches.shape[0] // self.temporal_patch_size |
|
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size |
|
patches = patches.reshape( |
|
grid_t, |
|
self.temporal_patch_size, |
|
channel, |
|
grid_h // self.merge_size, |
|
self.merge_size, |
|
self.patch_size, |
|
grid_w // self.merge_size, |
|
self.merge_size, |
|
self.patch_size, |
|
) |
|
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) |
|
flatten_patches = patches.reshape( |
|
grid_t * grid_h * grid_w, |
|
channel * self.temporal_patch_size * self.patch_size * self.patch_size, |
|
) |
|
|
|
return flatten_patches, (grid_t, grid_h, grid_w) |
|
|
|
|
|
class Srashina2VisionProcessorKwargs(ProcessingKwargs, total=False): |
|
_defaults = { |
|
"text_kwargs": { |
|
"padding": False, |
|
}, |
|
} |
|
|
|
|
|
class Srashina2VisionProcessor(ProcessorMixin): |
|
r""" |
|
Constructs Srashina2Vision processor which wraps a Srashina2Vision image processor and a LLama tokenizer into a single processor. |
|
[`Srashina2VisionProcessor`] offers all the functionalities of [`Sarashina2VisionImageProcessor`] and [`LlamaTokenizerFast`]. See the |
|
[`~Srashina2VisionProcessor.__call__`] and [`~Srashina2VisionProcessor.decode`] for more information. |
|
Args: |
|
image_processor ([`Sarashina2VisionImageProcessor`], *optional*): |
|
The image processor is a required input. |
|
tokenizer ([`LlamaTokenizerFast`], *optional*): |
|
The tokenizer is a required input. |
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages |
|
in a chat into a tokenizable string. |
|
""" |
|
|
|
attributes = ["image_processor", "tokenizer"] |
|
valid_kwargs = ["chat_template"] |
|
image_processor_class = "AutoImageProcessor" |
|
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") |
|
|
|
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): |
|
self.image_token = ( |
|
"<|file|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token |
|
) |
|
self.stop_symbol = "\n###" |
|
super().__init__(image_processor, tokenizer, chat_template=chat_template) |
|
|
|
def __call__( |
|
self, |
|
images: ImageInput = None, |
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, |
|
**kwargs: Unpack[Srashina2VisionProcessorKwargs], |
|
) -> BatchFeature: |
|
""" |
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` |
|
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode |
|
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to |
|
Sarashina2VisionImageProcessor's [`~Sarashina2VisionImageProcessor.__call__`] if `vision_infos` is not `None`. |
|
|
|
Args: |
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): |
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch |
|
tensor. Both channels-first and channels-last formats are supported. |
|
text (`str`, `List[str]`, `List[List[str]]`): |
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings |
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set |
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). |
|
return_tensors (`str` or [`~utils.TensorType`], *optional*): |
|
If set, will return tensors of a particular framework. Acceptable values are: |
|
- `'tf'`: Return TensorFlow `tf.constant` objects. |
|
- `'pt'`: Return PyTorch `torch.Tensor` objects. |
|
- `'np'`: Return NumPy `np.ndarray` objects. |
|
- `'jax'`: Return JAX `jnp.ndarray` objects. |
|
|
|
Returns: |
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields: |
|
|
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. |
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
|
`None`). |
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. |
|
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. |
|
""" |
|
output_kwargs = self._merge_kwargs( |
|
Srashina2VisionProcessorKwargs, |
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
**kwargs, |
|
) |
|
if images is not None: |
|
image_inputs = self.image_processor( |
|
images=images, videos=None, **output_kwargs["images_kwargs"] |
|
) |
|
image_grid_thw = image_inputs["image_grid_thw"] |
|
else: |
|
image_inputs = {} |
|
image_grid_thw = None |
|
|
|
if not isinstance(text, list): |
|
text = [text] |
|
|
|
if image_grid_thw is not None: |
|
merge_length = self.image_processor.merge_size**2 |
|
index = 0 |
|
for i in range(len(text)): |
|
while self.image_token in text[i]: |
|
text[i] = text[i].replace( |
|
self.image_token, |
|
"<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), |
|
1, |
|
) |
|
index += 1 |
|
text[i] = text[i].replace("<|placeholder|>", self.image_token) |
|
|
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
|
|
|
return BatchFeature(data={**text_inputs, **image_inputs}) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. |
|
""" |
|
return [ |
|
output.replace(self.stop_symbol, "") |
|
for output in self.tokenizer.batch_decode(*args, **kwargs) |
|
] |
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs).replace(self.stop_symbol, "") |
|
|
|
def post_process_image_text_to_text(self, generated_outputs): |
|
""" |
|
Post-process the output of the model to decode the text. |
|
|
|
Args: |
|
generated_outputs (`torch.Tensor` or `np.ndarray`): |
|
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` |
|
or `(sequence_length,)`. |
|
|
|
Returns: |
|
`List[str]`: The decoded text. |
|
""" |
|
return self.tokenizer.batch_decode( |
|
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
) |
|
|
|
@property |
|
def model_input_names(self): |
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
image_processor_input_names = self.image_processor.model_input_names |
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|
|
def get_stopping_criteria(self, stop_symbols: List[str]): |
|
stopping_criteria = GenerationStopper(stop_str_list=stop_symbols, tokenizer=self.tokenizer) |
|
return stopping_criteria.criteria |
|
|
|
|
|
Srashina2VisionProcessor.register_for_auto_class("AutoProcessor") |
|
AutoImageProcessor.register("Sarashina2VisionImageProcessor", Sarashina2VisionImageProcessor) |
|
|