Spaces:
Runtime error
Runtime error
File size: 5,510 Bytes
51f6859 0702ffc 51f6859 0702ffc 51f6859 0702ffc 51f6859 0702ffc 51f6859 0702ffc 51f6859 0702ffc 51f6859 0702ffc 51f6859 0702ffc 51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import cv2
import torch
import torch.nn as nn
from mmcv import Config
from mmcv.runner import load_checkpoint
from mmdet.core import bbox2result
from mmdet.models import DETECTORS, BaseDetector
from projects.instance_segment_anything.models.segment_anything import sam_model_registry, SamPredictor
from .focalnet_dino.focalnet_dino_wrapper import FocalNetDINOWrapper
from .hdetr.hdetr_wrapper import HDetrWrapper
@DETECTORS.register_module()
class DetWrapperInstanceSAM(BaseDetector):
wrapper_dict = {'hdetr': HDetrWrapper,
'focalnet_dino': FocalNetDINOWrapper}
def __init__(self,
det_wrapper_type='hdetr',
det_wrapper_cfg=None,
det_model_ckpt=None,
num_classes=80,
model_type='vit_b',
sam_checkpoint=None,
use_sam_iou=True,
best_in_multi_mask=False,
init_cfg=None,
train_cfg=None,
test_cfg=None):
super(DetWrapperInstanceSAM, self).__init__(init_cfg)
self.learnable_placeholder = nn.Embedding(1, 1)
det_wrapper_cfg = Config(det_wrapper_cfg)
assert det_wrapper_type in self.wrapper_dict.keys()
self.det_model = self.wrapper_dict[det_wrapper_type](args=det_wrapper_cfg)
if det_model_ckpt is not None:
load_checkpoint(self.det_model.model,
filename=det_model_ckpt,
map_location='cpu')
self.num_classes = num_classes
# Segment Anything
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
_ = sam.to(device=self.learnable_placeholder.weight.device)
self.predictor = SamPredictor(sam)
# Whether use SAM's predicted IoU to calibrate the confidence score.
self.use_sam_iou = use_sam_iou
# If True, set multimask_output=True and return the mask with highest predicted IoU.
# if False, set multimask_output=False and return the unique output mask.
self.best_in_multi_mask = best_in_multi_mask
def init_weights(self):
pass
def simple_test(self, img, img_metas, rescale=True, ori_img=None):
"""Test without augmentation.
Args:
imgs (Tensor): A batch of images.
img_metas (list[dict]): List of image information.
"""
assert rescale
assert len(img_metas) == 1
# results: List[dict(scores, labels, boxes)]
results = self.det_model.simple_test(img,
img_metas,
rescale)
# Tensor(n,4), xyxy, ori image scale
output_boxes = results[0]['boxes']
if ori_img is None:
image_path = img_metas[0]['filename']
ori_img = cv2.imread(image_path)
ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
self.predictor.set_image(ori_img)
transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])
# mask_pred: n,1/3,h,w
# sam_score: n, 1/3
mask_pred, sam_score, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=self.best_in_multi_mask,
return_logits=True,
)
if self.best_in_multi_mask:
# sam_score: n
sam_score, max_iou_idx = torch.max(sam_score, dim=1)
# mask_pred: n,h,w
mask_pred = mask_pred[torch.arange(mask_pred.size(0)),
max_iou_idx]
else:
# Tensor(n,h,w), raw mask pred
# n,1,h,w->n,h,w
mask_pred = mask_pred.squeeze(1)
# n,1->n
sam_score = sam_score.squeeze(-1)
# Tensor(n,)
label_pred = results[0]['labels']
score_pred = results[0]['scores']
# mask_pred: Tensor(n,h,w)
# label_pred: Tensor(n,)
# score_pred: Tensor(n,)
# sam_score: Tensor(n,)
mask_pred_binary = (mask_pred > self.predictor.model.mask_threshold).float()
if self.use_sam_iou:
det_scores = score_pred * sam_score
else:
# n
mask_scores_per_image = (mask_pred * mask_pred_binary).flatten(1).sum(1) / (
mask_pred_binary.flatten(1).sum(1) + 1e-6)
det_scores = score_pred * mask_scores_per_image
# det_scores = score_pred
mask_pred_binary = mask_pred_binary.bool()
bboxes = torch.cat([output_boxes, det_scores[:, None]], dim=-1)
bbox_results = bbox2result(bboxes, label_pred, self.num_classes)
mask_results = [[] for _ in range(self.num_classes)]
for j, label in enumerate(label_pred):
mask = mask_pred_binary[j].detach().cpu().numpy()
mask_results[label].append(mask)
output_results = [(bbox_results, mask_results)]
return output_results
# not implemented:
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError
def onnx_export(self, img, img_metas):
raise NotImplementedError
async def async_simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError
def extract_feat(self, imgs):
raise NotImplementedError
|