cels / src /dataset.py
alexandraroze's picture
solution
50bd1fc
import random
from typing import Optional, Tuple
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFilter
from torch.utils.data import Dataset
def generate_image(
size: int = 32,
contrast: Tuple[int, int] = (90, 110),
blur_radius: Tuple[float, float] = (0.5, 1.5),
shape: Optional[str] = None,
max_background_intensity: int = 128,
min_shape_intensity: Optional[int] = None,
shape_size: Optional[int] = None,
location: str = 'random',
random_intensity: bool = False
) -> Tuple[Image.Image, str]:
"""
Generate an image with a shape (circle or square) on a background.
:param size: size of the image
:param contrast: contrast of the shape
:param blur_radius: radius of the Gaussian blur
:param shape: shape type (circle or square)
:param max_background_intensity: maximum intensity of the background
:param min_shape_intensity: minimum intensity of the shape
:param shape_size: size of the shape
:param location: location of the shape ('random' or 'center')
:param random_intensity: whether to randomly invert the shape intensity
"""
background_intensity = random.randint(0, max_background_intensity)
background = Image.new('L', (size, size), background_intensity)
if shape:
assert shape in ['circle', 'square'], "Wrong shape type"
else:
shape = random.choice(['circle', 'square'])
if not min_shape_intensity:
random_contrast = random.randint(*contrast)
min_shape_intensity = min(background_intensity + random_contrast, 255)
shape_intensity = random.randint(min_shape_intensity, 255)
mask = Image.new('L', (size, size), 0)
draw = ImageDraw.Draw(mask)
if not shape_size:
min_size = 8
max_size = size // 2
shape_size = random.randint(min_size, max_size)
if location == 'random':
max_pos = size - shape_size - 1
top_left_x = random.randint(0, max_pos)
top_left_y = random.randint(0, max_pos)
else:
top_left_x = (size - shape_size) // 2
top_left_y = (size - shape_size) // 2
if shape == 'square':
draw.rectangle([top_left_x, top_left_y, top_left_x + shape_size, top_left_y + shape_size], fill=255)
else:
draw.ellipse([top_left_x, top_left_y, top_left_x + shape_size, top_left_y + shape_size], fill=255)
if blur_radius:
random_blur_radius = random.uniform(*blur_radius)
mask = mask.filter(ImageFilter.GaussianBlur(radius=random_blur_radius))
else:
mask = mask.filter(ImageFilter.SMOOTH)
shape_img = Image.new('L', (size, size), shape_intensity)
img = Image.composite(shape_img, background, mask)
if random_intensity and random.random() < 0.5:
img = Image.eval(img, lambda x: 255 - x)
return img, shape
class RandomPairDataset(Dataset):
def __init__(
self,
shape_params: Optional[dict] = None,
num_samples: int = 1000,
train: bool = True,
fixed_test_data: Optional[list] = None
):
"""
Dataset for training a model to compare two images.
:param shape_params: parameters for generate_image function
:param num_samples: number of samples in the dataset
:param train: whether to generate training or test data
:param fixed_test_data: fixed test data (optional)
"""
self.train = train
self.num_samples = num_samples
self.transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=(0.5,), std=(0.5,))
])
if not shape_params:
self.shape_params = {}
else:
self.shape_params = shape_params
if not self.train:
if fixed_test_data is None:
self.data = [self._generate_pair() for _ in range(num_samples)]
else:
self.data = fixed_test_data
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.train:
img1, shape1, img2, shape2, label = self._generate_pair()
else:
img1, shape1, img2, shape2, label = self.data[idx]
img1 = self.transform(img1)
img2 = self.transform(img2)
return img1, img2, torch.tensor(label, dtype=torch.float32)
def _generate_pair(self) -> Tuple[Image.Image, str, Image.Image, str, int]:
img1, shape1 = generate_image(**self.shape_params)
img2, shape2 = generate_image(**self.shape_params)
label = 1 if shape1 == shape2 else 0
return img1, shape1, img2, shape2, label
class RandomAugmentedDataset(Dataset):
def __init__(
self,
augmentations: T.Compose,
shape_params: Optional[dict] = None,
num_samples: int = 1000,
train: bool = True,
fixed_test_data: Optional[list] = None
):
"""
Dataset for training a model with contrastive learning.
:param augmentations: augmentations to apply to the images
:param shape_params: parameters for generate_image function
:param num_samples: number of samples in the dataset
:param train: whether to generate training or test data
:param fixed_test_data: fixed test data (optional
"""
self.train = train
self.num_samples = num_samples
self.augmentations = augmentations
if not shape_params:
self.shape_params = {}
else:
self.shape_params = shape_params
if not self.train:
if fixed_test_data is None:
self.data = [self._generate_single() for _ in range(num_samples)]
else:
self.data = fixed_test_data
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
if self.train:
img, _ = self._generate_single()
else:
img, _ = self.data[idx]
view_1, view_2 = self.augmentations(img), self.augmentations(img)
return view_1, view_2
def _generate_single(self) -> Tuple[Image.Image, int]:
img, shape = generate_image(**self.shape_params)
label = 1 if shape == "circle" else 0
return img, label
class AddGaussianNoise(object):
def __init__(self, mean: float = 0.0, std: float = 0.05):
self.mean = mean
self.std = std
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
noise = torch.randn(tensor.size()) * self.std + self.mean
tensor = tensor + noise
return torch.clamp(tensor, 0., 1.)
def __repr__(self):
return f'{self.__class__.__name__}(mean={self.mean}, std={self.std})'
class ColorInversion(object):
def __call__(self, image: Image.Image) -> Image.Image:
return Image.eval(image, lambda x: 255 - x)
def __repr__(self):
return f'{self.__class__.__name__}()'
def get_byol_transforms() -> T.Compose:
"""
Get augmentations for training with BYOL.
"""
augmentations = T.Compose([
T.RandomResizedCrop(size=32, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
T.RandomHorizontalFlip(p=0.5),
T.RandomVerticalFlip(p=0.5),
T.RandomRotation(degrees=15),
T.ColorJitter(brightness=0.2, contrast=0.2),
T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.5),
T.RandomApply([ColorInversion()]),
T.ToTensor(),
T.RandomApply([AddGaussianNoise(mean=0.0, std=0.05)], p=0.5),
T.Normalize(mean=(0.5,), std=(0.5,))
])
return augmentations
def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
"""
Convert a tensor to a PIL image.
"""
img_norm = tensor.cpu()[0]
img_denorm = img_norm * 0.5 + 0.5
arr = (img_denorm.numpy() * 255).astype(np.uint8)
pil_img = Image.fromarray(arr, mode='L')
return pil_img