Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
|
9 |
+
# ---------------------------
|
10 |
+
# Download helper
|
11 |
+
# ---------------------------
|
12 |
+
def download_if_missing(url, dest_path):
|
13 |
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
14 |
+
if not os.path.exists(dest_path):
|
15 |
+
print(f"Downloading {os.path.basename(dest_path)}...")
|
16 |
+
response = requests.get(url, stream=True)
|
17 |
+
response.raise_for_status()
|
18 |
+
with open(dest_path, "wb") as f:
|
19 |
+
for chunk in response.iter_content(chunk_size=8192):
|
20 |
+
f.write(chunk)
|
21 |
+
print(f"Saved to {dest_path}")
|
22 |
+
else:
|
23 |
+
print(f"{os.path.basename(dest_path)} already exists. Skipping.")
|
24 |
+
|
25 |
+
# ---------------------------
|
26 |
+
# Download models
|
27 |
+
# ---------------------------
|
28 |
+
download_if_missing(
|
29 |
+
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
30 |
+
"checkpoints/sam_vit_h_4b8939.pth"
|
31 |
+
)
|
32 |
+
|
33 |
+
download_if_missing(
|
34 |
+
"https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth",
|
35 |
+
"checkpoints/groundingdino_swinb_cogcoor.pth"
|
36 |
+
)
|
37 |
+
|
38 |
+
download_if_missing(
|
39 |
+
"https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinB_cfg.py",
|
40 |
+
"checkpoints/GroundingDINO_SwinB_cfg.py"
|
41 |
+
)
|
42 |
+
|
43 |
+
# ---------------------------
|
44 |
+
# Device setup
|
45 |
+
# ---------------------------
|
46 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
47 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
48 |
+
|
49 |
+
# ---------------------------
|
50 |
+
# Load models
|
51 |
+
# ---------------------------
|
52 |
+
from segment_anything import build_sam, SamPredictor
|
53 |
+
from diffusers import StableDiffusionInpaintPipeline
|
54 |
+
from groundingdino.util.inference import Model
|
55 |
+
import supervision as sv
|
56 |
+
|
57 |
+
# SAM
|
58 |
+
sam = build_sam(checkpoint="checkpoints/sam_vit_h_4b8939.pth")
|
59 |
+
sam.to(device=DEVICE)
|
60 |
+
sam_predictor = SamPredictor(sam)
|
61 |
+
|
62 |
+
# Grounding DINO
|
63 |
+
dino_model = Model(
|
64 |
+
model_config_path="checkpoints/GroundingDINO_SwinB_cfg.py",
|
65 |
+
model_checkpoint_path="checkpoints/groundingdino_swinb_cogcoor.pth",
|
66 |
+
device=DEVICE
|
67 |
+
)
|
68 |
+
|
69 |
+
# Stable Diffusion Inpainting
|
70 |
+
dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32
|
71 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
72 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
73 |
+
torch_dtype=dtype
|
74 |
+
)
|
75 |
+
if DEVICE.type != "cpu":
|
76 |
+
pipe = pipe.to(DEVICE)
|
77 |
+
|
78 |
+
# ---------------------------
|
79 |
+
# Inference Functions
|
80 |
+
# ---------------------------
|
81 |
+
def detection_fn(image, prompt):
|
82 |
+
image_np = np.array(image)
|
83 |
+
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
84 |
+
detections, _ = dino_model.predict_with_caption(
|
85 |
+
image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
|
86 |
+
)
|
87 |
+
detections.class_id = np.zeros(len(detections), dtype=int)
|
88 |
+
box_annotator = sv.BoxAnnotator()
|
89 |
+
annotated = box_annotator.annotate(scene=image_cv, detections=detections)
|
90 |
+
return cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
|
91 |
+
|
92 |
+
def segmentation_fn(image, prompt):
|
93 |
+
image_np = np.array(image)
|
94 |
+
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
95 |
+
detections, _ = dino_model.predict_with_caption(
|
96 |
+
image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
|
97 |
+
)
|
98 |
+
boxes = detections.xyxy
|
99 |
+
sam_predictor.set_image(image_np)
|
100 |
+
masks, scores, _ = sam_predictor.predict(box=boxes, multimask_output=True)
|
101 |
+
if masks is None or len(masks) == 0:
|
102 |
+
raise ValueError("No masks found")
|
103 |
+
mask = masks[np.argmax(scores)]
|
104 |
+
|
105 |
+
# Visualize mask
|
106 |
+
def overlay_mask(mask, image):
|
107 |
+
color = np.concatenate([np.random.random(3), np.array([0.8])])
|
108 |
+
h, w = mask.shape[-2:]
|
109 |
+
mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
110 |
+
image_pil = Image.fromarray(image).convert("RGBA")
|
111 |
+
mask_pil = Image.fromarray((mask_img * 255).astype(np.uint8)).convert("RGBA")
|
112 |
+
return np.array(Image.alpha_composite(image_pil, mask_pil))
|
113 |
+
|
114 |
+
return overlay_mask(mask, image_np)
|
115 |
+
|
116 |
+
def inpainting_fn(image, prompt):
|
117 |
+
image_np = np.array(image)
|
118 |
+
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
119 |
+
detections, _ = dino_model.predict_with_caption(
|
120 |
+
image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
|
121 |
+
)
|
122 |
+
boxes = detections.xyxy
|
123 |
+
sam_predictor.set_image(image_np)
|
124 |
+
masks, scores, _ = sam_predictor.predict(box=boxes, multimask_output=True)
|
125 |
+
if masks is None or len(masks) == 0:
|
126 |
+
raise ValueError("No masks found")
|
127 |
+
mask = masks[np.argmax(scores)]
|
128 |
+
|
129 |
+
image_pil = image.convert("RGB")
|
130 |
+
mask_img = Image.fromarray((mask.astype(np.uint8) * 255)).convert("L")
|
131 |
+
image_resized = image_pil.resize((512, 512))
|
132 |
+
mask_resized = mask_img.resize((512, 512))
|
133 |
+
inpainted = pipe(prompt=prompt, image=image_resized, mask_image=mask_resized).images[0]
|
134 |
+
return inpainted.resize(image_pil.size)
|
135 |
+
|
136 |
+
# ---------------------------
|
137 |
+
# Gradio Interface
|
138 |
+
# ---------------------------
|
139 |
+
with gr.Blocks() as demo:
|
140 |
+
gr.Markdown("# Grounded Segment Anything + SAM + Stable Diffusion")
|
141 |
+
with gr.Tabs():
|
142 |
+
with gr.TabItem("Detection"):
|
143 |
+
img = gr.Image(type="pil")
|
144 |
+
txt = gr.Textbox(label="Prompt", value="bench")
|
145 |
+
out = gr.Image()
|
146 |
+
btn = gr.Button("Run Detection")
|
147 |
+
btn.click(detection_fn, inputs=[img, txt], outputs=out)
|
148 |
+
|
149 |
+
with gr.TabItem("Segmentation"):
|
150 |
+
img2 = gr.Image(type="pil")
|
151 |
+
txt2 = gr.Textbox(label="Prompt", value="bench")
|
152 |
+
out2 = gr.Image()
|
153 |
+
btn2 = gr.Button("Run Segmentation")
|
154 |
+
btn2.click(segmentation_fn, inputs=[img2, txt2], outputs=out2)
|
155 |
+
|
156 |
+
with gr.TabItem("Inpainting"):
|
157 |
+
img3 = gr.Image(type="pil")
|
158 |
+
txt3 = gr.Textbox(label="Prompt", value="A sofa, cyberpunk style, colorful")
|
159 |
+
out3 = gr.Image()
|
160 |
+
btn3 = gr.Button("Run Inpainting")
|
161 |
+
btn3.click(inpainting_fn, inputs=[img3, txt3], outputs=out3)
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
demo.launch()
|