amaanwanie's picture
added multi-segmentation support
c969ea5 verified
import os
import torch
import cv2
import numpy as np
from PIL import Image
import gradio as gr
import requests
import os
import sys
import subprocess
# Clone the repo manually at runtime
if not os.path.exists("Grounded-Segment-Anything"):
subprocess.run(["git", "clone", "--recurse-submodules", "https://github.com/IDEA-Research/Grounded-Segment-Anything.git"])
# Add submodules to path
sys.path.append("Grounded-Segment-Anything/GroundingDINO")
sys.path.append("Grounded-Segment-Anything/segment_anything")
# ---------------------------
# Download helper
# ---------------------------
def download_if_missing(url, dest_path):
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
if not os.path.exists(dest_path):
print(f"Downloading {os.path.basename(dest_path)}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(dest_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Saved to {dest_path}")
else:
print(f"{os.path.basename(dest_path)} already exists. Skipping.")
# ---------------------------
# Download models
# ---------------------------
download_if_missing(
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"checkpoints/sam_vit_h_4b8939.pth"
)
download_if_missing(
"https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth",
"checkpoints/groundingdino_swinb_cogcoor.pth"
)
download_if_missing(
"https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinB_cfg.py",
"checkpoints/GroundingDINO_SwinB_cfg.py"
)
# ---------------------------
# Device setup
# ---------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# ---------------------------
# Load models
# ---------------------------
from segment_anything import build_sam, SamPredictor
from diffusers import StableDiffusionInpaintPipeline
from groundingdino.util.inference import Model
import supervision as sv
# SAM
sam = build_sam(checkpoint="checkpoints/sam_vit_h_4b8939.pth")
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)
# Grounding DINO
dino_model = Model(
model_config_path="checkpoints/GroundingDINO_SwinB_cfg.py",
model_checkpoint_path="checkpoints/groundingdino_swinb_cogcoor.pth",
device=DEVICE
)
# Stable Diffusion Inpainting
dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=dtype
)
if DEVICE.type != "cpu":
pipe = pipe.to(DEVICE)
# ---------------------------
# Inference Functions
# ---------------------------
def detection_fn(image, prompt):
image_np = np.array(image)
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
detections, _ = dino_model.predict_with_caption(
image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
)
detections.class_id = np.zeros(len(detections), dtype=int)
box_annotator = sv.BoxAnnotator()
annotated = box_annotator.annotate(scene=image_cv, detections=detections)
return cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
def segmentation_fn(image, prompt):
image_np = np.array(image)
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# Detect with Grounding DINO
detections, _ = dino_model.predict_with_caption(
image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
)
boxes = detections.xyxy
sam_predictor.set_image(image_np)
all_masks = []
for box in boxes:
box = box.reshape(1, 4)
masks, scores, _ = sam_predictor.predict(box=box, multimask_output=True)
if masks is not None:
all_masks.append(masks[np.argmax(scores)])
if not all_masks:
raise ValueError("No masks found")
# Combine masks into one binary mask
merged_mask = np.any(all_masks, axis=0).astype(np.uint8) * 255
# Overlay on image
def overlay_mask(mask, image):
color = np.array([0, 255, 0], dtype=np.uint8) # Green
mask_rgb = np.stack([mask] * 3, axis=-1)
overlay = np.where(mask_rgb, color, image)
return overlay
return overlay_mask(merged_mask, image_np)
def inpainting_fn(image, prompt):
image_np = np.array(image)
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
detections, _ = dino_model.predict_with_caption(
image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
)
boxes = detections.xyxy
sam_predictor.set_image(image_np)
masks, scores, _ = sam_predictor.predict(box=boxes, multimask_output=True)
if masks is None or len(masks) == 0:
raise ValueError("No masks found")
mask = masks[np.argmax(scores)]
image_pil = image.convert("RGB")
mask_img = Image.fromarray((mask.astype(np.uint8) * 255)).convert("L")
image_resized = image_pil.resize((512, 512))
mask_resized = mask_img.resize((512, 512))
inpainted = pipe(prompt=prompt, image=image_resized, mask_image=mask_resized).images[0]
return inpainted.resize(image_pil.size)
# ---------------------------
# Gradio Interface
# ---------------------------
with gr.Blocks() as demo:
gr.Markdown("# Grounded Segment Anything + SAM + Stable Diffusion")
with gr.Tabs():
with gr.TabItem("Detection"):
img = gr.Image(type="pil")
txt = gr.Textbox(label="Prompt", value="bench")
out = gr.Image()
btn = gr.Button("Run Detection")
btn.click(detection_fn, inputs=[img, txt], outputs=out)
with gr.TabItem("Segmentation"):
img2 = gr.Image(type="pil")
txt2 = gr.Textbox(label="Prompt", value="bench")
out2 = gr.Image()
btn2 = gr.Button("Run Segmentation")
btn2.click(segmentation_fn, inputs=[img2, txt2], outputs=out2)
with gr.TabItem("Inpainting"):
img3 = gr.Image(type="pil")
txt3 = gr.Textbox(label="Prompt", value="A sofa, cyberpunk style, colorful")
out3 = gr.Image()
btn3 = gr.Button("Run Inpainting")
btn3.click(inpainting_fn, inputs=[img3, txt3], outputs=out3)
if __name__ == "__main__":
demo.launch()