die_demo / utils.py
gabar92's picture
add implementation scripts
a9d81c5
"""
Utility functions for the DIE demo.
"""
import torch
from PIL import Image
from torch import Tensor
from torchvision import transforms
def resize_image(
image: Image.Image,
max_size: int = 1024
) -> Image.Image:
"""
Resizing images by keeping the ratios
:param image: PIL image
:param max_size: size of the new image larger side
:return: the resized PIL image
"""
# extracting size
width, height = image.size
# checking which side is larger
height_larger = True if height >= width else False
# reshaping based on the larger side
if height_larger:
height_new = max_size
width_new = round((height_new / height) * width)
else:
width_new = max_size
height_new = round((width_new / width) * height)
return image.resize((width_new, height_new))
def make_image_square(
image: Image.Image,
image_size: int = 1024
) -> Image.Image:
"""
Making the input image a square
:param image: PIL image
:param image_size: defines the size of the square image
:return: the square-sized PIL image
"""
if max(image.size) > image_size:
image_size = max(image.size)
# creating a new square image
if image.mode == 'L':
image_square = Image.new(image.mode, (image_size, image_size), (255,))
elif image.mode == 'RGB':
image_square = Image.new(image.mode, (image_size, image_size), (255, 255, 255))
else:
raise NotImplementedError("Not implemented image mode.")
# copying the original content onto the blank image
image_square.paste(image, (0, 0))
return image_square
def cast_pil_image_to_torch_tensor_with_4_channel_dim(
image: Image.Image,
device: str | None = None
) -> Tensor:
"""
Casting PIL image to torch tensor.
Adding the grayscale image of the original RGB image as a 4th channel dimension.
:param image: input image
:param device: cuda device
:return: torch tensor (4 channel dim)
"""
# PIL image to torch tensor transformation
transform = transforms.Compose([transforms.PILToTensor()])
# creating gray image
image_gray = image.convert('L')
# casting PIL images to torch tensor with normalization
image_tensor = transform(image.convert('RGB')).to(torch.float32) / 255.0
image_gray_tensor = transform(image_gray).to(torch.float32) / 255.0
# concatenating gray channel to RGB channel
final_image_tensor = torch.cat((image_tensor, image_gray_tensor), dim=0)
# moving tensor to gpu if required
if device is not None:
final_image_tensor = final_image_tensor.to(device)
return final_image_tensor
def remove_square_padding(
original_image: Image.Image | Tensor,
square_image: Image.Image | Tensor,
resize_back_to_original: bool = False
):
"""
Removing the square padding added to the original image to make square.
:param original_image: the image with the original size
:param square_image: the image with the square size
:param resize_back_to_original: defines if we want to resize the square image back to the original size
:return: square image with the original size ratio
"""
if isinstance(original_image, Image.Image):
original_width, original_height = original_image.size
else:
original_height, original_width = original_image.shape[:2]
if isinstance(square_image, Image.Image):
square_width, square_height = square_image.size
else:
square_height, square_width = square_image.shape[:2]
if original_width > original_height:
ratio = square_width / original_width
new_width = square_width
new_height = int(ratio * original_height)
else:
ratio = square_height / original_height
new_height = square_height
new_width = int(ratio * original_width)
# cutting size of the square image to the original ratio
if isinstance(square_image, Image.Image):
square_image_with_original_ratio = square_image.crop((0, 0, new_width, new_height))
else:
square_image_with_original_ratio = square_image[:new_height, :new_width]
if resize_back_to_original:
square_image_with_original_ratio = square_image_with_original_ratio.resize((original_width, original_height))
return square_image_with_original_ratio