amaanwanie commited on
Commit
b49cc80
·
verified ·
1 Parent(s): b169cde

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import requests
8
+
9
+ # ---------------------------
10
+ # Download helper
11
+ # ---------------------------
12
+ def download_if_missing(url, dest_path):
13
+ os.makedirs(os.path.dirname(dest_path), exist_ok=True)
14
+ if not os.path.exists(dest_path):
15
+ print(f"Downloading {os.path.basename(dest_path)}...")
16
+ response = requests.get(url, stream=True)
17
+ response.raise_for_status()
18
+ with open(dest_path, "wb") as f:
19
+ for chunk in response.iter_content(chunk_size=8192):
20
+ f.write(chunk)
21
+ print(f"Saved to {dest_path}")
22
+ else:
23
+ print(f"{os.path.basename(dest_path)} already exists. Skipping.")
24
+
25
+ # ---------------------------
26
+ # Download models
27
+ # ---------------------------
28
+ download_if_missing(
29
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
30
+ "checkpoints/sam_vit_h_4b8939.pth"
31
+ )
32
+
33
+ download_if_missing(
34
+ "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth",
35
+ "checkpoints/groundingdino_swinb_cogcoor.pth"
36
+ )
37
+
38
+ download_if_missing(
39
+ "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinB_cfg.py",
40
+ "checkpoints/GroundingDINO_SwinB_cfg.py"
41
+ )
42
+
43
+ # ---------------------------
44
+ # Device setup
45
+ # ---------------------------
46
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
48
+
49
+ # ---------------------------
50
+ # Load models
51
+ # ---------------------------
52
+ from segment_anything import build_sam, SamPredictor
53
+ from diffusers import StableDiffusionInpaintPipeline
54
+ from groundingdino.util.inference import Model
55
+ import supervision as sv
56
+
57
+ # SAM
58
+ sam = build_sam(checkpoint="checkpoints/sam_vit_h_4b8939.pth")
59
+ sam.to(device=DEVICE)
60
+ sam_predictor = SamPredictor(sam)
61
+
62
+ # Grounding DINO
63
+ dino_model = Model(
64
+ model_config_path="checkpoints/GroundingDINO_SwinB_cfg.py",
65
+ model_checkpoint_path="checkpoints/groundingdino_swinb_cogcoor.pth",
66
+ device=DEVICE
67
+ )
68
+
69
+ # Stable Diffusion Inpainting
70
+ dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32
71
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
72
+ "stabilityai/stable-diffusion-2-inpainting",
73
+ torch_dtype=dtype
74
+ )
75
+ if DEVICE.type != "cpu":
76
+ pipe = pipe.to(DEVICE)
77
+
78
+ # ---------------------------
79
+ # Inference Functions
80
+ # ---------------------------
81
+ def detection_fn(image, prompt):
82
+ image_np = np.array(image)
83
+ image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
84
+ detections, _ = dino_model.predict_with_caption(
85
+ image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
86
+ )
87
+ detections.class_id = np.zeros(len(detections), dtype=int)
88
+ box_annotator = sv.BoxAnnotator()
89
+ annotated = box_annotator.annotate(scene=image_cv, detections=detections)
90
+ return cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
91
+
92
+ def segmentation_fn(image, prompt):
93
+ image_np = np.array(image)
94
+ image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
95
+ detections, _ = dino_model.predict_with_caption(
96
+ image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
97
+ )
98
+ boxes = detections.xyxy
99
+ sam_predictor.set_image(image_np)
100
+ masks, scores, _ = sam_predictor.predict(box=boxes, multimask_output=True)
101
+ if masks is None or len(masks) == 0:
102
+ raise ValueError("No masks found")
103
+ mask = masks[np.argmax(scores)]
104
+
105
+ # Visualize mask
106
+ def overlay_mask(mask, image):
107
+ color = np.concatenate([np.random.random(3), np.array([0.8])])
108
+ h, w = mask.shape[-2:]
109
+ mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
110
+ image_pil = Image.fromarray(image).convert("RGBA")
111
+ mask_pil = Image.fromarray((mask_img * 255).astype(np.uint8)).convert("RGBA")
112
+ return np.array(Image.alpha_composite(image_pil, mask_pil))
113
+
114
+ return overlay_mask(mask, image_np)
115
+
116
+ def inpainting_fn(image, prompt):
117
+ image_np = np.array(image)
118
+ image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
119
+ detections, _ = dino_model.predict_with_caption(
120
+ image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
121
+ )
122
+ boxes = detections.xyxy
123
+ sam_predictor.set_image(image_np)
124
+ masks, scores, _ = sam_predictor.predict(box=boxes, multimask_output=True)
125
+ if masks is None or len(masks) == 0:
126
+ raise ValueError("No masks found")
127
+ mask = masks[np.argmax(scores)]
128
+
129
+ image_pil = image.convert("RGB")
130
+ mask_img = Image.fromarray((mask.astype(np.uint8) * 255)).convert("L")
131
+ image_resized = image_pil.resize((512, 512))
132
+ mask_resized = mask_img.resize((512, 512))
133
+ inpainted = pipe(prompt=prompt, image=image_resized, mask_image=mask_resized).images[0]
134
+ return inpainted.resize(image_pil.size)
135
+
136
+ # ---------------------------
137
+ # Gradio Interface
138
+ # ---------------------------
139
+ with gr.Blocks() as demo:
140
+ gr.Markdown("# Grounded Segment Anything + SAM + Stable Diffusion")
141
+ with gr.Tabs():
142
+ with gr.TabItem("Detection"):
143
+ img = gr.Image(type="pil")
144
+ txt = gr.Textbox(label="Prompt", value="bench")
145
+ out = gr.Image()
146
+ btn = gr.Button("Run Detection")
147
+ btn.click(detection_fn, inputs=[img, txt], outputs=out)
148
+
149
+ with gr.TabItem("Segmentation"):
150
+ img2 = gr.Image(type="pil")
151
+ txt2 = gr.Textbox(label="Prompt", value="bench")
152
+ out2 = gr.Image()
153
+ btn2 = gr.Button("Run Segmentation")
154
+ btn2.click(segmentation_fn, inputs=[img2, txt2], outputs=out2)
155
+
156
+ with gr.TabItem("Inpainting"):
157
+ img3 = gr.Image(type="pil")
158
+ txt3 = gr.Textbox(label="Prompt", value="A sofa, cyberpunk style, colorful")
159
+ out3 = gr.Image()
160
+ btn3 = gr.Button("Run Inpainting")
161
+ btn3.click(inpainting_fn, inputs=[img3, txt3], outputs=out3)
162
+
163
+ if __name__ == "__main__":
164
+ demo.launch()