Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2024 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """A pipeline for segmenting objects using the SAM model.""" | |
| # Copyright 2024 The Google Research Authors. | |
| # This file is based on the SAM (Segment Anything) and HQ-SAM. | |
| # | |
| # https://github.com/facebookresearch/segment-anything | |
| # https://github.com/SysCV/sam-hq/tree/main | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # pylint: disable=all | |
| # pylint: disable=g-importing-member | |
| import os | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from sam.utils import show_anns | |
| from sam.utils import show_box | |
| from sam.utils import show_mask | |
| from sam.utils import show_points | |
| from segment_anything_hq import sam_model_registry | |
| from segment_anything_hq import SamAutomaticMaskGenerator | |
| from segment_anything_hq import SamPredictor | |
| class SAMPipeline: | |
| def __init__( | |
| self, | |
| checkpoint, | |
| model_type, | |
| device="cuda:0", | |
| points_per_side=32, | |
| pred_iou_thresh=0.88, | |
| stability_score_thresh=0.95, | |
| box_nms_thresh=0.7, | |
| ): | |
| self.checkpoint = checkpoint | |
| self.model_type = model_type | |
| self.device = device | |
| self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint) | |
| self.sam.to(device=self.device) | |
| self.load_mask_generator( | |
| points_per_side=points_per_side, | |
| pred_iou_thresh=pred_iou_thresh, | |
| stability_score_thresh=stability_score_thresh, | |
| box_nms_thresh=box_nms_thresh, | |
| ) | |
| # Default Prompt Args | |
| self.click_args = {"k": 5, "order": "max", "how_filter": "median"} | |
| self.box_args = None | |
| def load_sam(self): | |
| print("Loading SAM") | |
| sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint) | |
| sam.to(device=self.device) | |
| self.predictor = SamPredictor(sam) | |
| print("Loading Done") | |
| def load_mask_generator( | |
| self, | |
| points_per_side, | |
| pred_iou_thresh, | |
| stability_score_thresh, | |
| box_nms_thresh, | |
| ): | |
| print("Loading SAM") | |
| self.mask_generator = SamAutomaticMaskGenerator( | |
| model=self.sam, | |
| points_per_side=points_per_side, | |
| pred_iou_thresh=pred_iou_thresh, | |
| stability_score_thresh=stability_score_thresh, | |
| box_nms_thresh=box_nms_thresh, | |
| crop_n_layers=0, | |
| crop_n_points_downscale_factor=1, | |
| ) | |
| print("Loading Done") | |
| # segment single object | |
| def segment_image_single( | |
| self, | |
| image_path, | |
| input_point=None, | |
| input_label=None, | |
| input_box=None, | |
| input_mask=None, | |
| multimask_output=True, | |
| visualize=False, | |
| save_path=None, | |
| fname="", | |
| image=None, | |
| ): | |
| if image is None: | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| self.predictor.set_image(image) | |
| masks, scores, logits = self.predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| box=input_box, | |
| mask_input=None, | |
| multimask_output=multimask_output, | |
| ) | |
| if visualize: | |
| self.visualize( | |
| image, | |
| masks, | |
| scores, | |
| save_path, | |
| input_point=input_point, | |
| input_label=input_label, | |
| input_box=input_box, | |
| input_mask=input_mask, | |
| fname=fname, | |
| ) | |
| return masks, scores, logits | |
| def segment_automask( | |
| self, | |
| image_path, | |
| visualize=False, | |
| save_path=None, | |
| image=None, | |
| fname="automask.jpg", | |
| ): | |
| if image is None: | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| mask_list, bbox_list = [], [] | |
| masks = self.mask_generator.generate(image) | |
| mask_list.extend([mask["segmentation"] for mask in masks]) | |
| bbox_list.extend([mask["bbox"] for mask in masks]) | |
| if visualize: | |
| self.visualize_automask(image, masks, save_path, fname=fname) | |
| masks_arr, bbox_arr = np.array(mask_list), np.array(bbox_list) | |
| return masks_arr, bbox_arr, masks | |
| def visualize_automask(self, image, masks, save_path, fname="mask.jpg"): | |
| if not os.path.exists(save_path): | |
| os.makedirs(save_path) | |
| plt.figure(figsize=(20, 20)) | |
| plt.imshow(image) | |
| show_anns(masks) | |
| plt.axis("off") | |
| plt.savefig(os.path.join(save_path, fname)) | |
| def visualize( | |
| self, | |
| image, | |
| masks, | |
| scores, | |
| save_path, | |
| input_point=None, | |
| input_label=None, | |
| input_box=None, | |
| input_mask=None, | |
| fname="", | |
| ): | |
| for i, (mask, score) in enumerate(zip(masks, scores)): | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(image) | |
| show_mask(mask, plt.gca()) | |
| if input_point is not None: | |
| show_points(input_point, input_label, plt.gca()) | |
| if input_box is not None: | |
| show_box(input_box, plt.gca()) | |
| if input_mask is not None: | |
| show_mask(input_mask[0], plt.gca(), True) | |
| plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) | |
| plt.axis("off") | |
| plt.savefig(os.path.join(save_path, f"{fname}{i}.jpg")) | |
| return input_point, input_label, input_box, input_mask | |