""" Utility functions for the DIE demo. """ import torch from PIL import Image from torch import Tensor from torchvision import transforms def resize_image( image: Image.Image, max_size: int = 1024 ) -> Image.Image: """ Resizing images by keeping the ratios :param image: PIL image :param max_size: size of the new image larger side :return: the resized PIL image """ # extracting size width, height = image.size # checking which side is larger height_larger = True if height >= width else False # reshaping based on the larger side if height_larger: height_new = max_size width_new = round((height_new / height) * width) else: width_new = max_size height_new = round((width_new / width) * height) return image.resize((width_new, height_new)) def make_image_square( image: Image.Image, image_size: int = 1024 ) -> Image.Image: """ Making the input image a square :param image: PIL image :param image_size: defines the size of the square image :return: the square-sized PIL image """ if max(image.size) > image_size: image_size = max(image.size) # creating a new square image if image.mode == 'L': image_square = Image.new(image.mode, (image_size, image_size), (255,)) elif image.mode == 'RGB': image_square = Image.new(image.mode, (image_size, image_size), (255, 255, 255)) else: raise NotImplementedError("Not implemented image mode.") # copying the original content onto the blank image image_square.paste(image, (0, 0)) return image_square def cast_pil_image_to_torch_tensor_with_4_channel_dim( image: Image.Image, device: str | None = None ) -> Tensor: """ Casting PIL image to torch tensor. Adding the grayscale image of the original RGB image as a 4th channel dimension. :param image: input image :param device: cuda device :return: torch tensor (4 channel dim) """ # PIL image to torch tensor transformation transform = transforms.Compose([transforms.PILToTensor()]) # creating gray image image_gray = image.convert('L') # casting PIL images to torch tensor with normalization image_tensor = transform(image.convert('RGB')).to(torch.float32) / 255.0 image_gray_tensor = transform(image_gray).to(torch.float32) / 255.0 # concatenating gray channel to RGB channel final_image_tensor = torch.cat((image_tensor, image_gray_tensor), dim=0) # moving tensor to gpu if required if device is not None: final_image_tensor = final_image_tensor.to(device) return final_image_tensor def remove_square_padding( original_image: Image.Image | Tensor, square_image: Image.Image | Tensor, resize_back_to_original: bool = False ): """ Removing the square padding added to the original image to make square. :param original_image: the image with the original size :param square_image: the image with the square size :param resize_back_to_original: defines if we want to resize the square image back to the original size :return: square image with the original size ratio """ if isinstance(original_image, Image.Image): original_width, original_height = original_image.size else: original_height, original_width = original_image.shape[:2] if isinstance(square_image, Image.Image): square_width, square_height = square_image.size else: square_height, square_width = square_image.shape[:2] if original_width > original_height: ratio = square_width / original_width new_width = square_width new_height = int(ratio * original_height) else: ratio = square_height / original_height new_height = square_height new_width = int(ratio * original_width) # cutting size of the square image to the original ratio if isinstance(square_image, Image.Image): square_image_with_original_ratio = square_image.crop((0, 0, new_width, new_height)) else: square_image_with_original_ratio = square_image[:new_height, :new_width] if resize_back_to_original: square_image_with_original_ratio = square_image_with_original_ratio.resize((original_width, original_height)) return square_image_with_original_ratio