jixin0101 commited on
Commit
7d4b8c8
·
0 Parent(s):

Clean history

Browse files
.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
Logo.png ADDED

Git LFS Details

  • SHA256: 8cfd430ff41ed80e783027809fabbb2dcd742c76e7f96469da4f7274d003f514
  • Pointer size: 130 Bytes
  • Size of remote file: 52 kB
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ObjectClear
3
+ emoji: 🪄
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.30.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from PIL import Image
4
+ import torch
5
+ from diffusers.utils import load_image, check_min_version
6
+ from pipeline_objectclear import ObjectClearPipeline
7
+ from tools.download_util import load_file_from_url
8
+ from tools.painter import mask_painter
9
+ import argparse
10
+ from safetensors.torch import load_file
11
+ from model import CLIPImageEncoder, PostfuseModule
12
+ import numpy as np
13
+ import torchvision.transforms.functional as TF
14
+ from scipy.ndimage import convolve, zoom
15
+ import cv2
16
+ import time
17
+ from huggingface_hub import hf_hub_download
18
+ import spaces
19
+
20
+ from tools.interact_tools import SamControler
21
+ from tools.misc import get_device
22
+ import json
23
+
24
+ check_min_version("0.30.2")
25
+
26
+
27
+ def parse_augment():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument('--device', type=str, default=None)
30
+ parser.add_argument('--sam_model_type', type=str, default="vit_h")
31
+ parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
32
+ args = parser.parse_args()
33
+
34
+ if not args.device:
35
+ args.device = str(get_device())
36
+
37
+ return args
38
+
39
+
40
+ def pad_to_multiple(image: np.ndarray, multiple: int = 8):
41
+ h, w = image.shape[:2]
42
+ pad_h = (multiple - h % multiple) % multiple
43
+ pad_w = (multiple - w % multiple) % multiple
44
+ if image.ndim == 3:
45
+ padded = np.pad(image, ((0, pad_h), (0, pad_w), (0,0)), mode='reflect')
46
+ else:
47
+ padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect')
48
+ return padded, h, w
49
+
50
+ def crop_to_original(image: np.ndarray, h: int, w: int):
51
+ return image[:h, :w]
52
+
53
+ def wavelet_blur_np(image: np.ndarray, radius: int):
54
+ kernel = np.array([
55
+ [0.0625, 0.125, 0.0625],
56
+ [0.125, 0.25, 0.125],
57
+ [0.0625, 0.125, 0.0625]
58
+ ], dtype=np.float32)
59
+
60
+ blurred = np.empty_like(image)
61
+ for c in range(image.shape[0]):
62
+ blurred_c = convolve(image[c], kernel, mode='nearest')
63
+ if radius > 1:
64
+ blurred_c = zoom(zoom(blurred_c, 1 / radius, order=1), radius, order=1)
65
+ blurred[c] = blurred_c
66
+ return blurred
67
+
68
+ def wavelet_decomposition_np(image: np.ndarray, levels=5):
69
+ high_freq = np.zeros_like(image)
70
+ for i in range(levels):
71
+ radius = 2 ** i
72
+ low_freq = wavelet_blur_np(image, radius)
73
+ high_freq += (image - low_freq)
74
+ image = low_freq
75
+ return high_freq, low_freq
76
+
77
+ def wavelet_reconstruction_np(content_feat: np.ndarray, style_feat: np.ndarray):
78
+ content_high, _ = wavelet_decomposition_np(content_feat)
79
+ _, style_low = wavelet_decomposition_np(style_feat)
80
+ return content_high + style_low
81
+
82
+ def wavelet_color_fix_np(fused: np.ndarray, mask: np.ndarray) -> np.ndarray:
83
+ fused_np = fused.astype(np.float32) / 255.0
84
+ mask_np = mask.astype(np.float32) / 255.0
85
+
86
+ fused_np = fused_np.transpose(2, 0, 1)
87
+ mask_np = mask_np.transpose(2, 0, 1)
88
+
89
+ result_np = wavelet_reconstruction_np(fused_np, mask_np)
90
+
91
+ result_np = result_np.transpose(1, 2, 0)
92
+ result_np = np.clip(result_np * 255.0, 0, 255).astype(np.uint8)
93
+
94
+ return result_np
95
+
96
+ def fuse_with_wavelet(ori: np.ndarray, removed: np.ndarray, attn_map: np.ndarray, multiple: int = 8):
97
+ H, W = ori.shape[:2]
98
+ attn_map = attn_map.astype(np.float32)
99
+ _, attn_map = cv2.threshold(attn_map, 128, 255, cv2.THRESH_BINARY)
100
+ am = attn_map.astype(np.float32)
101
+ am = am/255.0
102
+ am_up = cv2.resize(am, (W, H), interpolation=cv2.INTER_NEAREST)
103
+
104
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21,21))
105
+ am_d = cv2.dilate(am_up, kernel, iterations=1)
106
+ am_d = cv2.GaussianBlur(am_d.astype(np.float32), (9,9), sigmaX=2)
107
+
108
+ am_merged = np.maximum(am_up, am_d)
109
+ am_merged = np.clip(am_merged, 0, 1)
110
+
111
+ attn_up_3c = np.stack([am_merged]*3, axis=-1)
112
+ attn_up_ori_3c = np.stack([am_up]*3, axis=-1)
113
+
114
+ ori_out = ori * (1 - attn_up_ori_3c)
115
+ rem_out = removed * (1 - attn_up_ori_3c)
116
+
117
+ ori_pad, h0, w0 = pad_to_multiple(ori_out, multiple)
118
+ rem_pad, _, _ = pad_to_multiple(rem_out, multiple)
119
+
120
+ wave_rgb = wavelet_color_fix_np(ori_pad, rem_pad)
121
+ wave = crop_to_original(wave_rgb, h0, w0)
122
+ # fusion
123
+ fused = (wave * (1 - attn_up_3c) + removed * attn_up_3c).astype(np.uint8)
124
+ return fused
125
+
126
+
127
+ def resize_by_short_side(image, target_short=512, resample=Image.BICUBIC):
128
+ w, h = image.size
129
+ if w < h:
130
+ new_w = target_short
131
+ new_h = int(h * target_short / w)
132
+ new_h = (new_h + 15) // 16 * 16
133
+ else:
134
+ new_h = target_short
135
+ new_w = int(w * target_short / h)
136
+ new_w = (new_w + 15) // 16 * 16
137
+ return image.resize((new_w, new_h), resample=resample)
138
+
139
+ # convert points input to prompt state
140
+ def get_prompt(click_state, click_input):
141
+ inputs = json.loads(click_input)
142
+ points = click_state[0]
143
+ labels = click_state[1]
144
+ for input in inputs:
145
+ points.append(input[:2])
146
+ labels.append(input[2])
147
+ click_state[0] = points
148
+ click_state[1] = labels
149
+ prompt = {
150
+ "prompt_type":["click"],
151
+ "input_point":click_state[0],
152
+ "input_label":click_state[1],
153
+ "multimask_output":"True",
154
+ }
155
+ return prompt
156
+
157
+ # use sam to get the mask
158
+ @spaces.GPU
159
+ def sam_refine(image_state, point_prompt, click_state, evt:gr.SelectData):
160
+ if point_prompt == "Positive":
161
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
162
+ else:
163
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
164
+
165
+ # prompt for sam model
166
+ model.samcontroler.sam_controler.reset_image()
167
+ model.samcontroler.sam_controler.set_image(image_state["origin_image"])
168
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
169
+
170
+ mask, logit, painted_image = model.first_frame_click(
171
+ image=image_state["origin_image"],
172
+ points=np.array(prompt["input_point"]),
173
+ labels=np.array(prompt["input_label"]),
174
+ multimask=prompt["multimask_output"],
175
+ )
176
+ image_state["mask"] = mask
177
+ image_state["logit"] = logit
178
+ image_state["painted_image"] = painted_image
179
+
180
+ return painted_image, image_state, click_state
181
+
182
+
183
+ def add_multi_mask(image_state, interactive_state, mask_dropdown):
184
+ mask = image_state["mask"]
185
+ interactive_state["masks"].append(mask)
186
+ interactive_state["mask_names"].append("mask_{:03d}".format(len(interactive_state["masks"])))
187
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["masks"])))
188
+ select_frame = show_mask(image_state, interactive_state, mask_dropdown)
189
+
190
+ return interactive_state, gr.update(choices=interactive_state["mask_names"], value=mask_dropdown), select_frame, [[],[]]
191
+
192
+ def clear_click(image_state, click_state):
193
+ click_state = [[],[]]
194
+ input_image = image_state["origin_image"]
195
+ return input_image, click_state
196
+
197
+ def remove_multi_mask(interactive_state, click_state, image_state):
198
+ interactive_state["mask_names"]= []
199
+ interactive_state["masks"] = []
200
+ click_state = [[],[]]
201
+ input_image = image_state["origin_image"]
202
+
203
+ return interactive_state, gr.update(choices=[],value=[]), input_image, click_state
204
+
205
+ def show_mask(image_state, interactive_state, mask_dropdown):
206
+ mask_dropdown.sort()
207
+ if image_state["origin_image"] is not None:
208
+ select_frame = image_state["origin_image"]
209
+ for i in range(len(mask_dropdown)):
210
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
211
+ mask = interactive_state["masks"][mask_number]
212
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
213
+
214
+ return select_frame
215
+
216
+ @spaces.GPU
217
+ def upload_and_reset(image_input, interactive_state):
218
+ click_state = [[], []]
219
+
220
+ interactive_state["mask_names"]= []
221
+ interactive_state["masks"] = []
222
+
223
+ image_state, image_info, image_input = update_image_state_on_upload(image_input)
224
+
225
+ return (
226
+ image_state,
227
+ image_info,
228
+ image_input,
229
+ interactive_state,
230
+ click_state,
231
+ gr.update(choices=[], value=[]),
232
+ )
233
+
234
+ def update_image_state_on_upload(image_input):
235
+ frame = image_input
236
+
237
+ image_size = (frame.size[1], frame.size[0])
238
+
239
+ frame_np = np.array(frame)
240
+
241
+ image_state = {
242
+ "origin_image": frame_np,
243
+ "painted_image": frame_np.copy(),
244
+ "mask": np.zeros((image_size[0], image_size[1]), np.uint8),
245
+ "logit": None,
246
+ }
247
+
248
+ image_info = f"Image Name: uploaded.png,\nImage Size: {image_size}"
249
+
250
+ model.samcontroler.sam_controler.reset_image()
251
+ model.samcontroler.sam_controler.set_image(frame_np)
252
+
253
+ return image_state, image_info, image_input
254
+
255
+
256
+
257
+ # SAM generator
258
+ class MaskGenerator():
259
+ def __init__(self, sam_checkpoint, args):
260
+ self.args = args
261
+ self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
262
+
263
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
264
+ mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
265
+ return mask, logit, painted_image
266
+
267
+
268
+ # args, defined in track_anything.py
269
+ args = parse_augment()
270
+ sam_checkpoint_url_dict = {
271
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
272
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
273
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
274
+ }
275
+ checkpoint_folder = os.path.join('/home/user/app/', 'pretrained_models')
276
+
277
+ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder)
278
+ # initialize sams
279
+ model = MaskGenerator(sam_checkpoint, args)
280
+
281
+ # Build pipeline
282
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
283
+ pipe = ObjectClearPipeline.from_pretrained_with_custom_modules(
284
+ "jixin0101/ObjectClear",
285
+ torch_dtype=torch.float16,
286
+ save_cross_attn=True,
287
+ cache_dir="/home/jovyan/shared/jixinzhao/models",
288
+ )
289
+
290
+ pipe.to(device)
291
+
292
+ @spaces.GPU
293
+ def process(image_state, interactive_state, mask_dropdown, guidance_scale, seed, num_inference_steps, strength
294
+ ):
295
+ generator = torch.Generator(device="cuda").manual_seed(seed)
296
+ image_np = image_state["origin_image"]
297
+ image = Image.fromarray(image_np)
298
+ if interactive_state["masks"]:
299
+ if len(mask_dropdown) == 0:
300
+ mask_dropdown = ["mask_001"]
301
+ mask_dropdown.sort()
302
+ template_mask = interactive_state["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
303
+ for i in range(1,len(mask_dropdown)):
304
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
305
+ template_mask = np.clip(template_mask+interactive_state["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
306
+ image_state["mask"]= template_mask
307
+ else:
308
+ template_mask = image_state["mask"]
309
+ mask = Image.fromarray((template_mask).astype(np.uint8) * 255)
310
+ image_or = image.copy()
311
+
312
+ image = image.convert("RGB")
313
+ mask = mask.convert("RGB")
314
+
315
+ image = resize_by_short_side(image, 512, resample=Image.BICUBIC)
316
+ mask = resize_by_short_side(mask, 512, resample=Image.NEAREST)
317
+
318
+ w, h = image.size
319
+
320
+ result = pipe(
321
+ prompt="remove the instance of object",
322
+ image=image,
323
+ mask_image=mask,
324
+ generator=generator,
325
+ num_inference_steps=num_inference_steps,
326
+ strength=strength,
327
+ guidance_scale=guidance_scale,
328
+ height=h,
329
+ width=w,
330
+ )
331
+
332
+ inpainted_img = result[0].images[0]
333
+ attn_map = result[1]
334
+ attn_np = attn_map.mean(dim=1)[0].cpu().numpy() * 255.
335
+
336
+ fused_img = fuse_with_wavelet(np.array(image), np.array(inpainted_img), attn_np)
337
+ fused_img_pil = Image.fromarray(fused_img.astype(np.uint8))
338
+
339
+ return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2])))
340
+
341
+ import base64
342
+ with open("./Logo.png", "rb") as f:
343
+ img_bytes = f.read()
344
+ img_b64 = base64.b64encode(img_bytes).decode()
345
+
346
+ html_img = f'''
347
+ <div style="display:flex; justify-content:center; align-items:center; width:100%;">
348
+ <img src="data:image/png;base64,{img_b64}" style="border:none; width:200px; height:auto;"/>
349
+ </div>
350
+ '''
351
+
352
+ tutorial_url = "https://github.com/zjx0101/ObjectClear/releases/download/media/tutorial.mp4"
353
+ assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
354
+ load_file_from_url(tutorial_url, assets_path)
355
+
356
+ description = r"""
357
+ <b>Official Gradio demo</b> for <a href='https://github.com/zjx0101/ObjectClear' target='_blank'><b>ObjectClear: Complete Object Removal via Object-Effect Attention</b></a>.<br>
358
+ 🔥 ObjectClear is an object removal model that can jointly eliminate the target object and its associated effects leveraging Object-Effect Attention, while preserving background consistency.<br>
359
+ 🖼️ Try to drop your image, assign the target masks with a few clicks, and get the object removal results!<br>
360
+
361
+ *Note: Due to online GPU memory constraints, all input images will be resized during inference so that the shortest side is 512 pixels.<br>*
362
+ """
363
+
364
+ article = r"""<h3>
365
+ <b>If ObjectClear is helpful, please help to star the <a href='https://github.com/zjx0101/ObjectClear' target='_blank'>Github Repo</a>. Thanks!</b></h3>
366
+ <hr>
367
+
368
+ 📑 **Citation**
369
+ <br>
370
+ If our work is useful for your research, please consider citing:
371
+ ```bibtex
372
+ @InProceedings{zhao2025ObjectClear,
373
+ title = {{ObjectClear}: Complete Object Removal via Object-Effect Attention},
374
+ author = {Zhao, Jixin and Zhou, Shangchen and Wang, Zhouxia and Yang, Peiqing and Loy, Chen Change},
375
+ booktitle = {arXiv preprint arXiv:2505.22636},
376
+ year = {2025}
377
+ }
378
+ ```
379
+ 📧 **Contact**
380
+ <br>
381
+ If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
382
+ <br>
383
+ 👏 **Acknowledgement**
384
+ <br>
385
+ This demo is adapted from [MatAnyone](https://github.com/pq-yang/MatAnyone), and leveraging segmentation capabilities from [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks for their awesome works!
386
+ """
387
+
388
+ custom_css = """
389
+ #input-image {
390
+ aspect-ratio: 1 / 1;
391
+ width: 100%;
392
+ max-width: 100%;
393
+ height: auto;
394
+ display: flex;
395
+ align-items: center;
396
+ justify-content: center;
397
+ }
398
+
399
+ #input-image img {
400
+ max-width: 100%;
401
+ max-height: 100%;
402
+ object-fit: contain;
403
+ display: block;
404
+ }
405
+
406
+ #main-columns {
407
+ gap: 60px;
408
+ }
409
+
410
+ #main-columns > .gr-column {
411
+ flex: 1;
412
+ }
413
+
414
+ #compare-image {
415
+ width: 100%;
416
+ aspect-ratio: 1 / 1;
417
+ display: flex;
418
+ align-items: center;
419
+ justify-content: center;
420
+ margin: 0;
421
+ padding: 0;
422
+ max-width: 100%;
423
+ box-sizing: border-box;
424
+ }
425
+
426
+ #compare-image svg.svelte-zyxd38 {
427
+ position: absolute !important;
428
+ top: 50% !important;
429
+ left: 50% !important;
430
+ transform: translate(-50%, -50%) !important;
431
+ }
432
+
433
+ #compare-image .icon.svelte-1oiin9d {
434
+ position: absolute;
435
+ top: 50%;
436
+ left: 50%;
437
+ transform: translate(-50%, -50%);
438
+ }
439
+
440
+ #compare-image {
441
+ position: relative;
442
+ overflow: hidden;
443
+ }
444
+
445
+ .new_button {background-color: #171717 !important; color: #ffffff !important; border: none !important;}
446
+ .new_button:hover {background-color: #4b4b4b !important;}
447
+
448
+ #start-button {
449
+ background: linear-gradient(135deg, #2575fc 0%, #6a11cb 100%);
450
+ color: white;
451
+ border: none;
452
+ padding: 12px 24px;
453
+ font-size: 16px;
454
+ font-weight: bold;
455
+ border-radius: 12px;
456
+ cursor: pointer;
457
+ box-shadow: 0 0 12px rgba(100, 100, 255, 0.7);
458
+ transition: all 0.3s ease;
459
+ }
460
+ #start-button:hover {
461
+ transform: scale(1.05);
462
+ box-shadow: 0 0 20px rgba(100, 100, 255, 1);
463
+ }
464
+
465
+ <style>
466
+ .button-wrapper {
467
+ width: 30%;
468
+ text-align: center;
469
+ }
470
+ .wide-button {
471
+ width: 83% !important;
472
+ background-color: black !important;
473
+ color: white !important;
474
+ border: none !important;
475
+ padding: 8px 0 !important;
476
+ font-size: 16px !important;
477
+ display: inline-block;
478
+ margin: 30px 0px 0px 50px ;
479
+ }
480
+ .wide-button:hover {
481
+ background-color: #656262 !important;
482
+ }
483
+ </style>
484
+ """
485
+
486
+
487
+ with gr.Blocks(css=custom_css) as demo:
488
+ gr.HTML(html_img)
489
+ gr.Markdown(description)
490
+ with gr.Group(elem_classes="gr-monochrome-group", visible=True):
491
+ with gr.Row():
492
+ with gr.Accordion('SAM Settings (click to expand)', open=False):
493
+ with gr.Row():
494
+ point_prompt = gr.Radio(
495
+ choices=["Positive", "Negative"],
496
+ value="Positive",
497
+ label="Point Prompt",
498
+ info="Click to add positive or negative point for target mask",
499
+ interactive=True,
500
+ min_width=100,
501
+ scale=1)
502
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2")
503
+
504
+ with gr.Row(elem_id="main-columns"):
505
+ with gr.Column():
506
+
507
+ click_state = gr.State([[],[]])
508
+
509
+ interactive_state = gr.State(
510
+ {
511
+ "mask_names": [],
512
+ "masks": []
513
+ }
514
+ )
515
+
516
+ image_state = gr.State(
517
+ {
518
+ "origin_image": None,
519
+ "painted_image": None,
520
+ "mask": None,
521
+ "logit": None
522
+ }
523
+ )
524
+
525
+ image_info = gr.Textbox(label="Image Info", visible=False)
526
+ input_image = gr.Image(
527
+ label='Input',
528
+ type='pil',
529
+ sources=["upload"],
530
+ image_mode='RGB',
531
+ interactive=True,
532
+ elem_id="input-image"
533
+ )
534
+
535
+ with gr.Row(equal_height=True, elem_classes="mask_button_group"):
536
+ clear_button_click = gr.Button(value="Clear Clicks",elem_classes="new_button", min_width=100)
537
+ add_mask_button = gr.Button(value="Add Mask", elem_classes="new_button", min_width=100)
538
+ remove_mask_button = gr.Button(value="Delete Mask", elem_classes="new_button", min_width=100)
539
+
540
+ submit_button_component = gr.Button(
541
+ value='Start ObjectClear', elem_id="start-button"
542
+ )
543
+
544
+ with gr.Accordion('ObjectClear Settings', open=True):
545
+ strength = gr.Radio(
546
+ choices=[0.99, 1.0],
547
+ value=0.99,
548
+ label="Strength",
549
+ info="0.99 better preserves the background and color; use 1.0 if object/shadow is not fully removed (default: 0.99)"
550
+ )
551
+
552
+ guidance_scale = gr.Slider(
553
+ minimum=1, maximum=10, step=0.5, value=2.5,
554
+ label="Guidance Scale",
555
+ info="Higher = stronger removal; lower = better background preservation (default: 2.5)"
556
+ )
557
+
558
+ seed = gr.Slider(
559
+ minimum=0, maximum=1000000, step=1, value=300000,
560
+ label="Seed Value",
561
+ info="Different seeds can lead to noticeably different object removal results (default: 300000)"
562
+ )
563
+
564
+ num_inference_steps = gr.Slider(
565
+ minimum=1, maximum=40, step=1, value=20,
566
+ label="Num Inference Steps",
567
+ info="Higher values may improve quality but take longer (default: 20)"
568
+ )
569
+
570
+
571
+ with gr.Column():
572
+ output_image_component = gr.Image(
573
+ type='pil', image_mode='RGB', label='Output', format="png", elem_id="input-image")
574
+
575
+ output_compare_image_component = gr.ImageSlider(
576
+ label="Comparison",
577
+ type="pil",
578
+ format='png',
579
+ elem_id="compare-image"
580
+ )
581
+
582
+ input_image.upload(
583
+ fn=upload_and_reset,
584
+ inputs=[input_image, interactive_state],
585
+ outputs=[
586
+ image_state,
587
+ image_info,
588
+ input_image,
589
+ interactive_state,
590
+ click_state,
591
+ mask_dropdown,
592
+ ]
593
+ )
594
+
595
+ # click select image to get mask using sam
596
+ input_image.select(
597
+ fn=sam_refine,
598
+ inputs=[image_state, point_prompt, click_state],
599
+ outputs=[input_image, image_state, click_state]
600
+ )
601
+
602
+ # add different mask
603
+ add_mask_button.click(
604
+ fn=add_multi_mask,
605
+ inputs=[image_state, interactive_state, mask_dropdown],
606
+ outputs=[interactive_state, mask_dropdown, input_image, click_state]
607
+ )
608
+
609
+ remove_mask_button.click(
610
+ fn=remove_multi_mask,
611
+ inputs=[interactive_state, click_state, image_state],
612
+ outputs=[interactive_state, mask_dropdown, input_image, click_state]
613
+ )
614
+
615
+ # points clear
616
+ clear_button_click.click(
617
+ fn = clear_click,
618
+ inputs = [image_state, click_state,],
619
+ outputs = [input_image, click_state],
620
+ )
621
+
622
+ submit_button_component.click(
623
+ fn=process,
624
+ inputs=[
625
+ image_state,
626
+ interactive_state,
627
+ mask_dropdown,
628
+ guidance_scale,
629
+ seed,
630
+ num_inference_steps,
631
+ strength
632
+ ],
633
+ outputs=[
634
+ output_image_component, output_compare_image_component
635
+ ]
636
+ )
637
+
638
+ with gr.Accordion("📕 Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"):
639
+ with gr.Row():
640
+ gr.Video(value="/home/user/app/hugging_face/assets/tutorial.mp4", elem_classes="video")
641
+
642
+ gr.Markdown("---")
643
+ gr.Markdown("## Examples")
644
+
645
+ example_images = [
646
+ os.path.join(os.path.dirname(__file__), "examples", f"test{i}.png")
647
+ for i in range(10)
648
+ ]
649
+
650
+ examples_data = [
651
+ [example_images[i], None] for i in range(len(example_images))
652
+ ]
653
+
654
+ examples = gr.Examples(
655
+ examples=examples_data,
656
+ inputs=[input_image, interactive_state],
657
+ outputs=[image_state, image_info, input_image,
658
+ interactive_state, click_state, mask_dropdown],
659
+ fn=upload_and_reset,
660
+ run_on_click=True,
661
+ cache_examples=False,
662
+ label="Click below to load example images"
663
+ )
664
+
665
+ gr.Markdown(article)
666
+
667
+ def pre_update_input_image():
668
+ return gr.update(value=None)
669
+
670
+ demo.load(
671
+ fn=pre_update_input_image,
672
+ inputs=[],
673
+ outputs=[input_image]
674
+ )
675
+
676
+
677
+ demo.launch(debug=True, show_error=True)
examples/test0.png ADDED

