File size: 10,434 Bytes
e55fddf
 
 
 
 
0227876
 
 
 
e55fddf
 
 
 
 
 
 
 
 
2ca44ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e55fddf
 
 
88393ab
0227876
 
2ca44ee
 
88393ab
2ca44ee
 
 
 
 
0227876
2ca44ee
0227876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ca44ee
0227876
 
2ca44ee
0227876
 
 
2ca44ee
0227876
 
2ca44ee
0227876
 
 
 
 
 
 
 
 
 
 
2ca44ee
0227876
 
2ca44ee
0227876
 
 
 
 
 
 
 
 
2ca44ee
0227876
 
 
 
2ca44ee
0227876
2ca44ee
0227876
 
2ca44ee
0227876
 
 
2ca44ee
0227876
 
 
2ca44ee
0227876
2ca44ee
0227876
 
 
2ca44ee
0227876
 
 
 
 
2ca44ee
0227876
 
 
 
 
 
 
 
 
 
 
2ca44ee
0227876
 
 
 
 
2ca44ee
0227876
 
 
 
 
2ca44ee
0227876
 
e55fddf
 
 
 
 
 
 
0227876
 
 
 
 
 
2ca44ee
0227876
 
8268f49
 
 
 
 
0227876
2ca44ee
0227876
 
 
 
 
 
 
 
877e510
e55fddf
0227876
e55fddf
0227876
2ca44ee
0227876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ca44ee
e55fddf
 
 
 
609d64d
2ca44ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
from typing import  Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
import numpy as np
import os
import torch
import torchvision.transforms as T
from transformers import AutoProcessor, AutoModelForVision2Seq
import cv2

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
    raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16

colors = [
    (0, 255, 0),
    (0, 0, 255),
    (255, 255, 0),
    (255, 0, 255),
    (0, 255, 255),
    (114, 128, 250),
    (0, 165, 255),
    (0, 128, 0),
    (144, 238, 144),
    (238, 238, 175),
    (255, 191, 0),
    (0, 128, 0),
    (226, 43, 138),
    (255, 0, 255),
    (0, 215, 255),
    (255, 0, 0),
]

color_map = {
    f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for color_id, color in enumerate(colors)
}


