|
""" |
|
Utility functions for the DIE demo. |
|
""" |
|
|
|
|
|
import torch |
|
from PIL import Image |
|
from torch import Tensor |
|
from torchvision import transforms |
|
|
|
|
|
def resize_image( |
|
image: Image.Image, |
|
max_size: int = 1024 |
|
) -> Image.Image: |
|
""" |
|
Resizing images by keeping the ratios |
|
:param image: PIL image |
|
:param max_size: size of the new image larger side |
|
:return: the resized PIL image |
|
""" |
|
|
|
|
|
width, height = image.size |
|
|
|
|
|
height_larger = True if height >= width else False |
|
|
|
|
|
if height_larger: |
|
height_new = max_size |
|
width_new = round((height_new / height) * width) |
|
else: |
|
width_new = max_size |
|
height_new = round((width_new / width) * height) |
|
|
|
return image.resize((width_new, height_new)) |
|
|
|
|
|
def make_image_square( |
|
image: Image.Image, |
|
image_size: int = 1024 |
|
) -> Image.Image: |
|
""" |
|
Making the input image a square |
|
:param image: PIL image |
|
:param image_size: defines the size of the square image |
|
:return: the square-sized PIL image |
|
""" |
|
|
|
if max(image.size) > image_size: |
|
image_size = max(image.size) |
|
|
|
if image.mode == 'L': |
|
image_square = Image.new(image.mode, (image_size, image_size), (255,)) |
|
elif image.mode == 'RGB': |
|
image_square = Image.new(image.mode, (image_size, image_size), (255, 255, 255)) |
|
else: |
|
raise NotImplementedError("Not implemented image mode.") |
|
|
|
image_square.paste(image, (0, 0)) |
|
|
|
return image_square |
|
|
|
|
|
def cast_pil_image_to_torch_tensor_with_4_channel_dim( |
|
image: Image.Image, |
|
device: str | None = None |
|
) -> Tensor: |
|
""" |
|
Casting PIL image to torch tensor. |
|
Adding the grayscale image of the original RGB image as a 4th channel dimension. |
|
:param image: input image |
|
:param device: cuda device |
|
:return: torch tensor (4 channel dim) |
|
""" |
|
|
|
|
|
transform = transforms.Compose([transforms.PILToTensor()]) |
|
|
|
|
|
image_gray = image.convert('L') |
|
|
|
|
|
image_tensor = transform(image.convert('RGB')).to(torch.float32) / 255.0 |
|
image_gray_tensor = transform(image_gray).to(torch.float32) / 255.0 |
|
|
|
|
|
final_image_tensor = torch.cat((image_tensor, image_gray_tensor), dim=0) |
|
|
|
|
|
if device is not None: |
|
final_image_tensor = final_image_tensor.to(device) |
|
|
|
return final_image_tensor |
|
|
|
|
|
def remove_square_padding( |
|
original_image: Image.Image | Tensor, |
|
square_image: Image.Image | Tensor, |
|
resize_back_to_original: bool = False |
|
): |
|
""" |
|
Removing the square padding added to the original image to make square. |
|
:param original_image: the image with the original size |
|
:param square_image: the image with the square size |
|
:param resize_back_to_original: defines if we want to resize the square image back to the original size |
|
:return: square image with the original size ratio |
|
""" |
|
|
|
if isinstance(original_image, Image.Image): |
|
original_width, original_height = original_image.size |
|
else: |
|
original_height, original_width = original_image.shape[:2] |
|
|
|
if isinstance(square_image, Image.Image): |
|
square_width, square_height = square_image.size |
|
else: |
|
square_height, square_width = square_image.shape[:2] |
|
|
|
if original_width > original_height: |
|
ratio = square_width / original_width |
|
new_width = square_width |
|
new_height = int(ratio * original_height) |
|
else: |
|
ratio = square_height / original_height |
|
new_height = square_height |
|
new_width = int(ratio * original_width) |
|
|
|
|
|
if isinstance(square_image, Image.Image): |
|
square_image_with_original_ratio = square_image.crop((0, 0, new_width, new_height)) |
|
else: |
|
square_image_with_original_ratio = square_image[:new_height, :new_width] |
|
|
|
if resize_back_to_original: |
|
square_image_with_original_ratio = square_image_with_original_ratio.resize((original_width, original_height)) |
|
|
|
return square_image_with_original_ratio |
|
|