Git LFS Details

  • SHA256: 66cb4a2ef645cdb1e2e9c68892b6e94c38211673f97c9eaa09c6f1998788cee4
  • Pointer size: 131 Bytes
  • Size of remote file: 724 kB
examples/test1.png ADDED

Git LFS Details

  • SHA256: 097677dbe298b5b20f580ed7f42684b5ea2d8a2b011c567a88a346897ddd2b1a
  • Pointer size: 131 Bytes
  • Size of remote file: 617 kB
examples/test2.png ADDED

Git LFS Details

  • SHA256: 738793c28578dd0acf7fcf1111d58b307045984a1e5dbdedc65f6ce1644f11dc
  • Pointer size: 131 Bytes
  • Size of remote file: 467 kB
examples/test3.png ADDED

Git LFS Details

  • SHA256: b5dbd3dccc28294bcdf719b7b5a1e098f46d0e74cf5fbdc05ee3419f3d9ffd2c
  • Pointer size: 131 Bytes
  • Size of remote file: 817 kB
examples/test4.png ADDED

Git LFS Details

  • SHA256: 348a0175866d26b31b4035fb9864efe5578a483b66b6e951879756b5c04c7190
  • Pointer size: 131 Bytes
  • Size of remote file: 602 kB
examples/test5.png ADDED