class EndpointHandler():
    def __init__(self, path=""):
        print("Downloading Model!")
        self.ckpt_id = "ydshieh/kosmos-2-patch14-224"

        self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt_id, trust_remote_code=True).to("cuda")
        self.processor = AutoProcessor.from_pretrained(self.ckpt_id, trust_remote_code=True)
        print("Downloaded Model!")

    def is_overlapping(self, rect1, rect2):
        x1, y1, x2, y2 = rect1
        x3, y3, x4, y4 = rect2
        return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)

    def draw_entity_boxes_on_image(self, image, entities, show=False, save_path=None, entity_index=-1):
        """_summary_
        Args:
            image (_type_): image or image path
            collect_entity_location (_type_): _description_
        """
        if isinstance(image, Image.Image):
            image_h = image.height
            image_w = image.width
            image = np.array(image)[:, :, [2, 1, 0]]
        elif isinstance(image, str):
            if os.path.exists(image):
                pil_img = Image.open(image).convert("RGB")
                image = np.array(pil_img)[:, :, [2, 1, 0]]
                image_h = pil_img.height
                image_w = pil_img.width
            else:
                raise ValueError(f"invaild image path, {image}")
        elif isinstance(image, torch.Tensor):
            # pdb.set_trace()
            image_tensor = image.cpu()
            reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
            reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
            image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
            pil_img = T.ToPILImage()(image_tensor)
            image_h = pil_img.height
            image_w = pil_img.width
            image = np.array(pil_img)[:, :, [2, 1, 0]]
        else:
            raise ValueError(f"invaild image format, {type(image)} for {image}")

        if len(entities) == 0:
            return image

        indices = list(range(len(entities)))
        if entity_index >= 0:
            indices = [entity_index]

        # Not to show too many bboxes
        entities = entities[:len(color_map)]

        new_image = image.copy()
        previous_bboxes = []
        # size of text
        text_size = 1
        # thickness of text
        text_line = 1  # int(max(1 * min(image_h, image_w) / 512, 1))
        box_line = 3
        (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
        base_height = int(text_height * 0.675)
        text_offset_original = text_height - base_height
        text_spaces = 3

        # num_bboxes = sum(len(x[-1]) for x in entities)
        used_colors = colors  # random.sample(colors, k=num_bboxes)

        color_id = -1
        for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities):
            color_id += 1
            if entity_idx not in indices:
                continue
            for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
                # if start is None and bbox_id > 0:
                #     color_id += 1
                orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)

                # draw bbox
                # random color
                color = used_colors[color_id]  # tuple(np.random.randint(0, 255, size=3).tolist())
                new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)

                l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1

                x1 = orig_x1 - l_o
                y1 = orig_y1 - l_o

                if y1 < text_height + text_offset_original + 2 * text_spaces:
                    y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
                    x1 = orig_x1 + r_o

                # add text background
                (text_width, text_height), _ = cv2.getTextSize(f"  {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
                text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1

                for prev_bbox in previous_bboxes:
                    while self.is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox):
                        text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
                        text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
                        y1 += (text_height + text_offset_original + 2 * text_spaces)

                        if text_bg_y2 >= image_h:
                            text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
                            text_bg_y2 = image_h
                            y1 = image_h
                            break

                alpha = 0.5
                for i in range(text_bg_y1, text_bg_y2):
                    for j in range(text_bg_x1, text_bg_x2):
                        if i < image_h and j < image_w:
                            if j < text_bg_x1 + 1.35 * c_width:
                                # original color
                                bg_color = color
                            else:
                                # white
                                bg_color = [255, 255, 255]
                            new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8)

                cv2.putText(
                    new_image, f"  {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
                )
                # previous_locations.append((x1, y1))
                previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))

        pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
        if save_path:
            pil_image.save(save_path)
        if show:
            pil_image.show()

        return pil_image


    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        :param data: A dictionary contains `inputs` and optional `image` field.
        :return: A dictionary with `image` field contains image in base64.
        """
        image = data.pop("image", None)
        image_input = self.decode_base64_image(image)

        # Save the image and load it again to match the original Kosmos-2 demo.
        # (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
        user_image_path = "/tmp/user_input_test_image.jpg"
        image_input.save(user_image_path)

        # This might give different results from the original argument `image_input`
        image_input = Image.open(user_image_path)
        text_input = data.pop("inputs", None)
        if not text_input:
            text_input = "<grounding>Describe this image in detail:"
        else:
            text_input = f"<grounding>{text_input}"

        inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")

        generated_ids = self.model.generate(
            pixel_values=inputs["pixel_values"].to("cuda"),
            input_ids=inputs["input_ids"][:, :-1].to("cuda"),
            attention_mask=inputs["attention_mask"][:, :-1].to("cuda"),
            img_features=None,
            img_attn_mask=inputs["img_attn_mask"][:, :-1].to("cuda"),
            use_cache=True,
            max_new_tokens=512,
        )
        generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # By default, the generated  text is cleanup and the entities are extracted.
        processed_text, entities = self.processor.post_process_generation(generated_text)

        annotated_image = self.draw_entity_boxes_on_image(image_input, entities, show=False)

        color_id = -1
        entity_info = []
        filtered_entities = []
        for entity in entities:
            entity_name, (start, end), bboxes = entity
            if start == end:
                # skip bounding bbox without a `phrase` associated
                continue
            color_id += 1
            # for bbox_id, _ in enumerate(bboxes):
                # if start is None and bbox_id > 0:
                #     color_id += 1
            entity_info.append(((start, end), color_id))
            filtered_entities.append(entity)

        colored_text = []
        prev_start = 0
        end = 0
        for idx, ((start, end), color_id) in enumerate(entity_info):
            if start > prev_start:
                colored_text.append((processed_text[prev_start:start], None))
            colored_text.append((processed_text[start:end], f"{color_id}"))
            prev_start = end

        if end < len(processed_text):
            colored_text.append((processed_text[end:len(processed_text)], None))

        return annotated_image, colored_text, str(filtered_entities)

    # helper to decode input image
    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer).convert("RGB")
        return image