Yan commited on
Commit
2ca44ee
·
1 Parent(s): 0227876

added test script and data for local handler testing, fixed syntax error in handler script

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. handler.py +54 -28
  3. test.png +3 -0
  4. test.py +12 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
handler.py CHANGED
@@ -4,12 +4,10 @@ from PIL import Image
4
  from io import BytesIO
5
  import numpy as np
6
  import os
7
- import requests
8
  import torch
9
  import torchvision.transforms as T
10
  from transformers import AutoProcessor, AutoModelForVision2Seq
11
  import cv2
12
- import ast
13
 
14
  # set device
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -18,15 +16,43 @@ if device.type != 'cuda':
18
  # set mixed precision dtype
19
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  class EndpointHandler():
23
  def __init__(self, path=""):
24
  self.ckpt_id = "ydshieh/kosmos-2-patch14-224"
25
 
26
- self.model = AutoModelForVision2Seq.from_pretrained(ckpt_id, trust_remote_code=True).to("cuda")
27
- self.processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
 
 
 
 
 
28
 
29
- def draw_entity_boxes_on_image(image, entities, show=False, save_path=None, entity_index=-1):
30
  """_summary_
31
  Args:
32
  image (_type_): image or image path
@@ -56,17 +82,17 @@ class EndpointHandler():
56
  image = np.array(pil_img)[:, :, [2, 1, 0]]
57
  else:
58
  raise ValueError(f"invaild image format, {type(image)} for {image}")
59
-
60
  if len(entities) == 0:
61
  return image
62
-
63
  indices = list(range(len(entities)))
64
  if entity_index >= 0:
65
  indices = [entity_index]
66
-
67
  # Not to show too many bboxes
68
  entities = entities[:len(color_map)]
69
-
70
  new_image = image.copy()
71
  previous_bboxes = []
72
  # size of text
@@ -78,10 +104,10 @@ class EndpointHandler():
78
  base_height = int(text_height * 0.675)
79
  text_offset_original = text_height - base_height
80
  text_spaces = 3
81
-
82
  # num_bboxes = sum(len(x[-1]) for x in entities)
83
  used_colors = colors # random.sample(colors, k=num_bboxes)
84
-
85
  color_id = -1
86
  for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities):
87
  color_id += 1
@@ -91,37 +117,37 @@ class EndpointHandler():
91
  # if start is None and bbox_id > 0:
92
  # color_id += 1
93
  orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
94
-
95
  # draw bbox
96
  # random color
97
  color = used_colors[color_id] # tuple(np.random.randint(0, 255, size=3).tolist())
98
  new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
99
-
100
  l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
101
-
102
  x1 = orig_x1 - l_o
103
  y1 = orig_y1 - l_o
104
-
105
  if y1 < text_height + text_offset_original + 2 * text_spaces:
106
  y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
107
  x1 = orig_x1 + r_o
108
-
109
  # add text background
110
  (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
111
  text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
112
-
113
  for prev_bbox in previous_bboxes:
114
- while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox):
115
  text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
116
  text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
117
  y1 += (text_height + text_offset_original + 2 * text_spaces)
118
-
119
  if text_bg_y2 >= image_h:
120
  text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
121
  text_bg_y2 = image_h
122
  y1 = image_h
123
  break
124
-
125
  alpha = 0.5
126
  for i in range(text_bg_y1, text_bg_y2):
127
  for j in range(text_bg_x1, text_bg_x2):
@@ -133,19 +159,19 @@ class EndpointHandler():
133
  # white
134
  bg_color = [255, 255, 255]
135
  new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8)
136
-
137
  cv2.putText(
138
  new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
139
  )
140
  # previous_locations.append((x1, y1))
141
  previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
142
-
143
  pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
144
  if save_path:
145
  pil_image.save(save_path)
146
  if show:
147
  pil_image.show()
148
-
149
  return pil_image
150
 
151
 
@@ -161,13 +187,13 @@ class EndpointHandler():
161
  # (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
162
  user_image_path = "/tmp/user_input_test_image.jpg"
163
  image_input.save(user_image_path)
164
-
165
  # This might give different results from the original argument `image_input`
166
  image_input = Image.open(user_image_path)
167
  text_input = "<grounding>Describe this image in detail:"
168
  #text_input = f"<grounding>{text_input}"
169
 
170
- inputs = processor(text=text_input, images=image_input, return_tensors="pt")
171
 
172
  generated_ids = self.model.generate(
173
  pixel_values=inputs["pixel_values"].to("cuda"),
@@ -181,7 +207,7 @@ class EndpointHandler():
181
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
182
 
183
  # By default, the generated text is cleanup and the entities are extracted.
184
- processed_text, entities = processor.post_process_generation(generated_text)
185
 
186
  annotated_image = self.draw_entity_boxes_on_image(image_input, entities, show=False)
187
 
@@ -213,10 +239,10 @@ class EndpointHandler():
213
  colored_text.append((processed_text[end:len(processed_text)], None))
214
 
215
  return annotated_image, colored_text, str(filtered_entities)
216
-
217
  # helper to decode input image
218
  def decode_base64_image(self, image_string):
219
  base64_image = base64.b64decode(image_string)
220
  buffer = BytesIO(base64_image)
221
  image = Image.open(buffer)
222
- return image
 
4
  from io import BytesIO
5
  import numpy as np
6
  import os
 
7
  import torch
8
  import torchvision.transforms as T
9
  from transformers import AutoProcessor, AutoModelForVision2Seq
10
  import cv2
 
11
 
12
  # set device
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
16
  # set mixed precision dtype
17
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
18
 
19
+ colors = [
20
+ (0, 255, 0),
21
+ (0, 0, 255),
22
+ (255, 255, 0),
23
+ (255, 0, 255),
24
+ (0, 255, 255),
25
+ (114, 128, 250),
26
+ (0, 165, 255),
27
+ (0, 128, 0),
28
+ (144, 238, 144),
29
+ (238, 238, 175),
30
+ (255, 191, 0),
31
+ (0, 128, 0),
32
+ (226, 43, 138),
33
+ (255, 0, 255),
34
+ (0, 215, 255),
35
+ (255, 0, 0),
36
+ ]
37
+
38
+ color_map = {
39
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for color_id, color in enumerate(colors)
40
+ }
41
+
42
 
43
  class EndpointHandler():
44
  def __init__(self, path=""):
45
  self.ckpt_id = "ydshieh/kosmos-2-patch14-224"
46
 
47
+ self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt_id, trust_remote_code=True).to("cuda")
48
+ self.processor = AutoProcessor.from_pretrained(self.ckpt_id, trust_remote_code=True)
49
+
50
+ def is_overlapping(self, rect1, rect2):
51
+ x1, y1, x2, y2 = rect1
52
+ x3, y3, x4, y4 = rect2
53
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
54
 
55
+ def draw_entity_boxes_on_image(self, image, entities, show=False, save_path=None, entity_index=-1):
56
  """_summary_
