|
import torch |
|
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
|
from PIL import Image |
|
from matplotlib import pyplot as plt |
|
import numpy as np |
|
import cv2 |
|
from glob import glob |
|
import gradio as gr |
|
import os |
|
|
|
|
|
|
|
def show_example(path): |
|
return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
def overlay_masks_on_image(image, anns, borders=True): |
|
""" |
|
Overlays segmentation masks from 'anns' on top of 'image'. |
|
|
|
Parameters: |
|
image: np.ndarray (H, W, 3) β source RGB image |
|
anns: list of dicts β each with a 'segmentation' key containing a boolean mask |
|
borders: bool β whether to draw contours |
|
show_mask: bool β whether to show each mask separately |
|
|
|
Returns: |
|
masked_image: np.ndarray (H, W, 3) β image with overlays |
|
""" |
|
if len(anns) == 0: |
|
return image |
|
|
|
|
|
masked_image = image.copy().astype(np.float32) / 255.0 |
|
|
|
sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True) |
|
|
|
for ann in sorted_anns: |
|
m = ann['segmentation'].astype(bool) |
|
color_mask = np.random.random(3) |
|
alpha = 0.5 |
|
|
|
|
|
|
|
|
|
for c in range(3): |
|
masked_image[:, :, c] = np.where( |
|
m, |
|
(1 - alpha) * masked_image[:, :, c] + alpha * color_mask[c], |
|
masked_image[:, :, c] |
|
) |
|
|
|
if borders: |
|
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
|
contours = [cv2.approxPolyDP(contour, epsilon=0.01 * cv2.arcLength(contour, True), closed=True) |
|
for contour in contours] |
|
cv2.drawContours(masked_image, contours, -1, color=(0, 0, 1), thickness=1) |
|
|
|
return (masked_image * 255).astype(np.uint8) |
|
|
|
def get_response(image): |
|
image = np.array(image.convert("RGB")) |
|
masks = mask_generator.generate(image) |
|
return overlay_masks_on_image(image,masks) |
|
|
|
def download_checkpoint(): |
|
os.system('gdown 1RHSO8lHko3IK3dmABOzFDJuq7wmKVcun') |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
iface = gr.Interface( |
|
cache_examples=False, |
|
fn=get_response, |
|
inputs=[gr.Image(type="pil")], |
|
examples=[[show_example('test-images/5fc8c5b53c.png')],[show_example('test-images/80719af02f.png')],[show_example('test-images/f32c7bd62b.png')]], |
|
outputs=[gr.Image(type="numpy")], |
|
title="Segmenting Microscopic images with Segment Anything", |
|
description="Segmenting Microscopic images with Meta Segment Anything") |
|
|
|
model_path='model.pth' |
|
if not os.path.exists(model_path): |
|
print('Downloading model with weights') |
|
download_checkpoint() |
|
print('Model with weights Downloaded') |
|
|
|
model = torch.load(model_path, map_location="cpu", weights_only=False) |
|
mask_generator = SAM2AutomaticMaskGenerator(model) |
|
|
|
iface.launch() |