Spaces:
Running
Running
test later
Browse files- README.md +1 -1
- app.py +54 -183
- app_fast.py β app_slow.py +183 -54
- requirements.txt +2 -0
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: π
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.44.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
app.py
CHANGED
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from PIL import ImageDraw, Image, ImageFont
|
@@ -9,24 +14,18 @@ import matplotlib.pyplot as plt
|
|
9 |
import torch
|
10 |
from transformers import SamModel, SamProcessor
|
11 |
|
12 |
-
import
|
13 |
|
|
|
14 |
|
15 |
-
#
|
16 |
-
path = os.getcwd()
|
17 |
-
font_path = r'{}/arial.ttf'.format(path)
|
18 |
-
print(font_path)
|
19 |
-
|
20 |
-
# Load the pre-trained model - FastSAM
|
21 |
-
# fastsam_model = FastSAM('./FastSAM-s.pt')
|
22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
#
|
27 |
-
|
28 |
-
|
29 |
-
previous_box_points = 0
|
30 |
|
31 |
# Description
|
32 |
title = "<center><strong><font size='8'> π Segment food with clicks π</font></strong></center>"
|
@@ -34,85 +33,23 @@ title = "<center><strong><font size='8'> π Segment food with clicks π</fon
|
|
34 |
instruction = """ # Instruction
|
35 |
This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
|
36 |
π₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
|
37 |
-
π₯ Step 2: Add positive (
|
38 |
π₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
|
39 |
π₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
|
40 |
π₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
|
41 |
"""
|
42 |
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
|
43 |
|
|
|
44 |
|
45 |
def read_image(url):
|
46 |
response = requests.get(url)
|
47 |
img = Image.open(BytesIO(response.content))
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
global_point_label = []
|
54 |
-
return img
|
55 |
-
|
56 |
-
# def show_mask(mask, ax, random_color=False):
|
57 |
-
# if random_color:
|
58 |
-
# color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
59 |
-
# else:
|
60 |
-
# color = np.array([30/255, 144/255, 255/255, 0.6])
|
61 |
-
# h, w = mask.shape[-2:]
|
62 |
-
# mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
63 |
-
# ax.imshow(mask_image)
|
64 |
-
|
65 |
-
# def show_points(coords, labels, ax, marker_size=375):
|
66 |
-
# pos_points = coords[labels==1]
|
67 |
-
# neg_points = coords[labels==0]
|
68 |
-
# ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
69 |
-
# ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
70 |
-
|
71 |
-
# def show_masks_and_points_on_image(raw_image, mask, input_points, input_labels, args):
|
72 |
-
# masks = masks.squeeze() if len(masks.shape) == 4 else masks.unsqueeze(0) if len(masks.shape) == 2 else masks
|
73 |
-
# scores = scores.squeeze() if (scores.shape[0] == 1) & (len(scores.shape) == 3) else scores
|
74 |
-
# #
|
75 |
-
# input_points = np.array(input_points)
|
76 |
-
# labels = np.array(input_labels)
|
77 |
-
# #
|
78 |
-
# mask = mask.cpu().detach()
|
79 |
-
# plt.imshow(np.array(raw_image))
|
80 |
-
# ax = plt.gca()
|
81 |
-
# show_mask(mask, ax)
|
82 |
-
# show_points(input_points, labels, ax, marker_size=375)
|
83 |
-
# ax.axis("off")
|
84 |
-
|
85 |
-
# save_path = args.output
|
86 |
-
# if not os.path.exists(save_path):
|
87 |
-
# os.makedirs(save_path)
|
88 |
-
# plt.axis("off")
|
89 |
-
# fig = plt.gcf()
|
90 |
-
# plt.draw()
|
91 |
-
|
92 |
-
# try:
|
93 |
-
# buf = fig.canvas.tostring_rgb()
|
94 |
-
# except AttributeError:
|
95 |
-
# fig.canvas.draw()
|
96 |
-
# buf = fig.canvas.tostring_rgb()
|
97 |
-
|
98 |
-
# cols, rows = fig.canvas.get_width_height()
|
99 |
-
# img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
|
100 |
-
# cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
|
101 |
-
|
102 |
-
def format_prompt_points(points, labels):
|
103 |
-
prompt_points = [xy for xy, l in zip(points, labels) if l != 9]
|
104 |
-
point_labels = [l for l in labels if l != 9]
|
105 |
-
#
|
106 |
-
prompt_boxes = None
|
107 |
-
if len(point_labels) < len(labels):
|
108 |
-
prompt_boxes = [[np.array([xy for xy, l in zip(points, labels) if l == 9]).reshape(-1, 4).tolist()]]
|
109 |
-
return prompt_points, point_labels, prompt_boxes
|
110 |
-
|
111 |
-
# def get_mask_image(raw_image, mask):
|
112 |
-
# tmp_mask = np.array(mask)
|
113 |
-
# tmp_img_arr = np.array(raw_image)
|
114 |
-
# tmp_img_arr[tmp_mask == False] = [255,255,255]
|
115 |
-
# return tmp_img_arr
|
116 |
|
117 |
def get_mask_image(raw_image, mask):
|
118 |
tmp_mask = np.array(mask * 1)
|
@@ -123,29 +60,32 @@ def get_mask_image(raw_image, mask):
|
|
123 |
tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
|
124 |
return tmp_img_arr
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
def segment_with_points(
|
128 |
-
|
129 |
-
input_size=1024,
|
130 |
-
iou_threshold=0.7,
|
131 |
-
conf_threshold=0.25,
|
132 |
-
better_quality=False,
|
133 |
-
withContours=True,
|
134 |
-
use_retina=True,
|
135 |
-
mask_random_color=True,
|
136 |
):
|
137 |
-
global global_points
|
138 |
-
global global_point_label
|
139 |
|
140 |
-
#
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
prompt_points, point_labels, prompt_boxes = format_prompt_points(global_points, global_point_label)
|
145 |
print(prompt_points, point_labels, prompt_boxes)
|
146 |
# segment
|
147 |
-
inputs = processor(
|
148 |
-
input_boxes = prompt_boxes,
|
149 |
input_points=[[prompt_points]],
|
150 |
input_labels=[point_labels],
|
151 |
return_tensors="pt").to(device)
|
@@ -155,74 +95,15 @@ def segment_with_points(
|
|
155 |
masks = processor.image_processor.post_process_masks(
|
156 |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
157 |
scores = outputs.iou_scores
|
158 |
-
|
159 |
-
|
160 |
-
# fig = show_masks_and_points_on_image(raw_image, masks[0][0][0], [global_points], global_point_label)
|
161 |
-
mask_images = [get_mask_image(raw_image, m) for m in masks[0][0]]
|
162 |
mask_img1, mask_img2, mask_img3 = mask_images
|
163 |
# return fig, None
|
164 |
return mask_img1, mask_img2, mask_img3
|
165 |
|
166 |
-
def find_font_size(text, font_path, image, target_width_ratio):
|
167 |
-
tested_font_size = 100
|
168 |
-
tested_font = ImageFont.truetype(font_path, tested_font_size)
|
169 |
-
observed_width = get_text_size(text, image, tested_font)
|
170 |
-
estimated_font_size = tested_font_size / (observed_width / image.width) * target_width_ratio
|
171 |
-
return round(estimated_font_size)
|
172 |
-
|
173 |
-
def get_text_size(text, image, font):
|
174 |
-
im = Image.new('RGB', (image.width, image.height))
|
175 |
-
draw = ImageDraw.Draw(im)
|
176 |
-
return draw.textlength(text, font)
|
177 |
-
|
178 |
-
|
179 |
-
def get_points_with_draw(image, label, evt: gr.SelectData):
|
180 |
-
global global_points
|
181 |
-
global global_point_label
|
182 |
-
global previous_box_points
|
183 |
-
|
184 |
-
x, y = evt.index[0], evt.index[1]
|
185 |
-
point_radius = 15
|
186 |
-
point_color = (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
|
187 |
-
global_points.append([x, y])
|
188 |
-
global_point_label.append(1 if label == 'Add Mask' else 0 if label == 'Remove Area' else 9)
|
189 |
-
|
190 |
-
print(x, y, label)
|
191 |
-
print(previous_box_points)
|
192 |
-
|
193 |
-
draw = ImageDraw.Draw(image)
|
194 |
-
if label != 'Bounding Box':
|
195 |
-
draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
|
196 |
-
else:
|
197 |
-
if (previous_box_points == 0) | (previous_box_points%2 == 0):
|
198 |
-
target_width_ratio = 0.9
|
199 |
-
text = "Please Click Another Point For Bounding Box"
|
200 |
-
font_size = find_font_size(text, font_path, image, target_width_ratio)
|
201 |
-
font = ImageFont.truetype(font_path, font_size)
|
202 |
-
draw.text((x, y), text, fill = (0,0,0), font = font)
|
203 |
-
else:
|
204 |
-
[previous_x, previous_y] = global_points[-2]
|
205 |
-
print((previous_x, previous_y), (x, y))
|
206 |
-
draw.rectangle([(previous_x, previous_y), (x, y)], outline=(0, 0, 255), width=10)
|
207 |
-
previous_box_points += 1
|
208 |
-
return image
|
209 |
-
|
210 |
def clear():
|
211 |
-
global global_points
|
212 |
-
global global_point_label
|
213 |
-
|
214 |
-
global_points = []
|
215 |
-
global_point_label = []
|
216 |
-
previous_box_points = 0
|
217 |
return None, None, None, None
|
218 |
|
219 |
-
|
220 |
-
# Configure layout
|
221 |
-
cond_img_p = gr.Image(label="Input with points", type='pil')
|
222 |
-
segm_img_p1 = gr.Image(label="Segmented Image Option 1", interactive=False, type='pil', format="png")
|
223 |
-
segm_img_p2 = gr.Image(label="Segmented Image Option 2", interactive=False, type='pil', format="png")
|
224 |
-
segm_img_p3 = gr.Image(label="Segmented Image Option 3", interactive=False, type='pil', format="png")
|
225 |
-
|
226 |
with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
|
227 |
with gr.Row():
|
228 |
with gr.Column(scale=1):
|
@@ -231,42 +112,32 @@ with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
|
|
231 |
image_url = gr.Textbox(label="Enter Image URL",
|
232 |
value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
|
233 |
run_with_url = gr.Button("Upload Image")
|
|
|
|
|
234 |
with gr.Column(scale=1):
|
235 |
gr.Markdown(instruction)
|
236 |
|
237 |
# Images
|
238 |
with gr.Row(variant="panel"):
|
239 |
with gr.Column(scale=0):
|
240 |
-
|
241 |
-
|
242 |
with gr.Column(scale=0):
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
# Submit & Clear
|
247 |
-
with gr.Row():
|
248 |
-
with gr.Column():
|
249 |
-
add_or_remove = gr.Radio(["Add Mask", "Remove Area", "Bounding Box"],
|
250 |
-
value="Add Mask",
|
251 |
-
label="Point label")
|
252 |
-
with gr.Column():
|
253 |
-
segment_btn_p = gr.Button("Segment with prompts", variant='primary')
|
254 |
-
clear_btn_p = gr.Button("Clear points", variant='secondary')
|
255 |
|
256 |
# Define interaction relationship
|
257 |
run_with_url.click(read_image,
|
258 |
inputs=[image_url],
|
259 |
# outputs=[segm_img_p, cond_img_p])
|
260 |
-
outputs=[
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
segment_btn_p.click(segment_with_points,
|
265 |
-
inputs=[image_url],
|
266 |
# outputs=[segm_img_p, cond_img_p])
|
267 |
-
outputs=[
|
268 |
|
269 |
-
|
270 |
|
271 |
demo.queue()
|
272 |
demo.launch()
|
|
|
1 |
+
import os
|
2 |
+
# os.system("pip uninstall -y gradio")
|
3 |
+
# os.system("pip install gradio==4.44.1")
|
4 |
+
os.system("pip install gradio_image_prompter")
|
5 |
+
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
from PIL import ImageDraw, Image, ImageFont
|
|
|
14 |
import torch
|
15 |
from transformers import SamModel, SamProcessor
|
16 |
|
17 |
+
from gradio_image_prompter import ImagePrompter
|
18 |
|
19 |
+
import os
|
20 |
|
21 |
+
# define variables
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
# model_id = "facebook/sam-vit-huge" #60s
|
24 |
+
model_id = 'Zigeng/SlimSAM-uniform-50' #50s
|
25 |
+
# model_id = "facebook/sam-vit-base" #50s
|
26 |
+
# model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
27 |
+
model = SamModel.from_pretrained(model_id).to(device)
|
28 |
+
processor = SamProcessor.from_pretrained(model_id)
|
|
|
29 |
|
30 |
# Description
|
31 |
title = "<center><strong><font size='8'> π Segment food with clicks π</font></strong></center>"
|
|
|
33 |
instruction = """ # Instruction
|
34 |
This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
|
35 |
π₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
|
36 |
+
π₯ Step 2: Add positive (right click), negative (middle click), and bounding box (click and drag - only ONE box at most) for the food \n
|
37 |
π₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
|
38 |
π₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
|
39 |
π₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
|
40 |
"""
|
41 |
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
|
42 |
|
43 |
+
# functions
|
44 |
|
45 |
def read_image(url):
|
46 |
response = requests.get(url)
|
47 |
img = Image.open(BytesIO(response.content))
|
48 |
+
formatted_image = {
|
49 |
+
"image": np.array(img),
|
50 |
+
"points": [],
|
51 |
+
} # Create the correct format
|
52 |
+
return formatted_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
def get_mask_image(raw_image, mask):
|
55 |
tmp_mask = np.array(mask * 1)
|
|
|
60 |
tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
|
61 |
return tmp_img_arr
|
62 |
|
63 |
+
def format_prompt_points(points):
|
64 |
+
prompt_points = []
|
65 |
+
point_labels = []
|
66 |
+
prompt_boxes = []
|
67 |
+
for point in points:
|
68 |
+
print(point)
|
69 |
+
if point[2] == 2.0 and point[5] == 3.0:
|
70 |
+
prompt_boxes.append([point[0], point[1], point[3], point[4]])
|
71 |
+
else:
|
72 |
+
prompt_points.append([point[0], point[1]])
|
73 |
+
label = 1 if point[2] == 1.0 else 0
|
74 |
+
point_labels.append(label)
|
75 |
+
return prompt_points, point_labels, prompt_boxes
|
76 |
|
77 |
def segment_with_points(
|
78 |
+
prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
):
|
|
|
|
|
80 |
|
81 |
+
image = np.array(prompts["image"]) # Convert the image to a numpy array
|
82 |
+
points = prompts["points"] # Get the points from prompts
|
83 |
+
#
|
84 |
+
prompt_points, point_labels, prompt_boxes = format_prompt_points(points)
|
|
|
85 |
print(prompt_points, point_labels, prompt_boxes)
|
86 |
# segment
|
87 |
+
inputs = processor(image,
|
88 |
+
input_boxes = [prompt_boxes],
|
89 |
input_points=[[prompt_points]],
|
90 |
input_labels=[point_labels],
|
91 |
return_tensors="pt").to(device)
|
|
|
95 |
masks = processor.image_processor.post_process_masks(
|
96 |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
97 |
scores = outputs.iou_scores
|
98 |
+
#
|
99 |
+
mask_images = [get_mask_image(image, m) for m in masks[0][0]]
|
|
|
|
|
100 |
mask_img1, mask_img2, mask_img3 = mask_images
|
101 |
# return fig, None
|
102 |
return mask_img1, mask_img2, mask_img3
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
def clear():
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
return None, None, None, None
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
|
108 |
with gr.Row():
|
109 |
with gr.Column(scale=1):
|
|
|
112 |
image_url = gr.Textbox(label="Enter Image URL",
|
113 |
value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
|
114 |
run_with_url = gr.Button("Upload Image")
|
115 |
+
segment_btn = gr.Button("Segment with prompts", variant='primary')
|
116 |
+
clear_btn = gr.Button("Clear points", variant='secondary')
|
117 |
with gr.Column(scale=1):
|
118 |
gr.Markdown(instruction)
|
119 |
|
120 |
# Images
|
121 |
with gr.Row(variant="panel"):
|
122 |
with gr.Column(scale=0):
|
123 |
+
candidate_pic = ImagePrompter(show_label=False)
|
124 |
+
segpic_output1 = gr.Image(format="png")
|
125 |
with gr.Column(scale=0):
|
126 |
+
segpic_output2 = gr.Image(format="png")
|
127 |
+
segpic_output3 = gr.Image(format="png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# Define interaction relationship
|
130 |
run_with_url.click(read_image,
|
131 |
inputs=[image_url],
|
132 |
# outputs=[segm_img_p, cond_img_p])
|
133 |
+
outputs=[candidate_pic])
|
134 |
|
135 |
+
segment_btn.click(segment_with_points,
|
136 |
+
inputs=candidate_pic,
|
|
|
|
|
137 |
# outputs=[segm_img_p, cond_img_p])
|
138 |
+
outputs=[segpic_output1, segpic_output2, segpic_output3])
|
139 |
|
140 |
+
clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3])
|
141 |
|
142 |
demo.queue()
|
143 |
demo.launch()
|
app_fast.py β app_slow.py
RENAMED
@@ -1,8 +1,3 @@
|
|
1 |
-
import os
|
2 |
-
# os.system("pip uninstall -y gradio")
|
3 |
-
# os.system("pip install gradio==4.44.1")
|
4 |
-
os.system("pip install gradio_image_prompter")
|
5 |
-
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
from PIL import ImageDraw, Image, ImageFont
|
@@ -14,18 +9,24 @@ import matplotlib.pyplot as plt
|
|
14 |
import torch
|
15 |
from transformers import SamModel, SamProcessor
|
16 |
|
17 |
-
from gradio_image_prompter import ImagePrompter
|
18 |
-
|
19 |
import os
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
#
|
27 |
-
|
28 |
-
|
|
|
29 |
|
30 |
# Description
|
31 |
title = "<center><strong><font size='8'> π Segment food with clicks π</font></strong></center>"
|
@@ -33,23 +34,85 @@ title = "<center><strong><font size='8'> π Segment food with clicks π</fon
|
|
33 |
instruction = """ # Instruction
|
34 |
This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
|
35 |
π₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
|
36 |
-
π₯ Step 2: Add positive (
|
37 |
π₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
|
38 |
π₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
|
39 |
π₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
|
40 |
"""
|
41 |
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
|
42 |
|
43 |
-
# functions
|
44 |
|
45 |
def read_image(url):
|
46 |
response = requests.get(url)
|
47 |
img = Image.open(BytesIO(response.content))
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
def get_mask_image(raw_image, mask):
|
55 |
tmp_mask = np.array(mask * 1)
|
@@ -60,32 +123,29 @@ def get_mask_image(raw_image, mask):
|
|
60 |
tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
|
61 |
return tmp_img_arr
|
62 |
|
63 |
-
def format_prompt_points(points):
|
64 |
-
prompt_points = []
|
65 |
-
point_labels = []
|
66 |
-
prompt_boxes = []
|
67 |
-
for point in points:
|
68 |
-
print(point)
|
69 |
-
if point[2] == 2.0 and point[5] == 3.0:
|
70 |
-
prompt_boxes.append([point[0], point[1], point[3], point[4]])
|
71 |
-
else:
|
72 |
-
prompt_points.append([point[0], point[1]])
|
73 |
-
label = 1 if point[2] == 1.0 else 0
|
74 |
-
point_labels.append(label)
|
75 |
-
return prompt_points, point_labels, prompt_boxes
|
76 |
|
77 |
def segment_with_points(
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
):
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
85 |
print(prompt_points, point_labels, prompt_boxes)
|
86 |
# segment
|
87 |
-
inputs = processor(
|
88 |
-
input_boxes =
|
89 |
input_points=[[prompt_points]],
|
90 |
input_labels=[point_labels],
|
91 |
return_tensors="pt").to(device)
|
@@ -95,15 +155,74 @@ def segment_with_points(
|
|
95 |
masks = processor.image_processor.post_process_masks(
|
96 |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
97 |
scores = outputs.iou_scores
|
98 |
-
|
99 |
-
|
|
|
|
|
100 |
mask_img1, mask_img2, mask_img3 = mask_images
|
101 |
# return fig, None
|
102 |
return mask_img1, mask_img2, mask_img3
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
def clear():
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
return None, None, None, None
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
|
108 |
with gr.Row():
|
109 |
with gr.Column(scale=1):
|
@@ -112,32 +231,42 @@ with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
|
|
112 |
image_url = gr.Textbox(label="Enter Image URL",
|
113 |
value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
|
114 |
run_with_url = gr.Button("Upload Image")
|
115 |
-
segment_btn = gr.Button("Segment with prompts", variant='primary')
|
116 |
-
clear_btn = gr.Button("Clear points", variant='secondary')
|
117 |
with gr.Column(scale=1):
|
118 |
gr.Markdown(instruction)
|
119 |
|
120 |
# Images
|
121 |
with gr.Row(variant="panel"):
|
122 |
with gr.Column(scale=0):
|
123 |
-
|
124 |
-
|
125 |
with gr.Column(scale=0):
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# Define interaction relationship
|
130 |
run_with_url.click(read_image,
|
131 |
inputs=[image_url],
|
132 |
# outputs=[segm_img_p, cond_img_p])
|
133 |
-
outputs=[
|
134 |
|
135 |
-
|
136 |
-
|
|
|
|
|
137 |
# outputs=[segm_img_p, cond_img_p])
|
138 |
-
outputs=[
|
139 |
|
140 |
-
|
141 |
|
142 |
demo.queue()
|
143 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from PIL import ImageDraw, Image, ImageFont
|
|
|
9 |
import torch
|
10 |
from transformers import SamModel, SamProcessor
|
11 |
|
|
|
|
|
12 |
import os
|
13 |
|
14 |
+
|
15 |
+
# Define variables
|
16 |
+
path = os.getcwd()
|
17 |
+
font_path = r'{}/arial.ttf'.format(path)
|
18 |
+
print(font_path)
|
19 |
+
|
20 |
+
# Load the pre-trained model - FastSAM
|
21 |
+
# fastsam_model = FastSAM('./FastSAM-s.pt')
|
22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
24 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
25 |
+
|
26 |
+
# Points
|
27 |
+
global_points = []
|
28 |
+
global_point_label = []
|
29 |
+
previous_box_points = 0
|
30 |
|
31 |
# Description
|
32 |
title = "<center><strong><font size='8'> π Segment food with clicks π</font></strong></center>"
|
|
|
34 |
instruction = """ # Instruction
|
35 |
This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
|
36 |
π₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
|
37 |
+
π₯ Step 2: Add positive (Add mask), negative (Remove Area), and bounding box for the food \n
|
38 |
π₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
|
39 |
π₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
|
40 |
π₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
|
41 |
"""
|
42 |
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
|
43 |
|
|
|
44 |
|
45 |
def read_image(url):
|
46 |
response = requests.get(url)
|
47 |
img = Image.open(BytesIO(response.content))
|
48 |
+
|
49 |
+
global global_points
|
50 |
+
global global_point_label
|
51 |
+
|
52 |
+
global_points = []
|
53 |
+
global_point_label = []
|
54 |
+
return img
|
55 |
+
|
56 |
+
# def show_mask(mask, ax, random_color=False):
|
57 |
+
# if random_color:
|
58 |
+
# color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
59 |
+
# else:
|
60 |
+
# color = np.array([30/255, 144/255, 255/255, 0.6])
|
61 |
+
# h, w = mask.shape[-2:]
|
62 |
+
# mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
63 |
+
# ax.imshow(mask_image)
|
64 |
+
|
65 |
+
# def show_points(coords, labels, ax, marker_size=375):
|
66 |
+
# pos_points = coords[labels==1]
|
67 |
+
# neg_points = coords[labels==0]
|
68 |
+
# ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
69 |
+
# ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
70 |
+
|
71 |
+
# def show_masks_and_points_on_image(raw_image, mask, input_points, input_labels, args):
|
72 |
+
# masks = masks.squeeze() if len(masks.shape) == 4 else masks.unsqueeze(0) if len(masks.shape) == 2 else masks
|
73 |
+
# scores = scores.squeeze() if (scores.shape[0] == 1) & (len(scores.shape) == 3) else scores
|
74 |
+
# #
|
75 |
+
# input_points = np.array(input_points)
|
76 |
+
# labels = np.array(input_labels)
|
77 |
+
# #
|
78 |
+
# mask = mask.cpu().detach()
|
79 |
+
# plt.imshow(np.array(raw_image))
|
80 |
+
# ax = plt.gca()
|
81 |
+
# show_mask(mask, ax)
|
82 |
+
# show_points(input_points, labels, ax, marker_size=375)
|
83 |
+
# ax.axis("off")
|
84 |
+
|
85 |
+
# save_path = args.output
|
86 |
+
# if not os.path.exists(save_path):
|
87 |
+
# os.makedirs(save_path)
|
88 |
+
# plt.axis("off")
|
89 |
+
# fig = plt.gcf()
|
90 |
+
# plt.draw()
|
91 |
+
|
92 |
+
# try:
|
93 |
+
# buf = fig.canvas.tostring_rgb()
|
94 |
+
# except AttributeError:
|
95 |
+
# fig.canvas.draw()
|
96 |
+
# buf = fig.canvas.tostring_rgb()
|
97 |
+
|
98 |
+
# cols, rows = fig.canvas.get_width_height()
|
99 |
+
# img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
|
100 |
+
# cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
|
101 |
+
|
102 |
+
def format_prompt_points(points, labels):
|
103 |
+
prompt_points = [xy for xy, l in zip(points, labels) if l != 9]
|
104 |
+
point_labels = [l for l in labels if l != 9]
|
105 |
+
#
|
106 |
+
prompt_boxes = None
|
107 |
+
if len(point_labels) < len(labels):
|
108 |
+
prompt_boxes = [[np.array([xy for xy, l in zip(points, labels) if l == 9]).reshape(-1, 4).tolist()]]
|
109 |
+
return prompt_points, point_labels, prompt_boxes
|
110 |
+
|
111 |
+
# def get_mask_image(raw_image, mask):
|
112 |
+
# tmp_mask = np.array(mask)
|
113 |
+
# tmp_img_arr = np.array(raw_image)
|
114 |
+
# tmp_img_arr[tmp_mask == False] = [255,255,255]
|
115 |
+
# return tmp_img_arr
|
116 |
|
117 |
def get_mask_image(raw_image, mask):
|
118 |
tmp_mask = np.array(mask * 1)
|
|
|
123 |
tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
|
124 |
return tmp_img_arr
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
def segment_with_points(
|
128 |
+
input,
|
129 |
+
input_size=1024,
|
130 |
+
iou_threshold=0.7,
|
131 |
+
conf_threshold=0.25,
|
132 |
+
better_quality=False,
|
133 |
+
withContours=True,
|
134 |
+
use_retina=True,
|
135 |
+
mask_random_color=True,
|
136 |
):
|
137 |
+
global global_points
|
138 |
+
global global_point_label
|
139 |
|
140 |
+
# read image
|
141 |
+
raw_image = Image.open(requests.get(input, stream=True).raw).convert("RGB")
|
142 |
+
|
143 |
+
# get prompts
|
144 |
+
prompt_points, point_labels, prompt_boxes = format_prompt_points(global_points, global_point_label)
|
145 |
print(prompt_points, point_labels, prompt_boxes)
|
146 |
# segment
|
147 |
+
inputs = processor(raw_image,
|
148 |
+
input_boxes = prompt_boxes,
|
149 |
input_points=[[prompt_points]],
|
150 |
input_labels=[point_labels],
|
151 |
return_tensors="pt").to(device)
|
|
|
155 |
masks = processor.image_processor.post_process_masks(
|
156 |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
157 |
scores = outputs.iou_scores
|
158 |
+
|
159 |
+
# only show the first mask
|
160 |
+
# fig = show_masks_and_points_on_image(raw_image, masks[0][0][0], [global_points], global_point_label)
|
161 |
+
mask_images = [get_mask_image(raw_image, m) for m in masks[0][0]]
|
162 |
mask_img1, mask_img2, mask_img3 = mask_images
|
163 |
# return fig, None
|
164 |
return mask_img1, mask_img2, mask_img3
|
165 |
|
166 |
+
def find_font_size(text, font_path, image, target_width_ratio):
|
167 |
+
tested_font_size = 100
|
168 |
+
tested_font = ImageFont.truetype(font_path, tested_font_size)
|
169 |
+
observed_width = get_text_size(text, image, tested_font)
|
170 |
+
estimated_font_size = tested_font_size / (observed_width / image.width) * target_width_ratio
|
171 |
+
return round(estimated_font_size)
|
172 |
+
|
173 |
+
def get_text_size(text, image, font):
|
174 |
+
im = Image.new('RGB', (image.width, image.height))
|
175 |
+
draw = ImageDraw.Draw(im)
|
176 |
+
return draw.textlength(text, font)
|
177 |
+
|
178 |
+
|
179 |
+
def get_points_with_draw(image, label, evt: gr.SelectData):
|
180 |
+
global global_points
|
181 |
+
global global_point_label
|
182 |
+
global previous_box_points
|
183 |
+
|
184 |
+
x, y = evt.index[0], evt.index[1]
|
185 |
+
point_radius = 15
|
186 |
+
point_color = (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
|
187 |
+
global_points.append([x, y])
|
188 |
+
global_point_label.append(1 if label == 'Add Mask' else 0 if label == 'Remove Area' else 9)
|
189 |
+
|
190 |
+
print(x, y, label)
|
191 |
+
print(previous_box_points)
|
192 |
+
|
193 |
+
draw = ImageDraw.Draw(image)
|
194 |
+
if label != 'Bounding Box':
|
195 |
+
draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
|
196 |
+
else:
|
197 |
+
if (previous_box_points == 0) | (previous_box_points%2 == 0):
|
198 |
+
target_width_ratio = 0.9
|
199 |
+
text = "Please Click Another Point For Bounding Box"
|
200 |
+
font_size = find_font_size(text, font_path, image, target_width_ratio)
|
201 |
+
font = ImageFont.truetype(font_path, font_size)
|
202 |
+
draw.text((x, y), text, fill = (0,0,0), font = font)
|
203 |
+
else:
|
204 |
+
[previous_x, previous_y] = global_points[-2]
|
205 |
+
print((previous_x, previous_y), (x, y))
|
206 |
+
draw.rectangle([(previous_x, previous_y), (x, y)], outline=(0, 0, 255), width=10)
|
207 |
+
previous_box_points += 1
|
208 |
+
return image
|
209 |
+
|
210 |
def clear():
|
211 |
+
global global_points
|
212 |
+
global global_point_label
|
213 |
+
|
214 |
+
global_points = []
|
215 |
+
global_point_label = []
|
216 |
+
previous_box_points = 0
|
217 |
return None, None, None, None
|
218 |
|
219 |
+
|
220 |
+
# Configure layout
|
221 |
+
cond_img_p = gr.Image(label="Input with points", type='pil')
|
222 |
+
segm_img_p1 = gr.Image(label="Segmented Image Option 1", interactive=False, type='pil', format="png")
|
223 |
+
segm_img_p2 = gr.Image(label="Segmented Image Option 2", interactive=False, type='pil', format="png")
|
224 |
+
segm_img_p3 = gr.Image(label="Segmented Image Option 3", interactive=False, type='pil', format="png")
|
225 |
+
|
226 |
with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
|
227 |
with gr.Row():
|
228 |
with gr.Column(scale=1):
|
|
|
231 |
image_url = gr.Textbox(label="Enter Image URL",
|
232 |
value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
|
233 |
run_with_url = gr.Button("Upload Image")
|
|
|
|
|
234 |
with gr.Column(scale=1):
|
235 |
gr.Markdown(instruction)
|
236 |
|
237 |
# Images
|
238 |
with gr.Row(variant="panel"):
|
239 |
with gr.Column(scale=0):
|
240 |
+
cond_img_p.render()
|
241 |
+
segm_img_p2.render()
|
242 |
with gr.Column(scale=0):
|
243 |
+
segm_img_p1.render()
|
244 |
+
segm_img_p3.render()
|
245 |
+
|
246 |
+
# Submit & Clear
|
247 |
+
with gr.Row():
|
248 |
+
with gr.Column():
|
249 |
+
add_or_remove = gr.Radio(["Add Mask", "Remove Area", "Bounding Box"],
|
250 |
+
value="Add Mask",
|
251 |
+
label="Point label")
|
252 |
+
with gr.Column():
|
253 |
+
segment_btn_p = gr.Button("Segment with prompts", variant='primary')
|
254 |
+
clear_btn_p = gr.Button("Clear points", variant='secondary')
|
255 |
|
256 |
# Define interaction relationship
|
257 |
run_with_url.click(read_image,
|
258 |
inputs=[image_url],
|
259 |
# outputs=[segm_img_p, cond_img_p])
|
260 |
+
outputs=[cond_img_p])
|
261 |
|
262 |
+
cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
|
263 |
+
|
264 |
+
segment_btn_p.click(segment_with_points,
|
265 |
+
inputs=[image_url],
|
266 |
# outputs=[segm_img_p, cond_img_p])
|
267 |
+
outputs=[segm_img_p1, segm_img_p2, segm_img_p3])
|
268 |
|
269 |
+
clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p1, segm_img_p2, segm_img_p3])
|
270 |
|
271 |
demo.queue()
|
272 |
demo.launch()
|
requirements.txt
CHANGED
@@ -4,6 +4,8 @@ torch==2.2.2
|
|
4 |
opencv-python
|
5 |
transformers==4.49.0
|
6 |
pillow==10.4.0
|
|
|
|
|
7 |
|
8 |
|
9 |
|
|
|
4 |
opencv-python
|
5 |
transformers==4.49.0
|
6 |
pillow==10.4.0
|
7 |
+
gradio>=4.0.0,<5
|
8 |
+
gradio_image_prompter @ git+https://github.com/PhyscalX/gradio-image-prompter.
|
9 |
|
10 |
|
11 |
|