|
import os |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import gradio as gr |
|
import requests |
|
|
|
import os |
|
import sys |
|
import subprocess |
|
|
|
|
|
if not os.path.exists("Grounded-Segment-Anything"): |
|
subprocess.run(["git", "clone", "--recurse-submodules", "https://github.com/IDEA-Research/Grounded-Segment-Anything.git"]) |
|
|
|
|
|
sys.path.append("Grounded-Segment-Anything/GroundingDINO") |
|
sys.path.append("Grounded-Segment-Anything/segment_anything") |
|
|
|
|
|
|
|
|
|
|
|
def download_if_missing(url, dest_path): |
|
os.makedirs(os.path.dirname(dest_path), exist_ok=True) |
|
if not os.path.exists(dest_path): |
|
print(f"Downloading {os.path.basename(dest_path)}...") |
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
with open(dest_path, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
print(f"Saved to {dest_path}") |
|
else: |
|
print(f"{os.path.basename(dest_path)} already exists. Skipping.") |
|
|
|
|
|
|
|
|
|
download_if_missing( |
|
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", |
|
"checkpoints/sam_vit_h_4b8939.pth" |
|
) |
|
|
|
download_if_missing( |
|
"https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth", |
|
"checkpoints/groundingdino_swinb_cogcoor.pth" |
|
) |
|
|
|
download_if_missing( |
|
"https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinB_cfg.py", |
|
"checkpoints/GroundingDINO_SwinB_cfg.py" |
|
) |
|
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
|
|
|
|
|
|
from segment_anything import build_sam, SamPredictor |
|
from diffusers import StableDiffusionInpaintPipeline |
|
from groundingdino.util.inference import Model |
|
import supervision as sv |
|
|
|
|
|
sam = build_sam(checkpoint="checkpoints/sam_vit_h_4b8939.pth") |
|
sam.to(device=DEVICE) |
|
sam_predictor = SamPredictor(sam) |
|
|
|
|
|
dino_model = Model( |
|
model_config_path="checkpoints/GroundingDINO_SwinB_cfg.py", |
|
model_checkpoint_path="checkpoints/groundingdino_swinb_cogcoor.pth", |
|
device=DEVICE |
|
) |
|
|
|
|
|
dtype = torch.float16 if DEVICE.type != "cpu" else torch.float32 |
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", |
|
torch_dtype=dtype |
|
) |
|
if DEVICE.type != "cpu": |
|
pipe = pipe.to(DEVICE) |
|
|
|
|
|
|
|
|
|
def detection_fn(image, prompt): |
|
image_np = np.array(image) |
|
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
|
detections, _ = dino_model.predict_with_caption( |
|
image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25 |
|
) |
|
detections.class_id = np.zeros(len(detections), dtype=int) |
|
box_annotator = sv.BoxAnnotator() |
|
annotated = box_annotator.annotate(scene=image_cv, detections=detections) |
|
return cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) |
|
|
|
def segmentation_fn(image, prompt): |
|
image_np = np.array(image) |
|
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
detections, _ = dino_model.predict_with_caption( |
|
image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25 |
|
) |
|
|
|
boxes = detections.xyxy |
|
sam_predictor.set_image(image_np) |
|
|
|
all_masks = [] |
|
for box in boxes: |
|
box = box.reshape(1, 4) |
|
masks, scores, _ = sam_predictor.predict(box=box, multimask_output=True) |
|
if masks is not None: |
|
all_masks.append(masks[np.argmax(scores)]) |
|
|
|
if not all_masks: |
|
raise ValueError("No masks found") |
|
|
|
|
|
merged_mask = np.any(all_masks, axis=0).astype(np.uint8) * 255 |
|
|
|
|
|
def overlay_mask(mask, image): |
|
color = np.array([0, 255, 0], dtype=np.uint8) |
|
mask_rgb = np.stack([mask] * 3, axis=-1) |
|
overlay = np.where(mask_rgb, color, image) |
|
return overlay |
|
|
|
return overlay_mask(merged_mask, image_np) |
|
|
|
|
|
def inpainting_fn(image, prompt): |
|
image_np = np.array(image) |
|
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
|
detections, _ = dino_model.predict_with_caption( |
|
image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25 |
|
) |
|
boxes = detections.xyxy |
|
sam_predictor.set_image(image_np) |
|
masks, scores, _ = sam_predictor.predict(box=boxes, multimask_output=True) |
|
if masks is None or len(masks) == 0: |
|
raise ValueError("No masks found") |
|
mask = masks[np.argmax(scores)] |
|
|
|
image_pil = image.convert("RGB") |
|
mask_img = Image.fromarray((mask.astype(np.uint8) * 255)).convert("L") |
|
image_resized = image_pil.resize((512, 512)) |
|
mask_resized = mask_img.resize((512, 512)) |
|
inpainted = pipe(prompt=prompt, image=image_resized, mask_image=mask_resized).images[0] |
|
return inpainted.resize(image_pil.size) |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Grounded Segment Anything + SAM + Stable Diffusion") |
|
with gr.Tabs(): |
|
with gr.TabItem("Detection"): |
|
img = gr.Image(type="pil") |
|
txt = gr.Textbox(label="Prompt", value="bench") |
|
out = gr.Image() |
|
btn = gr.Button("Run Detection") |
|
btn.click(detection_fn, inputs=[img, txt], outputs=out) |
|
|
|
with gr.TabItem("Segmentation"): |
|
img2 = gr.Image(type="pil") |
|
txt2 = gr.Textbox(label="Prompt", value="bench") |
|
out2 = gr.Image() |
|
btn2 = gr.Button("Run Segmentation") |
|
btn2.click(segmentation_fn, inputs=[img2, txt2], outputs=out2) |
|
|
|
with gr.TabItem("Inpainting"): |
|
img3 = gr.Image(type="pil") |
|
txt3 = gr.Textbox(label="Prompt", value="A sofa, cyberpunk style, colorful") |
|
out3 = gr.Image() |
|
btn3 = gr.Button("Run Inpainting") |
|
btn3.click(inpainting_fn, inputs=[img3, txt3], outputs=out3) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|