Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
SAM 2.1 Interface | |
""" | |
import torch | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
from transformers import Sam2Model, Sam2Processor | |
import warnings | |
import io | |
import base64 | |
import os | |
from datetime import datetime | |
# Grounding DINO will be imported dynamically in the initialization function | |
warnings.filterwarnings("ignore") | |
# Global model instance to avoid reloading | |
MODEL = None | |
PROCESSOR = None | |
DEVICE = None | |
# Global Grounding DINO instance | |
GROUNDING_DINO = None | |
# Global state for saving | |
CURRENT_MASK = None | |
CURRENT_IMAGE_NAME = None | |
CURRENT_POINTS = None | |
def initialize_sam(model_size="small"): | |
"""Initialize SAM model once""" | |
global MODEL, PROCESSOR, DEVICE | |
if MODEL is None: | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Initializing SAM 2.1 {model_size} on {DEVICE}...") | |
model_name = f"facebook/sam2-hiera-{model_size}" | |
MODEL = Sam2Model.from_pretrained(model_name).to(DEVICE) | |
PROCESSOR = Sam2Processor.from_pretrained(model_name) | |
print("β Model loaded successfully!") | |
return MODEL, PROCESSOR, DEVICE | |
def initialize_grounding_dino(): | |
"""Initialize Grounding DINO model once""" | |
global GROUNDING_DINO, DEVICE | |
if GROUNDING_DINO is None: | |
if DEVICE is None: | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Initializing Grounding DINO on {DEVICE}...") | |
try: | |
# Use Hugging Face model for Grounding DINO | |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
model_id = "IDEA-RESEARCH/grounding-dino-base" | |
GROUNDING_DINO = { | |
'processor': AutoProcessor.from_pretrained(model_id), | |
'model': AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE) | |
} | |
print("β Grounding DINO loaded successfully!") | |
except Exception as e: | |
print(f"β Failed to load Grounding DINO: {e}") | |
print("Note: Falling back to manual point selection only") | |
GROUNDING_DINO = None | |
return GROUNDING_DINO | |
def detect_objects_with_text(image, text_prompt, confidence_threshold=0.25): | |
"""Use Grounding DINO to detect objects based on text prompt""" | |
global GROUNDING_DINO | |
try: | |
# Initialize Grounding DINO if needed | |
grounding_dino = initialize_grounding_dino() | |
if grounding_dino is None: | |
return None, "β Grounding DINO not available" | |
# Fix image format | |
pil_image = fix_image_array(image) | |
# Prepare inputs for Grounding DINO | |
processor = grounding_dino['processor'] | |
model = grounding_dino['model'] | |
# Process inputs | |
inputs = processor(images=pil_image, text=text_prompt, return_tensors="pt").to(DEVICE) | |
# Run inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Post-process results | |
results = processor.post_process_grounded_object_detection( | |
outputs, | |
input_ids=inputs.input_ids, | |
threshold=confidence_threshold, | |
text_threshold=0.25, | |
target_sizes=[pil_image.size[::-1]] # (height, width) | |
)[0] | |
if len(results['boxes']) == 0: | |
return None, f"No objects found for prompt: '{text_prompt}'" | |
# Convert boxes to the format expected by SAM [x1, y1, x2, y2] | |
detected_boxes = [] | |
for box in results['boxes']: | |
x1, y1, x2, y2 = box.tolist() | |
detected_boxes.append([int(x1), int(y1), int(x2), int(y2)]) | |
return detected_boxes, f"β Found {len(detected_boxes)} object(s) for '{text_prompt}'" | |
except Exception as e: | |
return None, f"β Detection failed: {str(e)}" | |
def fix_image_array(image): | |
"""Fix image input for SAM processing - handles filepath, numpy array, or PIL Image""" | |
if isinstance(image, str): | |
# Handle filepath input from Gradio | |
return Image.open(image).convert("RGB") | |
elif isinstance(image, np.ndarray): | |
# Make sure array is contiguous | |
if not image.flags['C_CONTIGUOUS']: | |
image = np.ascontiguousarray(image) | |
# Ensure uint8 dtype | |
if image.dtype != np.uint8: | |
if image.max() <= 1.0: | |
image = (image * 255).astype(np.uint8) | |
else: | |
image = image.astype(np.uint8) | |
# Convert to PIL Image to avoid any stride issues | |
return Image.fromarray(image).convert("RGB") | |
elif isinstance(image, Image.Image): | |
return image.convert("RGB") | |
else: | |
raise ValueError(f"Unsupported image type: {type(image)}") | |
def apply_mask_post_processing(mask, stability_threshold=0.95): | |
"""Apply post-processing to refine mask size and quality""" | |
import cv2 | |
# Convert to binary mask | |
binary_mask = (mask > 0).astype(np.uint8) | |
# Apply morphological operations to clean up the mask | |
kernel_size = max(3, int(mask.shape[0] * 0.01)) # Adaptive kernel size | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) | |
# Close small holes | |
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) | |
# Remove small noise | |
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) | |
return binary_mask.astype(np.float32) | |
def apply_erosion_dilation(mask, erosion_dilation_value): | |
"""Apply erosion or dilation to adjust mask size""" | |
import cv2 | |
binary_mask = (mask > 0).astype(np.uint8) | |
if erosion_dilation_value == 0: | |
return mask | |
kernel_size = abs(erosion_dilation_value) | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) | |
if erosion_dilation_value > 0: | |
# Dilate (make larger) | |
binary_mask = cv2.dilate(binary_mask, kernel, iterations=1) | |
else: | |
# Erode (make smaller) | |
binary_mask = cv2.erode(binary_mask, kernel, iterations=1) | |
return binary_mask.astype(np.float32) | |
def save_binary_mask(mask, image_name, points, mask_threshold, erosion_dilation, save_low_res=False, custom_folder_name=None): | |
"""Save binary mask to organized folder structure""" | |
global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS | |
try: | |
# Store current state for saving | |
CURRENT_MASK = mask | |
CURRENT_IMAGE_NAME = image_name | |
CURRENT_POINTS = points | |
# Extract image name without extension and sanitize | |
if image_name: | |
base_name = os.path.splitext(os.path.basename(image_name))[0] | |
# Remove any path separators and special characters | |
base_name = base_name.replace('/', '_').replace('\\', '_').replace(':', '_').replace(' ', '_') | |
else: | |
base_name = f"image_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
# Choose folder tag: user-provided name if available, else 'default' | |
folder_tag = None | |
if custom_folder_name and str(custom_folder_name).strip(): | |
folder_tag = str(custom_folder_name).strip().replace(' ', '_') | |
else: | |
folder_tag = "default" | |
# Create folder structure: masks/<image_base>/<folder_tag>/ | |
folder_name = f"masks/{base_name}/{folder_tag}" | |
os.makedirs(folder_name, exist_ok=True) | |
# Create binary mask (0 and 255 values) | |
binary_mask = (mask > 0).astype(np.uint8) * 255 | |
# Calculate low resolution dimensions if requested | |
original_height, original_width = binary_mask.shape | |
if save_low_res: | |
# Calculate sqrt-based resolution | |
sqrt_factor = int(np.sqrt(max(original_width, original_height))) | |
low_res_width = sqrt_factor | |
low_res_height = sqrt_factor | |
print(f"Original mask size: {original_width}x{original_height}") | |
print(f"Low-res mask size: {low_res_width}x{low_res_height}") | |
# Save binary mask | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
# Sanitize filename - replace problematic characters | |
threshold_str = f"{mask_threshold:.2f}".replace('.', 'p') # 0.30 -> 0p30 | |
adj_str = f"{erosion_dilation:+d}".replace('+', 'plus').replace('-', 'minus') # +2 -> plus2, -2 -> minus2 | |
saved_paths = [] | |
# Save full resolution mask as JPEG with a simple filename | |
mask_filename = "image.jpg" | |
mask_path = os.path.join(folder_name, mask_filename) | |
mask_image = Image.fromarray(binary_mask, mode='L') | |
mask_image.save(mask_path, format="JPEG", quality=95, optimize=True) | |
saved_paths.append(mask_path) | |
# Save tensor mask (.pt) as float tensor (0.0/1.0) | |
tensor_filename = "image.pt" | |
tensor_path = os.path.join(folder_name, tensor_filename) | |
torch.save(torch.from_numpy((mask > 0).astype(np.float32)), tensor_path) | |
saved_paths.append(tensor_path) | |
# Save low resolution mask if requested | |
if save_low_res: | |
# Resize mask to low resolution | |
low_res_mask = mask_image.resize((low_res_width, low_res_height), Image.Resampling.NEAREST) | |
low_res_filename = f"mask_lowres_{sqrt_factor}x{sqrt_factor}_t{threshold_str}_adj{adj_str}_{timestamp}.png" | |
low_res_path = os.path.join(folder_name, low_res_filename) | |
low_res_mask.save(low_res_path) | |
saved_paths.append(low_res_path) | |
# Also save metadata | |
metadata = { | |
"timestamp": timestamp, | |
"points": points, | |
"mask_threshold": mask_threshold, | |
"erosion_dilation": erosion_dilation, | |
"image_name": image_name, | |
"original_resolution": f"{original_width}x{original_height}", | |
"saved_paths": saved_paths, | |
"low_resolution_saved": save_low_res | |
} | |
if save_low_res: | |
metadata["low_resolution"] = f"{low_res_width}x{low_res_height}" | |
metadata["sqrt_factor"] = sqrt_factor | |
import json | |
metadata_path = os.path.join(folder_name, f"metadata_{timestamp}.json") | |
with open(metadata_path, 'w') as f: | |
json.dump(metadata, f, indent=2) | |
# Return appropriate message | |
if save_low_res: | |
return f"β Masks saved:\nπ Full: {os.path.basename(mask_path)}\nπ Low-res: {os.path.basename(low_res_path)}" | |
else: | |
return f"β Mask saved to: {os.path.basename(mask_path)}" | |
except Exception as e: | |
return f"β Save failed: {str(e)}" | |
def process_sam_segmentation(image, points_data, bbox_data, mode, image_name=None, top_k=3, mask_threshold=0.0, stability_score_threshold=0.95, erosion_dilation=0, text_prompt=None, confidence_threshold=0.25): | |
"""Main processing function with mask size controls - supports points, bounding boxes, and text prompts""" | |
global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS | |
if image is None: | |
return None, None, "Please upload an image first." | |
# Check input based on mode | |
if mode == "Points": | |
if not points_data or len(points_data) == 0: | |
return None, None, "Please click on the image to select points." | |
elif mode == "Bounding Box": | |
if bbox_data is None: | |
return None, None, "Please click two corners to define a bounding box." | |
elif mode == "Text Prompt": | |
if not text_prompt or not text_prompt.strip(): | |
return None, None, "Please enter a text prompt to detect objects." | |
try: | |
# Initialize model | |
model, processor, device = initialize_sam() | |
# Fix image | |
pil_image = fix_image_array(image) | |
# Prepare SAM inputs based on mode | |
input_points = None | |
input_labels = None | |
input_boxes = None | |
points = None | |
if mode == "Points": | |
# Extract points with positive/negative labels | |
points = [] | |
labels = [] | |
for point_info in points_data: | |
if isinstance(point_info, dict): | |
points.append([point_info.get("x", 0), point_info.get("y", 0)]) | |
labels.append(1 if point_info.get("positive", True) else 0) # 1 = positive, 0 = negative | |
elif isinstance(point_info, (list, tuple)) and len(point_info) >= 2: | |
points.append([point_info[0], point_info[1]]) | |
labels.append(1) # Default to positive for old format | |
if not points: | |
return None, "No valid points found." | |
print(f"Processing {len(points)} points: {points} with labels: {labels}") | |
input_points = [[points]] | |
input_labels = [[labels]] | |
elif mode == "Bounding Box": | |
# Use bounding box | |
bbox = bbox_data # [x1, y1, x2, y2] | |
print(f"Processing bounding box: {bbox}") | |
input_boxes = [[bbox]] | |
# For visualization, store the bbox corners as points | |
points = [[bbox[0], bbox[1]], [bbox[2], bbox[3]]] | |
elif mode == "Text Prompt": | |
# Use Grounding DINO to detect objects from text prompt | |
detected_boxes, detection_status = detect_objects_with_text(pil_image, text_prompt, confidence_threshold) | |
if detected_boxes is None: | |
return None, None, detection_status | |
# Use the first detected bounding box (highest confidence) | |
bbox = detected_boxes[0] | |
print(f"Using detected bounding box: {bbox}") | |
input_boxes = [[bbox]] | |
# For visualization, store the bbox corners as points | |
points = [[bbox[0], bbox[1]], [bbox[2], bbox[3]]] | |
# Process with SAM | |
processor_inputs = { | |
"images": pil_image, | |
"return_tensors": "pt" | |
} | |
# Add points and/or boxes based on what's available | |
if input_points is not None: | |
processor_inputs["input_points"] = input_points | |
processor_inputs["input_labels"] = input_labels | |
if input_boxes is not None: | |
processor_inputs["input_boxes"] = input_boxes | |
inputs = processor(**processor_inputs).to(device) | |
# Generate masks with multiple outputs for better control | |
with torch.no_grad(): | |
outputs = model(**inputs, multimask_output=True) | |
# Get masks and scores | |
masks = processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"] | |
)[0] | |
scores = outputs.iou_scores.cpu().numpy().flatten() | |
# Get top-k masks and process all of them | |
top_indices = np.argsort(scores)[::-1][:top_k] | |
processed_masks = [] | |
mask_scores = [] | |
for i, idx in enumerate(top_indices): | |
mask = masks[0, idx].numpy() | |
score = scores[idx] | |
# Apply threshold to control mask size | |
if mask_threshold > 0: | |
mask = (mask > mask_threshold).astype(np.float32) | |
# Additional mask processing for size control | |
mask = apply_mask_post_processing(mask, stability_score_threshold) | |
# Apply erosion/dilation for fine size control | |
if erosion_dilation != 0: | |
mask = apply_erosion_dilation(mask, erosion_dilation) | |
processed_masks.append(mask) | |
mask_scores.append(score) | |
# Store current state for saving (use first mask as default) | |
CURRENT_MASK = processed_masks[0] | |
CURRENT_IMAGE_NAME = image_name | |
CURRENT_POINTS = points | |
# Create visualizations for the first mask | |
original_with_input = create_original_with_input_visualization(pil_image, points, bbox_data, mode) | |
mask_result = create_mask_visualization(pil_image, processed_masks[0], mask_scores[0], mask_threshold) | |
status = f"β Generated {len(processed_masks)} masks\nπ Use navigation to browse masks" | |
# Return multiple masks and related data | |
return original_with_input, mask_result, status, processed_masks, mask_scores | |
except Exception as e: | |
print(f"Error in processing: {e}") | |
return None, None, f"Error: {str(e)}" | |
def create_original_with_input_visualization(pil_image, points, bbox, mode, negative_points=None): | |
"""Create visualization of original image with input points/bbox overlay""" | |
# Convert PIL to numpy for matplotlib | |
img_array = np.array(pil_image) | |
fig, ax = plt.subplots(1, 1, figsize=(8, 6)) | |
# Show original image only | |
ax.imshow(img_array) | |
# Show input visualization based on mode | |
if mode == "Points": | |
total_points = 0 | |
# Show positive points (green) | |
if points: | |
for point in points: | |
ax.plot(point[0], point[1], 'go', markersize=12, markeredgewidth=3, markerfacecolor='lime') | |
total_points += len(points) | |
# Show negative points (red) | |
if negative_points: | |
for point in negative_points: | |
ax.plot(point[0], point[1], 'ro', markersize=12, markeredgewidth=3, markerfacecolor='red') | |
total_points += len(negative_points) | |
pos_count = len(points) if points else 0 | |
neg_count = len(negative_points) if negative_points else 0 | |
title_suffix = f"Points: {pos_count}+ {neg_count}-" if neg_count > 0 else f"Points: {pos_count}" | |
elif mode == "Bounding Box" and bbox: | |
# Show bounding box | |
x1, y1, x2, y2 = bbox | |
width = x2 - x1 | |
height = y2 - y1 | |
# Draw bounding box rectangle | |
from matplotlib.patches import Rectangle | |
rect = Rectangle((x1, y1), width, height, linewidth=3, edgecolor='lime', facecolor='none') | |
ax.add_patch(rect) | |
# Show corner points | |
ax.plot([x1, x2], [y1, y2], 'go', markersize=8, markeredgewidth=2, markerfacecolor='lime') | |
title_suffix = f"BBox: {int(width)}Γ{int(height)}" | |
else: | |
title_suffix = "No input" | |
ax.set_title(f"Input Selection ({title_suffix})", fontsize=14) | |
ax.axis('off') | |
# Convert to numpy array | |
fig.canvas.draw() | |
buf = fig.canvas.buffer_rgba() | |
result_array = np.asarray(buf) | |
# Convert RGBA to RGB | |
result_array = result_array[:, :, :3] | |
plt.close(fig) | |
return result_array | |
def create_mask_visualization(pil_image, mask, score, mask_threshold=0.0): | |
"""Create clean mask visualization without input overlays""" | |
# Convert PIL to numpy for matplotlib | |
img_array = np.array(pil_image) | |
fig, ax = plt.subplots(1, 1, figsize=(8, 6)) | |
# Show original image | |
ax.imshow(img_array) | |
# Overlay mask in red | |
mask_overlay = np.zeros((*mask.shape, 4)) | |
mask_overlay[mask > 0] = [1, 0, 0, 0.6] # Red with transparency | |
ax.imshow(mask_overlay) | |
ax.set_title(f"Generated Mask (Score: {float(score):.3f}, Threshold: {mask_threshold:.2f})", fontsize=14) | |
ax.axis('off') | |
# Convert to numpy array | |
fig.canvas.draw() | |
buf = fig.canvas.buffer_rgba() | |
result_array = np.asarray(buf) | |
# Convert RGBA to RGB | |
result_array = result_array[:, :, :3] | |
plt.close(fig) | |
return result_array | |
def create_interface(): | |
"""Create a simplified single-image annotator interface.""" | |
with gr.Blocks(title="SAM 2.1 - Simple Annotator", theme=gr.themes.Soft(), css=""" | |
.negative-mode-checkbox label { | |
color: #d00000 !important; | |
font-weight: 800 !important; | |
font-size: 16px !important; | |
} | |
""") as interface: | |
gr.HTML(""" | |
<div style="text-align: center;"> | |
<h1>π― AI-Powered Image Segmentation</h1> | |
<h2>SAM 2.1 + Grounding DINO</h2> | |
<p><strong>β¨ Just type what you want to segment!</strong> Try "person", "face", "car", "dog" - or click points manually.</p> | |
<p>π Generate multiple mask options and pick your favorite!</p> | |
<hr style="margin: 20px 0;"> | |
<p style="font-size: 12px; color: #666;"> | |
<strong>Acknowledgment:</strong> This is a GUI interface for research by Meta AI (SAM 2.1) and IDEA Research (Grounding DINO).<br> | |
All credit goes to the original researchers. This tool only provides an easy-to-use web interface. | |
</p> | |
</div> | |
""") | |
# Image input (single image) - directly annotate; this serves as uploader too | |
# Users can upload by clicking the annotatable image component below. | |
image_input = gr.Image( | |
label=None, | |
type="filepath", | |
height=0, | |
visible=False | |
) | |
# Text prompt input with clear button | |
with gr.Row(): | |
text_prompt_input = gr.Textbox( | |
label="π Text Prompt (Optional)", | |
placeholder="Type what to segment (e.g., 'person', 'car', 'dog') and press Enter", | |
value="snoopy", | |
interactive=True, | |
info="π‘ Text = auto-detection | Empty + clicking = manual points | Text takes priority if both provided", | |
scale=4 | |
) | |
clear_text_btn = gr.Button("ποΈ Clear Text", variant="secondary", scale=1) | |
# Number of masks to generate | |
num_masks = gr.Slider( | |
minimum=1, | |
maximum=5, | |
value=3, | |
step=1, | |
label="π Number of Masks to Generate", | |
info="Generate multiple mask options to choose from" | |
) | |
# Main layout: Selected Points on the left, annotatable image in the center, preview on the right | |
with gr.Row(): | |
with gr.Column(scale=1): | |
clear_points_btn = gr.Button("ποΈ Clear Points", variant="secondary", size="sm") | |
points_display = gr.JSON(label="π Selected Points", value=[], visible=True) | |
with gr.Column(scale=3): | |
# Negative mode toggle with clear red styling | |
negative_point_mode = gr.Checkbox( | |
label="β NEGATIVE POINT MODE", | |
value=False, | |
info="π΄ Enable to add negative points (shown in red)", | |
interactive=True, | |
elem_classes="negative-mode-checkbox" | |
) | |
original_with_input = gr.Image( | |
label="π Click to Annotate (toggle negative mode to exclude)", | |
height=640, | |
interactive=True, | |
value="data/snoopy.jpg" | |
) | |
with gr.Column(scale=1): | |
points_overlay = gr.Image(label="π Points Preview (green=positive, red=negative)", height=720, interactive=False) | |
# Action buttons | |
with gr.Row(): | |
generate_btn = gr.Button("π― Generate Mask", variant="primary", size="lg") | |
# Mask result with navigation | |
with gr.Row(): | |
mask_result = gr.Image(label="π Generated Mask", height=512) | |
# Mask navigation controls | |
with gr.Row(): | |
prev_mask_btn = gr.Button("β¬ οΈ Previous", variant="secondary", size="sm") | |
mask_info = gr.Textbox( | |
label="Mask Info", | |
value="No masks generated yet", | |
interactive=False, | |
scale=2 | |
) | |
next_mask_btn = gr.Button("β‘οΈ Next", variant="secondary", size="sm") | |
# Save controls under mask | |
with gr.Row(): | |
mask_name_input = gr.Textbox(label="Folder name (optional)", placeholder="e.g., Glasses", value="Glasses", scale=2) | |
format_selector = gr.Radio( | |
choices=["PNG", "JPG", "PT"], | |
value="PNG", | |
label="π Download Format", | |
scale=1 | |
) | |
save_btn = gr.Button("πΎ Prepare for saving", variant="stop", size="lg", scale=1) | |
# Status and Download | |
with gr.Row(): | |
status_text = gr.Textbox(label="π Status", interactive=False, lines=3, scale=2) | |
download_file = gr.File(label="π₯ Download", visible=False, scale=1) | |
# State to store points and masks | |
points_state = gr.State([]) | |
masks_data = gr.State({"masks": [], "scores": [], "image": None}) # Store all mask data | |
current_mask_index = gr.State(0) # Current mask being viewed | |
# Event handlers | |
def on_image_click(image, current_points, negative_mode, evt: gr.SelectData): | |
"""Handle clicks on the image for point annotations only.""" | |
if evt.index is not None and image is not None: | |
x, y = evt.index | |
try: | |
pil_image = fix_image_array(image) | |
is_negative = negative_mode | |
new_point = {"x": int(x), "y": int(y), "positive": not is_negative} | |
updated_points = current_points + [new_point] | |
positive_points = [[p["x"], p["y"]] for p in updated_points if p.get("positive", True)] | |
negative_points = [[p["x"], p["y"]] for p in updated_points if not p.get("positive", True)] | |
updated_visualization = create_original_with_input_visualization( | |
pil_image, positive_points, None, "Points", negative_points | |
) | |
point_type = "positive" if not is_negative else "negative" | |
pos_count = len(positive_points) | |
neg_count = len(negative_points) | |
return updated_points, updated_points, updated_visualization, ( | |
f"Added {point_type} point at ({x}, {y}). Total: {pos_count} positive, {neg_count} negative points." | |
) | |
except Exception as e: | |
print(f"Error in visualization: {e}") | |
return current_points, current_points, None, f"Error updating visualization: {str(e)}" | |
return current_points, current_points, None, "Click on the image to add points." | |
def on_image_upload(image): | |
"""Handle image upload and show it for annotation.""" | |
if image is not None: | |
try: | |
pil_image = fix_image_array(image) | |
img_array = np.array(pil_image) | |
# Populate both the annotation image (left) and the points preview (right) | |
return img_array, img_array, [], [], "Image uploaded. Click on the left image to add points (enable negative mode for exclusion)." | |
except Exception as e: | |
return None, None, [], [], f"Error loading image: {str(e)}" | |
return None, None, [], [], "No image uploaded." | |
def clear_all_points(image): | |
"""Clear points and keep the image visible for annotation.""" | |
try: | |
if image is not None: | |
pil_image = fix_image_array(image) | |
img_array = np.array(pil_image) | |
return [], [], img_array, img_array, None, "All points cleared. You can continue annotating." | |
except Exception: | |
pass | |
return [], [], None, None, None, "All points cleared." | |
def clear_text_prompt(): | |
"""Clear the text prompt.""" | |
return "", "Text prompt cleared. You can now use manual points." | |
def generate_segmentation(image, points, text_prompt, num_masks_to_generate): | |
"""Generate multiple segmentation masks - auto-detects input type.""" | |
# Determine image name | |
if isinstance(image, str): | |
image_name = os.path.basename(image) | |
else: | |
# Prefer an explicit friendly default if metadata lacks a good name | |
image_name = None | |
if hasattr(image, 'orig_name'): | |
image_name = image.orig_name | |
elif isinstance(image, dict) and 'orig_name' in image: | |
image_name = image['orig_name'] | |
elif hasattr(image, 'name'): | |
image_name = image.name | |
if not image_name or 'tmp' in str(image_name).lower() or 'uploaded_image' in str(image_name).lower(): | |
image_name = "michael_phelps_bottom_left.jpg" | |
# Auto-detect input type and run segmentation | |
has_text = text_prompt and text_prompt.strip() | |
has_points = points and len(points) > 0 | |
if has_text and has_points: | |
# Combine text detection with manual point refinement | |
status_info = "π― Combining text detection with manual point refinement" | |
# First, detect with text to get initial bounding box | |
detected_boxes, detection_status = detect_objects_with_text(image, text_prompt, 0.25) | |
if detected_boxes: | |
# Use the detected bounding box AND manual points together | |
bbox = detected_boxes[0] # Use first detection as guidance | |
# Process with both bounding box and points | |
# The points will be used to refine the segmentation within the detected area | |
_, mask_img, status, masks, scores = process_sam_segmentation( | |
image, points, bbox, "Points", image_name, int(num_masks_to_generate), 0.0, 0.95, 0, None, 0.25 | |
) | |
status = f"{status_info}\nβ Text: {detection_status}\nβ Using {len(points)} manual points for refinement\n{status}" | |
masks_data_dict = {"masks": masks, "scores": scores, "image": image} | |
return mask_img, status, masks_data_dict, 0, f"Mask 1 of {len(masks)} (Score: {scores[0]:.3f})" | |
else: | |
# Fall back to points only if text detection fails | |
_, mask_img, status, masks, scores = process_sam_segmentation( | |
image, points, None, "Points", image_name, int(num_masks_to_generate), 0.0, 0.95, 0, None, 0.25 | |
) | |
status = f"π Text detection failed, using {len(points)} manual points only\n{status}" | |
masks_data_dict = {"masks": masks, "scores": scores, "image": image} | |
return mask_img, status, masks_data_dict, 0, f"Mask 1 of {len(masks)} (Score: {scores[0]:.3f})" | |
elif has_text: | |
# Use text prompt | |
_, mask_img, status, masks, scores = process_sam_segmentation( | |
image, None, None, "Text Prompt", image_name, int(num_masks_to_generate), 0.0, 0.95, 0, text_prompt, 0.25 | |
) | |
masks_data_dict = {"masks": masks, "scores": scores, "image": image} | |
return mask_img, status, masks_data_dict, 0, f"Mask 1 of {len(masks)} (Score: {scores[0]:.3f})" | |
elif has_points: | |
# Use points | |
_, mask_img, status, masks, scores = process_sam_segmentation( | |
image, points, None, "Points", image_name, int(num_masks_to_generate), 0.0, 0.95, 0, None, 0.25 | |
) | |
masks_data_dict = {"masks": masks, "scores": scores, "image": image} | |
return mask_img, status, masks_data_dict, 0, f"Mask 1 of {len(masks)} (Score: {scores[0]:.3f})" | |
else: | |
return None, "β Please either enter a text prompt or click points on the image.", {"masks": [], "scores": [], "image": None}, 0, "No masks generated" | |
def navigate_mask(direction, current_index, masks_data): | |
"""Navigate through generated masks""" | |
masks = masks_data.get("masks", []) | |
scores = masks_data.get("scores", []) | |
image = masks_data.get("image", None) | |
if not masks or len(masks) == 0: | |
return None, current_index, "No masks available" | |
# Calculate new index | |
if direction == "next": | |
new_index = (current_index + 1) % len(masks) | |
else: # previous | |
new_index = (current_index - 1) % len(masks) | |
# Get the mask at new index | |
mask = masks[new_index] | |
score = scores[new_index] | |
# Update global state for saving | |
global CURRENT_MASK | |
CURRENT_MASK = mask | |
# Create visualization | |
if image is not None: | |
pil_image = fix_image_array(image) | |
mask_visualization = create_mask_visualization(pil_image, mask, score, 0.0) | |
else: | |
mask_visualization = None | |
mask_info_text = f"Mask {new_index + 1} of {len(masks)} (Score: {score:.3f})" | |
return mask_visualization, new_index, mask_info_text | |
def save_and_download_mask(custom_folder_name, download_format): | |
"""Save mask locally and prepare download for user.""" | |
global CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS | |
if CURRENT_MASK is None: | |
return "β No mask to save. Generate a mask first.", None | |
if CURRENT_POINTS is None: | |
return "β No points available. Generate a mask first.", None | |
try: | |
# Save locally (keep existing hierarchy) | |
local_save_status = save_binary_mask( | |
CURRENT_MASK, CURRENT_IMAGE_NAME, CURRENT_POINTS, | |
0.0, 0, False, custom_folder_name=(custom_folder_name or None) | |
) | |
# Create download file | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
base_name = os.path.splitext(os.path.basename(CURRENT_IMAGE_NAME or "mask"))[0] | |
if download_format == "PNG": | |
# Create PNG for download | |
binary_mask = (CURRENT_MASK > 0).astype(np.uint8) * 255 | |
mask_image = Image.fromarray(binary_mask, mode='L') | |
download_path = f"/tmp/mask_{base_name}_{timestamp}.png" | |
mask_image.save(download_path, format="PNG") | |
elif download_format == "JPG": | |
# Create JPG for download | |
binary_mask = (CURRENT_MASK > 0).astype(np.uint8) * 255 | |
mask_image = Image.fromarray(binary_mask, mode='L') | |
download_path = f"/tmp/mask_{base_name}_{timestamp}.jpg" | |
mask_image.save(download_path, format="JPEG", quality=95) | |
elif download_format == "PT": | |
# Create PyTorch tensor for download | |
download_path = f"/tmp/mask_{base_name}_{timestamp}.pt" | |
torch.save(torch.from_numpy((CURRENT_MASK > 0).astype(np.float32)), download_path) | |
# Make download visible and return file | |
download_status = f"π₯ Download ready: {download_format} format" | |
return download_status, gr.File(value=download_path, visible=True) | |
except Exception as e: | |
return f"β Save/download failed: {str(e)}", None | |
# Wire events | |
# Let the annotatable image also handle image uploads (drag & drop / click upload) | |
original_with_input.upload( | |
on_image_upload, | |
inputs=[original_with_input], | |
outputs=[original_with_input, points_overlay, points_state, points_display, status_text] | |
) | |
original_with_input.select( | |
on_image_click, | |
inputs=[original_with_input, points_state, negative_point_mode], | |
outputs=[points_state, points_display, points_overlay, status_text] | |
) | |
# Generate button and Enter key support | |
generate_btn.click( | |
generate_segmentation, | |
inputs=[original_with_input, points_state, text_prompt_input, num_masks], | |
outputs=[mask_result, status_text, masks_data, current_mask_index, mask_info] | |
) | |
# Enter key support for text prompt | |
text_prompt_input.submit( | |
generate_segmentation, | |
inputs=[original_with_input, points_state, text_prompt_input, num_masks], | |
outputs=[mask_result, status_text, masks_data, current_mask_index, mask_info] | |
) | |
# Mask navigation | |
prev_mask_btn.click( | |
lambda idx, data: navigate_mask("prev", idx, data), | |
inputs=[current_mask_index, masks_data], | |
outputs=[mask_result, current_mask_index, mask_info] | |
) | |
next_mask_btn.click( | |
lambda idx, data: navigate_mask("next", idx, data), | |
inputs=[current_mask_index, masks_data], | |
outputs=[mask_result, current_mask_index, mask_info] | |
) | |
clear_points_btn.click( | |
clear_all_points, | |
inputs=[original_with_input], | |
outputs=[points_state, points_display, points_overlay, original_with_input, mask_result, status_text] | |
) | |
clear_text_btn.click( | |
clear_text_prompt, | |
outputs=[text_prompt_input, status_text] | |
) | |
save_btn.click( | |
save_and_download_mask, | |
inputs=[mask_name_input, format_selector], | |
outputs=[status_text, download_file] | |
) | |
return interface | |
def main(): | |
"""Main function""" | |
print("π Starting Fixed SAM 2.1 Interface...") | |
interface = create_interface() | |
print("π Launching web interface...") | |
print("π Click on objects in images to segment them!") | |
interface.launch( | |
server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7860)), | |
share=True, # Enable public sharing | |
inbrowser=False, # Don't auto-open browser in server environment | |
show_error=True, | |
server_name="0.0.0.0", # Allow external connections | |
auth=None # No authentication for public access | |
) | |
if __name__ == "__main__": | |
main() |