| | import os |
| | from pycocotools import mask as mask_util |
| | import json |
| | import numpy as np |
| | import cv2 |
| | from distinctipy import distinctipy |
| | import matplotlib.pyplot as plt |
| | from PIL import Image |
| | from types import MethodType |
| | import json |
| | import random |
| |
|
| | import torch |
| | import torchvision |
| | from detectron2.data import MetadataCatalog |
| | from detectron2.structures import BitMasks, PolygonMasks |
| | from detectron2.utils.visualizer import ColorMode, Visualizer |
| | from detectron2.data.detection_utils import read_image |
| |
|
| | from third_parts.APE.build_ape import build_ape_predictor |
| | from third_parts.recognize_anything.build_ram_plus import build_ram_predictor |
| | from third_parts.segment_anything import build_sam_vit_h, SamPredictor, SamAutomaticMaskGenerator |
| |
|
| | def show_mask(mask, ax, random_color=False): |
| | if random_color: |
| | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
| | else: |
| | color = np.array([30/255, 144/255, 255/255, 0.6]) |
| | h, w = mask.shape[-2:] |
| | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| | ax.imshow(mask_image) |
| | |
| |
|
| |
|
| | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| | best_ratio_diff = float('inf') |
| | best_ratio = (1, 1) |
| | area = width * height |
| | for ratio in target_ratios: |
| | target_aspect_ratio = ratio[0] / ratio[1] |
| | ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
| | if ratio_diff < best_ratio_diff: |
| | best_ratio_diff = ratio_diff |
| | best_ratio = ratio |
| | elif ratio_diff == best_ratio_diff: |
| | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
| | best_ratio = ratio |
| | return best_ratio |
| |
|
| | def sample_points(box, mask, min_points=3, max_points=16): |
| | x0, y0, w, h = box |
| | aspect_ratio = w / h |
| |
|
| | |
| | target_ratios = set( |
| | (i, j) for n in range(min_points, max_points + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
| | i * j <= max_points and i * j >= min_points) |
| | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
| |
|
| | |
| | target_aspect_ratio = find_closest_aspect_ratio( |
| | aspect_ratio, target_ratios, w, h, 50) |
| | width_bin = w / target_aspect_ratio[0] |
| | height_bin = h / target_aspect_ratio[1] |
| |
|
| | ret_points = [] |
| | for wi in range(target_aspect_ratio[0]): |
| | xi = x0 + (wi+0.5) * width_bin |
| | for hi in range(target_aspect_ratio[1]): |
| | yi = y0 + (hi+0.5) * height_bin |
| | if mask[int(yi), int(xi)] > 0: |
| | ret_points.append((xi, yi)) |
| | |
| | |
| | temp_points = [] |
| | for wi in range(int(x0), int(x0+w)): |
| | for hi in range(int(y0), int(y0+h)): |
| | if mask[int(hi), int(wi)] > 0: |
| | temp_points.append((wi, hi)) |
| | if len(temp_points)//max_points < 1: |
| | uniform_indices = list(range(0, len(temp_points))) |
| | else: |
| | uniform_indices = list(range(0, len(temp_points), len(temp_points)//max_points)) |
| | additional_points = [temp_points[uniform_idx] for uniform_idx in uniform_indices[1:-1]] |
| | |
| | ret_points = ret_points + additional_points |
| | return ret_points |
| |
|
| | def mask_iou(masks, chunk_size=50, chunk_mode=False): |
| | masks1 = masks.unsqueeze(1).char() |
| | masks2 = masks.unsqueeze(0).char() |
| |
|
| | if not chunk_mode: |
| | intersection = (masks1 * masks2) |
| | union = (masks1 + masks2 - intersection).sum(-1).sum(-1) |
| | intersection = intersection.sum(-1).sum(-1) |
| | return intersection, union |
| | |
| | def chunk_mask_iou(_chunk_size=50): |
| |
|
| | num_chunks = masks1.shape[0] // _chunk_size |
| | if masks1.shape[0] % _chunk_size > 0: |
| | num_chunks += 1 |
| | |
| | row_chunks_intersection, row_chunks_union = [], [] |
| | for row_idx in range(num_chunks): |
| | col_chunks_intersection, col_chunks_union = [], [] |
| | masks1_chunk = masks1[row_idx*_chunk_size:(row_idx+1)*_chunk_size] |
| | for col_idx in range(num_chunks): |
| | masks2_chunk = masks2[:, col_idx*_chunk_size:(col_idx+1)*_chunk_size] |
| | try: |
| | intersection = masks1_chunk * masks2_chunk |
| | temp_sum = masks1_chunk + masks2_chunk |
| | union = (temp_sum - intersection).sum(-1).sum(-1) |
| | intersection = intersection.sum(-1).sum(-1) |
| | except torch.cuda.OutOfMemoryError: |
| | return False, None, None |
| | col_chunks_intersection.append(intersection) |
| | col_chunks_union.append(union) |
| | row_chunks_intersection.append(torch.cat(col_chunks_intersection, dim=1)) |
| | row_chunks_union.append(torch.cat(col_chunks_union, dim=1)) |
| | intersection = torch.cat(row_chunks_intersection, dim=0) |
| | union = torch.cat(row_chunks_union, dim=0) |
| | return True, intersection, union |
| | |
| | for c_size in [chunk_size, chunk_size//2, chunk_size//4]: |
| | is_ok, intersection, union = chunk_mask_iou(c_size) |
| | if not is_ok: |
| | continue |
| | return intersection, union |
| |
|
| | def mask_iou_v2(masks1, masks2, chunk_size=50, chunk_mode=False): |
| | masks1 = masks1.unsqueeze(1).char() |
| | masks2 = masks2.unsqueeze(0).char() |
| |
|
| | if not chunk_mode: |
| | intersection = (masks1 * masks2) |
| | union = (masks1 + masks2 - intersection).sum(-1).sum(-1) |
| | intersection = intersection.sum(-1).sum(-1) |
| |
|
| | return intersection, union |
| | |
| | def chunk_mask_iou(_chunk_size=50): |
| | num_chunks1 = masks1.shape[0] // _chunk_size |
| | if masks1.shape[0] % _chunk_size > 0: |
| | num_chunks1 += 1 |
| | |
| | num_chunks2 = masks2.shape[1] // _chunk_size |
| | if masks2.shape[0] % _chunk_size > 0: |
| | num_chunks2 += 1 |
| |
|
| | row_chunks_intersection, row_chunks_union = [], [] |
| | for row_idx in range(num_chunks1): |
| | col_chunks_intersection, col_chunks_union = [], [] |
| | masks1_chunk = masks1[row_idx*_chunk_size:(row_idx+1)*_chunk_size] |
| | for col_idx in range(num_chunks2): |
| | masks2_chunk = masks2[:, col_idx*_chunk_size:(col_idx+1)*_chunk_size] |
| | try: |
| | intersection = masks1_chunk * masks2_chunk |
| | temp_sum = masks1_chunk + masks2_chunk |
| | union = (temp_sum - intersection).sum(-1).sum(-1) |
| | intersection = intersection.sum(-1).sum(-1) |
| | except torch.cuda.OutOfMemoryError: |
| | return False, None, None |
| | col_chunks_intersection.append(intersection) |
| | col_chunks_union.append(union) |
| | row_chunks_intersection.append(torch.cat(col_chunks_intersection, dim=1)) |
| | row_chunks_union.append(torch.cat(col_chunks_union, dim=1)) |
| | intersection = torch.cat(row_chunks_intersection, dim=0) |
| | union = torch.cat(row_chunks_union, dim=0) |
| | return True, intersection, union |
| | |
| | for c_size in [chunk_size, chunk_size//2, chunk_size//4]: |
| | is_ok, intersection, union = chunk_mask_iou(c_size) |
| | if not is_ok: |
| | continue |
| | return intersection, union |
| |
|
| | return intersection, union |
| |
|
| |
|
| | def mask_area(masks, chunk_size=50, chunk_mode=False): |
| | if not chunk_mode: |
| | return masks.sum(-1).sum(-1) |
| | |
| | num_chunks = masks.shape[0] // chunk_size |
| | if masks.shape[0] % chunk_size > 0: |
| | num_chunks += 1 |
| |
|
| | areas = [] |
| | for i in range(num_chunks): |
| | masks_i = masks[i*chunk_size:(i+1)*chunk_size] |
| | areas.append(masks_i.sum(-1).sum(-1)) |
| | return torch.cat(areas, dim=0) |
| | |
| |
|
| |
|
| | from detectron2.utils.visualizer import GenericMask |
| | import matplotlib.colors as mplc |
| | def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True): |
| | """ |
| | Draw instance-level prediction results on an image. |
| | |
| | Args: |
| | predictions (Instances): the output of an instance detection/segmentation |
| | model. Following fields will be used to draw: |
| | "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). |
| | jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class |
| | to distinguish instances from the same class |
| | |
| | Returns: |
| | output (VisImage): image object with visualizations. |
| | """ |
| | boxes = None |
| | scores = None |
| | classes = None |
| | keypoints = None |
| |
|
| | masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks] |
| |
|
| |
|
| | if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): |
| | colors = ( |
| | [self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes] |
| | if jittering |
| | else [ |
| | tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]])) |
| | for c in classes |
| | ] |
| | ) |
| |
|
| | alpha = 0.8 |
| | else: |
| | colors = None |
| | alpha = 0.5 |
| |
|
| | self.overlay_instances( |
| | masks=masks, |
| | boxes=boxes, |
| | labels=labels, |
| | keypoints=keypoints, |
| | assigned_colors=colors, |
| | alpha=alpha, |
| | ) |
| | return self.output |
| |
|
| |
|
| | def merge_sa1b_image(image_file, anno_file, save_path, generated_annos, visualize=False): |
| | file_name = os.path.basename(image_file).split('.')[0] |
| | if file_name != 'sa_11224': |
| | return None |
| |
|
| |
|
| | if anno_file is not None: |
| | with open(anno_file, 'r') as f: |
| | json_results = json.load(f) |
| | generated_annos = json_results["annotations"] |
| | assert generated_annos is not None, "Provide valid annotation file or generated_annos from sam automatic generator." |
| |
|
| | _all_sam_masks, predicted_iou_scores = [], [] |
| | for object_anno in generated_annos: |
| | object_mask = object_anno["segmentation"] |
| | if isinstance(object_mask["counts"], list): |
| | object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) |
| | mask = mask_util.decode(object_mask) |
| | mask = mask.astype(np.uint8).squeeze() |
| | _all_sam_masks.append(torch.from_numpy(mask)) |
| | predicted_iou_scores.append(object_anno['predicted_iou']) |
| |
|
| | |
| | sorted_idx = sorted(range(len(predicted_iou_scores)), key=lambda k: predicted_iou_scores[k], reverse=True) |
| | all_sam_masks = [] |
| | for idx in sorted_idx: |
| | all_sam_masks.append(_all_sam_masks[idx]) |
| |
|
| | all_sam_masks = torch.stack(all_sam_masks) |
| | ori_height, ori_width = all_sam_masks.shape[-2:] |
| | downsampled_sam_masks = torch.nn.functional.interpolate(all_sam_masks[None].to(torch.float32), size=(ori_height//2, ori_width//2), mode="bilinear") |
| | downsampled_sam_masks = (downsampled_sam_masks[0] > 0.5).to(all_sam_masks.dtype).to("cuda") |
| | |
| | intersection, union = mask_iou(downsampled_sam_masks, chunk_size=50, chunk_mode=True) |
| |
|
| | mask_iou_matrix = intersection / union |
| |
|
| | |
| | num_instances = len(mask_iou_matrix) |
| | keep = [True] * num_instances |
| | for ins_i in range(num_instances): |
| | if not keep[ins_i]: |
| | continue |
| | for ins_j in range(ins_i, num_instances): |
| | if ins_j == ins_i: |
| | continue |
| | if mask_iou_matrix[ins_i, ins_j] > 0.8: |
| | keep[ins_j] = False |
| |
|
| | |
| | |
| | area = mask_area(downsampled_sam_masks, chunk_mode=True) |
| | roc = intersection / area[:, None] |
| | for ins_i in range(num_instances): |
| | if not keep[ins_i]: |
| | continue |
| | for ins_j in range(num_instances): |
| | if ins_i == ins_j: |
| | continue |
| | if not keep[ins_j]: |
| | continue |
| | if roc[ins_i, ins_j] > 0.8: |
| | keep[ins_i] = False |
| | break |
| | |
| | left_masks = [all_sam_masks[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] |
| | left_tags = ['object' for _ in range(len(left_masks))] |
| |
|
| | unique_tags = list(set(left_tags)) |
| | text_prompt = ','.join(unique_tags) |
| | metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) |
| | metadata.thing_classes = unique_tags |
| | metadata.stuff_classes = unique_tags |
| | |
| | if not visualize: |
| | return torch.stack(left_masks) |
| | |
| | result_masks = torch.stack(left_masks).cpu().numpy() |
| | |
| | input_image = read_image(image_file, format="BGR") |
| | visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE) |
| | visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer) |
| | vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks) |
| | output_image = vis_output.get_image() |
| | output_image = Image.fromarray(output_image) |
| |
|
| | final_out_path = os.path.join(save_path, 'sam_merge_out') |
| | if not os.path.exists(final_out_path): |
| | os.makedirs(final_out_path) |
| | output_image.save(os.path.join(final_out_path, file_name+'.jpg')) |
| |
|
| | save_json_results = [] |
| | for tag, mask in zip(left_tags, result_masks): |
| | rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0] |
| | rle["counts"] = rle["counts"].decode("utf-8") |
| | save_json_results.append({ |
| | "tag": tag, |
| | "segmentation": rle, |
| | }) |
| |
|
| | json_path = os.path.join(save_path, 'sam_merge_json_out') |
| | if not os.path.exists(json_path): |
| | os.makedirs(json_path) |
| | with open(os.path.join(json_path, file_name+'.json'), 'w') as f: |
| | json.dump(save_json_results, f) |
| |
|
| |
|
| |
|
| |
|
| | def run_on_image_v2(image_file, anno_file, save_path, ram_predictor, ape_predictor, sam_predictor, sam_auto_mask_generator): |
| | if not os.path.exists(image_file): |
| | return None |
| | file_name = os.path.basename(image_file).split('.')[0] |
| | if (anno_file is None) or (not os.path.exists(anno_file)): |
| | image = cv2.imread(image_file) |
| | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| | generated_annos = sam_auto_mask_generator.generate(image) |
| |
|
| | sam_masks = merge_sa1b_image(image_file, None, save_path, generated_annos, visualize=False) |
| | else: |
| | sam_masks = merge_sa1b_image(image_file, anno_file, save_path, None, visualize=False) |
| |
|
| | ape_masks, ape_tags = run_on_image(image_file, save_path, ram_predictor, ape_predictor, sam_predictor, visualize=False) |
| |
|
| | ori_height, ori_width = sam_masks.shape[-2:] |
| | downsampled_sam_masks = torch.nn.functional.interpolate(sam_masks[None].to(torch.float32), size=(ori_height//2, ori_width//2), mode="bilinear") |
| | downsampled_sam_masks = (downsampled_sam_masks[0] > 0.5).to(sam_masks.dtype).to("cuda") |
| |
|
| | downsampled_ape_masks = torch.nn.functional.interpolate(ape_masks[None].to(torch.float32), size=(ori_height//2, ori_width//2), mode="bilinear") |
| | downsampled_ape_masks = (downsampled_ape_masks[0] > 0.5).to(ape_masks.dtype).to("cuda") |
| | |
| | sam_ape_masks_intersection, sam_ape_masks_union = mask_iou_v2(downsampled_sam_masks, downsampled_ape_masks, chunk_size=50, chunk_mode=True) |
| | |
| | sam_ape_masks_iou = sam_ape_masks_intersection / sam_ape_masks_union |
| | |
| | sam_area = mask_area(downsampled_sam_masks, chunk_mode=True) |
| | sam_masks_roc = sam_ape_masks_intersection / sam_area[:, None] |
| |
|
| | sam_boxes = torchvision.ops.masks_to_boxes(sam_masks) |
| | ape_boxes = torchvision.ops.masks_to_boxes(ape_masks) |
| |
|
| | first_round_masks = [] |
| | iou_target_indices = torch.argmax(sam_ape_masks_iou, dim=1) |
| | roc_target_indices = torch.argmax(sam_masks_roc, dim=1) |
| | for sam_idx in range(downsampled_sam_masks.shape[0]): |
| | iou_tgt_idx = iou_target_indices[sam_idx] |
| | roc_tgt_idx = roc_target_indices[sam_idx] |
| | |
| | if sam_ape_masks_iou[sam_idx, iou_tgt_idx] > 0.8: |
| | first_round_masks.append(sam_masks[sam_idx]) |
| | elif sam_masks_roc[sam_idx, roc_tgt_idx] > 0.8: |
| | |
| | box_x1, box_y1, box_x2, box_y2 = sam_boxes[sam_idx] |
| | box_w = box_x2 - box_x1 |
| | box_h = box_y2 - box_y1 |
| | ret_points = sample_points([box_x1, box_y1, box_w, box_h], sam_masks[sam_idx], min_points=1, max_points=3) |
| | |
| | if len(ret_points) == 0 : |
| | first_round_masks.append(sam_masks[sam_idx]) |
| | else: |
| | point_labels = [1 for _ in range(len(ret_points))] |
| | temp_masks, scores, _ = sam_predictor.predict( |
| | point_coords=np.array(ret_points), |
| | point_labels=np.array(point_labels), |
| | multimask_output=True, |
| | ) |
| |
|
| | temp_masks = torch.from_numpy(temp_masks) |
| | downsampled_temp_masks = torch.nn.functional.interpolate(temp_masks[None].to(torch.float32), size=(ori_height//2, ori_width//2), mode="bilinear") |
| | downsampled_temp_masks = (downsampled_temp_masks[0] > 0.5).to(temp_masks.dtype).to("cuda") |
| | downsampled_ape_mask = downsampled_ape_masks[roc_tgt_idx][None] |
| | ape_temp_masks_intersection, ape_temp_masks_union = mask_iou_v2(downsampled_ape_mask, downsampled_temp_masks) |
| | ape_temp_masks_iou = ape_temp_masks_intersection / ape_temp_masks_union |
| | iou_temp_indices = torch.argmax(ape_temp_masks_iou, dim=1) |
| | iou_temp_idx = iou_temp_indices[0] |
| | if ape_temp_masks_iou[0, iou_temp_idx] > 0.8 and scores[iou_temp_idx] > 0.9: |
| | first_round_masks.append(temp_masks[iou_temp_idx]) |
| | else: |
| | first_round_masks.append(sam_masks[sam_idx]) |
| | else: |
| | |
| | box_x1, box_y1, box_x2, box_y2 = sam_boxes[sam_idx] |
| | box_w = box_x2 - box_x1 |
| | box_h = box_y2 - box_y1 |
| | ret_points = sample_points([box_x1, box_y1, box_w, box_h], sam_masks[sam_idx], min_points=1, max_points=3) |
| | |
| | if len(ret_points) == 0: |
| | first_round_masks.append(sam_masks[sam_idx]) |
| | else: |
| | point_labels = [1 for _ in range(len(ret_points))] |
| | temp_masks, scores, _ = sam_predictor.predict( |
| | point_coords=np.array(ret_points), |
| | point_labels=np.array(point_labels), |
| | multimask_output=True, |
| | ) |
| | |
| | temp_masks = torch.from_numpy(temp_masks) |
| | temp_masks_area = temp_masks.sum(-1).sum(-1) |
| | tgt_idx = torch.argmax(temp_masks_area) |
| | if scores[tgt_idx] > 0.9: |
| | first_round_masks.append(temp_masks[tgt_idx]) |
| | else: |
| | first_round_masks.append(sam_masks[sam_idx]) |
| |
|
| |
|
| | ape_sam_masks_intersection, ape_sam_masks_union = sam_ape_masks_intersection.transpose(0, 1), sam_ape_masks_union.transpose(0, 1) |
| | |
| | ape_area = mask_area(downsampled_ape_masks, chunk_mode=True) |
| | ape_masks_roc = ape_sam_masks_intersection / ape_area[:, None] |
| | roc_target_indices = torch.argmax(ape_masks_roc, dim=1) |
| | for ape_idx in range(ape_masks.shape[0]): |
| | roc_tgt_idx = roc_target_indices[ape_idx] |
| | if ape_masks_roc[ape_idx, roc_tgt_idx] < 0.2: |
| | if sam_masks_roc[roc_tgt_idx, ape_idx] < 0.2: |
| | box_x1, box_y1, box_x2, box_y2 = ape_boxes[ape_idx] |
| | box_w = box_x2 - box_x1 |
| | box_h = box_y2 - box_y1 |
| | ret_points = sample_points([box_x1, box_y1, box_w, box_h], ape_masks[ape_idx], min_points=3, max_points=16) |
| |
|
| | if len(ret_points) == 0: |
| | first_round_masks.append(ape_masks[ape_idx]) |
| | else: |
| | point_labels = [1 for _ in range(len(ret_points))] |
| | temp_masks, scores, _ = sam_predictor.predict( |
| | point_coords=np.array(ret_points), |
| | point_labels=np.array(point_labels), |
| | multimask_output=False, |
| | ) |
| | temp_masks = torch.from_numpy(temp_masks) |
| | if scores[0] > 0.9: |
| | first_round_masks.append(temp_masks[0]) |
| | else: |
| | first_round_masks.append(ape_masks[ape_idx]) |
| | else: |
| | |
| | box_x1, box_y1, box_x2, box_y2 = ape_boxes[ape_idx] |
| | box_w = box_x2 - box_x1 |
| | box_h = box_y2 - box_y1 |
| | ret_points = sample_points([box_x1, box_y1, box_w, box_h], ape_masks[ape_idx], min_points=3, max_points=8) |
| | for point in ret_points: |
| | temp_masks, scores, _ = sam_predictor.predict( |
| | point_coords=np.array([point]), |
| | point_labels=np.array([1]), |
| | multimask_output=True, |
| | ) |
| | temp_masks = torch.from_numpy(temp_masks) |
| | downsampled_temp_masks = torch.nn.functional.interpolate(temp_masks[None].to(torch.float32), size=(ori_height//2, ori_width//2), mode="bilinear") |
| | downsampled_temp_masks = (downsampled_temp_masks[0] > 0.5).to(temp_masks.dtype).to("cuda") |
| | downsampled_ape_mask = downsampled_ape_masks[ape_idx][None] |
| | ape_temp_masks_intersection, ape_temp_masks_union = mask_iou_v2(downsampled_ape_mask, downsampled_temp_masks) |
| | ape_temp_masks_iou = ape_temp_masks_intersection / ape_temp_masks_union |
| | iou_temp_indices = torch.argmax(ape_temp_masks_iou, dim=1) |
| | iou_temp_idx = iou_temp_indices[0] |
| | if ape_temp_masks_iou[0, iou_temp_idx] > 0.8: |
| | first_round_masks.append(ape_masks[ape_idx]) |
| |
|
| | |
| | |
| | first_round_scores = mask_area(torch.stack(first_round_masks), chunk_mode=True) |
| |
|
| | sorted_idx = sorted(range(len(first_round_masks)), key=lambda k: first_round_scores[k], reverse=True) |
| | sorted_first_round_masks = [] |
| | for idx in sorted_idx: |
| | sorted_first_round_masks.append(first_round_masks[idx]) |
| |
|
| | sorted_first_round_masks = torch.stack(sorted_first_round_masks) |
| | downsampled_first_round_masks = torch.nn.functional.interpolate(sorted_first_round_masks[None].to(torch.float32), size=(ori_height//2, ori_width//2), mode="bilinear") |
| | downsampled_first_round_masks = (downsampled_first_round_masks[0] > 0.5).to(sorted_first_round_masks.dtype) |
| |
|
| | intersection, union = mask_iou(downsampled_first_round_masks, chunk_mode=True) |
| | mask_iou_matrix = intersection / union |
| |
|
| | |
| | num_instances = len(mask_iou_matrix) |
| | keep = [True] * num_instances |
| | for ins_i in range(num_instances): |
| | if not keep[ins_i]: |
| | continue |
| | for ins_j in range(ins_i, num_instances): |
| | if ins_j == ins_i: |
| | continue |
| | if mask_iou_matrix[ins_i, ins_j] > 0.8: |
| | keep[ins_j] = False |
| |
|
| | |
| | |
| | area = mask_area(downsampled_first_round_masks, chunk_mode=True) |
| | roc = intersection / area[:, None] |
| | for ins_i in range(num_instances): |
| | if not keep[ins_i]: |
| | continue |
| | for ins_j in range(num_instances): |
| | if ins_i == ins_j: |
| | continue |
| | if not keep[ins_j]: |
| | continue |
| | if roc[ins_i, ins_j] > 0.5: |
| | keep[ins_i] = False |
| | break |
| | |
| | left_masks = [sorted_first_round_masks[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] |
| | left_tags = ['object' for _ in range(len(left_masks))] |
| |
|
| | unique_tags = list(set(left_tags)) |
| | text_prompt = ','.join(unique_tags) |
| | metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) |
| | metadata.thing_classes = unique_tags |
| | metadata.stuff_classes = unique_tags |
| |
|
| | result_masks = torch.stack(left_masks).cpu().numpy() |
| | |
| | input_image = read_image(image_file, format="BGR") |
| | visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE) |
| | visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer) |
| | vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks) |
| | output_image = vis_output.get_image() |
| | output_image = Image.fromarray(output_image) |
| |
|
| | final_out_path = os.path.join(save_path, 'fine_out_0118') |
| | if not os.path.exists(final_out_path): |
| | os.makedirs(final_out_path) |
| | output_image.save(os.path.join(final_out_path, file_name+'.jpg')) |
| |
|
| | save_json_results = [] |
| | for tag, mask in zip(left_tags, result_masks): |
| | rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0] |
| | rle["counts"] = rle["counts"].decode("utf-8") |
| | save_json_results.append({ |
| | "tag": tag, |
| | "segmentation": rle, |
| | }) |
| |
|
| | json_path = os.path.join(save_path, 'fine_json_out_0118') |
| | if not os.path.exists(json_path): |
| | os.makedirs(json_path) |
| | with open(os.path.join(json_path, file_name+'.json'), 'w') as f: |
| | json.dump(save_json_results, f) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def run_on_image(image_file, save_path, ram_predictor, ape_predictor, sam_predictor, visualize=False): |
| | res = ram_predictor.run_on_image(image_file_path=image_file, dynamic_resolution=True) |
| | tag_list = [] |
| | for tag_string in res[0]: |
| | tags = tag_string.split(' | ') |
| | tag_list += tags |
| | tags = list(set(tag_list)) |
| | text_prompt = ','.join(tags) |
| | |
| | output_image, json_results = ape_predictor.run_on_image( |
| | image_file, |
| | input_text=text_prompt, |
| | visualize=True, |
| | score_threhold=0.1, |
| | output_type=["instance segmentation"], |
| | ) |
| | |
| | if visualize: |
| | file_name = os.path.basename(image_file).split('.')[0] |
| | raw_ape_out_path = os.path.join(save_path, 'raw_ape_out_0116') |
| | if not os.path.exists(raw_ape_out_path): |
| | os.makedirs(raw_ape_out_path) |
| | output_image.save(os.path.join(raw_ape_out_path, file_name+'.jpg')) |
| |
|
| | |
| | |
| | sam_image = cv2.imread(image_file) |
| | ori_height, ori_width = sam_image.shape[:2] |
| | sam_image = cv2.cvtColor(sam_image, cv2.COLOR_BGR2RGB) |
| | sam_predictor.set_image(sam_image) |
| |
|
| | new_masks_from_sam = [] |
| | correspondding_tags = [] |
| | correspondding_scores = [] |
| | for idx, item in enumerate(json_results): |
| | object_mask = item["segmentation"] |
| | if isinstance(object_mask["counts"], list): |
| | object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) |
| | mask = mask_util.decode(object_mask) |
| | mask = mask.astype(np.uint8).squeeze() |
| | |
| | box = item["bbox"] |
| |
|
| | ret_points = sample_points(box, mask) |
| |
|
| | if len(ret_points) == 0: |
| | continue |
| | |
| | mask_h, mask_w = object_mask["size"] |
| | input_point, input_label = [], [] |
| | for point in ret_points: |
| | _x = point[0] / mask_w * ori_width |
| | _y = point[1] / mask_h * ori_height |
| | input_point.append([int(_x), int(_y)]) |
| | input_label.append(1) |
| | |
| | masks, scores, logits = sam_predictor.predict( |
| | point_coords=np.array(input_point), |
| | point_labels=np.array(input_label), |
| | multimask_output=False |
| | ) |
| |
|
| | new_masks_from_sam.append(torch.from_numpy(masks)) |
| | correspondding_tags.append(item["category_name"]) |
| | correspondding_scores.append(item["score"]) |
| | |
| | new_masks_from_sam = torch.cat(new_masks_from_sam) |
| | downsampled_new_masks_from_sam = torch.nn.functional.interpolate(new_masks_from_sam[None].to(torch.float32), size=(ori_height//2, ori_width//2), mode="bilinear") |
| | downsampled_new_masks_from_sam = (downsampled_new_masks_from_sam[0] > 0.5).to(new_masks_from_sam.dtype).to("cuda") |
| |
|
| | intersection, union = mask_iou(downsampled_new_masks_from_sam, chunk_mode=True) |
| | mask_iou_matrix = intersection / union |
| |
|
| | |
| | num_instances = len(mask_iou_matrix) |
| | keep = [True] * num_instances |
| | for ins_i in range(num_instances): |
| | if not keep[ins_i]: |
| | continue |
| | for ins_j in range(ins_i, num_instances): |
| | if ins_j == ins_i: |
| | continue |
| | if mask_iou_matrix[ins_i, ins_j] > 0.8: |
| | keep[ins_j] = False |
| |
|
| |
|
| | |
| | |
| | area = mask_area(downsampled_new_masks_from_sam, chunk_mode=True) |
| | roc = intersection / area[:, None] |
| | for ins_i in range(num_instances): |
| | if not keep[ins_i]: |
| | continue |
| | for ins_j in range(num_instances): |
| | if ins_i == ins_j: |
| | continue |
| | if not keep[ins_j]: |
| | continue |
| | if roc[ins_i, ins_j] > 0.8: |
| | keep[ins_i] = False |
| | break |
| | |
| | left_masks = [new_masks_from_sam[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] |
| | left_masks = torch.stack(left_masks) |
| | left_boxes = torchvision.ops.masks_to_boxes(left_masks) |
| | left_tags = [correspondding_tags[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] |
| | |
| | |
| | result_mask_list = [] |
| | result_tag_list = [] |
| | ori_image = Image.open(image_file) |
| | for ins_i, ins_box in enumerate(left_boxes): |
| | ins_box = ins_box.numpy().tolist() |
| | box_w = ins_box[2] - ins_box[0] |
| | box_h = ins_box[3] - ins_box[1] |
| | loose_box_x0 = int(ins_box[0] - box_w // 4) |
| | loose_box_y0 = int(ins_box[1] - box_h // 4) |
| | loose_box_x1 = int(ins_box[2] + box_w // 4) |
| | loose_box_y1 = int(ins_box[3] + box_h // 4) |
| | loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 |
| | loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 |
| | loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width |
| | loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height |
| | |
| | loose_box_w = loose_box_x1 - loose_box_x0 |
| | loose_box_h = loose_box_y1 - loose_box_y0 |
| | assert loose_box_w >= box_w and loose_box_h >= box_h |
| |
|
| | if loose_box_w < 256: |
| | padded_length_w = 256 - loose_box_w |
| | left_padded = padded_length_w // 2 |
| | right_padded = padded_length_w - left_padded |
| | if loose_box_x0 - left_padded < 0: |
| | right_padded = right_padded + left_padded - loose_box_x0 |
| | left_padded = loose_box_x0 |
| | if loose_box_x1 + right_padded > ori_width: |
| | left_padded = left_padded + loose_box_x1 + right_padded - ori_width |
| | right_padded = ori_width - loose_box_x1 |
| | loose_box_x0 = int(loose_box_x0 - left_padded) |
| | loose_box_x1 = int(loose_box_x1 + right_padded) |
| | loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 |
| | loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width |
| | if loose_box_h < 256: |
| | padded_length_h = 256 - loose_box_h |
| | top_padded = padded_length_h // 2 |
| | bottom_padded = padded_length_h - top_padded |
| | if loose_box_y0 - top_padded < 0: |
| | bottom_padded = bottom_padded + top_padded - loose_box_y0 |
| | top_padded = loose_box_y0 |
| | if loose_box_y1 + bottom_padded > ori_height: |
| | top_padded = top_padded + loose_box_y1 + bottom_padded - ori_height |
| | bottom_padded = ori_height - loose_box_y1 |
| | loose_box_y0 = int(loose_box_y0 - top_padded) |
| | loose_box_y1 = int(loose_box_y1 + bottom_padded) |
| | loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 |
| | loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height |
| | |
| | loose_box_w = loose_box_x1 - loose_box_x0 |
| | loose_box_h = loose_box_y1 - loose_box_y0 |
| | if loose_box_w > loose_box_h: |
| | padded_length_h = loose_box_w - loose_box_h |
| | top_padded = padded_length_h // 2 |
| | bottom_padded = padded_length_h - top_padded |
| | if loose_box_y0 - top_padded < 0: |
| | bottom_padded = bottom_padded + top_padded - loose_box_y0 |
| | top_padded = loose_box_y0 |
| | if loose_box_y1 + bottom_padded > ori_height: |
| | top_padded = top_padded + loose_box_y1 + bottom_padded - ori_height |
| | bottom_padded = ori_height - loose_box_y1 |
| | loose_box_y0 = int(loose_box_y0 - top_padded) |
| | loose_box_y1 = int(loose_box_y1 + bottom_padded) |
| | loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 |
| | loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height |
| | elif loose_box_h > loose_box_w: |
| | padded_length_w = loose_box_h - loose_box_w |
| | left_padded = padded_length_w // 2 |
| | right_padded = padded_length_w - left_padded |
| | if loose_box_x0 - left_padded < 0: |
| | right_padded = right_padded + left_padded - loose_box_x0 |
| | left_padded = loose_box_x0 |
| | if loose_box_x1 + right_padded > ori_width: |
| | left_padded = left_padded + loose_box_x1 + right_padded - ori_width |
| | right_padded = ori_width - loose_box_x1 |
| | loose_box_x0 = int(loose_box_x0 - left_padded) |
| | loose_box_x1 = int(loose_box_x1 + right_padded) |
| | loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 |
| | loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width |
| |
|
| | image_patch = ori_image.crop((loose_box_x0, loose_box_y0, loose_box_x1, loose_box_y1)) |
| | image_patch_w, image_patch_h = image_patch.size |
| |
|
| | res = ram_predictor.run_on_image(image_file_path=image_patch, dynamic_resolution=False) |
| | tag_list = [] |
| | for tag_string in res[0]: |
| | tags = tag_string.split(' | ') |
| | tag_list += tags |
| | tags = list(set(tag_list)) |
| | text_prompt = ','.join(tags) |
| | |
| | if image_patch_w > image_patch_h: |
| | rescaled_image_patch_w = 1024 |
| | rescaled_image_patch_h = int(image_patch_h / image_patch_w * 1024) |
| | else: |
| | rescaled_image_patch_h = 1024 |
| | rescaled_image_patch_w = int(image_patch_w / image_patch_h * 1024) |
| |
|
| | image_patch = image_patch.resize((rescaled_image_patch_w, rescaled_image_patch_h)) |
| | output_image, json_results = ape_predictor.run_on_image( |
| | image_patch, |
| | input_text=text_prompt, |
| | visualize=True, |
| | score_threhold=0.1, |
| | output_type=["instance segmentation"], |
| | ) |
| | |
| | all_masks, all_tags = [], [] |
| | for idx, item in enumerate(json_results): |
| | object_mask = item["segmentation"] |
| | if isinstance(object_mask["counts"], list): |
| | object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) |
| | mask = mask_util.decode(object_mask) |
| | mask = torch.as_tensor(mask.astype(np.uint8)) |
| | all_masks.append(mask) |
| | all_tags.append(item['category_name']) |
| | |
| | |
| | if len(all_masks) == 0: |
| | result_mask_list.append(left_masks[ins_i]) |
| | result_tag_list.append(left_tags[ins_i]) |
| | continue |
| | |
| | all_masks = torch.stack(all_masks) |
| |
|
| | all_masks_ori_size = torch.nn.functional.interpolate(all_masks.unsqueeze(0), size=(image_patch_h, image_patch_w), |
| | mode='bilinear') |
| | all_masks_ori_size = all_masks_ori_size > 0.4 |
| | |
| | ori_mask_crop = left_masks[ins_i, loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] |
| |
|
| | |
| |
|
| | |
| | masks1 = ori_mask_crop[None, None, :, :].char().to('cuda') |
| | masks2 = all_masks_ori_size.char().to('cuda') |
| | intersection = (masks1 * masks2) |
| | union = (masks1 + masks2 - intersection).sum(-1).sum(-1) |
| | intersection = intersection.sum(-1).sum(-1) |
| | area = masks2.sum(-1).sum(-1) |
| | |
| | masks_iou = intersection / union |
| | target_idx = torch.argmax(masks_iou, dim=1) |
| |
|
| | if masks_iou[0, target_idx] < 0.8: |
| | temp_result_mask_list = [] |
| | temp_result_tag_list = [] |
| | for ins_j, mask_j_iou in enumerate(masks_iou[0]): |
| | if mask_j_iou < 0.1: |
| | continue |
| | roc_j = intersection[0, ins_j] / area[0, ins_j] |
| | if roc_j < 0.8: |
| | continue |
| | result_mask = torch.zeros((ori_height, ori_width)).to(all_masks.dtype) |
| | result_mask[loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] = all_masks_ori_size[0, ins_j] |
| | temp_result_mask_list.append(result_mask) |
| | temp_result_tag_list.append(all_tags[ins_j]) |
| | if len(temp_result_mask_list) > 1: |
| | result_mask_list.extend(temp_result_mask_list) |
| | result_tag_list.extend(temp_result_tag_list) |
| | else: |
| | result_mask_list.append(left_masks[ins_i]) |
| | result_tag_list.append(left_tags[ins_i]) |
| | else: |
| | result_mask = torch.zeros((ori_height, ori_width)).to(all_masks.dtype) |
| | result_mask[loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] = all_masks_ori_size[0, target_idx.item()] |
| | result_mask_list.append(result_mask) |
| | result_tag_list.append(all_tags[target_idx]) |
| | |
| | unique_tags = list(set(result_tag_list)) |
| | text_prompt = ','.join(unique_tags) |
| | metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) |
| | metadata.thing_classes = unique_tags |
| | metadata.stuff_classes = unique_tags |
| |
|
| | if not visualize: |
| | return torch.stack(result_mask_list), result_tag_list |
| |
|
| | result_masks = torch.stack(result_mask_list).cpu().numpy() |
| | |
| | input_image = read_image(image_file, format="BGR") |
| | visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE) |
| | visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer) |
| | vis_output = visualizer.draw_instance_predictions(labels=result_tag_list, np_masks=result_masks) |
| | output_image = vis_output.get_image() |
| | output_image = Image.fromarray(output_image) |
| |
|
| | final_out_path = os.path.join(save_path, 'final_out_0116') |
| | if not os.path.exists(final_out_path): |
| | os.makedirs(final_out_path) |
| | output_image.save(os.path.join(final_out_path, file_name+'.jpg')) |
| | |
| | save_json_results = [] |
| | for tag, mask in zip(result_tag_list, result_masks): |
| | rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0] |
| | rle["counts"] = rle["counts"].decode("utf-8") |
| | save_json_results.append({ |
| | "tag": tag, |
| | "segmentation": rle, |
| | }) |
| |
|
| | json_path = os.path.join(save_path, 'final_json_out_0116') |
| | if not os.path.exists(json_path): |
| | os.makedirs(json_path) |
| | with open(os.path.join(json_path, file_name+'.json'), 'w') as f: |
| | json.dump(save_json_results, f) |
| |
|
| |
|
| | def main(): |
| |
|
| | ram_predictor = build_ram_predictor(override_ckpt_file="third_parts/recognize_anything/xinyu1205/recognize-anything-plus-model/ram_plus_swin_large_14m.pth") |
| | ape_predictor = build_ape_predictor(which_categories='COCO', |
| | override_ckpt_file="third_parts/APE/shenyunhang/APE/configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth") |
| |
|
| | sam = build_sam_vit_h("third_parts/zhouyik/zt_any_visual_prompt/sam_vit_h_4b8939.pth") |
| | sam.to(device="cuda") |
| | sam_predictor = SamPredictor(sam) |
| |
|
| | sam_auto_mask_generator = SamAutomaticMaskGenerator(sam) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | sam_images = [img for img in os.listdir("/mnt/bn/zhnagtao-lq/xiangtai-mnt/sam/sa_000001") if img.endswith('jpg')] |
| | |
| |
|
| | for idx, sam_image_file in enumerate(sam_images): |
| | if idx == 300: |
| | break |
| | image_name = sam_image_file.split('.')[0] |
| | sam_image_file = os.path.join("/mnt/bn/zhnagtao-lq/xiangtai-mnt/sam/sa_000001", sam_image_file) |
| | sam_anno_file = os.path.join("/mnt/bn/zhnagtao-lq/xiangtai-mnt/sam/sa_000001", image_name+".json") |
| | merge_sa1b_image(sam_image_file, sam_anno_file, "third_parts/zhouyik/zt_any_visual_prompt/", None, visualize=True) |
| | print(f"processed {idx+1}'th image: {sam_image_file}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |