|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, List, Mapping, Optional, Tuple |
|
|
|
import torch |
|
|
|
import torchvision |
|
from torchvision.transforms.v2 import functional as F |
|
|
|
from utils import ( |
|
find_supported_resolutions, |
|
get_canvas_best_fit, |
|
resize_with_pad, |
|
tile_crop, |
|
) |
|
|
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
class CLIPTransform: |
|
""" |
|
This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it |
|
based on the image aspect ratio and the number of image tiles we allow. |
|
|
|
The algorithm will NOT distort the image to fit a certain aspect ratio, because |
|
that leads to a significant degradation in image quality. |
|
|
|
The user can choose if they want to allow upscaling by using the flag ``resize_to_max_canvas``. |
|
|
|
For example, if an input image is of size 300x800, and we want to allow |
|
a maximum of 16 image tiles, with side 224px, then: |
|
|
|
If ``resize_to_max_canvas=False``, then: |
|
best_resolution = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling |
|
image is NOT resized |
|
image is padded (300, 800) -> 448,896 |
|
Image is tiled 2x4, for a final output shape of (8, 3, 224, 224) |
|
|
|
If ``resize_to_max_canvas=True``, then: |
|
best_resolution = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles |
|
image is resized without distortion (300,800) -> (448, 1194) #448 is the limiting side for the resize |
|
image is padded (448, 1194) -> (448, 1344) |
|
Image is tiled 2x6, for a final output shape of (10, 3, 224, 224) |
|
|
|
Args: |
|
image_mean (Optional[List[float]]): Mean values of each channel, used for normalization. |
|
Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. |
|
image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization. |
|
Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. |
|
possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width). |
|
where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``. |
|
If None, this will be calculated using max_num_tiles and tile_size. Default None. |
|
tile_size (int): Size of the tiles to divide the image into. Default 224. |
|
max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given. |
|
Maximum number of tiles to break an image into. |
|
This will be used to generate possible_resolutions, |
|
e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. |
|
Default 4. |
|
dtype (torch.dtype): Data type of the output image. Default torch.bfloat16. |
|
resample (str): Resampling method used when resizing images. Supports any enum of |
|
``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". |
|
Default 'bilinear'. |
|
resize_to_max_canvas (bool): "If True, the image will be upscaled without distortion to fit the largest possible |
|
resolution from possible_resolutions. |
|
If False, it will pick the resolution that minimizes downscaling, including no downscaling at all. |
|
In this case, the image will only be upscaled if it's size < tile_size. Default False. |
|
|
|
Examples: |
|
>>> image_transform = CLIPImageTransform( |
|
... image_mean=None, |
|
... image_std=None, |
|
... tile_size=224, |
|
... possible_resolutions=None, |
|
... max_num_tiles=4, |
|
... resample="bilinear", |
|
... resize_to_max_canvas=True, |
|
...) |
|
>>> # create random image |
|
>>> image = (np.random.rand(100,200,3) * 255).astype(np.uint8) |
|
>>> image = PIL.Image.fromarray(image) |
|
>>> output = image_transform(image) |
|
>>> output['image'].shape # [num_tiles, num_channels, tile_size, tile_size] |
|
torch.Size([2, 3, 224, 224]) |
|
>>> output['ar'] # image best fits the canvas 224x448 |
|
torch.tensor([1,2]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
image_mean: Optional[List[float]] = None, |
|
image_std: Optional[List[float]] = None, |
|
possible_resolutions: Optional[List[Tuple[int, int]]] = None, |
|
tile_size: int = 224, |
|
max_num_tiles: Optional[int] = 4, |
|
dtype: torch.dtype = torch.bfloat16, |
|
resample: str = "bilinear", |
|
resize_to_max_canvas: bool = False, |
|
) -> None: |
|
|
|
|
|
assert ( |
|
possible_resolutions is not None or max_num_tiles is not None |
|
), f"Either possible_resolutions or max_num_tiles must be given. Got {possible_resolutions} and {max_num_tiles}" |
|
|
|
|
|
if not possible_resolutions and max_num_tiles: |
|
possible_resolutions = find_supported_resolutions( |
|
max_num_tiles=max_num_tiles, tile_size=tile_size |
|
) |
|
else: |
|
possible_resolutions = possible_resolutions |
|
|
|
self.possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) |
|
logger.debug( |
|
f"Found possible_resolutions: {self.possible_resolutions}. Will fit the images into the canvas with best fit." |
|
) |
|
|
|
self.resize_to_max_canvas = resize_to_max_canvas |
|
|
|
|
|
assert (image_mean is None) == ( |
|
image_std is None |
|
), f"Need to provide both or none of image_mean and image_std. Got {image_mean=} and {image_std=}" |
|
self.mean = image_mean |
|
self.std = image_std |
|
|
|
|
|
self.max_size = None if resize_to_max_canvas else tile_size |
|
self.dtype = dtype |
|
self.resample = torchvision.transforms.InterpolationMode[resample.upper()] |
|
|
|
|
|
self.tile_size = tile_size |
|
|
|
def __call__(self, image: torch.Tensor) -> Mapping[str, Any]: |
|
""" |
|
Apply image decoding and transformations to the "image" field in the sample. |
|
|
|
Args: |
|
sample (Mapping[str, Any]): A sample with an "image" field containing |
|
a List[Message] to tokenize |
|
|
|
Returns: |
|
Mapping[str, Any]: The sample with an updated "image" filed and added |
|
"aspect_ratio" field. |
|
""" |
|
assert isinstance(image, torch.Tensor), "Input image must be a torch.Tensor." |
|
|
|
image = F.to_image(image) |
|
image = F.grayscale_to_rgb_image(image) |
|
image = F.to_dtype(image, dtype=self.dtype, scale=True) |
|
|
|
|
|
best_resolution = get_canvas_best_fit( |
|
image=image, |
|
possible_resolutions=self.possible_resolutions, |
|
resize_to_max_canvas=self.resize_to_max_canvas, |
|
) |
|
|
|
|
|
image = resize_with_pad( |
|
image=image, |
|
target_size=best_resolution, |
|
resample=self.resample, |
|
max_size=self.max_size, |
|
) |
|
|
|
|
|
if self.mean: |
|
image = F.normalize(image, mean=self.mean, std=self.std) |
|
|
|
|
|
image = tile_crop(image=image, tile_size=self.tile_size) |
|
|
|
aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size |
|
|
|
return { |
|
"image": image, |
|
"aspect_ratio": aspect_ratio, |
|
} |
|
|