Spaces:
Running
Running
import os | |
import time | |
import numpy as np | |
from tqdm import tqdm | |
from PIL import Image, ImageDraw | |
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation | |
import torch | |
import cv2 | |
def dilate_image_mask(image_mask: Image, dilate_siz=50): | |
# Convert the PIL image to a NumPy array | |
image_np = np.array(image_mask) | |
kernel = np.ones((dilate_siz, dilate_siz),np.uint8) | |
dilated_image_np = cv2.dilate(image_np, kernel, iterations = 1) | |
# Convert the expanded NumPy array back to PIL format | |
dilated_image = Image.fromarray(dilated_image_np) | |
return dilated_image | |
def get_foreground_image(image: Image, mask_array: np.ndarray): | |
"""Returns a PIL RGBA image with the mask applied to the original image.""" | |
# resize the overlay mask to the original image size | |
resized_mask = Image.fromarray(mask_array.astype(np.uint8)).resize(image.size) | |
resized_mask = np.array(resized_mask) | |
image_array = np.array(image) | |
# Apply binary mask element-wise using NumPy for each color channel | |
fg_array = image_array * resized_mask[:, :, np.newaxis] | |
# Create a new ndarray with 4 channels (R, G, B, A) | |
result_array = np.zeros((*fg_array.shape[:2], 4), dtype=np.uint8) | |
# Assign RGB values from the original image | |
result_array[:, :, :3] = fg_array | |
# Assign alpha values from the resized mask | |
result_array[:, :, 3] = resized_mask*255 | |
result_image = Image.fromarray(result_array, mode='RGBA') | |
return result_image | |
def overlay_mask_on_image(image: Image, mask_array: np.ndarray, alpha=0.5): | |
original_image = image | |
overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 0)) | |
# resize the overlay mask to the original image size | |
overlay_mask = Image.fromarray(mask_array.astype(np.uint8)*255).resize(original_image.size, resample=Image.LANCZOS) | |
# dilates the mask a bit to cover the edges of the objects | |
dilate_image_mask(overlay_mask, dilate_siz=50) | |
# Apply the overlay color to the overlayed array | |
overlay_color = (0, 240, 0, int(255*alpha)) # RGBA | |
draw = ImageDraw.Draw(overlay_image) | |
draw.bitmap((0, 0), overlay_mask, fill=overlay_color) | |
result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image) | |
return result_image | |
def filter_segment_classes(segmentation, filter_classes, mode='filt_out') -> np.ndarray: | |
""" Returns a boolean mask removing the values in filter_classes from the segmentation array. | |
mode: 'filt_out' - filter out the classes in filter_classes | |
'filt_in' - keeps only the classes in filter_classes | |
""" | |
# Create a boolean mask removing the values in filter_classes | |
if mode=='filt_out': | |
overlay_mask = ~np.isin(segmentation, filter_classes) | |
elif mode=='filt_in': | |
overlay_mask = np.isin(segmentation, filter_classes) | |
else: | |
raise ValueError(f'Invalid mode: {mode}') | |
return overlay_mask | |
class Mask2FormerSegmenter: | |
def __init__(self): | |
self.processor = None | |
self.model = None | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# TODO - train a classifier to learn this from the dataset | |
# - classes that appear much less frequently are good candidates | |
self.filter_classes = [0,1,2,3,5,6,10,11,12,13,14,15,18,19,22,24,36,38,40,45,46,47,69,105,128] | |
def load_models(self, checkpoint_name): | |
self.processor = AutoImageProcessor.from_pretrained(checkpoint_name) | |
self.model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint_name) | |
self.model.to(self.device) | |
def run_semantic_inference(self, image, model, processor)-> torch.Tensor: | |
"""Runs semantic segmentation inference on a single image file.""" | |
if (model is None) or (processor is None): | |
raise ValueError(f'Model or Processor not loaded.') | |
funcstart_time = time.time() | |
inputs = processor(image, return_tensors="pt") | |
inputs = inputs.to(self.device) | |
#Forward pass - to segment the image | |
outputs = model(**inputs) | |
#meaures the time taken for the processing and forward pass | |
model_time = time.time() - funcstart_time | |
print(f'Model time: {model_time:.2f}') | |
#Post Processing - Semantic Segmentation | |
semantic_segmentation = processor.post_process_semantic_segmentation(outputs)[0] | |
return semantic_segmentation | |
def batch_inference_demo(self, dirpath): | |
# List image files in the input directory | |
image_files = [file for file in os.listdir(dirpath) if file.lower().endswith(('.jpg', '.jpeg', '.png'))] | |
for file in tqdm(image_files, desc="Processing images"): | |
filepath = os.path.join(dirpath, file) | |
image = Image.open(filepath) | |
semantic_segmentation = self.run_semantic_inference(image, self.model, self.processor) | |
labels_ids = torch.unique(semantic_segmentation).tolist() | |
valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes] | |
print(f'{os.path.basename(file)}: {valid_ids}') | |
# filter out the classes in filter_classes | |
binary_mask = filter_segment_classes(semantic_segmentation.numpy(), self.filter_classes) | |
overlaid_img = overlay_mask_on_image(image, binary_mask) | |
foreground_img = get_foreground_image(image, binary_mask) | |
mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(image.size) | |
# dilates the mask a bit | |
mask_img = dilate_image_mask(mask_img, dilate_siz=50) | |
#saves the images in the results folder | |
outp_folder = 'results/mask2former_masked' | |
overlaid_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_overlay.png") | |
foreground_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_foreground.png") | |
mask_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_mask.png") | |
def retrieve_fg_image_and_mask(self, input_image: Image, | |
dilate_siz=50, | |
verbose=False | |
) -> (Image, Image): | |
"""Generetes a RGBA image with the foreground objects of the input image | |
and a binary mask for the given image file. | |
input_image: PIL image | |
dilate_siz: size in pixels of the dilation kernel to aply on the objects' mask | |
verbose: if True, prints the list of classes in the image that have not been filtered | |
returns: foreground_img (RGBA), mask_img (L) | |
""" | |
# runs the semantic segmentation model | |
semantic_segmentation = self.run_semantic_inference(input_image, | |
self.model, | |
self.processor) | |
semantic_segmentation = semantic_segmentation.cpu() | |
if (verbose): | |
labels_ids = torch.unique(semantic_segmentation).tolist() | |
valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes] | |
print(f'valid classes detected: {valid_ids}') | |
# filter out the classes in filter_classes | |
binary_mask = filter_segment_classes(semantic_segmentation.numpy(), | |
self.filter_classes) | |
foreground_img = get_foreground_image(input_image, binary_mask) | |
mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(input_image.size, resample=Image.LANCZOS) | |
# dilates the mask a bit to cover the edges of the objects. This helps the inpainting model | |
mask_img = dilate_image_mask(mask_img, dilate_siz=dilate_siz) | |
return foreground_img, mask_img | |