File size: 4,369 Bytes
a9d81c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Utility functions for the DIE demo.
"""


import torch
from PIL import Image
from torch import Tensor
from torchvision import transforms


def resize_image(
    image: Image.Image,
    max_size: int = 1024
) -> Image.Image:
    """
    Resizing images by keeping the ratios
    :param image: PIL image
    :param max_size: size of the new image larger side
    :return: the resized PIL image
    """

    # extracting size
    width, height = image.size

    # checking which side is larger
    height_larger = True if height >= width else False

    # reshaping based on the larger side
    if height_larger:
        height_new = max_size
        width_new = round((height_new / height) * width)
    else:
        width_new = max_size
        height_new = round((width_new / width) * height)

    return image.resize((width_new, height_new))


def make_image_square(
    image: Image.Image,
    image_size: int = 1024
) -> Image.Image:
    """
    Making the input image a square
    :param image: PIL image
    :param image_size: defines the size of the square image
    :return: the square-sized PIL image
    """

    if max(image.size) > image_size:
        image_size = max(image.size)
    # creating a new square image
    if image.mode == 'L':
        image_square = Image.new(image.mode, (image_size, image_size), (255,))
    elif image.mode == 'RGB':
        image_square = Image.new(image.mode, (image_size, image_size), (255, 255, 255))
    else:
        raise NotImplementedError("Not implemented image mode.")
    # copying the original content onto the blank image
    image_square.paste(image, (0, 0))

    return image_square


def cast_pil_image_to_torch_tensor_with_4_channel_dim(
    image: Image.Image,
    device: str | None = None
) -> Tensor:
    """
    Casting PIL image to torch tensor.
    Adding the grayscale image of the original RGB image as a 4th channel dimension.
    :param image: input image
    :param device: cuda device
    :return: torch tensor (4 channel dim)
    """

    # PIL image to torch tensor transformation
    transform = transforms.Compose([transforms.PILToTensor()])

    # creating gray image
    image_gray = image.convert('L')

    # casting PIL images to torch tensor with normalization
    image_tensor = transform(image.convert('RGB')).to(torch.float32) / 255.0
    image_gray_tensor = transform(image_gray).to(torch.float32) / 255.0

    # concatenating gray channel to RGB channel
    final_image_tensor = torch.cat((image_tensor, image_gray_tensor), dim=0)

    # moving tensor to gpu if required
    if device is not None:
        final_image_tensor = final_image_tensor.to(device)

    return final_image_tensor


def remove_square_padding(
    original_image: Image.Image | Tensor,
    square_image: Image.Image | Tensor,
    resize_back_to_original: bool = False
):
    """
    Removing the square padding added to the original image to make square.
    :param original_image: the image with the original size
    :param square_image: the image with the square size
    :param resize_back_to_original: defines if we want to resize the square image back to the original size
    :return: square image with the original size ratio
    """

    if isinstance(original_image, Image.Image):
        original_width, original_height = original_image.size
    else:
        original_height, original_width = original_image.shape[:2]

    if isinstance(square_image, Image.Image):
        square_width, square_height = square_image.size
    else:
        square_height, square_width = square_image.shape[:2]

    if original_width > original_height:
        ratio = square_width / original_width
        new_width = square_width
        new_height = int(ratio * original_height)
    else:
        ratio = square_height / original_height
        new_height = square_height
        new_width = int(ratio * original_width)

    # cutting size of the square image to the original ratio
    if isinstance(square_image, Image.Image):
        square_image_with_original_ratio = square_image.crop((0, 0, new_width, new_height))
    else:
        square_image_with_original_ratio = square_image[:new_height, :new_width]

    if resize_back_to_original:
        square_image_with_original_ratio = square_image_with_original_ratio.resize((original_width, original_height))

    return square_image_with_original_ratio