dong.hyun
HyperCLOVAX-Seed-Vision-3B
42c6bee
import base64
import copy
import io
import math
import os
import uuid
from typing import Dict, List, Optional, Union
from urllib.parse import urlparse
import av
import cv2
import numpy as np
import requests
import torch
from decord import VideoReader, cpu
from PIL import Image, UnidentifiedImageError
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import (
convert_to_rgb,
get_resize_output_image_size,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
def determine_possible_resolutions(anyres: bool, max_num_grids: int, grid_size: int, use_1x1_grid: bool = False):
"""
Finds and returns possible resolution combinations with a total number of grids less than or equal to max_num_grids.
For example, if max_num_grids is 4, the possible grid combinations are:
[1x1, 1x2, 1x3, 1x4, 2x1, 2x2, 3x1, 4x1], and the resolutions are calculated accordingly.
Example:
>>> possible_resolutions = determine_possible_resolutions(anyres=True, max_num_grids=4, grid_size=336)
>>> print(possible_resolutions)
[[336, 336], [336, 672], [336, 1008], [336, 1344], [672, 336], [672, 672], [1008, 336], [1344, 336]]
Args:
anyres (bool): Whether to allow any resolution combinations up to the maximum grid count.
max_num_grids (int): The maximum number of grids allowed (height x width must be ≤ this value).
grid_size (int): The size of each grid in pixels (e.g., 336).
use_1x1_grid (bool, optional): Whether to include the 1x1 grid as a valid resolution. Defaults to False.
Returns:
List[List[int]]: A list of possible [height, width] resolution pairs.
"""
possible_resolutions = []
if anyres:
assert max_num_grids > 0
for i in range(1, max_num_grids + 1):
for j in range(1, max_num_grids + 1):
if i == 1 and j == 1 and not use_1x1_grid:
continue
if i * j <= max_num_grids:
possible_resolutions.append([i, j])
possible_resolutions = [[ys * grid_size, xs * grid_size] for ys, xs in possible_resolutions]
return possible_resolutions
def divide_to_grids(image: np.array, grid_size: int, input_data_format=None) -> List[np.array]:
"""
Divides a local image into grids of size (grid_size x grid_size).
Args:
image (np.array): Input image as a NumPy array.
grid_size (int): The size (in pixels) of each square grid.
input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
Returns:
List[np.array]: A list of image patches, each of size (grid_size x grid_size).
"""
grids = []
height, width = get_image_size(image, channel_dim=input_data_format)
for i in range(0, height, grid_size):
for j in range(0, width, grid_size):
if input_data_format == ChannelDimension.LAST:
grid = image[i : i + grid_size, j : j + grid_size]
else:
grid = image[:, i : i + grid_size, j : j + grid_size]
grids.append(grid)
return grids
def pad(
image: np.array,
target_size: tuple,
background_color=(127, 127, 127),
input_data_format=None,
) -> np.array:
"""
Pads the input image on the sides (top/bottom and left/right) to match the target height and width.
Args:
image (np.array): Input image as a NumPy array.
target_size (tuple): Target size as (target_height, target_width).
background_color (tuple, optional): RGB color value used for padding. Defaults to (127, 127, 127).
input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
Returns:
np.array: The padded image with the specified target size.
"""
target_height, target_width = target_size
height, width = get_image_size(image, channel_dim=input_data_format)
# result = np.ones((target_height, target_width, image.shape[2]), dtype=image.dtype) * background_color
result = np.empty((target_height, target_width, image.shape[2]), dtype=image.dtype)
for i in range(image.shape[2]):
result[..., i].fill(background_color[i])
paste_x = (target_width - width) // 2
paste_y = (target_height - height) // 2
result[paste_y : paste_y + height, paste_x : paste_x + width, :] = image
return result
def expand2square(
image: np.array,
bboxes_dict=None,
background_color=(127, 127, 127),
input_data_format=None,
) -> np.array:
"""
Expands the input image to a square shape by placing it at the center of a new square canvas,
with padding added to the shorter side (either top/bottom or left/right).
The image is always centered on the new canvas, and padding is applied symmetrically.
Args:
image (np.array): Input image as a NumPy array.
bboxes_dict (dict, optional): A dictionary of bounding boxes, where each value is an NDArray of shape (N, 4, 2)
with box coordinates in the format [[xtl, ytl], [xtr, ytr], [xbr, ybr], [xbl, ybl]].
Supports multiple categories (e.g., "ocr", "html") simultaneously.
background_color (tuple, optional): RGB color to fill the padding area. Defaults to (127, 127, 127).
input_data_format (optional): Optional format specifier for image data (e.g., "channels_first" or "channels_last").
Returns:
np.array: A square-shaped image with the original image centered and padded as needed.
Example:
>>> _img = np.ones((80, 100), dtype=np.uint8) * 100
>>> _bboxes_dict = {"words": np.array([[[10, 10], [20, 10], [20, 20], [10, 20]],
... [[30, 30], [40, 30], [40, 40], [30, 40]]])}
>>> _img, _bboxes_dict = expand2square(_img, _bboxes_dict, (255, 255, 255))
>>> _img.shape
(100, 100)
>>> guessed_ocr_bboxes = np.array([[[20, 10], [30, 10], [30, 20], [20, 20]],
... [[40, 30], [50, 30], [50, 40], [40, 40]]])
>>> np.testing.assert_array_almost_equal(_bboxes_dict["words"], guessed_ocr_bboxes) is None
True
"""
height, width = get_image_size(image, channel_dim=input_data_format)
if width == height:
return image, bboxes_dict
elif width > height:
# result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
result = np.empty((width, width, image.shape[2]), dtype=image.dtype)
for i in range(image.shape[2]):
result[..., i].fill(background_color[i])
result[(width - height) // 2 : (width - height) // 2 + height, :] = image
if bboxes_dict is not None:
for key in bboxes_dict:
bboxes_dict[key][:, :, 1] += (width - height) // 2
return result, bboxes_dict
else:
# result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
result = np.empty((height, height, image.shape[2]), dtype=image.dtype)
for i in range(image.shape[2]):
result[..., i].fill(background_color[i])
result[:, (height - width) // 2 : (height - width) // 2 + width] = image
if bboxes_dict is not None:
for key in bboxes_dict:
bboxes_dict[key][:, :, 0] += (height - width) // 2
return result, bboxes_dict
def resize_longside(
image: np.array,
size: int,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Resizes the image so that its longer side matches the specified size, maintaining the original aspect ratio.
Args:
image (np.array): Input image as a NumPy array.
size (int): Target size for the longer side of the image.
resample (PILImageResampling, optional): Resampling method to use during resizing. Defaults to BICUBIC.
data_format (str or ChannelDimension, optional): Output data format (e.g., "channels_first" or "channels_last").
input_data_format (str or ChannelDimension, optional): Input data format of the image.
Returns:
np.array: The resized image with its aspect ratio preserved.
"""
height, width = get_image_size(image, channel_dim=input_data_format)
if width == height:
target_height, target_width = size, size
elif width > height:
target_width = size
target_height = math.ceil(height / width * size)
else:
target_width = math.ceil(width / height * size)
target_height = size
return resize(
image,
size=(target_height, target_width),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
)
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
"""
Selects the best-fit resolution from a list of possible resolutions based on the original image size.
This function, adapted from LLaVA-Next
(https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llava_next/image_processing_llava_next.py),
evaluates each resolution by computing its effective and wasted area compared to the original size.
The optimal resolution is the one that maximizes the effective area while minimizing unused (wasted) space.
Args:
original_size (tuple): The original image size in the format (height, width).
possible_resolutions (list): A list of candidate resolutions in the format [(height1, width1), (height2, width2), ...].
Returns:
tuple: The best-fit resolution in the format (height, width).
"""
original_height, original_width = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for height, width in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (height, width)
return best_fit
def _get_local_grids_output_size(image: np.array, target_resolution: tuple, input_data_format=None):
"""
Computes the number of local grids (patches) along the height and width when resizing an image
to the target resolution.
Args:
image (np.array): Input image as a NumPy array.
target_resolution (tuple): Target resolution in the format (target_height, target_width).
input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
Returns:
tuple: A tuple (grid_h, grid_w) representing the number of grids along the height and width.
"""
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
def determine_anyres_num_vision_patches(
num_grids,
image_size,
grid_size,
patch_size,
possible_resolutions,
anyres=False,
unpad=True,
num_queries_vis_abstractor=0,
num_queries_vis_abstractor_slow=0,
is_video=False,
first_last_frames_slow=False, # sample-wise option
is_first_or_last_frames=False, # grid-wise option
):
"""
Computes the number of visual tokens (patches) based on image resolution, grid configuration, and patch size.
This function supports both fixed-size and any-resolution settings, as well as video-specific configurations
such as handling slow frames and frame position flags.
Args:
num_grids (int): Number of grids per image (e.g., 1 for 1x1, 4 for 2x2, etc.).
image_size (tuple): The original image size as (height, width).
grid_size (int): Size of each grid in pixels (e.g., 336).
patch_size (int): Size of each vision patch (e.g., 14 for ViT models).
possible_resolutions (list): List of possible resolution tuples [(h1, w1), (h2, w2), ...].
anyres (bool, optional): Whether to use any-resolution mode. Defaults to False.
unpad (bool, optional): Whether to unpad the image before computing patches. Defaults to True.
num_queries_vis_abstractor (int, optional): Number of query tokens for vision abstractor (fast path).
num_queries_vis_abstractor_slow (int, optional): Number of query tokens for vision abstractor (slow path).
is_video (bool, optional): Whether the input is a video. Defaults to False.
first_last_frames_slow (bool, optional): Whether to treat first/last video frames as "slow". Defaults to False.
is_first_or_last_frames (bool, optional): Whether current grid corresponds to first/last frame. Defaults to False.
Returns:
int: Total number of visual tokens (patches) after processing.
"""
if not anyres:
return num_queries_vis_abstractor if num_queries_vis_abstractor > 0 else (grid_size // patch_size) ** 2
if num_queries_vis_abstractor > 0:
num_patch_per_grid = int(num_queries_vis_abstractor**0.5)
else:
num_patch_per_grid = grid_size // patch_size
num_global_per_grid = num_patch_per_grid
# In anyres mode, a global image is included, so there are always at least 2 grids.
# However, for video inputs, there is no global image, so it's possible to have only 1 grid.
# Therefore, the assertion below is commented out:
# assert num_grids > 1
# Compute the number of vision patches.
height, width = select_best_resolution(image_size, possible_resolutions)
num_patch_height = (height // grid_size) * num_patch_per_grid
num_patch_width = (width // grid_size) * num_patch_per_grid
# local images
if unpad:
original_height, original_width = image_size
original_aspect_ratio = original_width / original_height
current_aspect_ratio = num_patch_width / num_patch_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = num_patch_width / original_width
new_height = int(original_height * scale_factor)
padding = (num_patch_height - new_height) // 2
num_patch_height = num_patch_height - padding * 2
else:
scale_factor = num_patch_height / original_height
new_width = int(original_width * scale_factor)
padding = (num_patch_width - new_width) // 2
num_patch_width = num_patch_width - padding * 2
num_patches = num_patch_width * num_patch_height + num_patch_height
else:
num_patches = num_patch_width * num_patch_height
# In the "slow" strategy, when applying to first and last frames only, it is applied exclusively to those two frames.
if num_queries_vis_abstractor_slow > 0:
if first_last_frames_slow:
if is_first_or_last_frames:
num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
else:
num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
# The slowfast feature is only applicable when unpad is set to False.
assert unpad is False
# Global image is not included for video inputs.
if not is_video:
num_patches += num_global_per_grid**2
return num_patches
class HCXVisionProcessor(BaseImageProcessor):
r"""
Constructs a VLM image processor.
This processor is based on [`CLIPImageProcessor`] and incorporates additional techniques
for handling high-resolution images, such as flexible resolution support (`anyres`), unpadding,
square padding, and multi-grid patching strategies.
Args:
do_resize (bool): Whether to resize the image.
size (Dict[str, int], optional): Target size for resizing, typically with keys `"height"` and `"width"`.
anyres (bool): Whether to enable the any-resolution (`anyres`) feature, which allows flexible resolution handling via grid division.
unpad (bool): When `anyres` is enabled, whether to remove visual tokens corresponding to pure padding regions.
max_num_grids (int): Maximum number of grids allowed per image.
max_image_cnt (int): Maximum number of images that can be processed at once (used for batching).
num_queries_vis_abstractor (int): Number of visual query tokens per grid when using a visual resampler (e.g., Perceiver).
num_queries_vis_abstractor_video_fast (int): Number of visual queries for fast-path video frames.
num_queries_vis_abstractor_video_slow (int): Number of visual queries for slow-path video frames (e.g., first/last).
possible_resolutions (List): List of allowed resolution pairs when `anyres` is enabled. Example: [[336, 336], [336, 672], [672, 336]].
patch_size (int): Patch size for the Vision Transformer (ViT).
pad_to_square (bool): Whether to pad images to a square shape. If `False`, a center crop is applied to fit ViT input.
resample (PILImageResampling): Resampling method to use for resizing. Default is `BICUBIC`.
do_center_crop (bool): Whether to apply center cropping.
crop_size (Dict[str, int], optional): Size for center cropping.
do_rescale (bool): Whether to rescale pixel values.
rescale_factor (float or int): Factor to use for rescaling pixel values (typically `1/255`).
do_normalize (bool): Whether to normalize pixel values using `image_mean` and `image_std`.
image_mean (float or List[float], optional): Mean values for normalization. Can be a single float or list of floats per channel.
image_std (float or List[float], optional): Standard deviation values for normalization. Can be a single float or list of floats per channel.
do_convert_rgb (bool): Whether to convert the input image to RGB.
first_last_frames_slow (bool): Whether to treat the first and last frames of a video as “slow path” (processed differently).
Attributes:
model_input_names (List[str]): Names of the expected model inputs. Defaults to `["pixel_values"]`.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
anyres: bool = False,
unpad: bool = False,
max_num_grids: int = 9,
max_image_cnt: int = 12,
num_queries_vis_abstractor: int = 0,
num_queries_vis_abstractor_video_fast: int = 0,
num_queries_vis_abstractor_video_slow: int = 0,
possible_resolutions: List = [],
patch_size: int = 14,
pad_to_square: bool = True,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
first_last_frames_slow: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 512}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.do_resize = do_resize
self.size = size
self.anyres = anyres
self.unpad = unpad
self.max_num_grids = max_num_grids
self.max_image_cnt = max_image_cnt
self.num_queries_vis_abstractor = num_queries_vis_abstractor
self.num_queries_vis_abstractor_video_fast = num_queries_vis_abstractor_video_fast
self.num_queries_vis_abstractor_video_slow = num_queries_vis_abstractor_video_slow
self.possible_resolutions = [_resolution for _resolution in possible_resolutions]
self.patch_size = patch_size
self.pad_to_square = pad_to_square
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self.first_last_frames_slow = first_last_frames_slow
assert self.crop_size["height"] == self.crop_size["width"]
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resizes the input image to the specified target size.
Args:
image (np.ndarray): The input image to resize.
size (Dict[str, int]): A dictionary specifying the target size with keys `"height"` and `"width"`.
resample (PILImageResampling, optional): The resampling filter to use. Defaults to `BICUBIC`.
data_format (str or ChannelDimension, optional): The desired output data format (e.g., "channels_last").
input_data_format (str or ChannelDimension, optional): The input data format of the image.
**kwargs: Additional keyword arguments, if any.
Returns:
np.ndarray: The resized image as a NumPy array.
"""
default_to_square = True
if "shortest_edge" in size:
size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size(
image,
size=size,
default_to_square=default_to_square,
input_data_format=input_data_format,
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def _preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Image.Image:
"""
Applies a sequence of preprocessing operations to the input image(s), including resizing, cropping, rescaling,
normalization, and format conversion.
This method is typically used internally to prepare images for model input.
Args:
images (ImageInput): A single image or a batch of images to preprocess.
do_resize (bool, optional): Whether to resize the image(s).
size (Dict[str, int], optional): Target size for resizing, with keys `"height"` and `"width"`.
resample (PILImageResampling, optional): Resampling method to use for resizing.
do_center_crop (bool, optional): Whether to apply center cropping.
crop_size (int, optional): Size of the center crop (applied to both height and width).
do_rescale (bool, optional): Whether to rescale the image pixel values.
rescale_factor (float, optional): Factor to use when rescaling pixel values (e.g., 1/255).
do_normalize (bool, optional): Whether to normalize the image using `image_mean` and `image_std`.
image_mean (float or List[float], optional): Mean value(s) used for normalization.
image_std (float or List[float], optional): Standard deviation value(s) used for normalization.
data_format (ChannelDimension, optional): The desired output data format (e.g., `ChannelDimension.FIRST`).
input_data_format (str or ChannelDimension, optional): The format of the input image(s).
Returns:
Image.Image: The preprocessed image or batch of images, ready for model input.
"""
images = make_list_of_images(images)
if do_resize:
images = [
self.resize(
image=image,
size=size,
resample=resample,
input_data_format=input_data_format,
)
for image in images
]
if do_center_crop:
images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
if do_rescale:
images = [
self.rescale(
image=image,
scale=rescale_factor,
input_data_format=input_data_format,
)
for image in images
]
if do_normalize:
images = [
self.normalize(
image=image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
return images
def _resize_for_local_grids(
self,
image: np.array,
target_resolution: tuple,
resample,
input_data_format: ChannelDimension,
) -> np.array:
"""
Resizes the image to the given target resolution for use in local grid processing.
This function ensures that the image is properly resized to match the (height, width) specified
in `target_resolution`, using the provided resampling method. It supports channel-first and
channel-last formats based on `input_data_format`.
Args:
image (np.array): Input image as a NumPy array.
target_resolution (tuple): Target resolution as (height, width) for resizing.
resample: Resampling method to use (e.g., `PILImageResampling.BICUBIC`).
input_data_format (ChannelDimension): Format of the input image (e.g., `ChannelDimension.FIRST` or `LAST`).
Returns:
np.array: The resized image in NumPy array format.
"""
new_height, new_width = _get_local_grids_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(
image,
(new_height, new_width),
resample=resample,
input_data_format=input_data_format,
)
return resized_image
def _pad_for_patching(
self,
image: np.array,
target_resolution: tuple,
input_data_format: ChannelDimension,
) -> np.array:
"""
Pads the image to match the target resolution, ensuring compatibility with patch-based models.
This is typically used to make sure the image dimensions are divisible by the patch size or to
meet specific model input requirements. Padding is applied symmetrically where needed.
Args:
image (np.array): Input image as a NumPy array.
target_resolution (tuple): The desired resolution after padding, in the format (height, width).
input_data_format (ChannelDimension): Format of the input image (e.g., `ChannelDimension.FIRST` or `LAST`).
Returns:
np.array: The padded image as a NumPy array.
"""
target_height, target_width = target_resolution
background_color = tuple(int(x * 255) for x in self.image_mean)
padded_image = pad(
image,
target_size=(target_height, target_width),
background_color=background_color,
input_data_format=input_data_format,
)
return padded_image
def get_image_grids(
self,
image: np.array,
possible_resolutions,
grid_size: int,
resample: PILImageResampling,
data_format: ChannelDimension,
input_data_format: ChannelDimension,
) -> List[np.array]:
"""
Splits the input image into multiple local grids based on possible resolutions and grid size.
The function selects the best resolution from the provided list, resizes the image accordingly,
and divides it into non-overlapping grid patches of size (grid_size x grid_size). It is commonly
used for any-resolution (anyres) visual processing.
Args:
image (np.array): Input image as a NumPy array.
possible_resolutions (List[Tuple[int, int]]): List of allowed resolutions to choose from.
grid_size (int): The size of each grid patch (e.g., 336 pixels).
resample (PILImageResampling): Resampling method used during resizing.
data_format (ChannelDimension): Output data format (e.g., `ChannelDimension.FIRST`).
input_data_format (ChannelDimension): Input data format of the image.
Returns:
List[np.array]: A list of grid image patches as NumPy arrays.
"""
if not isinstance(possible_resolutions, list):
raise ValueError("possible_resolutions must be a list of possible resolutions.")
image_size = get_image_size(image, channel_dim=input_data_format)
best_resolution = select_best_resolution(image_size, possible_resolutions)
resized_image = self._resize_for_local_grids(
image,
best_resolution,
resample=resample,
input_data_format=input_data_format,
)
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
local_grids = divide_to_grids(padded_image, grid_size=grid_size, input_data_format=input_data_format)
# make sure that all patches are in the input data format
local_grids = [
to_channel_dimension_format(grid, channel_dim=data_format, input_channel_dim=input_data_format)
for grid in local_grids
]
return local_grids
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
anyres: bool = None,
unpad: bool = None,
is_video_list: List[bool] = None,
possible_resolutions: List = None,
patch_size: int = None,
pad_to_square: bool = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
is_first_or_last_frames: List[bool] = False,
):
"""
Preprocesses images using HCXVisionProcessor.
This method prepares images for visual language models by applying resizing, padding, cropping,
normalization, and tokenization into visual patches. In video mode, each frame is converted to
a 1D sequence of patches. The `unpad` option is disabled when processing videos.
Args:
images (ImageInput): A single image or a batch of images (PIL, NumPy, or tensor format).
do_resize (bool, optional): Whether to resize the image(s).
size (Dict[str, int], optional): Resize target with keys `"height"` and `"width"`.
anyres (bool, optional): Whether to use any-resolution processing with grid splitting.
unpad (bool, optional): Whether to remove visual tokens that belong to padding areas (only in non-video mode).
is_video_list (List[bool], optional): A list indicating which inputs are video frames.
possible_resolutions (List, optional): List of resolution pairs allowed in `anyres` mode.
patch_size (int, optional): Patch size for the Vision Transformer (ViT).
pad_to_square (bool, optional): Whether to pad the image to a square.
resample (PILImageResampling, optional): Resampling method to use for resizing.
do_center_crop (bool, optional): Whether to apply center cropping.
crop_size (int, optional): Target crop size for center cropping.
do_rescale (bool, optional): Whether to rescale image pixel values.
rescale_factor (float, optional): Factor for pixel rescaling, e.g., `1/255`.
do_normalize (bool, optional): Whether to normalize using mean and std.
image_mean (float or List[float], optional): Mean value(s) for normalization.
image_std (float or List[float], optional): Standard deviation(s) for normalization.
do_convert_rgb (bool, optional): Whether to convert the image to RGB.
return_tensors (str or TensorType, optional): Desired output tensor type (e.g., "pt" for PyTorch).
data_format (ChannelDimension, optional): Output data format (e.g., `ChannelDimension.FIRST`).
input_data_format (str or ChannelDimension, optional): Format of the input image.
is_first_or_last_frames (List[bool], optional): Flags indicating whether each image is a first/last video frame.
Returns:
Tuple:
pixel_values (List[torch.Tensor]): A list of 4D image tensors ready for model input.
image_sizes (List[List[int]]): A list of list containing the original width and height [width, height]
of each image, e.g., `[[width, height], ...]`.
vision_query_lengths (List[int]): A list of integers representing the number of visual tokens
each image contributes to the LLM input.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
anyres = anyres if anyres is not None else self.anyres
unpad = unpad if unpad is not None else self.unpad
possible_resolutions = possible_resolutions if possible_resolutions is not None else self.possible_resolutions
patch_size = patch_size if patch_size is not None else self.patch_size
pad_to_square = pad_to_square if pad_to_square is not None else self.pad_to_square
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
new_images = []
image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
vision_query_lengths = []
assert crop_size["height"] == crop_size["width"]
# Padding operations for the global image can become a bottleneck when the original image width or height is large.
# To mitigate this, the image is first resized such that the longest side is scaled proportionally based on size["shortest_edge"],
# and then padding is applied to reach the target dimensions.
if anyres:
anyres_global_images = copy.deepcopy(images)
if pad_to_square:
background_color = tuple(int(x * 255) for x in self.image_mean)
anyres_global_images = [
resize_longside(
copy.deepcopy(image),
size["shortest_edge"],
resample,
input_data_format,
)
for image in anyres_global_images
]
anyres_global_images = [
expand2square(
image,
background_color=background_color,
input_data_format=input_data_format,
)[0]
for image in anyres_global_images
]
else:
anyres_global_images = [
self.resize(
image=image,
size={
"height": size["shortest_edge"],
"width": size["shortest_edge"],
},
resample=resample,
input_data_format=input_data_format,
)
for image in anyres_global_images
]
else:
anyres_global_images = [None for _ in range(len(images))]
if pad_to_square:
background_color = tuple(int(x * 255) for x in self.image_mean)
images = [
resize_longside(image, size["shortest_edge"], resample, input_data_format) for image in images
]
images = [
expand2square(
image,
background_color=background_color,
input_data_format=input_data_format,
)[0]
for image in images
]
num_queries_vis_abstractors = []
num_queries_vis_abstractors_slow = []
first_last_frames_slows = []
for image, is_video, anyres_global_image, image_size in zip(
images, is_video_list, anyres_global_images, image_sizes
):
if is_video:
num_queries_vis_abstractor = self.num_queries_vis_abstractor_video_fast
num_queries_vis_abstractor_slow = self.num_queries_vis_abstractor_video_slow
else:
num_queries_vis_abstractor = self.num_queries_vis_abstractor
num_queries_vis_abstractor_slow = 0
num_queries_vis_abstractors.append(num_queries_vis_abstractor)
num_queries_vis_abstractors_slow.append(num_queries_vis_abstractor_slow)
first_last_frames_slows.append(self.first_last_frames_slow)
if anyres:
# convert image into a list of grids
# we intentially use the same data format as the input data format
image_grids = self.get_image_grids(
image,
possible_resolutions,
grid_size=crop_size["height"],
resample=resample,
data_format=input_data_format,
input_data_format=input_data_format,
)
# Global image (thumbnail) is not used for video inputs.
if not is_video:
image_grids = [anyres_global_image] + image_grids
else:
image_grids = [image]
pixel_values = self._preprocess(
image_grids,
do_resize=do_resize,
size=size,
resample=resample,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
input_data_format=input_data_format,
)
pixel_values = np.array(pixel_values)
new_images.append(pixel_values)
num_grids = pixel_values.shape[0]
vision_query_length = determine_anyres_num_vision_patches(
num_grids=num_grids,
image_size=image_size,
grid_size=crop_size["height"],
patch_size=patch_size,
possible_resolutions=possible_resolutions,
anyres=anyres,
unpad=False if is_video else unpad,
num_queries_vis_abstractor=num_queries_vis_abstractor,
num_queries_vis_abstractor_slow=num_queries_vis_abstractor_slow,
is_video=is_video,
first_last_frames_slow=self.first_last_frames_slow,
is_first_or_last_frames=self.first_last_frames_slow,
)
vision_query_lengths.append(vision_query_length)
data = {
"pixel_values": [[torch.tensor(new_image) for new_image in new_images]],
"image_sizes": [[[image_size[1], image_size[0]] for image_size in image_sizes]],
"vision_query_lengths": [vision_query_lengths],
"is_videos": [is_video_list],
"num_queries_vis_abstractors": [num_queries_vis_abstractors],
"num_queries_vis_abstractors_slow": [num_queries_vis_abstractors_slow],
"first_last_frames_slows": [first_last_frames_slows],
}
return BatchFeature(data=data)
def load_images_videos(self, vlm_chat):
"""
Loads and prepares images or video frames from a VLM chat input.
This function parses the input `vlm_chat` object, extracts image or video sources,
and loads them into memory as PIL or NumPy images, ready for preprocessing.
Args:
vlm_chat: A VLM chat input structure containing multimodal elements
(e.g., images, videos, URLs, or file paths). The format is typically a list of messages
with associated media fields.
Returns:
List[Union[PIL.Image.Image, List[PIL.Image.Image]]]:
A list of loaded images. For video entries, a list of frames is returned instead of a single image.
"""
vlm_chat = copy.deepcopy(vlm_chat)
new_vlm_chat = []
all_images = [] # images + images_from_videos
is_video_list = []
for line in vlm_chat:
if "content" in line:
content = line["content"]
if "image" in content:
if "filename" not in content:
content["filename"] = f"{uuid.uuid4().hex}.jpg"
image_pil = load_image(content["image"])
all_images.append(image_pil)
is_video_list.append(False)
new_vlm_chat.append(line)
elif "video" in content:
video_bytesio = load_video_to_bytesio(content["video"])
pil_img_frames, video_time_stamp = process_video(
video_bytesio, self.max_num_grids, self.max_image_cnt, self.crop_size["width"]
)
all_images.extend(pil_img_frames)
is_video_list.extend([True] * len(pil_img_frames))
if "filename" not in content:
content["filename"] = f"{uuid.uuid4().hex}.mp4"
for i, image_time_stamp in enumerate(video_time_stamp):
new_line = copy.deepcopy(line)
basename, ext = os.path.splitext(content["filename"])
new_line["content"]["filename"] = f"{basename}-{i}{ext}"
new_line["content"]["video_time_stamp"] = image_time_stamp
if i == len(video_time_stamp) - 1:
new_line["content"]["is_final_grid"] = True
for last_frame_target_key in ["lens_keywords", "lens_local_keywords", "speech_to_text"]:
if last_frame_target_key in content:
new_line["content"][last_frame_target_key] = content[last_frame_target_key]
new_vlm_chat.append(new_line)
else:
new_vlm_chat.append(line)
return new_vlm_chat, all_images, is_video_list
def process_video(video_bytesio, max_num_grids, max_image_cnt, vit_input_size):
"""
Processes a video file and extracts frames suitable for vision transformer (ViT) input.
The function reads video data from a BytesIO object, extracts a limited number of frames
based on `max_num_grids` and `max_image_cnt`, and resizes them to the appropriate ViT input size.
Args:
video_bytesio (io.BytesIO): A BytesIO object containing the raw video file data.
max_num_grids (int): The maximum number of grids allowed (e.g., for tiling or patching).
max_image_cnt (int): The maximum number of frames to extract from the video.
vit_input_size (int): The desired input size (height and width) for the ViT model.
Returns:
List[np.ndarray]: A list of processed video frames as NumPy arrays, each resized to (vit_input_size, vit_input_size).
"""
frames, time_interval = video_decoder(
video_bytesio, max_num_grids=max_num_grids, max_image_cnt=max_image_cnt, default_interval=0.4
)
pil_img_frames, video_time_stamp = combine_frames_into_images(
frames, time_interval, max_grid_shape=(max_num_grids, 1), vit_input_size=vit_input_size
)
return pil_img_frames, video_time_stamp
def load_image(image_src):
"""
Loads an image from various sources (file path, URL, base64 string, or raw bytes)
and returns it as a PIL Image object.
Args:
image_src (str or bytes): The image source. It can be:
- A local file path
- A URL
- A base64-encoded string
- Raw image bytes
Returns:
PIL.Image.Image: The loaded image as a PIL Image object.
Raises:
ValueError: If the image cannot be loaded or the format is unsupported.
TypeError: If the input is not of type str or bytes.
"""
try:
# 1. If input is bytes type
if isinstance(image_src, bytes):
return Image.open(io.BytesIO(image_src))
# 2. If input is str type (path, URL, base64)
if isinstance(image_src, str):
# 2a. Check if it's a Base64 data URI format ('data:image/...')
if image_src.startswith("data:image"):
try:
# Remove the 'data:image/...;base64,' part and decode
header, encoded = image_src.split(",", 1)
image_bytes = base64.b64decode(encoded)
return Image.open(io.BytesIO(image_bytes))
except (ValueError, base64.binascii.Error) as e:
raise ValueError(f"Invalid base64 data URI format: {e}") from e
# 2b. Check if it's a URL format ('http://' or 'https://')
elif image_src.startswith("http://") or image_src.startswith("https://"):
try:
response = requests.get(image_src, stream=True, timeout=10)
response.raise_for_status() # Raise an exception for HTTP errors
image_bytes = response.content
return Image.open(io.BytesIO(image_bytes))
except requests.exceptions.RequestException as e:
raise ValueError(f"Error loading image from URL '{image_src}': {e}") from e
# 2c. Assume it's a local file path
else:
return Image.open(image_src)
else:
raise TypeError(f"Unsupported image_src type: {type(image_src)}")
# Common exception handling
except FileNotFoundError:
raise ValueError(f"Image loading error: File not found '{image_src}'")
except UnidentifiedImageError:
raise ValueError("Image loading error: Cannot identify image file format.")
except IOError as e:
raise ValueError(f"Image loading error (I/O): {e}") from e
except Exception as e:
raise ValueError(f"Unexpected error during image loading: {e}") from e
def load_video_to_bytesio(video_src):
"""
Loads video data from various sources (file path, URL, base64 string, or raw bytes)
and returns an `io.BytesIO` object containing the raw video content.
Args:
video_src (str or bytes): The video source. Supported formats include:
- Local file path
- URL
- Base64-encoded data URI string
- Raw video bytes
Returns:
io.BytesIO: A `BytesIO` object containing the loaded video data.
Raises:
ValueError: If the video cannot be loaded due to issues such as an invalid path,
URL failure, malformed base64 string, or unsupported format.
TypeError: If the input is not a `str` or `bytes` object.
"""
video_bytes = None
try:
# 1. If input is bytes type
if isinstance(video_src, bytes):
video_bytes = video_src
# 2. If input is str type (path, URL, base64)
elif isinstance(video_src, str):
# 2a. Check if it's a Base64 data URI format ('data:video/...')
if video_src.startswith("data:video"):
try:
# Remove the 'data:video/...;base64,' part and decode
header, encoded = video_src.split(",", 1)
video_bytes = base64.b64decode(encoded)
except (ValueError, base64.binascii.Error) as e:
raise ValueError(f"Invalid base64 data URI format: {e}") from e
# 2b. Check if it looks like a URL
elif urlparse(video_src).scheme in ("http", "https"):
try:
response = requests.get(
video_src, stream=True, timeout=30
) # Increased timeout for potentially large videos
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
# Read all content from the stream into bytes
video_bytes = response.content
except requests.exceptions.MissingSchema:
# If urlparse thinks it's a scheme but requests disagrees (e.g., "http:/example.com")
# Treat it as a potential file path below.
pass
except requests.exceptions.RequestException as e:
raise ValueError(f"Error loading video from URL '{video_src}': {e}") from e
# 2c. Assume it's a local file path if not base64 or confirmed URL
if video_bytes is None: # Only attempt file read if not already loaded as base64 or URL failed gracefully
# Check if it could potentially be a file path
# Note: This check is basic. A string like "http:/path/file" might incorrectly be treated as a path here
# if the requests call failed due to MissingSchema. More robust path validation could be added.
if (
os.path.exists(video_src) or "/" in video_src or "\\" in video_src
): # Basic check if it resembles a path
try:
with open(video_src, "rb") as f:
video_bytes = f.read()
except FileNotFoundError:
raise ValueError(f"Video loading error: File not found at path '{video_src}'")
except IsADirectoryError:
raise ValueError(f"Video loading error: Path '{video_src}' is a directory, not a file.")
except IOError as e:
raise ValueError(f"Video loading error (I/O) for path '{video_src}': {e}") from e
else:
# If it's not base64, not a valid downloadable URL, and doesn't look like a path/doesn't exist
raise ValueError(f"Unsupported string input format or resource not found: '{video_src}'")
# 3. If the type is unsupported
else:
raise TypeError(f"Unsupported video_src type: {type(video_src)}")
# Final check if video_bytes was successfully obtained
if video_bytes is None:
raise ValueError(f"Could not load video data from the provided source: {video_src}")
# Return the bytes wrapped in BytesIO
return io.BytesIO(video_bytes)
# Catch specific exceptions first for better error reporting
except FileNotFoundError as e: # Should be caught above, but as a safeguard
raise ValueError(f"Video loading error: File not found '{video_src}'") from e
except requests.exceptions.RequestException as e: # Already handled, but for clarity
raise ValueError(f"Video loading error (Network): {e}") from e
except (ValueError, TypeError) as e: # Re-raise ValueErrors/TypeErrors raised intentionally within the try block
raise e
except Exception as e:
# Catch any other unexpected errors during processing
raise ValueError(f"Unexpected error during video loading from source '{video_src}': {e}") from e
def video_decoder(video_bytesio, max_num_grids, max_image_cnt, default_interval=0.4):
"""
Decodes video data from a BytesIO object and returns a list of extracted frames.
Args:
video_bytesio (io.BytesIO): A BytesIO object containing the raw video data.
max_num_grids (int): Maximum number of grids allowed per image. Used to determine how many frames to extract.
max_image_cnt (int): Maximum number of frames to extract from the video.
default_interval (float, optional): Default time interval (in seconds) between frames. Used when frame rate info is unavailable. TODO: make configurable.
Returns:
Tuple:
frames (List[PIL.Image.Image]): A list of extracted frames as PIL Images.
time_interval (float): Time interval (in seconds) between selected frames.
"""
error_messages = []
frames = []
# 1. Try decoding the video using Decord.
try:
vr = VideoReader(video_bytesio, ctx=cpu(0), num_threads=8)
fps = vr.get_avg_fps()
play_time = len(vr) / fps
total_frames = len(vr)
frame_indices, time_interval = extract_frame_indices(
play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval
) # Sample every 0.4 seconds; if the video is too long, apply uniform sampling instead.
if frame_indices is None:
frame_indices = range(len(vr)) # Convert all frames.
batch_frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(frame).convert("RGB") for frame in batch_frames]
return frames, time_interval
except Exception as e:
print("error with decord")
error_messages.append(f"Decord 실패: {e}")
# 2. Fallback: Try decoding the video using PyAV.
try:
container = av.open(video_bytesio)
fps = container.streams.video[0].average_rate
play_time = len(container) / fps
total_frames = len(container)
frame_indices, time_interval = extract_frame_indices(
play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval
) # Sample frames every 0.4 seconds. If the video is long, use uniform sampling to limit the number of frames.
# Even if frame_indices were assigned using Decord, reprocess them to be compatible with PyAV.
target_indices = None if frame_indices is None else set(frame_indices)
frames = []
for i, frame in enumerate(container.decode(video=0)):
if target_indices is not None and i not in target_indices:
continue # Skip frames that are not in the required indices.
pil_frame = Image.fromarray(frame.to_ndarray(format="rgb24")).convert("RGB")
frames.append(pil_frame)
if frames:
return frames, time_interval
else:
raise Exception("Decoding with PyAV succeeded, but no frames were extracted.")
except Exception as e:
error_messages.append(f"PyAV failed: {e}")
# 3. Fallback: Try decoding the video using OpenCV.
try:
byte_data = np.frombuffer(video_bytesio.getvalue(), dtype=np.uint8)
video = cv2.imdecode(byte_data, cv2.IMREAD_UNCHANGED)
cap = cv2.VideoCapture(video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
play_time = total_frames / fps
frame_indices, time_interval = extract_frame_indices(
play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval
) # Sample frames every 0.4 seconds; if the video is too long, apply uniform sampling to limit the total number of frames.
if frame_indices is None:
frame_indices = range(total_frames) # Convert all frames.
index_set = set(frame_indices) # Convert to a set for faster lookup.
current_index = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if current_index in index_set:
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).convert("RGB"))
current_index += 1
if current_index > max(index_set): # Stop processing once all required indices have been handled.
break
cap.release()
if frames:
return frames, time_interval
except Exception as e:
error_messages.append(f"OpenCV failed: {e}")
if error_messages:
raise Exception(f"All decoding attempts have failed.: {error_messages}")
def convert_format_for_multi_image(img, json, convert_key_list=["words", "text", "objects", "entities"]):
"""
Converts the format of image and annotation data from a single-image dataset to a multi-image dataset format.
Single-image datasets typically return a single image and its associated annotation as individual objects.
This function wraps them in a dictionary format used by multi-image datasets.
Args:
img: The input image (e.g., a PIL Image or NumPy array).
json: The annotation data associated with the image.
convert_key_list (List[str], optional): A list of keys to extract and convert from the original JSON.
Defaults to ["words", "text", "objects", "entities"].
Returns:
Tuple[Dict, Dict]:
- A dictionary mapping image IDs to images (e.g., {"image_0": img}).
- A dictionary mapping image IDs to corresponding annotation JSONs (with filtered keys).
"""
is_multi_image_dataset = isinstance(img, dict)
if not is_multi_image_dataset:
img = {"00": img}
for convert_key in convert_key_list:
if convert_key in json:
json[convert_key] = {"00": json[convert_key]}
for json_key in json:
if "region" in json_key:
json[json_key] = {"00": json[json_key]}
return is_multi_image_dataset, img, json
def convert_tags_for_video(img, json):
"""
Converts <video_00> tags to <image_xx> tags based on the number of video frames.
In video datasets, annotations often use a generic <video_00> tag. This function replaces that tag
with frame-specific tags such as <image_00>, <image_01>, ..., <image_NN> based on the number of frames in `img`.
Args:
img: A list of video frames (e.g., list of PIL Images or NumPy arrays).
json: The annotation data containing <video_00> tags to be replaced.
Returns:
Dict: The updated annotation JSON with frame-specific <image_xx> tags.
"""
image_tag = "".join([f"<image_{idx:02d}>" for idx in range(len(img))])
# image_tag = "<image_00>" # Use this format to construct and insert image-specific tags.
for json_key in json:
if "qa_pairs" in json_key:
new_qa_pairs = []
for qa_pair in json[json_key]:
question = qa_pair[0]
# Replace <video_00> tags with corresponding <image_xx> tags.
question = question.replace("<video_00>", image_tag)
new_qa_pairs.append([question, qa_pair[1]])
json[json_key] = new_qa_pairs
return img, json
def split_list(input_list, split_value):
"""
Splits a list into sublists using a specified delimiter value.
Each time `split_value` is encountered in `input_list`, a new sublist is started.
The delimiter itself is not included in the output.
Args:
input_list (List[Any]): The input list to split.
split_value (Any): The value used as the delimiter for splitting.
Returns:
List[List[Any]]: A list of sublists, split by the specified delimiter.
Example:
>>> split_list(["a", "b", "|", "c", "d", "|", "e"], "|")
[['a', 'b'], ['c', 'd'], ['e']]
"""
temp_list = []
result = []
for value in input_list:
if value == split_value:
result.append(temp_list)
temp_list = []
else:
temp_list.append(value)
result.append(temp_list)
return result
def combine_frames_into_images(frames, time_interval, max_grid_shape=(3, 3), vit_input_size=378):
"""
Combines a sequence of video frames into grid-based images and generates corresponding time range labels.
Frames are grouped and arranged into a grid (e.g., 3x3) such that each combined image contains up to
`max_grid_shape[0] * max_grid_shape[1]` frames. Each combined image is resized to the given ViT input size.
Args:
frames (List[PIL.Image.Image]): A list of frames extracted from a video.
time_interval (float): Time interval (in seconds) between consecutive frames.
max_grid_shape (Tuple[int, int], optional): The maximum grid shape as (rows, cols). Defaults to (3, 3).
vit_input_size (int, optional): The target size (height and width) for the Vision Transformer input. Defaults to 378.
Returns:
Tuple:
image_list (List[PIL.Image.Image]): A list of grid-combined images.
image_time_stamps (List[str]): A list of time span labels for each combined image,
e.g., ["0.00s~1.50s", "1.50s~3.00s", ...].
"""
# grid_size = int(np.sqrt(max_num_grids))
# assert grid_size**2 == max_num_grids, "max_num_grids must be a perfect square."
max_num_grids = max_grid_shape[0] * max_grid_shape[1]
assert (
max_grid_shape[1] == 1
), f"For video processing, decided to concatenate frames horizontally into a wide image."
# List to store the resulting combined images.
image_list = []
# Calculate the number of canvases needed.
num_frames = len(frames)
num_canvases = num_frames // max_num_grids
leftover_frames = num_frames % max_num_grids
time_stamp = 0 # second
image_time_stamps = []
for canvas_idx in range(num_canvases):
# Initialize the current canvas.
combined_image = Image.new(
"RGB", (vit_input_size * max_grid_shape[0], vit_input_size * max_grid_shape[1]), color=(0, 0, 0)
)
# Determine the frames to fill in the current canvas.
start_idx = canvas_idx * max_num_grids
end_idx = min(start_idx + max_num_grids, num_frames)
for idx in range(start_idx, end_idx):
img = frames[idx]
# Resize each frame to a square shape.
img_resized = img.resize((vit_input_size, vit_input_size))
# Calculate the (row, column) position to place the frame within the grid layout.
local_idx = idx - start_idx
x_offset = (local_idx % max_grid_shape[0]) * vit_input_size
y_offset = (local_idx // max_grid_shape[0]) * vit_input_size
# Calculate the position to place the frame in the grid.
combined_image.paste(img_resized, (x_offset, y_offset))
# Append the current canvas to the result list.
image_list.append(combined_image)
frame_cnt = end_idx - start_idx
image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s")
time_stamp += frame_cnt * time_interval
if leftover_frames > 0:
# canvas_idx might be undefined; default to 0 if not previously assigned to avoid "referenced before assignment" error.
canvas_idx = num_canvases
# Add the remaining frames to the final canvas.
combined_image = Image.new("RGB", (vit_input_size * leftover_frames, vit_input_size * 1), color=(0, 0, 0))
for idx in range(leftover_frames):
img = frames[num_canvases * max_num_grids + idx]
# Resize the frame to a square (equal width and height).
img_resized = img.resize((vit_input_size, vit_input_size))
# Calculate the (row, column) position to place the frame within the grid layout.
x_offset = (idx % leftover_frames) * vit_input_size
y_offset = (idx // leftover_frames) * vit_input_size
# Calculate the position to place the frame within the grid layout.
combined_image.paste(img_resized, (x_offset, y_offset))
# Add the current canvas to the list of combined images.
image_list.append(combined_image)
frame_cnt = leftover_frames
image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s")
time_stamp += frame_cnt * time_interval
return image_list, image_time_stamps
def extract_frame_indices(play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=0.4):
"""
Extracts specific frame indices from a video based on duration, frame count, and sampling strategy.
The function determines which frames to extract given the video duration (`play_time`),
total frame count, and frame rate. It samples frames at regular intervals (default: 0.4s),
but if the number of frames exceeds the limit defined by `max_num_grids * max_image_cnt`,
it performs uniform sampling to stay within that limit.
Args:
play_time (float): Total play time of the video in seconds.
total_frames (int): Total number of frames in the video.
fps (float): Frames per second of the video.
max_num_grids (int): Maximum number of grids to display.
max_image_cnt (int): Maximum number of images per grid.
default_interval (float, optional): Interval in seconds between frame samples. Defaults to 0.4.
Returns:
Tuple:
frame_indices (List[int]): A list of selected frame indices.
time_interval (float): Time interval between selected frames (in seconds).
"""
# Calculate how many frames to extract with the default interval
default_frame_count = int(play_time / default_interval)
# Maximum frames allowed based on max_num_grids and max_image_cnt
max_frames_allowed = max_num_grids * max_image_cnt
# Determine whether we can use the default interval or need uniform sampling
if default_frame_count <= max_frames_allowed:
# Default interval is sufficient, extract frames every 0.4 seconds
frame_interval = int(total_frames / default_frame_count)
else:
# Use uniform sampling to fit within max_frames_allowed
frame_interval = int(total_frames / max_frames_allowed)
# Extract frame indices at the calculated interval
selected_indices = list(range(0, total_frames, frame_interval))
time_interval = frame_interval / fps
# Ensure the number of selected indices does not exceed max_frames_allowed
return selected_indices[:max_frames_allowed], time_interval