SkalskiP's picture
migrating SAM2 space from T4 to ZERO
9fdc53a
raw
history blame
5.75 kB
from typing import Optional
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from gradio_image_prompter import ImagePrompter
from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \
MASK_GENERATION_MODE, BOX_PROMPT_MODE
import spaces
MARKDOWN = """
# Segment Anything Model 2 🔥
<div>
<a href="https://github.com/facebookresearch/segment-anything-2">
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;">
</a>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
</a>
<a href="https://blog.roboflow.com/what-is-segment-anything-2/">
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
</a>
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y">
<img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;">
</a>
</div>
Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable
visual segmentation in both images and videos. **Video segmentation will be available
soon.**
"""
EXAMPLES = [
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None],
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None],
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-4.jpeg", None],
]
DEVICE = torch.device('cuda')
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE)
@spaces.GPU
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process(
checkpoint_dropdown,
mode_dropdown,
image_input,
image_prompter_input
) -> Optional[Image.Image]:
if mode_dropdown == BOX_PROMPT_MODE:
image_input = image_prompter_input["image"]
prompt = image_prompter_input["points"]
if len(prompt) == 0:
return image_input
model = IMAGE_PREDICTORS[checkpoint_dropdown]
image = np.array(image_input.convert("RGB"))
box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt])
model.set_image(image)
masks, _, _ = model.predict(box=box, multimask_output=False)
# dirty fix; remove this later
if len(masks.shape) == 4:
masks = np.squeeze(masks)
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks=masks),
mask=masks.astype(bool)
)
return MASK_ANNOTATOR.annotate(image_input, detections)
if mode_dropdown == MASK_GENERATION_MODE:
model = MASK_GENERATORS[checkpoint_dropdown]
image = np.array(image_input.convert("RGB"))
result = model.generate(image)
detections = sv.Detections.from_sam(result)
return MASK_ANNOTATOR.annotate(image_input, detections)
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
checkpoint_dropdown_component = gr.Dropdown(
choices=CHECKPOINT_NAMES,
value=CHECKPOINT_NAMES[0],
label="Checkpoint", info="Select a SAM2 checkpoint to use.",
interactive=True
)
mode_dropdown_component = gr.Dropdown(
choices=MODE_NAMES,
value=MODE_NAMES[0],
label="Mode",
info="Select a mode to use. `box prompt` if you want to generate masks for "
"selected objects, `mask generation` if you want to generate masks "
"for the whole image.",
interactive=True
)
with gr.Row():
with gr.Column():
image_input_component = gr.Image(
type='pil', label='Upload image', visible=False)
image_prompter_input_component = ImagePrompter(
type='pil', label='Image prompt')
submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
image_output_component = gr.Image(type='pil', label='Image Output')
with gr.Row():
gr.Examples(
fn=process,
examples=EXAMPLES,
inputs=[
checkpoint_dropdown_component,
mode_dropdown_component,
image_input_component,
image_prompter_input_component,
],
outputs=[image_output_component],
run_on_click=True
)
def on_mode_dropdown_change(text):
return [
gr.Image(visible=text == MASK_GENERATION_MODE),
ImagePrompter(visible=text == BOX_PROMPT_MODE)
]
mode_dropdown_component.change(
on_mode_dropdown_change,
inputs=[mode_dropdown_component],
outputs=[
image_input_component,
image_prompter_input_component
]
)
submit_button_component.click(
fn=process,
inputs=[
checkpoint_dropdown_component,
mode_dropdown_component,
image_input_component,
image_prompter_input_component,
],
outputs=[image_output_component]
)
demo.launch(debug=False, show_error=True)