Git LFS Details

  • SHA256: 59e3857564e9aafd6d1d3aceb3944da0a2e94682e47460315934dfdf623ce758
  • Pointer size: 131 Bytes
  • Size of remote file: 522 kB
examples/test6.png ADDED

Git LFS Details

  • SHA256: 0ac8ba5cfe64e48caa296640a83980f3b4e177432bc3e3512e715c736895100f
  • Pointer size: 131 Bytes
  • Size of remote file: 548 kB
examples/test7.png ADDED

Git LFS Details

  • SHA256: 3dac86ced73d33ae75143978112757695afd763a134bb9b7bde344fe22d46897
  • Pointer size: 131 Bytes
  • Size of remote file: 570 kB
examples/test8.png ADDED

Git LFS Details

  • SHA256: dfbe01aa72c61b03ae5c9f57bdb31abd95460ec20b9dc03260847b0cc668ec85
  • Pointer size: 131 Bytes
  • Size of remote file: 291 kB
examples/test9.png ADDED

Git LFS Details

  • SHA256: ba42d15d3c7a4419bc0dad7896edacec91d80c8dff8dc35a0ea74251957eb5e1
  • Pointer size: 131 Bytes
  • Size of remote file: 849 kB
model.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as T
5
+ from transformers.models.clip.modeling_clip import (
6
+ CLIPTextTransformer,
7
+ CLIPPreTrainedModel,
8
+ CLIPModel,
9
+ )
10
+
11
+
12
+ class CLIPImageEncoder(CLIPPreTrainedModel):
13
+ @staticmethod
14
+ def from_pretrained(
15
+ global_model_name_or_path,
16
+ cache_dir
17
+ ):
18
+ model = CLIPModel.from_pretrained(
19
+ global_model_name_or_path,
20
+ subfolder="image_prompt_encoder",
21
+ cache_dir=cache_dir
22
+ )
23
+ vision_model = model.vision_model
24
+ visual_projection = model.visual_projection
25
+ vision_processor = T.Normalize(
26
+ (0.48145466, 0.4578275, 0.40821073),
27
+ (0.26862954, 0.26130258, 0.27577711),
28
+ )
29
+ return CLIPImageEncoder(
30
+ vision_model,
31
+ visual_projection,
32
+ vision_processor,
33
+ )
34
+
35
+ def __init__(
36
+ self,
37
+ vision_model,
38
+ visual_projection,
39
+ vision_processor,
40
+ ):
41
+ super().__init__(vision_model.config)
42
+ self.vision_model = vision_model
43
+ self.visual_projection = visual_projection
44
+ self.vision_processor = vision_processor
45
+
46
+ self.image_size = vision_model.config.image_size
47
+
48
+ def forward(self, object_pixel_values):
49
+ b, c, h, w = object_pixel_values.shape
50
+
51
+ if h != self.image_size or w != self.image_size:
52
+ h, w = self.image_size, self.image_size
53
+ object_pixel_values = F.interpolate(
54
+ object_pixel_values, (h, w), mode="bilinear", antialias=True
55
+ )
56
+
57
+ object_pixel_values = self.vision_processor(object_pixel_values)
58
+ object_embeds = self.vision_model(object_pixel_values)[1]
59
+ object_embeds = self.visual_projection(object_embeds)
60
+ object_embeds = object_embeds.view(b, 1, -1)
61
+ return object_embeds
62
+
63
+
64
+ class MLP(nn.Module):
65
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
66
+ super().__init__()
67
+ if use_residual:
68
+ assert in_dim == out_dim
69
+ self.layernorm = nn.LayerNorm(in_dim)
70
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
71
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
72
+ self.use_residual = use_residual
73
+ self.act_fn = nn.GELU()
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+ x = self.layernorm(x)
78
+ x = self.fc1(x)
79
+ x = self.act_fn(x)
80
+ x = self.fc2(x)
81
+ if self.use_residual:
82
+ x = x + residual
83
+ return x
84
+
85
+ class PostfuseModule(nn.Module):
86
+ def __init__(self, embed_dim, embed_dim_img):
87
+ super().__init__()
88
+ self.mlp1 = MLP(embed_dim_img, embed_dim, embed_dim, use_residual=False)
89
+ self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
90
+ self.layer_norm = nn.LayerNorm(embed_dim)
91
+
92
+ @property
93
+ def dtype(self):
94
+ try:
95
+ return next(self.parameters()).dtype
96
+ except StopIteration:
97
+ return torch.float32
98
+
99
+ def fuse_fn(self, object_embeds):
100
+ text_object_embeds = self.mlp1(object_embeds)
101
+ text_object_embeds = self.mlp2(text_object_embeds)
102
+ text_object_embeds = self.layer_norm(text_object_embeds)
103
+ return text_object_embeds
104
+
105
+ def forward(
106
+ self,
107
+ text_embeds,
108
+ object_embeds,
109
+ fuse_index,
110
+ ) -> torch.Tensor:
111
+ text_object_embed = self.fuse_fn(object_embeds)
112
+ text_embeds_new = text_embeds.clone()
113
+ text_embeds_new[:, fuse_index, :] = text_object_embed.squeeze(1)
114
+
115
+ return text_embeds_new
pipeline_objectclear.py ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.0
2
+ torchvision
3
+ numpy==1.26.4
4
+ opencv-python
5
+ pillow
6
+ transformers
7
+ scipy
8
+ diffusers
9
+ segment-anything
10
+ matplotlib
tools/__init__.py ADDED
File without changes
tools/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (168 Bytes). View file
 
