vk
hydra dependency removed
b596a1b
import torch
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import cv2
from glob import glob
import gradio as gr
import os
def show_example(path):
return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
def overlay_masks_on_image(image, anns, borders=True):
"""
Overlays segmentation masks from 'anns' on top of 'image'.
Parameters:
image: np.ndarray (H, W, 3) β€” source RGB image
anns: list of dicts β€” each with a 'segmentation' key containing a boolean mask
borders: bool β€” whether to draw contours
show_mask: bool β€” whether to show each mask separately
Returns:
masked_image: np.ndarray (H, W, 3) β€” image with overlays
"""
if len(anns) == 0:
return image
# Copy image to avoid modifying original
masked_image = image.copy().astype(np.float32) / 255.0
sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
for ann in sorted_anns:
m = ann['segmentation'].astype(bool)
color_mask = np.random.random(3) # RGB color
alpha = 0.5 # transparency
# Blend mask with source image
for c in range(3): # RGB channels
masked_image[:, :, c] = np.where(
m,
(1 - alpha) * masked_image[:, :, c] + alpha * color_mask[c],
masked_image[:, :, c]
)
if borders:
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
contours = [cv2.approxPolyDP(contour, epsilon=0.01 * cv2.arcLength(contour, True), closed=True)
for contour in contours]
cv2.drawContours(masked_image, contours, -1, color=(0, 0, 1), thickness=1)
return (masked_image * 255).astype(np.uint8)
def get_response(image):
image = np.array(image.convert("RGB"))
masks = mask_generator.generate(image)
return overlay_masks_on_image(image,masks)
def download_checkpoint():
os.system('gdown 1RHSO8lHko3IK3dmABOzFDJuq7wmKVcun')
if __name__ == "__main__":
iface = gr.Interface(
cache_examples=False,
fn=get_response,
inputs=[gr.Image(type="pil")], # Accepts image input
examples=[[show_example('test-images/5fc8c5b53c.png')],[show_example('test-images/80719af02f.png')],[show_example('test-images/f32c7bd62b.png')]],
outputs=[gr.Image(type="numpy")],
title="Segmenting Microscopic images with Segment Anything",
description="Segmenting Microscopic images with Meta Segment Anything")
model_path='model.pth'
if not os.path.exists(model_path):
print('Downloading model with weights')
download_checkpoint()
print('Model with weights Downloaded')
model = torch.load(model_path, map_location="cpu", weights_only=False)
mask_generator = SAM2AutomaticMaskGenerator(model)
iface.launch()