MykolaL commited on
Commit
44fbca1
·
verified ·
1 Parent(s): 4a991a4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -12
app.py CHANGED
@@ -88,11 +88,8 @@ def create_refseg_demo(model, tokenizer, device):
88
 
89
  def on_submit(image, text):
90
  # Convert PIL -> np array
91
- image_np = np.array(image).copy()
92
  transform = transforms.ToTensor()
93
  image_t = transform(image).unsqueeze(0).to(device)
94
- image_t = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(image_t)
95
- image_t = torch.nn.functional.interpolate(image_t, (512, 512), mode='bilinear', align_corners=True)
96
 
97
  with torch.no_grad():
98
  out = model(image_t, text)
@@ -103,24 +100,23 @@ def create_refseg_demo(model, tokenizer, device):
103
  else:
104
  mask = out
105
 
106
- # Convert to binary mask
107
  if mask.ndim > 2:
108
  mask = np.argmax(mask, axis=0)
109
- mask = (mask > 0).astype(np.uint8)
110
 
111
- # Resize mask to original image size
112
- mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
113
 
114
- # Overlay mask
 
115
  alpha = 0.65
116
- overlay = image_np.copy()
117
- overlay[mask == 0] = (overlay[mask == 0] * alpha).astype(np.uint8)
118
 
119
  # Draw contours
120
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
121
- cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)
 
 
122
 
123
- return Image.fromarray(overlay)
124
 
125
  submit.click(on_submit, inputs=[input_image, input_text], outputs=refseg_image)
126
  examples = gr.Examples(
 
88
 
89
  def on_submit(image, text):
90
  # Convert PIL -> np array
 
91
  transform = transforms.ToTensor()
92
  image_t = transform(image).unsqueeze(0).to(device)
 
 
93
 
94
  with torch.no_grad():
95
  out = model(image_t, text)
 
100
  else:
101
  mask = out
102
 
103
+ # If model returns multi-channel, collapse with argmax
104
  if mask.ndim > 2:
105
  mask = np.argmax(mask, axis=0)
 
106
 
107
+ mask = mask.astype(np.uint8)
 
108
 
109
+ # Overlay mask on original image
110
+ image_np = np.array(image).copy()
111
  alpha = 0.65
112
+ image_np[mask == 0] = (image_np[mask == 0] * alpha).astype(np.uint8)
 
113
 
114
  # Draw contours
115
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
116
+ cv2.drawContours(image_np, contours, -1, (0, 255, 0), 2)
117
+
118
+ return Image.fromarray(image_np)
119
 
 
120
 
121
  submit.click(on_submit, inputs=[input_image, input_text], outputs=refseg_image)
122
  examples = gr.Examples(