tools/__pycache__/base_segmenter.cpython-310.pyc ADDED
Binary file (4.11 kB). View file
 
tools/__pycache__/download_util.cpython-310.pyc ADDED
Binary file (3.5 kB). View file
 
tools/__pycache__/interact_tools.cpython-310.pyc ADDED
Binary file (2.49 kB). View file
 
tools/__pycache__/mask_painter.cpython-310.pyc ADDED
Binary file (6.52 kB). View file
 
tools/__pycache__/misc.cpython-310.pyc ADDED
Binary file (4.34 kB). View file
 
tools/__pycache__/painter.cpython-310.pyc ADDED
Binary file (4.81 kB). View file
 
tools/base_segmenter.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter
11
+
12
+
13
+ class BaseSegmenter:
14
+ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
15
+ """
16
+ device: model device
17
+ SAM_checkpoint: path of SAM checkpoint
18
+ model_type: vit_b, vit_l, vit_h
19
+ """
20
+ print(f"Initializing BaseSegmenter to {device}")
21
+ assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
22
+
23
+ self.device = device
24
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
25
+ self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
26
+ self.model.to(device=self.device)
27
+ self.predictor = SamPredictor(self.model)
28
+ self.embedded = False
29
+
30
+ @torch.no_grad()
31
+ def set_image(self, image: np.ndarray):
32
+ # PIL.open(image_path) 3channel: RGB
33
+ # image embedding: avoid encode the same image multiple times
34
+ self.orignal_image = image
35
+ if self.embedded:
36
+ print('repeat embedding, please reset_image.')
37
+ return
38
+ self.predictor.set_image(image)
39
+ self.embedded = True
40
+ return
41
+
42
+ @torch.no_grad()
43
+ def reset_image(self):
44
+ # reset image embeding
45
+ self.predictor.reset_image()
46
+ self.embedded = False
47
+
48
+ def predict(self, prompts, mode, multimask=True):
49
+ """
50
+ image: numpy array, h, w, 3
51
+ prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
52
+ prompts['point_coords']: numpy array [N,2]
53
+ prompts['point_labels']: numpy array [1,N]
54
+ prompts['mask_input']: numpy array [1,256,256]
55
+ mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
56
+ mask_outputs: True (return 3 masks), False (return 1 mask only)
57
+ whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
58
+ """
59
+ assert self.embedded, 'prediction is called before set_image (feature embedding).'
60
+ assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
61
+
62
+ if mode == 'point':
63
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
64
+ point_labels=prompts['point_labels'],
65
+ multimask_output=multimask)
66
+ elif mode == 'mask':
67
+ masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
68
+ multimask_output=multimask)
69
+ elif mode == 'both': # both
70
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
71
+ point_labels=prompts['point_labels'],
72
+ mask_input=prompts['mask_input'],
73
+ multimask_output=multimask)
74
+ else:
75
+ raise("Not implement now!")
76
+ # masks (n, h, w), scores (n,), logits (n, 256, 256)
77
+ return masks, scores, logits
78
+
79
+
80
+ if __name__ == "__main__":
81
+ # load and show an image
82
+ image = cv2.imread('/hhd3/gaoshang/truck.jpg')
83
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
84
+
85
+ # initialise BaseSegmenter
86
+ SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
87
+ model_type = 'vit_h'
88
+ device = "cuda:4"
89
+ base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
90
+
91
+ # image embedding (once embedded, multiple prompts can be applied)
92
+ base_segmenter.set_image(image)
93
+
94
+ # examples
95
+ # point only ------------------------
96
+ mode = 'point'
97
+ prompts = {
98
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
99
+ 'point_labels': np.array([1, 1]),
100
+ }
101
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
102
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
103
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
104
+ cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
105
+
106
+ # both ------------------------
107
+ mode = 'both'
108
+ mask_input = logits[np.argmax(scores), :, :]
109
+ prompts = {'mask_input': mask_input [None, :, :]}
110
+ prompts = {
111
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
112
+ 'point_labels': np.array([1, 0]),
113
+ 'mask_input': mask_input[None, :, :]
114
+ }
115
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
116
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
117
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
118
+ cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
119
+
120
+ # mask only ------------------------
121
+ mode = 'mask'
122
+ mask_input = logits[np.argmax(scores), :, :]
123
+
124
+ prompts = {'mask_input': mask_input[None, :, :]}
125
+
126
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
127
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
128
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
129
+ cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
tools/download_util.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import requests
4
+ from torch.hub import download_url_to_file, get_dir
5
+ from tqdm import tqdm
6
+ from urllib.parse import urlparse
7
+
8
+ def sizeof_fmt(size, suffix='B'):
9
+ """Get human readable file size.
10
+
11
+ Args:
12
+ size (int): File size.
13
+ suffix (str): Suffix. Default: 'B'.
14
+
15
+ Return:
16
+ str: Formated file siz.
17
+ """
18
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
19
+ if abs(size) < 1024.0:
20
+ return f'{size:3.1f} {unit}{suffix}'
21
+ size /= 1024.0
22
+ return f'{size:3.1f} Y{suffix}'
23
+
24
+
25
+ def download_file_from_google_drive(file_id, save_path):
26
+ """Download files from google drive.
27
+ Ref:
28
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
29
+ Args:
30
+ file_id (str): File id.
31
+ save_path (str): Save path.
32
+ """
33
+
34
+ session = requests.Session()
35
+ URL = 'https://docs.google.com/uc?export=download'
36
+ params = {'id': file_id}
37
+
38
+ response = session.get(URL, params=params, stream=True)
39
+ token = get_confirm_token(response)
40
+ if token:
41
+ params['confirm'] = token
42
+ response = session.get(URL, params=params, stream=True)
43
+
44
+ # get file size
45
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
46
+ print(response_file_size)
47
+ if 'Content-Range' in response_file_size.headers:
48
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
49
+ else:
50
+ file_size = None
51
+
52
+ save_response_content(response, save_path, file_size)
53
+
54
+
55
+ def get_confirm_token(response):
56
+ for key, value in response.cookies.items():
57
+ if key.startswith('download_warning'):
58
+ return value
59
+ return None
60
+
61
+
62
+ def save_response_content(response, destination, file_size=None, chunk_size=32768):
63
+ if file_size is not None:
64
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
65
+
66
+ readable_file_size = sizeof_fmt(file_size)
67
+ else:
68
+ pbar = None
69
+
70
+ with open(destination, 'wb') as f:
71
+ downloaded_size = 0
72
+ for chunk in response.iter_content(chunk_size):
73
+ downloaded_size += chunk_size
74
+ if pbar is not None:
75
+ pbar.update(1)
76
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
77
+ if chunk: # filter out keep-alive new chunks
78
+ f.write(chunk)
79
+ if pbar is not None:
80
+ pbar.close()
81
+
82
+
83
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
84
+ """Load file form http url, will download models if necessary.
85
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
86
+ Args:
87
+ url (str): URL to be downloaded.
88
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
89
+ Default: None.
90
+ progress (bool): Whether to show the download progress. Default: True.
91
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
92
+ Returns:
93
+ str: The path to the downloaded file.
94
+ """
95
+ if model_dir is None: # use the pytorch hub_dir
96
+ hub_dir = get_dir()
97
+ model_dir = os.path.join(hub_dir, 'checkpoints')
98
+
99
+ os.makedirs(model_dir, exist_ok=True)
100
+
101
+ parts = urlparse(url)
102
+ filename = os.path.basename(parts.path)
103
+ if file_name is not None:
104
+ filename = file_name
105
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
106
+ if not os.path.exists(cached_file):
107
+ print(f'Downloading: "{url}" to {cached_file}\n')
108
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
109
+ return cached_file
tools/interact_tools.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter as mask_painter2
11
+ from .base_segmenter import BaseSegmenter
12
+ from .painter import mask_painter, point_painter
13
+ import os
14
+ import requests
15
+ import sys
16
+
17
+
18
+ mask_color = 3
19
+ mask_alpha = 0.7
20
+ contour_color = 1
21
+ contour_width = 5
22
+ point_color_ne = 8
23
+ point_color_ps = 50
24
+ point_alpha = 0.9
25
+ point_radius = 15
26
+ contour_color = 2
27
+ contour_width = 5
28
+
29
+
30
+ class SamControler():
31
+ def __init__(self, SAM_checkpoint, model_type, device):
32
+ '''
33
+ initialize sam controler
34
+ '''
35
+ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
36
+
37
+
38
+ # def seg_again(self, image: np.ndarray):
39
+ # '''
40
+ # it is used when interact in video
41
+ # '''
42
+ # self.sam_controler.reset_image()
43
+ # self.sam_controler.set_image(image)
44
+ # return
45
+
46
+
47
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
48
+ '''
49
+ it is used in first frame in video
50
+ return: mask, logit, painted image(mask+point)
51
+ '''
52
+ # self.sam_controler.set_image(image)
53
+ origal_image = self.sam_controler.orignal_image
54
+ neg_flag = labels[-1]
55
+ if neg_flag==1:
56
+ #find neg
57
+ prompts = {
58
+ 'point_coords': points,
59
+ 'point_labels': labels,
60
+ }
61
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
62
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
63
+ prompts = {
64
+ 'point_coords': points,
65
+ 'point_labels': labels,
66
+ 'mask_input': logit[None, :, :]
67
+ }
68
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
69
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
70
+ else:
71
+ #find positive
72
+ prompts = {
73
+ 'point_coords': points,
74
+ 'point_labels': labels,
75
+ }
76
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
77
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
78
+
79
+
80
+ assert len(points)==len(labels)
81
+
82
+ painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
83
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
84
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
85
+ painted_image = Image.fromarray(painted_image)
86
+
87
+ return mask, logit, painted_image
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
tools/mask_painter.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import copy
6
+ import time
7
+
8
+
9
+ def colormap(rgb=True):
10
+ color_list = np.array(
11
+ [
12
+ 0.000, 0.000, 0.000,
13
+ 1.000, 1.000, 1.000,
14
+ 1.000, 0.498, 0.313,
15
+ 0.392, 0.581, 0.929,
16
+ 0.000, 0.447, 0.741,
17
+ 0.850, 0.325, 0.098,
18
+ 0.929, 0.694, 0.125,
19
+ 0.494, 0.184, 0.556,
20
+ 0.466, 0.674, 0.188,
21
+ 0.301, 0.745, 0.933,
22
+ 0.635, 0.078, 0.184,
23
+ 0.300, 0.300, 0.300,
24
+ 0.600, 0.600, 0.600,
25
+ 1.000, 0.000, 0.000,
26
+ 1.000, 0.500, 0.000,
27
+ 0.749, 0.749, 0.000,
28
+ 0.000, 1.000, 0.000,
29
+ 0.000, 0.000, 1.000,
30
+ 0.667, 0.000, 1.000,
31
+ 0.333, 0.333, 0.000,
32
+ 0.333, 0.667, 0.000,
33
+ 0.333, 1.000, 0.000,
34
+ 0.667, 0.333, 0.000,
35
+ 0.667, 0.667, 0.000,
36
+ 0.667, 1.000, 0.000,
37
+ 1.000, 0.333, 0.000,
38
+ 1.000, 0.667, 0.000,
39
+ 1.000, 1.000, 0.000,
40
+ 0.000, 0.333, 0.500,
41
+ 0.000, 0.667, 0.500,
42
+ 0.000, 1.000, 0.500,
43
+ 0.333, 0.000, 0.500,
44
+ 0.333, 0.333, 0.500,
45
+ 0.333, 0.667, 0.500,
46
+ 0.333, 1.000, 0.500,
47
+ 0.667, 0.000, 0.500,
48
+ 0.667, 0.333, 0.500,
49
+ 0.667, 0.667, 0.500,
50
+ 0.667, 1.000, 0.500,
51
+ 1.000, 0.000, 0.500,
52
+ 1.000, 0.333, 0.500,
53
+ 1.000, 0.667, 0.500,
54
+ 1.000, 1.000, 0.500,
55
+ 0.000, 0.333, 1.000,
56
+ 0.000, 0.667, 1.000,
57
+ 0.000, 1.000, 1.000,
58
+ 0.333, 0.000, 1.000,
59
+ 0.333, 0.333, 1.000,
60
+ 0.333, 0.667, 1.000,
61
+ 0.333, 1.000, 1.000,
62
+ 0.667, 0.000, 1.000,
63
+ 0.667, 0.333, 1.000,
64
+ 0.667, 0.667, 1.000,
65
+ 0.667, 1.000, 1.000,
66
+ 1.000, 0.000, 1.000,
67
+ 1.000, 0.333, 1.000,
68
+ 1.000, 0.667, 1.000,
69
+ 0.167, 0.000, 0.000,
70
+ 0.333, 0.000, 0.000,
71
+ 0.500, 0.000, 0.000,
72
+ 0.667, 0.000, 0.000,
73
+ 0.833, 0.000, 0.000,
74
+ 1.000, 0.000, 0.000,
75
+ 0.000, 0.167, 0.000,
76
+ 0.000, 0.333, 0.000,
77
+ 0.000, 0.500, 0.000,
78
+ 0.000, 0.667, 0.000,
79
+ 0.000, 0.833, 0.000,
80
+ 0.000, 1.000, 0.000,
81
+ 0.000, 0.000, 0.167,
82
+ 0.000, 0.000, 0.333,
83
+ 0.000, 0.000, 0.500,
84
+ 0.000, 0.000, 0.667,
85
+ 0.000, 0.000, 0.833,
86
+ 0.000, 0.000, 1.000,
87
+ 0.143, 0.143, 0.143,
88
+ 0.286, 0.286, 0.286,
89
+ 0.429, 0.429, 0.429,
90
+ 0.571, 0.571, 0.571,
91
+ 0.714, 0.714, 0.714,
92
+ 0.857, 0.857, 0.857
93
+ ]
94
+ ).astype(np.float32)
95
+ color_list = color_list.reshape((-1, 3)) * 255
96
+ if not rgb:
97
+ color_list = color_list[:, ::-1]
98
+ return color_list
99
+
100
+
101
+ color_list = colormap()
102
+ color_list = color_list.astype('uint8').tolist()
103
+
104
+
105
+ def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
106
+ background_color = np.array(background_color)
107
+ contour_color = np.array(contour_color)
108
+
109
+ # background_mask = 1 - background_mask
110
+ # contour_mask = 1 - contour_mask
111
+
112
+ for i in range(3):
113
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
114
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
115
+
116
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
117
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
118
+
119
+ return image.astype('uint8')
120
+
121
+
122
+ def mask_generator_00(mask, background_radius, contour_radius):
123
+ # no background width when '00'
124
+ # distance map
125
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
126
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
127
+ dist_map = dist_transform_fore - dist_transform_back
128
+ # ...:::!!!:::...
129
+ contour_radius += 2
130
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
131
+ contour_mask = contour_mask / np.max(contour_mask)
132
+ contour_mask[contour_mask>0.5] = 1.
133
+
134
+ return mask, contour_mask
135
+
136
+
137
+ def mask_generator_01(mask, background_radius, contour_radius):
138
+ # no background width when '00'
139
+ # distance map
140
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
141
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
142
+ dist_map = dist_transform_fore - dist_transform_back
143
+ # ...:::!!!:::...
144
+ contour_radius += 2
145
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
146
+ contour_mask = contour_mask / np.max(contour_mask)
147
+ return mask, contour_mask
148
+
149
+
150
+ def mask_generator_10(mask, background_radius, contour_radius):
151
+ # distance map
152
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
153
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
154
+ dist_map = dist_transform_fore - dist_transform_back
155
+ # .....:::::!!!!!
156
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
157
+ background_mask = (background_mask - np.min(background_mask))
158
+ background_mask = background_mask / np.max(background_mask)
159
+ # ...:::!!!:::...
160
+ contour_radius += 2
161
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
162
+ contour_mask = contour_mask / np.max(contour_mask)
163
+ contour_mask[contour_mask>0.5] = 1.
164
+ return background_mask, contour_mask
165
+
166
+
167
+ def mask_generator_11(mask, background_radius, contour_radius):
168
+ # distance map
169
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
170
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
171
+ dist_map = dist_transform_fore - dist_transform_back
172
+ # .....:::::!!!!!
173
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
174
+ background_mask = (background_mask - np.min(background_mask))
175
+ background_mask = background_mask / np.max(background_mask)
176
+ # ...:::!!!:::...
177
+ contour_radius += 2
178
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
179
+ contour_mask = contour_mask / np.max(contour_mask)
180
+ return background_mask, contour_mask
181
+
182
+
183
+ def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
184
+ """
185
+ Input:
186
+ input_image: numpy array
187
+ input_mask: numpy array
188
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
189
+ background_blur_radius: radius of background blur, must be odd number
190
+ contour_width: width of mask contour, must be odd number
191
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
192
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
193
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
194
+
195
+ Output:
196
+ painted_image: numpy array
197
+ """
198
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
199
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
200
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
201
+
202
+ # downsample input image and mask
203
+ width, height = input_image.shape[0], input_image.shape[1]
204
+ res = 1024
205
+ ratio = min(1.0 * res / max(width, height), 1.0)
206
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
207
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
208
+
209
+ # 0: background, 1: foreground
210
+ msk = np.clip(input_mask, 0, 1)
211
+
212
+ # generate masks for background and contour pixels
213
+ background_radius = (background_blur_radius - 1) // 2
214
+ contour_radius = (contour_width - 1) // 2
215
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
216
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
217
+
218
+ # paint
219
+ painted_image = vis_add_mask\
220
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
221
+
222
+ return painted_image
223
+
224
+
225
+ if __name__ == '__main__':
226
+
227
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
228
+ background_blur_radius = 31 # radius of background blur, must be odd number
229
+ contour_width = 11 # contour width, must be odd number
230
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
231
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
232
+
233
+ # load input image and mask
234
+ input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
235
+ input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
236
+
237
+ # paint
238
+ overall_time_1 = 0
239
+ overall_time_2 = 0
240
+ overall_time_3 = 0
241
+ overall_time_4 = 0
242
+ overall_time_5 = 0
243
+
244
+ for i in range(50):
245
+ t2 = time.time()
246
+ painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
247
+ e2 = time.time()
248
+
249
+ t3 = time.time()
250
+ painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
251
+ e3 = time.time()
252
+
253
+ t1 = time.time()
254
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
255
+ e1 = time.time()
256
+
257
+ t4 = time.time()
258
+ painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
259
+ e4 = time.time()
260
+
261
+ t5 = time.time()
262
+ painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
263
+ e5 = time.time()
264
+
265
+ overall_time_1 += (e1 - t1)
266
+ overall_time_2 += (e2 - t2)
267
+ overall_time_3 += (e3 - t3)
268
+ overall_time_4 += (e4 - t4)
269
+ overall_time_5 += (e5 - t5)
270
+
271
+ print(f'average time w gaussian: {overall_time_1/50}')
272
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
273
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
274
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
275
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
276
+
277
+ # save
278
+ painted_image_00 = Image.fromarray(painted_image_00)
279
+ painted_image_00.save('./test_img/painter_output_image_00.png')
280
+
281
+ painted_image_10 = Image.fromarray(painted_image_10)
282
+ painted_image_10.save('./test_img/painter_output_image_10.png')
283
+
284
+ painted_image_01 = Image.fromarray(painted_image_01)
285
+ painted_image_01.save('./test_img/painter_output_image_01.png')
286
+
287
+ painted_image_11 = Image.fromarray(painted_image_11)
288
+ painted_image_11.save('./test_img/painter_output_image_11.png')
tools/misc.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import random
4
+ import time
5
+ import torch
6
+ import torch.nn as nn
7
+ import logging
8
+ import numpy as np
9
+ from os import path as osp
10
+
11
+ def constant_init(module, val, bias=0):
12
+ if hasattr(module, 'weight') and module.weight is not None:
13
+ nn.init.constant_(module.weight, val)
14
+ if hasattr(module, 'bias') and module.bias is not None:
15
+ nn.init.constant_(module.bias, bias)
16
+
17
+ initialized_logger = {}
18
+ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
19
+ """Get the root logger.
20
+ The logger will be initialized if it has not been initialized. By default a
21
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
22
+ also be added.
23
+ Args:
24
+ logger_name (str): root logger name. Default: 'basicsr'.
25
+ log_file (str | None): The log filename. If specified, a FileHandler
26
+ will be added to the root logger.
27
+ log_level (int): The root logger level. Note that only the process of
28
+ rank 0 is affected, while other processes will set the level to
29
+ "Error" and be silent most of the time.
30
+ Returns:
31
+ logging.Logger: The root logger.
32
+ """
33
+ logger = logging.getLogger(logger_name)
34
+ # if the logger has been initialized, just return it
35
+ if logger_name in initialized_logger:
36
+ return logger
37
+
38
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
39
+ stream_handler = logging.StreamHandler()
40
+ stream_handler.setFormatter(logging.Formatter(format_str))
41
+ logger.addHandler(stream_handler)
42
+ logger.propagate = False
43
+
44
+ if log_file is not None:
45
+ logger.setLevel(log_level)
46
+ # add file handler
47
+ # file_handler = logging.FileHandler(log_file, 'w')
48
+ file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
49
+ file_handler.setFormatter(logging.Formatter(format_str))
50
+ file_handler.setLevel(log_level)
51
+ logger.addHandler(file_handler)
52
+ initialized_logger[logger_name] = True
53
+ return logger
54
+
55
+
56
+ IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
57
+ torch.__version__)[0][:3])] >= [1, 12, 0]
58
+
59
+ def gpu_is_available():
60
+ if IS_HIGH_VERSION:
61
+ if torch.backends.mps.is_available():
62
+ return True
63
+ return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
64
+
65
+ def get_device(gpu_id=None):
66
+ if gpu_id is None:
67
+ gpu_str = ''
68
+ elif isinstance(gpu_id, int):
69
+ gpu_str = f':{gpu_id}'
70
+ else:
71
+ raise TypeError('Input should be int value.')
72
+
73
+ if IS_HIGH_VERSION:
74
+ if torch.backends.mps.is_available():
75
+ return torch.device('mps'+gpu_str)
76
+ return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
77
+
78
+
79
+ def set_random_seed(seed):
80
+ """Set random seeds."""
81
+ random.seed(seed)
82
+ np.random.seed(seed)
83
+ torch.manual_seed(seed)
84
+ torch.cuda.manual_seed(seed)
85
+ torch.cuda.manual_seed_all(seed)
86
+
87
+
88
+ def get_time_str():
89
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
90
+
91
+
92
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
93
+ """Scan a directory to find the interested files.
94
+
95
+ Args:
96
+ dir_path (str): Path of the directory.
97
+ suffix (str | tuple(str), optional): File suffix that we are
98
+ interested in. Default: None.
99
+ recursive (bool, optional): If set to True, recursively scan the
100
+ directory. Default: False.
101
+ full_path (bool, optional): If set to True, include the dir_path.
102
+ Default: False.
103
+
104
+ Returns:
105
+ A generator for all the interested files with relative pathes.
106
+ """
107
+
108
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
109
+ raise TypeError('"suffix" must be a string or tuple of strings')
110
+
111
+ root = dir_path
112
+
113
+ def _scandir(dir_path, suffix, recursive):
114
+ for entry in os.scandir(dir_path):
115
+ if not entry.name.startswith('.') and entry.is_file():
116
+ if full_path:
117
+ return_path = entry.path
118
+ else:
119
+ return_path = osp.relpath(entry.path, root)
120
+
121
+ if suffix is None:
122
+ yield return_path
123
+ elif return_path.endswith(suffix):
124
+ yield return_path
125
+ else:
126
+ if recursive:
127
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
128
+ else:
129
+ continue
130
+
131
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
tools/painter.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # paint masks, contours, or points on images, with specified colors
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import copy
7
+ import time
8
+
9
+
10
+ def colormap(rgb=True):
11
+ color_list = np.array(
12
+ [
13
+ 0.000, 0.000, 0.000,
14
+ 1.000, 1.000, 1.000,
15
+ 1.000, 0.498, 0.313,
16
+ 0.392, 0.581, 0.929,
17
+ 0.000, 0.447, 0.741,
18
+ 0.850, 0.325, 0.098,
19
+ 0.929, 0.694, 0.125,
20
+ 0.494, 0.184, 0.556,
21
+ 0.466, 0.674, 0.188,
22
+ 0.301, 0.745, 0.933,
23
+ 0.635, 0.078, 0.184,
24
+ 0.300, 0.300, 0.300,
25
+ 0.600, 0.600, 0.600,
26
+ 1.000, 0.000, 0.000,
27
+ 1.000, 0.500, 0.000,
28
+ 0.749, 0.749, 0.000,
29
+ 0.000, 1.000, 0.000,
30
+ 0.000, 0.000, 1.000,
31
+ 0.667, 0.000, 1.000,
32
+ 0.333, 0.333, 0.000,
33
+ 0.333, 0.667, 0.000,
34
+ 0.333, 1.000, 0.000,
35
+ 0.667, 0.333, 0.000,
36
+ 0.667, 0.667, 0.000,
37
+ 0.667, 1.000, 0.000,
38
+ 1.000, 0.333, 0.000,
39
+ 1.000, 0.667, 0.000,
40
+ 1.000, 1.000, 0.000,
41
+ 0.000, 0.333, 0.500,
42
+ 0.000, 0.667, 0.500,
43
+ 0.000, 1.000, 0.500,
44
+ 0.333, 0.000, 0.500,
45
+ 0.333, 0.333, 0.500,
46
+ 0.333, 0.667, 0.500,
47
+ 0.333, 1.000, 0.500,
48
+ 0.667, 0.000, 0.500,
49
+ 0.667, 0.333, 0.500,
50
+ 0.667, 0.667, 0.500,
51
+ 0.667, 1.000, 0.500,
52
+ 1.000, 0.000, 0.500,
53
+ 1.000, 0.333, 0.500,
54
+ 1.000, 0.667, 0.500,
55
+ 1.000, 1.000, 0.500,
56
+ 0.000, 0.333, 1.000,
57
+ 0.000, 0.667, 1.000,
58
+ 0.000, 1.000, 1.000,
59
+ 0.333, 0.000, 1.000,
60
+ 0.333, 0.333, 1.000,
61
+ 0.333, 0.667, 1.000,
62
+ 0.333, 1.000, 1.000,
63
+ 0.667, 0.000, 1.000,
64
+ 0.667, 0.333, 1.000,
65
+ 0.667, 0.667, 1.000,
66
+ 0.667, 1.000, 1.000,
67
+ 1.000, 0.000, 1.000,
68
+ 1.000, 0.333, 1.000,
69
+ 1.000, 0.667, 1.000,
70
+ 0.167, 0.000, 0.000,
71
+ 0.333, 0.000, 0.000,
72
+ 0.500, 0.000, 0.000,
73
+ 0.667, 0.000, 0.000,
74
+ 0.833, 0.000, 0.000,
75
+ 1.000, 0.000, 0.000,
76
+ 0.000, 0.167, 0.000,
77
+ 0.000, 0.333, 0.000,
78
+ 0.000, 0.500, 0.000,
79
+ 0.000, 0.667, 0.000,
80
+ 0.000, 0.833, 0.000,
81
+ 0.000, 1.000, 0.000,
82
+ 0.000, 0.000, 0.167,
83
+ 0.000, 0.000, 0.333,
84
+ 0.000, 0.000, 0.500,
85
+ 0.000, 0.000, 0.667,
86
+ 0.000, 0.000, 0.833,
87
+ 0.000, 0.000, 1.000,
88
+ 0.143, 0.143, 0.143,
89
+ 0.286, 0.286, 0.286,
90
+ 0.429, 0.429, 0.429,
91
+ 0.571, 0.571, 0.571,
92
+ 0.714, 0.714, 0.714,
93
+ 0.857, 0.857, 0.857
94
+ ]
95
+ ).astype(np.float32)
96
+ color_list = color_list.reshape((-1, 3)) * 255
97
+ if not rgb:
98
+ color_list = color_list[:, ::-1]
99
+ return color_list
100
+
101
+
102
+ color_list = colormap()
103
+ color_list = color_list.astype('uint8').tolist()
104
+
105
+
106
+ def vis_add_mask(image, mask, color, alpha):
107
+ color = np.array(color_list[color])
108
+ mask = mask > 0.5
109
+ image[mask] = image[mask] * (1-alpha) + color * alpha
110
+ return image.astype('uint8')
111
+
112
+ def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
113
+ h, w = input_image.shape[:2]
114
+ point_mask = np.zeros((h, w)).astype('uint8')
115
+ for point in input_points:
116
+ point_mask[point[1], point[0]] = 1
117
+
118
+ kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
119
+ point_mask = cv2.dilate(point_mask, kernel)
120
+
121
+ contour_radius = (contour_width - 1) // 2
122
+ dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
123
+ dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
124
+ dist_map = dist_transform_fore - dist_transform_back
125
+ # ...:::!!!:::...
126
+ contour_radius += 2
127
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
128
+ contour_mask = contour_mask / np.max(contour_mask)
129
+ contour_mask[contour_mask>0.5] = 1.
130
+
131
+ # paint mask
132
+ painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
133
+ # paint contour
134
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
135
+ return painted_image
136
+
137
+ def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
138
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
139
+ # 0: background, 1: foreground
140
+ mask = np.clip(input_mask, 0, 1)
141
+ contour_radius = (contour_width - 1) // 2
142
+
143
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
144
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
145
+ dist_map = dist_transform_fore - dist_transform_back
146
+ # ...:::!!!:::...
147
+ contour_radius += 2
148
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
149
+ contour_mask = contour_mask / np.max(contour_mask)
150
+ contour_mask[contour_mask>0.5] = 1.
151
+
152
+ # paint mask
153
+ painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
154
+ # paint contour
155
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
156
+
157
+ return painted_image
158
+
159
+ def background_remover(input_image, input_mask):
160
+ """
161
+ input_image: H, W, 3, np.array
162
+ input_mask: H, W, np.array
163
+
164
+ image_wo_background: PIL.Image
165
+ """
166
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
167
+ # 0: background, 1: foreground
168
+ mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
169
+ image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
170
+ image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
171
+
172
+ return image_wo_background
173
+
174
+ if __name__ == '__main__':
175
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
176
+ input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
177
+
178
+ # example of mask painter
179
+ mask_color = 3
180
+ mask_alpha = 0.7
181
+ contour_color = 1
182
+ contour_width = 5
183
+
184
+ # save
185
+ painted_image = Image.fromarray(input_image)
186
+ painted_image.save('images/original.png')
187
+
188
+ painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
189
+ # save
190
+ painted_image = Image.fromarray(input_image)
191
+ painted_image.save('images/original1.png')
192
+
193
+ # example of point painter
194
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
195
+ input_points = np.array([[500, 375], [70, 600]]) # x, y
196
+ point_color = 5
197
+ point_alpha = 0.9
198
+ point_radius = 15
199
+ contour_color = 2
200
+ contour_width = 5
201
+ painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
202
+ # save
203
+ painted_image = Image.fromarray(painted_image_1)
204
+ painted_image.save('images/point_painter_1.png')
205
+
206
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
207
+ painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
208
+ # save
209
+ painted_image = Image.fromarray(painted_image_2)
210
+ painted_image.save('images/point_painter_2.png')
211
+
212
+ # example of background remover
213
+ input_image = np.array(Image.open('images/original.png').convert('RGB'))
214
+ image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
215
+ image_wo_background.save('images/image_wo_background.png')