Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
from transformers import SamModel, SamProcessor | |
from gradio_image_prompter import ImagePrompter | |
device = 'cpu' | |
model_id = "nielsr/slimsam-50-uniform" | |
slim_sam_model = SamModel.from_pretrained(model_id).to(device) | |
slim_sam_processor = SamProcessor.from_pretrained(model_id) | |
def sam_box_inference(image, x_min, y_min, x_max, y_max): | |
processor, model = slim_sam_processor, slim_sam_model | |
inputs = processor( | |
Image.fromarray(image), | |
input_boxes=[[[[x_min, y_min, x_max, y_max]]]], | |
return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
mask = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
)[0][0][0].numpy() | |
mask = mask[np.newaxis, ...] | |
print(mask) | |
print(mask.shape) | |
return [(mask, "mask")] | |
def sam_point_inference(image, x, y): | |
processor, model = slim_sam_processor, slim_sam_model | |
inputs = processor( | |
image, | |
input_points=[[[x, y]]], | |
return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
mask = processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
)[0][0][0].numpy() | |
mask = mask[np.newaxis, ...] | |
print(type(mask)) | |
print(mask.shape) | |
return [(mask, "mask")] | |
def infer_point(img): | |
if img is None: | |
gr.Error("Please upload an image and select a point.") | |
if img["background"] is None: | |
gr.Error("Please upload an image and select a point.") | |
image = img["background"].convert("RGB") | |
point_prompt = img["layers"][0] | |
total_image = img["composite"] | |
img_arr = np.array(point_prompt) | |
if not np.any(img_arr): | |
gr.Error("Please select a point on top of the image.") | |
else: | |
nonzero_indices = np.nonzero(img_arr) | |
img_arr = np.array(point_prompt) | |
nonzero_indices = np.nonzero(img_arr) | |
center_x = int(np.mean(nonzero_indices[1])) | |
center_y = int(np.mean(nonzero_indices[0])) | |
print("Point inference returned.") | |
return (image, sam_point_inference(image, center_x, center_y)) | |
def infer_box(prompts): | |
image = prompts["image"] | |
if image is None: | |
gr.Error("Please upload an image and draw a box before submitting") | |
points = prompts["points"][0] | |
if points is None: | |
gr.Error("Please draw a box before submitting.") | |
print(points) | |
return (image, sam_box_inference(image, points[0], points[1], points[3], points[4])) | |
if __name__ == '__main__': | |
with gr.Blocks(title="SlimSAM") as demo: | |
gr.Markdown("# SlimSAM") | |
gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.") | |
gr.Markdown("In this demo, you can compare SlimSAM outputs in point and box prompts.") | |
with gr.Tab("Box Prompt"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("To try box prompting, simply upload and image and draw a box on it.") | |
with gr.Row(): | |
with gr.Column(): | |
im = ImagePrompter() | |
btn = gr.Button("Submit") | |
with gr.Column(): | |
output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output") | |
btn.click(infer_box, inputs=im, outputs=[output_box_slimsam]) | |
with gr.Tab("Point Prompt"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("To try point prompting, simply upload and image and leave a dot on it.") | |
with gr.Row(): | |
with gr.Column(): | |
im = gr.ImageEditor( | |
type="pil", | |
) | |
with gr.Column(): | |
output_slimsam = gr.AnnotatedImage(label="SlimSAM Output") | |
im.change(infer_point, inputs=im, outputs=[output_slimsam]) | |
demo.launch(debug=True) | |