File size: 3,536 Bytes
bcc0f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os

import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import yaml
from albumentations.pytorch import ToTensorV2
from easydict import EasyDict
from PIL import Image

from src.models.unet.resunet import UNet as Model


class ResUnetInfer:
    def __init__(self, model_path, config_path):
        use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")

        self.config = self.load_config(config_path=config_path)
        self.model = self.load_model(model_path=model_path)

        self.transform = A.Compose(
            [
                A.Resize(self.config.input_size[0], self.config.input_size[1]),
                A.Normalize(
                    mean=self.config.mean,
                    std=self.config.std,
                    max_pixel_value=255,
                ),
                ToTensorV2(),
            ]
        )

    def load_model(self, model_path):
        model = Model(
            decoder_config=self.config.decoder_config, nclasses=self.config.nclasses
        ).to(self.device)

        if os.path.isfile(model_path):
            checkpoint = torch.load(model_path, map_location=self.device)
            model.decoder.load_state_dict(
                checkpoint["decoder_state_dict"], strict=False
            )
            model.output.load_state_dict(checkpoint["output_state_dict"], strict=False)

        return model

    def load_config(self, config_path):
        with open(config_path, "r") as file:
            yaml_data = yaml.safe_load(file)

        return EasyDict(yaml_data)

    def infer(self, image, image_weight=0.01):
        self.model.eval()
        input_tensor = self.transform(image=image)["image"].unsqueeze(0)

        # get mask
        with torch.no_grad():
            """
            output_tensor = [batch, 1, 224, 224]
            batch = 1
            """
            output_tensor = self.model(input_tensor.to(self.device))
        
        mask = torch.sigmoid(output_tensor)
        mask = nn.UpsamplingBilinear2d(size=(image.shape[0], image.shape[1]))(mask)
        mask = mask.squeeze(0)

        # add zeros for green and blue channels
        # our mask will be red in colour
        zero_channels = torch.zeros((2, image.shape[0], image.shape[1]), device=self.device)
        mask = torch.cat([mask, zero_channels], dim=0)
        mask = mask.permute(1,2,0).cpu().numpy()
        mask = np.uint8(255 * mask)
        
        # overlap image and mask
        mask = (1 - image_weight) * mask + image_weight * image
        mask = mask / np.max(mask)
        return np.uint8(255 * mask)
    
    @staticmethod
    def load_image_as_array(image_path):
        # Load a PIL image
        pil_image = Image.open(image_path)

        # Convert PIL image to NumPy array
        return np.array(pil_image.convert("RGB"))

    @staticmethod
    def plot_array(array: np.array, figsize=(10, 10)):
        plt.figure(figsize=figsize)
        plt.imshow(array)
        plt.show()

    @staticmethod
    def save_numpy_as_image(numpy_array, image_path):
        """
        Saves a NumPy array as an image.
        Args:
            numpy_array (numpy.ndarray): The NumPy array to be saved as an image.
            image_path (str): The path where the image will be saved.
        """
        # Convert the NumPy array to a PIL image
        image = Image.fromarray(numpy_array)

        # Save the PIL image to the specified path
        image.save(image_path)