import torch from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from PIL import Image from matplotlib import pyplot as plt import numpy as np import cv2 from glob import glob import gradio as gr import os def show_example(path): return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) def overlay_masks_on_image(image, anns, borders=True): """ Overlays segmentation masks from 'anns' on top of 'image'. Parameters: image: np.ndarray (H, W, 3) — source RGB image anns: list of dicts — each with a 'segmentation' key containing a boolean mask borders: bool — whether to draw contours show_mask: bool — whether to show each mask separately Returns: masked_image: np.ndarray (H, W, 3) — image with overlays """ if len(anns) == 0: return image # Copy image to avoid modifying original masked_image = image.copy().astype(np.float32) / 255.0 sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True) for ann in sorted_anns: m = ann['segmentation'].astype(bool) color_mask = np.random.random(3) # RGB color alpha = 0.5 # transparency # Blend mask with source image for c in range(3): # RGB channels masked_image[:, :, c] = np.where( m, (1 - alpha) * masked_image[:, :, c] + alpha * color_mask[c], masked_image[:, :, c] ) if borders: contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) contours = [cv2.approxPolyDP(contour, epsilon=0.01 * cv2.arcLength(contour, True), closed=True) for contour in contours] cv2.drawContours(masked_image, contours, -1, color=(0, 0, 1), thickness=1) return (masked_image * 255).astype(np.uint8) def get_response(image): image = np.array(image.convert("RGB")) masks = mask_generator.generate(image) return overlay_masks_on_image(image,masks) def download_checkpoint(): os.system('gdown 1RHSO8lHko3IK3dmABOzFDJuq7wmKVcun') if __name__ == "__main__": iface = gr.Interface( cache_examples=False, fn=get_response, inputs=[gr.Image(type="pil")], # Accepts image input examples=[[show_example('test-images/5fc8c5b53c.png')],[show_example('test-images/80719af02f.png')],[show_example('test-images/f32c7bd62b.png')]], outputs=[gr.Image(type="numpy")], title="Segmenting Microscopic images with Segment Anything", description="Segmenting Microscopic images with Meta Segment Anything") model_path='model.pth' if not os.path.exists(model_path): print('Downloading model with weights') download_checkpoint() print('Model with weights Downloaded') model = torch.load(model_path, map_location="cpu", weights_only=False) mask_generator = SAM2AutomaticMaskGenerator(model) iface.launch()