Spaces:
Sleeping
Sleeping
import random | |
from typing import Optional, Tuple | |
import numpy as np | |
import torch | |
import torchvision.transforms as T | |
from PIL import Image, ImageDraw, ImageFilter | |
from torch.utils.data import Dataset | |
def generate_image( | |
size: int = 32, | |
contrast: Tuple[int, int] = (90, 110), | |
blur_radius: Tuple[float, float] = (0.5, 1.5), | |
shape: Optional[str] = None, | |
max_background_intensity: int = 128, | |
min_shape_intensity: Optional[int] = None, | |
shape_size: Optional[int] = None, | |
location: str = 'random', | |
random_intensity: bool = False | |
) -> Tuple[Image.Image, str]: | |
""" | |
Generate an image with a shape (circle or square) on a background. | |
:param size: size of the image | |
:param contrast: contrast of the shape | |
:param blur_radius: radius of the Gaussian blur | |
:param shape: shape type (circle or square) | |
:param max_background_intensity: maximum intensity of the background | |
:param min_shape_intensity: minimum intensity of the shape | |
:param shape_size: size of the shape | |
:param location: location of the shape ('random' or 'center') | |
:param random_intensity: whether to randomly invert the shape intensity | |
""" | |
background_intensity = random.randint(0, max_background_intensity) | |
background = Image.new('L', (size, size), background_intensity) | |
if shape: | |
assert shape in ['circle', 'square'], "Wrong shape type" | |
else: | |
shape = random.choice(['circle', 'square']) | |
if not min_shape_intensity: | |
random_contrast = random.randint(*contrast) | |
min_shape_intensity = min(background_intensity + random_contrast, 255) | |
shape_intensity = random.randint(min_shape_intensity, 255) | |
mask = Image.new('L', (size, size), 0) | |
draw = ImageDraw.Draw(mask) | |
if not shape_size: | |
min_size = 8 | |
max_size = size // 2 | |
shape_size = random.randint(min_size, max_size) | |
if location == 'random': | |
max_pos = size - shape_size - 1 | |
top_left_x = random.randint(0, max_pos) | |
top_left_y = random.randint(0, max_pos) | |
else: | |
top_left_x = (size - shape_size) // 2 | |
top_left_y = (size - shape_size) // 2 | |
if shape == 'square': | |
draw.rectangle([top_left_x, top_left_y, top_left_x + shape_size, top_left_y + shape_size], fill=255) | |
else: | |
draw.ellipse([top_left_x, top_left_y, top_left_x + shape_size, top_left_y + shape_size], fill=255) | |
if blur_radius: | |
random_blur_radius = random.uniform(*blur_radius) | |
mask = mask.filter(ImageFilter.GaussianBlur(radius=random_blur_radius)) | |
else: | |
mask = mask.filter(ImageFilter.SMOOTH) | |
shape_img = Image.new('L', (size, size), shape_intensity) | |
img = Image.composite(shape_img, background, mask) | |
if random_intensity and random.random() < 0.5: | |
img = Image.eval(img, lambda x: 255 - x) | |
return img, shape | |
class RandomPairDataset(Dataset): | |
def __init__( | |
self, | |
shape_params: Optional[dict] = None, | |
num_samples: int = 1000, | |
train: bool = True, | |
fixed_test_data: Optional[list] = None | |
): | |
""" | |
Dataset for training a model to compare two images. | |
:param shape_params: parameters for generate_image function | |
:param num_samples: number of samples in the dataset | |
:param train: whether to generate training or test data | |
:param fixed_test_data: fixed test data (optional) | |
""" | |
self.train = train | |
self.num_samples = num_samples | |
self.transform = T.Compose([ | |
T.ToTensor(), | |
T.Normalize(mean=(0.5,), std=(0.5,)) | |
]) | |
if not shape_params: | |
self.shape_params = {} | |
else: | |
self.shape_params = shape_params | |
if not self.train: | |
if fixed_test_data is None: | |
self.data = [self._generate_pair() for _ in range(num_samples)] | |
else: | |
self.data = fixed_test_data | |
def __len__(self) -> int: | |
return self.num_samples | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
if self.train: | |
img1, shape1, img2, shape2, label = self._generate_pair() | |
else: | |
img1, shape1, img2, shape2, label = self.data[idx] | |
img1 = self.transform(img1) | |
img2 = self.transform(img2) | |
return img1, img2, torch.tensor(label, dtype=torch.float32) | |
def _generate_pair(self) -> Tuple[Image.Image, str, Image.Image, str, int]: | |
img1, shape1 = generate_image(**self.shape_params) | |
img2, shape2 = generate_image(**self.shape_params) | |
label = 1 if shape1 == shape2 else 0 | |
return img1, shape1, img2, shape2, label | |
class RandomAugmentedDataset(Dataset): | |
def __init__( | |
self, | |
augmentations: T.Compose, | |
shape_params: Optional[dict] = None, | |
num_samples: int = 1000, | |
train: bool = True, | |
fixed_test_data: Optional[list] = None | |
): | |
""" | |
Dataset for training a model with contrastive learning. | |
:param augmentations: augmentations to apply to the images | |
:param shape_params: parameters for generate_image function | |
:param num_samples: number of samples in the dataset | |
:param train: whether to generate training or test data | |
:param fixed_test_data: fixed test data (optional | |
""" | |
self.train = train | |
self.num_samples = num_samples | |
self.augmentations = augmentations | |
if not shape_params: | |
self.shape_params = {} | |
else: | |
self.shape_params = shape_params | |
if not self.train: | |
if fixed_test_data is None: | |
self.data = [self._generate_single() for _ in range(num_samples)] | |
else: | |
self.data = fixed_test_data | |
def __len__(self) -> int: | |
return self.num_samples | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
if self.train: | |
img, _ = self._generate_single() | |
else: | |
img, _ = self.data[idx] | |
view_1, view_2 = self.augmentations(img), self.augmentations(img) | |
return view_1, view_2 | |
def _generate_single(self) -> Tuple[Image.Image, int]: | |
img, shape = generate_image(**self.shape_params) | |
label = 1 if shape == "circle" else 0 | |
return img, label | |
class AddGaussianNoise(object): | |
def __init__(self, mean: float = 0.0, std: float = 0.05): | |
self.mean = mean | |
self.std = std | |
def __call__(self, tensor: torch.Tensor) -> torch.Tensor: | |
noise = torch.randn(tensor.size()) * self.std + self.mean | |
tensor = tensor + noise | |
return torch.clamp(tensor, 0., 1.) | |
def __repr__(self): | |
return f'{self.__class__.__name__}(mean={self.mean}, std={self.std})' | |
class ColorInversion(object): | |
def __call__(self, image: Image.Image) -> Image.Image: | |
return Image.eval(image, lambda x: 255 - x) | |
def __repr__(self): | |
return f'{self.__class__.__name__}()' | |
def get_byol_transforms() -> T.Compose: | |
""" | |
Get augmentations for training with BYOL. | |
""" | |
augmentations = T.Compose([ | |
T.RandomResizedCrop(size=32, scale=(0.8, 1.0), ratio=(0.9, 1.1)), | |
T.RandomHorizontalFlip(p=0.5), | |
T.RandomVerticalFlip(p=0.5), | |
T.RandomRotation(degrees=15), | |
T.ColorJitter(brightness=0.2, contrast=0.2), | |
T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.5), | |
T.RandomApply([ColorInversion()]), | |
T.ToTensor(), | |
T.RandomApply([AddGaussianNoise(mean=0.0, std=0.05)], p=0.5), | |
T.Normalize(mean=(0.5,), std=(0.5,)) | |
]) | |
return augmentations | |
def tensor_to_image(tensor: torch.Tensor) -> Image.Image: | |
""" | |
Convert a tensor to a PIL image. | |
""" | |
img_norm = tensor.cpu()[0] | |
img_denorm = img_norm * 0.5 + 0.5 | |
arr = (img_denorm.numpy() * 255).astype(np.uint8) | |
pil_img = Image.fromarray(arr, mode='L') | |
return pil_img | |