import random from typing import Optional, Tuple import numpy as np import torch import torchvision.transforms as T from PIL import Image, ImageDraw, ImageFilter from torch.utils.data import Dataset def generate_image( size: int = 32, contrast: Tuple[int, int] = (90, 110), blur_radius: Tuple[float, float] = (0.5, 1.5), shape: Optional[str] = None, max_background_intensity: int = 128, min_shape_intensity: Optional[int] = None, shape_size: Optional[int] = None, location: str = 'random', random_intensity: bool = False ) -> Tuple[Image.Image, str]: """ Generate an image with a shape (circle or square) on a background. :param size: size of the image :param contrast: contrast of the shape :param blur_radius: radius of the Gaussian blur :param shape: shape type (circle or square) :param max_background_intensity: maximum intensity of the background :param min_shape_intensity: minimum intensity of the shape :param shape_size: size of the shape :param location: location of the shape ('random' or 'center') :param random_intensity: whether to randomly invert the shape intensity """ background_intensity = random.randint(0, max_background_intensity) background = Image.new('L', (size, size), background_intensity) if shape: assert shape in ['circle', 'square'], "Wrong shape type" else: shape = random.choice(['circle', 'square']) if not min_shape_intensity: random_contrast = random.randint(*contrast) min_shape_intensity = min(background_intensity + random_contrast, 255) shape_intensity = random.randint(min_shape_intensity, 255) mask = Image.new('L', (size, size), 0) draw = ImageDraw.Draw(mask) if not shape_size: min_size = 8 max_size = size // 2 shape_size = random.randint(min_size, max_size) if location == 'random': max_pos = size - shape_size - 1 top_left_x = random.randint(0, max_pos) top_left_y = random.randint(0, max_pos) else: top_left_x = (size - shape_size) // 2 top_left_y = (size - shape_size) // 2 if shape == 'square': draw.rectangle([top_left_x, top_left_y, top_left_x + shape_size, top_left_y + shape_size], fill=255) else: draw.ellipse([top_left_x, top_left_y, top_left_x + shape_size, top_left_y + shape_size], fill=255) if blur_radius: random_blur_radius = random.uniform(*blur_radius) mask = mask.filter(ImageFilter.GaussianBlur(radius=random_blur_radius)) else: mask = mask.filter(ImageFilter.SMOOTH) shape_img = Image.new('L', (size, size), shape_intensity) img = Image.composite(shape_img, background, mask) if random_intensity and random.random() < 0.5: img = Image.eval(img, lambda x: 255 - x) return img, shape class RandomPairDataset(Dataset): def __init__( self, shape_params: Optional[dict] = None, num_samples: int = 1000, train: bool = True, fixed_test_data: Optional[list] = None ): """ Dataset for training a model to compare two images. :param shape_params: parameters for generate_image function :param num_samples: number of samples in the dataset :param train: whether to generate training or test data :param fixed_test_data: fixed test data (optional) """ self.train = train self.num_samples = num_samples self.transform = T.Compose([ T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,)) ]) if not shape_params: self.shape_params = {} else: self.shape_params = shape_params if not self.train: if fixed_test_data is None: self.data = [self._generate_pair() for _ in range(num_samples)] else: self.data = fixed_test_data def __len__(self) -> int: return self.num_samples def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.train: img1, shape1, img2, shape2, label = self._generate_pair() else: img1, shape1, img2, shape2, label = self.data[idx] img1 = self.transform(img1) img2 = self.transform(img2) return img1, img2, torch.tensor(label, dtype=torch.float32) def _generate_pair(self) -> Tuple[Image.Image, str, Image.Image, str, int]: img1, shape1 = generate_image(**self.shape_params) img2, shape2 = generate_image(**self.shape_params) label = 1 if shape1 == shape2 else 0 return img1, shape1, img2, shape2, label class RandomAugmentedDataset(Dataset): def __init__( self, augmentations: T.Compose, shape_params: Optional[dict] = None, num_samples: int = 1000, train: bool = True, fixed_test_data: Optional[list] = None ): """ Dataset for training a model with contrastive learning. :param augmentations: augmentations to apply to the images :param shape_params: parameters for generate_image function :param num_samples: number of samples in the dataset :param train: whether to generate training or test data :param fixed_test_data: fixed test data (optional """ self.train = train self.num_samples = num_samples self.augmentations = augmentations if not shape_params: self.shape_params = {} else: self.shape_params = shape_params if not self.train: if fixed_test_data is None: self.data = [self._generate_single() for _ in range(num_samples)] else: self.data = fixed_test_data def __len__(self) -> int: return self.num_samples def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: if self.train: img, _ = self._generate_single() else: img, _ = self.data[idx] view_1, view_2 = self.augmentations(img), self.augmentations(img) return view_1, view_2 def _generate_single(self) -> Tuple[Image.Image, int]: img, shape = generate_image(**self.shape_params) label = 1 if shape == "circle" else 0 return img, label class AddGaussianNoise(object): def __init__(self, mean: float = 0.0, std: float = 0.05): self.mean = mean self.std = std def __call__(self, tensor: torch.Tensor) -> torch.Tensor: noise = torch.randn(tensor.size()) * self.std + self.mean tensor = tensor + noise return torch.clamp(tensor, 0., 1.) def __repr__(self): return f'{self.__class__.__name__}(mean={self.mean}, std={self.std})' class ColorInversion(object): def __call__(self, image: Image.Image) -> Image.Image: return Image.eval(image, lambda x: 255 - x) def __repr__(self): return f'{self.__class__.__name__}()' def get_byol_transforms() -> T.Compose: """ Get augmentations for training with BYOL. """ augmentations = T.Compose([ T.RandomResizedCrop(size=32, scale=(0.8, 1.0), ratio=(0.9, 1.1)), T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.5), T.RandomRotation(degrees=15), T.ColorJitter(brightness=0.2, contrast=0.2), T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.5), T.RandomApply([ColorInversion()]), T.ToTensor(), T.RandomApply([AddGaussianNoise(mean=0.0, std=0.05)], p=0.5), T.Normalize(mean=(0.5,), std=(0.5,)) ]) return augmentations def tensor_to_image(tensor: torch.Tensor) -> Image.Image: """ Convert a tensor to a PIL image. """ img_norm = tensor.cpu()[0] img_denorm = img_norm * 0.5 + 0.5 arr = (img_denorm.numpy() * 255).astype(np.uint8) pil_img = Image.fromarray(arr, mode='L') return pil_img