Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import SamModel, SamProcessor, pipeline
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import cv2
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn.functional import cosine_similarity
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
class RoiMatching():
|
| 11 |
+
def __init__(self,img1,img2,device='cuda:1', v_min=200, v_max= 7000, mode = 'embedding'):
|
| 12 |
+
"""
|
| 13 |
+
Initialize
|
| 14 |
+
:param img1: PIL image
|
| 15 |
+
:param img2:
|
| 16 |
+
"""
|
| 17 |
+
self.img1 = img1
|
| 18 |
+
self.img2 = img2
|
| 19 |
+
self.device = device
|
| 20 |
+
self.v_min = v_min
|
| 21 |
+
self.v_max = v_max
|
| 22 |
+
self.mode = mode
|
| 23 |
+
|
| 24 |
+
def _sam_everything(self,imgs):
|
| 25 |
+
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=self.device)
|
| 26 |
+
outputs = generator(imgs, points_per_batch=64,pred_iou_thresh=0.90,stability_score_thresh=0.9,)
|
| 27 |
+
return outputs
|
| 28 |
+
def _mask_criteria(self, masks, v_min=200, v_max= 7000):
|
| 29 |
+
remove_list = set()
|
| 30 |
+
for _i, mask in enumerate(masks):
|
| 31 |
+
if mask.sum() < v_min or mask.sum() > v_max:
|
| 32 |
+
remove_list.add(_i)
|
| 33 |
+
masks = [mask for idx, mask in enumerate(masks) if idx not in remove_list]
|
| 34 |
+
n = len(masks)
|
| 35 |
+
remove_list = set()
|
| 36 |
+
for i in range(n):
|
| 37 |
+
for j in range(i + 1, n):
|
| 38 |
+
mask1, mask2 = masks[i], masks[j]
|
| 39 |
+
intersection = (mask1 & mask2).sum()
|
| 40 |
+
smaller_mask_area = min(masks[i].sum(), masks[j].sum())
|
| 41 |
+
|
| 42 |
+
if smaller_mask_area > 0 and (intersection / smaller_mask_area) >= 0.9:
|
| 43 |
+
if mask1.sum() < mask2.sum():
|
| 44 |
+
remove_list.add(i)
|
| 45 |
+
else:
|
| 46 |
+
remove_list.add(j)
|
| 47 |
+
return [mask for idx, mask in enumerate(masks) if idx not in remove_list]
|
| 48 |
+
|
| 49 |
+
def _roi_proto(self, image, masks):
|
| 50 |
+
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(self.device)
|
| 51 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 52 |
+
inputs = processor(image, return_tensors="pt").to(self.device)
|
| 53 |
+
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
| 54 |
+
embs = []
|
| 55 |
+
for _m in masks:
|
| 56 |
+
# Convert mask to uint8, resize, and then back to boolean
|
| 57 |
+
tmp_m = _m.astype(np.uint8)
|
| 58 |
+
tmp_m = cv2.resize(tmp_m, (64, 64), interpolation=cv2.INTER_NEAREST)
|
| 59 |
+
tmp_m = torch.tensor(tmp_m.astype(bool), device=self.device,
|
| 60 |
+
dtype=torch.float32) # Convert to tensor and send to CUDA
|
| 61 |
+
tmp_m = tmp_m.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions to match emb1
|
| 62 |
+
|
| 63 |
+
# Element-wise multiplication with emb1
|
| 64 |
+
tmp_emb = image_embeddings * tmp_m
|
| 65 |
+
# (1,256,64,64)
|
| 66 |
+
|
| 67 |
+
tmp_emb[tmp_emb == 0] = torch.nan
|
| 68 |
+
emb = torch.nanmean(tmp_emb, dim=(2, 3))
|
| 69 |
+
emb[torch.isnan(emb)] = 0
|
| 70 |
+
embs.append(emb)
|
| 71 |
+
return embs
|
| 72 |
+
|
| 73 |
+
def _cosine_similarity(self, vec1, vec2):
|
| 74 |
+
# Ensure vec1 and vec2 are 2D tensors [1, N]
|
| 75 |
+
vec1 = vec1.view(1, -1)
|
| 76 |
+
vec2 = vec2.view(1, -1)
|
| 77 |
+
return cosine_similarity(vec1, vec2).item()
|
| 78 |
+
|
| 79 |
+
def _similarity_matrix(self, protos1, protos2):
|
| 80 |
+
# Initialize similarity_matrix as a torch tensor
|
| 81 |
+
similarity_matrix = torch.zeros(len(protos1), len(protos2), device=self.device)
|
| 82 |
+
for i, vec_a in enumerate(protos1):
|
| 83 |
+
for j, vec_b in enumerate(protos2):
|
| 84 |
+
similarity_matrix[i, j] = self._cosine_similarity(vec_a, vec_b)
|
| 85 |
+
# Normalize the similarity matrix
|
| 86 |
+
sim_matrix = (similarity_matrix - similarity_matrix.min()) / (similarity_matrix.max() - similarity_matrix.min())
|
| 87 |
+
return similarity_matrix
|
| 88 |
+
|
| 89 |
+
def _roi_match(self, matrix, masks1, masks2, sim_criteria=0.8):
|
| 90 |
+
index_pairs = []
|
| 91 |
+
while torch.any(matrix > sim_criteria):
|
| 92 |
+
max_idx = torch.argmax(matrix)
|
| 93 |
+
max_sim_idx = (max_idx // matrix.shape[1], max_idx % matrix.shape[1])
|
| 94 |
+
if matrix[max_sim_idx[0], max_sim_idx[1]] > sim_criteria:
|
| 95 |
+
index_pairs.append(max_sim_idx)
|
| 96 |
+
matrix[max_sim_idx[0], :] = -1
|
| 97 |
+
matrix[:, max_sim_idx[1]] = -1
|
| 98 |
+
masks1_new = []
|
| 99 |
+
masks2_new = []
|
| 100 |
+
for i, j in index_pairs:
|
| 101 |
+
masks1_new.append(masks1[i])
|
| 102 |
+
masks2_new.append(masks2[j])
|
| 103 |
+
return masks1_new, masks2_new
|
| 104 |
+
|
| 105 |
+
def _overlap_pair(self, masks1,masks2):
|
| 106 |
+
self.masks1_cor = []
|
| 107 |
+
self.masks2_cor = []
|
| 108 |
+
k = 0
|
| 109 |
+
for mask in masks1[:-1]:
|
| 110 |
+
k += 1
|
| 111 |
+
print('mask1 {} is finding corresponding region mask...'.format(k))
|
| 112 |
+
m1 = mask
|
| 113 |
+
a1 = mask.sum()
|
| 114 |
+
v1 = np.mean(np.expand_dims(m1, axis=-1) * self.im1)
|
| 115 |
+
overlap = m1 * masks2[-1].astype(np.int64)
|
| 116 |
+
# print(np.unique(overlap))
|
| 117 |
+
if (overlap > 0).sum() / a1 > 0.3:
|
| 118 |
+
counts = np.bincount(overlap.flatten())
|
| 119 |
+
# print(counts)
|
| 120 |
+
sorted_indices = np.argsort(counts)[::-1]
|
| 121 |
+
top_two = sorted_indices[1:3]
|
| 122 |
+
# print(top_two)
|
| 123 |
+
if top_two[-1] == 0:
|
| 124 |
+
cor_ind = 0
|
| 125 |
+
elif abs(counts[top_two[-1]] - counts[top_two[0]]) / max(counts[top_two[-1]], counts[top_two[0]]) < 0.2:
|
| 126 |
+
cor_ind = 0
|
| 127 |
+
else:
|
| 128 |
+
# cor_ind = 0
|
| 129 |
+
m21 = masks2[top_two[0]-1]
|
| 130 |
+
m22 = masks2[top_two[1]-1]
|
| 131 |
+
a21 = masks2[top_two[0]-1].sum()
|
| 132 |
+
a22 = masks2[top_two[1]-1].sum()
|
| 133 |
+
v21 = np.mean(np.expand_dims(m21, axis=-1)*self.im2)
|
| 134 |
+
v22 = np.mean(np.expand_dims(m22, axis=-1)*self.im2)
|
| 135 |
+
if np.abs(a21-a1) > np.abs(a22-a1):
|
| 136 |
+
cor_ind = 0
|
| 137 |
+
else:
|
| 138 |
+
cor_ind = 1
|
| 139 |
+
print('area judge to cor_ind {}'.format(cor_ind))
|
| 140 |
+
if np.abs(v21-v1) < np.abs(v22-v1):
|
| 141 |
+
cor_ind = 0
|
| 142 |
+
else:
|
| 143 |
+
cor_ind = 1
|
| 144 |
+
# print('value judge to cor_ind {}'.format(cor_ind))
|
| 145 |
+
# print('mask1 {} has found the corresponding region mask: mask2 {}'.format(k, top_two[cor_ind]))
|
| 146 |
+
|
| 147 |
+
self.masks2_cor.append(masks2[top_two[cor_ind] - 1])
|
| 148 |
+
self.masks1_cor.append(mask)
|
| 149 |
+
# return masks1_new, masks2_new
|
| 150 |
+
|
| 151 |
+
def get_paired_roi(self):
|
| 152 |
+
self.masks1 = self._sam_everything(self.img1) # len(RM.masks1) 2; RM.masks1[0] dict; RM.masks1[0]['masks'] list
|
| 153 |
+
self.masks2 = self._sam_everything(self.img2)
|
| 154 |
+
self.masks1 = self._mask_criteria(self.masks1['masks'], v_min=self.v_min, v_max=self.v_max)
|
| 155 |
+
self.masks2 = self._mask_criteria(self.masks2['masks'], v_min=self.v_min, v_max=self.v_max)
|
| 156 |
+
|
| 157 |
+
match self.mode:
|
| 158 |
+
case 'embedding':
|
| 159 |
+
if len(self.masks1) > 0 and len(self.masks2) > 0:
|
| 160 |
+
self.embs1 = self._roi_proto(self.img1,self.masks1) #device:cuda1
|
| 161 |
+
self.embs2 = self._roi_proto(self.img2,self.masks2)
|
| 162 |
+
self.sim_matrix = self._similarity_matrix(self.embs1, self.embs2)
|
| 163 |
+
self.masks1, self.masks2 = self._roi_match(self.sim_matrix,self.masks1,self.masks2)
|
| 164 |
+
case 'overlaping':
|
| 165 |
+
self._overlap_pair(self.masks1,self.masks2)
|
| 166 |
+
|
| 167 |
+
def visualize_masks(image1, masks1, image2, masks2):
|
| 168 |
+
# Convert PIL images to numpy arrays
|
| 169 |
+
background1 = np.array(image1)
|
| 170 |
+
background2 = np.array(image2)
|
| 171 |
+
|
| 172 |
+
# Convert RGB to BGR (OpenCV uses BGR color format)
|
| 173 |
+
background1 = cv2.cvtColor(background1, cv2.COLOR_RGB2BGR)
|
| 174 |
+
background2 = cv2.cvtColor(background2, cv2.COLOR_RGB2BGR)
|
| 175 |
+
|
| 176 |
+
# Create a blank mask for each image
|
| 177 |
+
mask1 = np.zeros_like(background1)
|
| 178 |
+
mask2 = np.zeros_like(background2)
|
| 179 |
+
|
| 180 |
+
distinct_colors = [
|
| 181 |
+
(255, 0, 0), # Red
|
| 182 |
+
(0, 255, 0), # Green
|
| 183 |
+
(0, 0, 255), # Blue
|
| 184 |
+
(255, 255, 0), # Cyan
|
| 185 |
+
(255, 0, 255), # Magenta
|
| 186 |
+
(0, 255, 255), # Yellow
|
| 187 |
+
(128, 0, 0), # Maroon
|
| 188 |
+
(0, 128, 0), # Olive
|
| 189 |
+
(0, 0, 128), # Navy
|
| 190 |
+
(128, 128, 0), # Teal
|
| 191 |
+
(128, 0, 128), # Purple
|
| 192 |
+
(0, 128, 128), # Gray
|
| 193 |
+
(192, 192, 192) # Silver
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
def random_color():
|
| 197 |
+
"""Generate a random color with high saturation and value in HSV color space."""
|
| 198 |
+
hue = random.randint(0, 179) # Random hue value between 0 and 179 (HSV uses 0-179 range)
|
| 199 |
+
saturation = random.randint(200, 255) # High saturation value between 200 and 255
|
| 200 |
+
value = random.randint(200, 255) # High value (brightness) between 200 and 255
|
| 201 |
+
color = np.array([[[hue, saturation, value]]], dtype=np.uint8)
|
| 202 |
+
return cv2.cvtColor(color, cv2.COLOR_HSV2BGR)[0][0]
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# Iterate through mask lists and overlay on the blank masks with different colors
|
| 206 |
+
for idx, (mask1_item, mask2_item) in enumerate(zip(masks1, masks2)):
|
| 207 |
+
# color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
| 208 |
+
# color = distinct_colors[idx % len(distinct_colors)]
|
| 209 |
+
color = random_color()
|
| 210 |
+
# Convert binary masks to uint8
|
| 211 |
+
mask1_item = np.uint8(mask1_item)
|
| 212 |
+
mask2_item = np.uint8(mask2_item)
|
| 213 |
+
|
| 214 |
+
# Create a mask where binary mask is True
|
| 215 |
+
fg_mask1 = np.where(mask1_item, 255, 0).astype(np.uint8)
|
| 216 |
+
fg_mask2 = np.where(mask2_item, 255, 0).astype(np.uint8)
|
| 217 |
+
|
| 218 |
+
# Apply the foreground masks on the corresponding masks with the same color
|
| 219 |
+
mask1[fg_mask1 > 0] = color
|
| 220 |
+
mask2[fg_mask2 > 0] = color
|
| 221 |
+
|
| 222 |
+
# Add the masks on top of the background images
|
| 223 |
+
result1 = cv2.addWeighted(background1, 1, mask1, 0.5, 0)
|
| 224 |
+
result2 = cv2.addWeighted(background2, 1, mask2, 0.5, 0)
|
| 225 |
+
|
| 226 |
+
return result1, result2
|
| 227 |
+
|
| 228 |
+
def predict(im1,im2):
|
| 229 |
+
RM = RoiMatching(im1,im2,device='cpu')
|
| 230 |
+
RM.get_paired_roi()
|
| 231 |
+
visualized_image1, visualized_image2 = visualize_masks(im1, RM.masks1, im2, RM.masks2)
|
| 232 |
+
return visualized_image1, visualized_image2
|
| 233 |
+
|
| 234 |
+
examples = [
|
| 235 |
+
['./example/prostate_2d/image1.png', './example/prostate_2d/image2.png'],
|
| 236 |
+
['./example/cardiac_2d/image1.png', './example/cardiac_2d/image2.png'],
|
| 237 |
+
['./example/pathology/1B_B7_R.png', './example/pathology/1B_B7_T.png'],
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
gradio_app = gr.Interface(
|
| 242 |
+
predict,
|
| 243 |
+
inputs=gr.Image(label="Select hot dog candidate", sources=['upload', 'webcam'], type="pil"),
|
| 244 |
+
outputs=[gr.Image(label="Processed Image"), gr.Label(label="Result", num_top_classes=2)],
|
| 245 |
+
title="SAMReg: One Registration is Worth Two Segmentations",
|
| 246 |
+
examples=examples,
|
| 247 |
+
description="<p> \
|
| 248 |
+
<strong>Register anything with ROI-based registration representation.</strong> <br>\
|
| 249 |
+
Choose an example below 🔥 🔥 🔥 <br>\
|
| 250 |
+
Or, upload by yourself: <br>\
|
| 251 |
+
1. Upload images to be tested to 'img1' and 'img2'. <br>2. Upload a prompt image to 'im1' and 'im2'. <br>\
|
| 252 |
+
<br> \
|
| 253 |
+
π SAM segments the target with any point or scribble, then SegGPT segments all other images. <br>\
|
| 254 |
+
π Examples below were never trained and are randomly selected for testing in the wild. <br>\
|
| 255 |
+
π Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. \
|
| 256 |
+
</p>",
|
| 257 |
+
)
|