Spaces:
Runtime error
Runtime error
| from typing import IO, List | |
| import cv2 | |
| import torch | |
| from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator | |
| from PIL import Image | |
| import numpy as np | |
| import io | |
| def to_file(item) -> IO[bytes]: | |
| # Create a BytesIO object | |
| file_obj = io.BytesIO() | |
| if isinstance(item, Image.Image): | |
| item.save(file_obj, format='PNG') | |
| if isinstance(item, np.ndarray): | |
| np.save(file_obj, item) | |
| # Reset the file object's position to the beginning | |
| file_obj.seek(0) | |
| # Return the file object | |
| return file_obj | |
| def get_sam(model_type, checkpoint_path, device=None): | |
| if device is None: | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| sam = sam_model_registry[model_type](checkpoint=checkpoint_path) | |
| sam.to(device=device) | |
| return sam | |
| def draw_mask(img: Image.Image, boolean_mask: np.ndarray, color: tuple, mask_alpha: float) -> Image.Image: | |
| int_alpha = int(mask_alpha*255) | |
| color_mask = Image.new('RGBA', img.size, color=color) | |
| color_mask.putalpha(Image.fromarray(boolean_mask.astype(np.uint8)*int_alpha, mode='L')) | |
| result = Image.alpha_composite(img, color_mask) | |
| return result | |
| def random_color(): | |
| return tuple(np.random.randint(0,255, 3)) | |
| def draw_masks(img: Image.Image, boolean_masks: np.ndarray) -> Image.Image: | |
| img = img.copy() | |
| for boolean_mask in boolean_masks: | |
| img = draw_mask(img, boolean_mask, random_color(), 0.2) | |
| return img | |
| def cutout(img: Image.Image, boolean_mask: np.ndarray): | |
| rgba_img = img.convert('RGBA') | |
| mask = Image.fromarray(boolean_mask).convert("L") | |
| rgba_img.putalpha(mask) | |
| return rgba_img | |
| def predict_conditioned(sam, pil_img, **kwargs): | |
| rgb_arr = pil_image_to_rgb_array(pil_img) | |
| predictor = SamPredictor(sam) | |
| predictor.set_image(rgb_arr) | |
| masks, quality, _ = predictor.predict(**kwargs) | |
| return masks, quality | |
| def predict_all(sam, pil_img): | |
| rgb_arr = pil_image_to_rgb_array(pil_img) | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| results = mask_generator.generate(rgb_arr) | |
| masks = [] | |
| quality = [] | |
| for result in results: | |
| masks.append(result['segmentation']) | |
| quality.append(result['stability_score']) | |
| masks = np.array(masks) | |
| quality = np.array(quality) | |
| return masks, quality | |
| def pil_image_to_rgb_array(image): | |
| if image.mode == "RGBA": | |
| rgb_image = Image.new("RGB", image.size, (255, 255, 255)) | |
| rgb_image.paste(image, mask=image.split()[3]) # Apply alpha channel as the mask | |
| rgb_array = np.array(rgb_image) | |
| else: | |
| rgb_array = np.array(image.convert("RGB")) | |
| return rgb_array | |
| def box_pts_to_xyxy(pt1, pt2): | |
| """convert box from pts format to XYXY | |
| Args: | |
| pt1 : (x1, y1) first corner of a box | |
| pt2 : (x2, y2) second corner, diagonal to pt1 | |
| Returns: | |
| xyxy: (x_min, y_min, x_max, y_max) | |
| """ | |
| x1, y1 = pt1 | |
| x2, y2 = pt2 | |
| return (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)) | |
| def crop_empty(image:Image.Image): | |
| # Convert image to numpy array | |
| np_image = np.array(image) | |
| # Find non-transparent pixels | |
| non_transparent_pixels = np_image[:, :, 3] > 0 | |
| # Calculate bounding box coordinates | |
| rows = np.any(non_transparent_pixels, axis=1) | |
| cols = np.any(non_transparent_pixels, axis=0) | |
| ymin, ymax = np.where(rows)[0][[0, -1]] | |
| xmin, xmax = np.where(cols)[0][[0, -1]] | |
| # Crop the image | |
| cropped_image = np_image[ymin:ymax+1, xmin:xmax+1, :] | |
| # Convert cropped image back to PIL image | |
| pil_image = Image.fromarray(cropped_image) | |
| return pil_image |