import os import torch import cv2 import numpy as np from PIL import Image import gradio as gr import requests import os import sys import subprocess # Clone the repo manually at runtime if not os.path.exists("Grounded-Segment-Anything"): subprocess.run(["git", "clone", "--recurse-submodules", "https://github.com/IDEA-Research/Grounded-Segment-Anything.git"]) # Add submodules to path sys.path.append("Grounded-Segment-Anything/GroundingDINO") sys.path.append("Grounded-Segment-Anything/segment_anything") # --------------------------- # Download helper # --------------------------- def download_if_missing(url, dest_path): os.makedirs(os.path.dirname(dest_path), exist_ok=True) if not os.path.exists(dest_path): print(f"Downloading {os.path.basename(dest_path)}...") response = requests.get(url, stream=True) response.raise_for_status() with open(dest_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Saved to {dest_path}") else: print(f"{os.path.basename(dest_path)} already exists. Skipping.") # --------------------------- # Download models # --------------------------- download_if_missing( "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "checkpoints/sam_vit_h_4b8939.pth" ) download_if_missing( "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth", "checkpoints/groundingdino_swinb_cogcoor.pth" ) download_if_missing( "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinB_cfg.py", "checkpoints/GroundingDINO_SwinB_cfg.py" ) # --------------------------- # Device setup # --------------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.environ["CUDA_VISIBLE_DEVICES"] = "0" # --------------------------- # Load models # --------------------------- from segment_anything import build_sam, SamPredictor from diffusers import StableDiffusionInpaintPipeline from groundingdino.util.inference import Model import supervision as sv # SAM sam = build_sam(checkpoint="checkpoints/sam_vit_h_4b8939.pth") sam.to(device=DEVICE) sam_predictor = SamPredictor(sam) # Grounding DINO dino_model = Model( model_config_path="checkpoints/GroundingDINO_SwinB_cfg.py", model_checkpoint_path="checkpoints/groundingdino_swinb_cogcoor.pth", device=DEVICE ) # Stable Diffusion Inpainting dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32 pipe = StableDiffusionInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", torch_dtype=dtype ) if DEVICE.type != "cpu": pipe = pipe.to(DEVICE) # --------------------------- # Inference Functions # --------------------------- def detection_fn(image, prompt): image_np = np.array(image) image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) detections, _ = dino_model.predict_with_caption( image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25 ) detections.class_id = np.zeros(len(detections), dtype=int) box_annotator = sv.BoxAnnotator() annotated = box_annotator.annotate(scene=image_cv, detections=detections) return cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) def segmentation_fn(image, prompt): image_np = np.array(image) image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # Detect with Grounding DINO detections, _ = dino_model.predict_with_caption( image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25 ) boxes = detections.xyxy sam_predictor.set_image(image_np) all_masks = [] for box in boxes: box = box.reshape(1, 4) masks, scores, _ = sam_predictor.predict(box=box, multimask_output=True) if masks is not None: all_masks.append(masks[np.argmax(scores)]) if not all_masks: raise ValueError("No masks found") # Combine masks into one binary mask merged_mask = np.any(all_masks, axis=0).astype(np.uint8) * 255 # Overlay on image def overlay_mask(mask, image): color = np.array([0, 255, 0], dtype=np.uint8) # Green mask_rgb = np.stack([mask] * 3, axis=-1) overlay = np.where(mask_rgb, color, image) return overlay return overlay_mask(merged_mask, image_np) def inpainting_fn(image, prompt): image_np = np.array(image) image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) detections, _ = dino_model.predict_with_caption( image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25 ) boxes = detections.xyxy sam_predictor.set_image(image_np) masks, scores, _ = sam_predictor.predict(box=boxes, multimask_output=True) if masks is None or len(masks) == 0: raise ValueError("No masks found") mask = masks[np.argmax(scores)] image_pil = image.convert("RGB") mask_img = Image.fromarray((mask.astype(np.uint8) * 255)).convert("L") image_resized = image_pil.resize((512, 512)) mask_resized = mask_img.resize((512, 512)) inpainted = pipe(prompt=prompt, image=image_resized, mask_image=mask_resized).images[0] return inpainted.resize(image_pil.size) # --------------------------- # Gradio Interface # --------------------------- with gr.Blocks() as demo: gr.Markdown("# Grounded Segment Anything + SAM + Stable Diffusion") with gr.Tabs(): with gr.TabItem("Detection"): img = gr.Image(type="pil") txt = gr.Textbox(label="Prompt", value="bench") out = gr.Image() btn = gr.Button("Run Detection") btn.click(detection_fn, inputs=[img, txt], outputs=out) with gr.TabItem("Segmentation"): img2 = gr.Image(type="pil") txt2 = gr.Textbox(label="Prompt", value="bench") out2 = gr.Image() btn2 = gr.Button("Run Segmentation") btn2.click(segmentation_fn, inputs=[img2, txt2], outputs=out2) with gr.TabItem("Inpainting"): img3 = gr.Image(type="pil") txt3 = gr.Textbox(label="Prompt", value="A sofa, cyberpunk style, colorful") out3 = gr.Image() btn3 = gr.Button("Run Inpainting") btn3.click(inpainting_fn, inputs=[img3, txt3], outputs=out3) if __name__ == "__main__": demo.launch()