57
  Args:
58
  image (_type_): image or image path
 
82
  image = np.array(pil_img)[:, :, [2, 1, 0]]
83
  else:
84
  raise ValueError(f"invaild image format, {type(image)} for {image}")
85
+
86
  if len(entities) == 0:
87
  return image
88
+
89
  indices = list(range(len(entities)))
90
  if entity_index >= 0:
91
  indices = [entity_index]
92
+
93
  # Not to show too many bboxes
94
  entities = entities[:len(color_map)]
95
+
96
  new_image = image.copy()
97
  previous_bboxes = []
98
  # size of text
 
104
  base_height = int(text_height * 0.675)
105
  text_offset_original = text_height - base_height
106
  text_spaces = 3
107
+
108
  # num_bboxes = sum(len(x[-1]) for x in entities)
109
  used_colors = colors # random.sample(colors, k=num_bboxes)
110
+
111
  color_id = -1
112
  for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities):
113
  color_id += 1
 
117
  # if start is None and bbox_id > 0:
118
  # color_id += 1
119
  orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
120
+
121
  # draw bbox
122
  # random color
123
  color = used_colors[color_id] # tuple(np.random.randint(0, 255, size=3).tolist())
124
  new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
125
+
126
  l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
127
+
128
  x1 = orig_x1 - l_o
129
  y1 = orig_y1 - l_o
130
+
131
  if y1 < text_height + text_offset_original + 2 * text_spaces:
132
  y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
133
  x1 = orig_x1 + r_o
134
+
135
  # add text background
136
  (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
137
  text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
138
+
139
  for prev_bbox in previous_bboxes:
140
+ while self.is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox):
141
  text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
142
  text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
143
  y1 += (text_height + text_offset_original + 2 * text_spaces)
144
+
145
  if text_bg_y2 >= image_h:
146
  text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
147
  text_bg_y2 = image_h
148
  y1 = image_h
149
  break
150
+
151
  alpha = 0.5
152
  for i in range(text_bg_y1, text_bg_y2):
153
  for j in range(text_bg_x1, text_bg_x2):
 
159
  # white
160
  bg_color = [255, 255, 255]
161
  new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8)
162
+
163
  cv2.putText(
164
  new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
165
  )
166
  # previous_locations.append((x1, y1))
167
  previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
168
+
169
  pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
170
  if save_path:
171
  pil_image.save(save_path)
172
  if show:
173
  pil_image.show()
174
+
175
  return pil_image
176
 
177
 
 
187
  # (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
188
  user_image_path = "/tmp/user_input_test_image.jpg"
189
  image_input.save(user_image_path)
190
+
191
  # This might give different results from the original argument `image_input`
192
  image_input = Image.open(user_image_path)
193
  text_input = "<grounding>Describe this image in detail:"
194
  #text_input = f"<grounding>{text_input}"
195
 
196
+ inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")
197
 
198
  generated_ids = self.model.generate(
199
  pixel_values=inputs["pixel_values"].to("cuda"),
 
207
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
208
 
209
  # By default, the generated text is cleanup and the entities are extracted.
210
+ processed_text, entities = self.processor.post_process_generation(generated_text)
211
 
212
  annotated_image = self.draw_entity_boxes_on_image(image_input, entities, show=False)
213
 
 
239
  colored_text.append((processed_text[end:len(processed_text)], None))
240
 
241
  return annotated_image, colored_text, str(filtered_entities)
242
+
243
  # helper to decode input image
244
  def decode_base64_image(self, image_string):
245
  base64_image = base64.b64decode(image_string)
246
  buffer = BytesIO(base64_image)
247
  image = Image.open(buffer)
248
+ return image
test.png ADDED

Git LFS Details

  • SHA256: 91905662c8deed94cf19ca41c71b48ab013eda1a88d4c5d9cdb97cda96b04f54
  • Pointer size: 132 Bytes
  • Size of remote file: 6.62 MB
test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+ from PIL import Image
3
+ import base64
4
+
5
+ # init handler
6
+ my_handler = EndpointHandler(path=".")
7
+
8
+ # prepare sample payload
9
+ image = Image.open("test.png")
10
+ payload = {"image": base64.b64encode(image)}
11
+
12
+ pred=my_handler(payload)