Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Sequence | |
| import torch | |
| from torchvision import transforms | |
| class GaussianBlur(transforms.RandomApply): | |
| """ | |
| Apply Gaussian Blur to the PIL image. | |
| """ | |
| def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): | |
| # NOTE: torchvision is applying 1 - probability to return the original image | |
| keep_p = 1 - p | |
| transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) | |
| super().__init__(transforms=[transform], p=keep_p) | |
| class MaybeToTensor(transforms.ToTensor): | |
| """ | |
| Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. | |
| """ | |
| def __call__(self, pic): | |
| """ | |
| Args: | |
| pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. | |
| Returns: | |
| Tensor: Converted image. | |
| """ | |
| if isinstance(pic, torch.Tensor): | |
| return pic | |
| return super().__call__(pic) | |
| # Use timm's names | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| def make_normalize_transform( | |
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, | |
| std: Sequence[float] = IMAGENET_DEFAULT_STD, | |
| ) -> transforms.Normalize: | |
| return transforms.Normalize(mean=mean, std=std) | |
| # This roughly matches torchvision's preset for classification training: | |
| # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 | |
| def make_classification_train_transform( | |
| *, | |
| crop_size: int = 224, | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| hflip_prob: float = 0.5, | |
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, | |
| std: Sequence[float] = IMAGENET_DEFAULT_STD, | |
| ): | |
| transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] | |
| if hflip_prob > 0.0: | |
| transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) | |
| transforms_list.extend( | |
| [ | |
| MaybeToTensor(), | |
| make_normalize_transform(mean=mean, std=std), | |
| ] | |
| ) | |
| return transforms.Compose(transforms_list) | |
| # This matches (roughly) torchvision's preset for classification evaluation: | |
| # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 | |
| def make_classification_eval_transform( | |
| *, | |
| resize_size: int = 256, | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| crop_size: int = 224, | |
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, | |
| std: Sequence[float] = IMAGENET_DEFAULT_STD, | |
| ) -> transforms.Compose: | |
| transforms_list = [ | |
| transforms.Resize(resize_size, interpolation=interpolation), | |
| transforms.CenterCrop(crop_size), | |
| MaybeToTensor(), | |
| make_normalize_transform(mean=mean, std=std), | |
| ] | |
| return transforms.Compose(transforms_list) | |