Spaces:
Runtime error
Runtime error
File size: 19,280 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmdet.core import bbox2result, bbox2roi, bbox_xyxy_to_cxcywh
from mmdet.core.bbox.samplers import PseudoSampler
from ..builder import HEADS
from .cascade_roi_head import CascadeRoIHead
@HEADS.register_module()
class SparseRoIHead(CascadeRoIHead):
r"""The RoIHead for `Sparse R-CNN: End-to-End Object Detection with
Learnable Proposals <https://arxiv.org/abs/2011.12450>`_
and `Instances as Queries <http://arxiv.org/abs/2105.01928>`_
Args:
num_stages (int): Number of stage whole iterative process.
Defaults to 6.
stage_loss_weights (Tuple[float]): The loss
weight of each stage. By default all stages have
the same weight 1.
bbox_roi_extractor (dict): Config of box roi extractor.
mask_roi_extractor (dict): Config of mask roi extractor.
bbox_head (dict): Config of box head.
mask_head (dict): Config of mask head.
train_cfg (dict, optional): Configuration information in train stage.
Defaults to None.
test_cfg (dict, optional): Configuration information in test stage.
Defaults to None.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
num_stages=6,
stage_loss_weights=(1, 1, 1, 1, 1, 1),
proposal_feature_channel=256,
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(
type='RoIAlign', output_size=7, sampling_ratio=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_roi_extractor=None,
bbox_head=dict(
type='DIIHead',
num_classes=80,
num_fcs=2,
num_heads=8,
num_cls_fcs=1,
num_reg_fcs=3,
feedforward_channels=2048,
hidden_channels=256,
dropout=0.0,
roi_feat_size=7,
ffn_act_cfg=dict(type='ReLU', inplace=True)),
mask_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
assert bbox_roi_extractor is not None
assert bbox_head is not None
assert len(stage_loss_weights) == num_stages
self.num_stages = num_stages
self.stage_loss_weights = stage_loss_weights
self.proposal_feature_channel = proposal_feature_channel
super(SparseRoIHead, self).__init__(
num_stages,
stage_loss_weights,
bbox_roi_extractor=bbox_roi_extractor,
mask_roi_extractor=mask_roi_extractor,
bbox_head=bbox_head,
mask_head=mask_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
# train_cfg would be None when run the test.py
if train_cfg is not None:
for stage in range(num_stages):
assert isinstance(self.bbox_sampler[stage], PseudoSampler), \
'Sparse R-CNN and QueryInst only support `PseudoSampler`'
def _bbox_forward(self, stage, x, rois, object_feats, img_metas):
"""Box head forward function used in both training and testing. Returns
all regression, classification results and a intermediate feature.
Args:
stage (int): The index of current stage in
iterative process.
x (List[Tensor]): List of FPN features
rois (Tensor): Rois in total batch. With shape (num_proposal, 5).
the last dimension 5 represents (img_index, x1, y1, x2, y2).
object_feats (Tensor): The object feature extracted from
the previous stage.
img_metas (dict): meta information of images.
Returns:
dict[str, Tensor]: a dictionary of bbox head outputs,
Containing the following results:
- cls_score (Tensor): The score of each class, has
shape (batch_size, num_proposals, num_classes)
when use focal loss or
(batch_size, num_proposals, num_classes+1)
otherwise.
- decode_bbox_pred (Tensor): The regression results
with shape (batch_size, num_proposal, 4).
The last dimension 4 represents
[tl_x, tl_y, br_x, br_y].
- object_feats (Tensor): The object feature extracted
from current stage
- detach_cls_score_list (list[Tensor]): The detached
classification results, length is batch_size, and
each tensor has shape (num_proposal, num_classes).
- detach_proposal_list (list[tensor]): The detached
regression results, length is batch_size, and each
tensor has shape (num_proposal, 4). The last
dimension 4 represents [tl_x, tl_y, br_x, br_y].
"""
num_imgs = len(img_metas)
bbox_roi_extractor = self.bbox_roi_extractor[stage]
bbox_head = self.bbox_head[stage]
bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
rois)
cls_score, bbox_pred, object_feats, attn_feats = bbox_head(
bbox_feats, object_feats)
proposal_list = self.bbox_head[stage].refine_bboxes(
rois,
rois.new_zeros(len(rois)), # dummy arg
bbox_pred.view(-1, bbox_pred.size(-1)),
[rois.new_zeros(object_feats.size(1)) for _ in range(num_imgs)],
img_metas)
bbox_results = dict(
cls_score=cls_score,
decode_bbox_pred=torch.cat(proposal_list),
object_feats=object_feats,
attn_feats=attn_feats,
# detach then use it in label assign
detach_cls_score_list=[
cls_score[i].detach() for i in range(num_imgs)
],
detach_proposal_list=[item.detach() for item in proposal_list])
return bbox_results
def _mask_forward(self, stage, x, rois, attn_feats):
"""Mask head forward function used in both training and testing."""
mask_roi_extractor = self.mask_roi_extractor[stage]
mask_head = self.mask_head[stage]
mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
rois)
# do not support caffe_c4 model anymore
mask_pred = mask_head(mask_feats, attn_feats)
mask_results = dict(mask_pred=mask_pred)
return mask_results
def _mask_forward_train(self, stage, x, attn_feats, sampling_results,
gt_masks, rcnn_train_cfg):
"""Run forward function and calculate loss for mask head in
training."""
pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
attn_feats = torch.cat([
feats[res.pos_inds]
for (feats, res) in zip(attn_feats, sampling_results)
])
mask_results = self._mask_forward(stage, x, pos_rois, attn_feats)
mask_targets = self.mask_head[stage].get_targets(
sampling_results, gt_masks, rcnn_train_cfg)
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
loss_mask = self.mask_head[stage].loss(mask_results['mask_pred'],
mask_targets, pos_labels)
mask_results.update(loss_mask)
return mask_results
def forward_train(self,
x,
proposal_boxes,
proposal_features,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
imgs_whwh=None,
gt_masks=None):
"""Forward function in training stage.
Args:
x (list[Tensor]): list of multi-level img features.
proposals (Tensor): Decoded proposal bboxes, has shape
(batch_size, num_proposals, 4)
proposal_features (Tensor): Expanded proposal
features, has shape
(batch_size, num_proposals, proposal_feature_channel)
img_metas (list[dict]): list of image info dict where
each dict has: 'img_shape', 'scale_factor', 'flip',
and may also contain 'filename', 'ori_shape',
'pad_shape', and 'img_norm_cfg'. For details on the
values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
imgs_whwh (Tensor): Tensor with shape (batch_size, 4),
the dimension means
[img_width,img_height, img_width, img_height].
gt_masks (None | Tensor) : true segmentation masks for each box
used if the architecture supports a segmentation task.
Returns:
dict[str, Tensor]: a dictionary of loss components of all stage.
"""
num_imgs = len(img_metas)
num_proposals = proposal_boxes.size(1)
imgs_whwh = imgs_whwh.repeat(1, num_proposals, 1)
all_stage_bbox_results = []
proposal_list = [proposal_boxes[i] for i in range(len(proposal_boxes))]
object_feats = proposal_features
all_stage_loss = {}
for stage in range(self.num_stages):
rois = bbox2roi(proposal_list)
bbox_results = self._bbox_forward(stage, x, rois, object_feats,
img_metas)
all_stage_bbox_results.append(bbox_results)
if gt_bboxes_ignore is None:
# TODO support ignore
gt_bboxes_ignore = [None for _ in range(num_imgs)]
sampling_results = []
cls_pred_list = bbox_results['detach_cls_score_list']
proposal_list = bbox_results['detach_proposal_list']
for i in range(num_imgs):
normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(proposal_list[i] /
imgs_whwh[i])
assign_result = self.bbox_assigner[stage].assign(
normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
gt_labels[i], img_metas[i])
sampling_result = self.bbox_sampler[stage].sample(
assign_result, proposal_list[i], gt_bboxes[i])
sampling_results.append(sampling_result)
bbox_targets = self.bbox_head[stage].get_targets(
sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage],
True)
cls_score = bbox_results['cls_score']
decode_bbox_pred = bbox_results['decode_bbox_pred']
single_stage_loss = self.bbox_head[stage].loss(
cls_score.view(-1, cls_score.size(-1)),
decode_bbox_pred.view(-1, 4),
*bbox_targets,
imgs_whwh=imgs_whwh)
if self.with_mask:
mask_results = self._mask_forward_train(
stage, x, bbox_results['attn_feats'], sampling_results,
gt_masks, self.train_cfg[stage])
single_stage_loss['loss_mask'] = mask_results['loss_mask']
for key, value in single_stage_loss.items():
all_stage_loss[f'stage{stage}_{key}'] = value * \
self.stage_loss_weights[stage]
object_feats = bbox_results['object_feats']
return all_stage_loss
def simple_test(self,
x,
proposal_boxes,
proposal_features,
img_metas,
imgs_whwh,
rescale=False):
"""Test without augmentation.
Args:
x (list[Tensor]): list of multi-level img features.
proposal_boxes (Tensor): Decoded proposal bboxes, has shape
(batch_size, num_proposals, 4)
proposal_features (Tensor): Expanded proposal
features, has shape
(batch_size, num_proposals, proposal_feature_channel)
img_metas (dict): meta information of images.
imgs_whwh (Tensor): Tensor with shape (batch_size, 4),
the dimension means
[img_width,img_height, img_width, img_height].
rescale (bool): If True, return boxes in original image
space. Defaults to False.
Returns:
list[list[np.ndarray]] or list[tuple]: When no mask branch,
it is bbox results of each image and classes with type
`list[list[np.ndarray]]`. The outer list
corresponds to each image. The inner list
corresponds to each class. When the model has a mask branch,
it is a list[tuple] that contains bbox results and mask results.
The outer list corresponds to each image, and first element
of tuple is bbox results, second element is mask results.
"""
assert self.with_bbox, 'Bbox head must be implemented.'
# Decode initial proposals
num_imgs = len(img_metas)
proposal_list = [proposal_boxes[i] for i in range(num_imgs)]
ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
object_feats = proposal_features
if all([proposal.shape[0] == 0 for proposal in proposal_list]):
# There is no proposal in the whole batch
bbox_results = [[
np.zeros((0, 5), dtype=np.float32)
for i in range(self.bbox_head[-1].num_classes)
]] * num_imgs
return bbox_results
for stage in range(self.num_stages):
rois = bbox2roi(proposal_list)
bbox_results = self._bbox_forward(stage, x, rois, object_feats,
img_metas)
object_feats = bbox_results['object_feats']
cls_score = bbox_results['cls_score']
proposal_list = bbox_results['detach_proposal_list']
if self.with_mask:
rois = bbox2roi(proposal_list)
mask_results = self._mask_forward(stage, x, rois,
bbox_results['attn_feats'])
mask_results['mask_pred'] = mask_results['mask_pred'].reshape(
num_imgs, -1, *mask_results['mask_pred'].size()[1:])
num_classes = self.bbox_head[-1].num_classes
det_bboxes = []
det_labels = []
if self.bbox_head[-1].loss_cls.use_sigmoid:
cls_score = cls_score.sigmoid()
else:
cls_score = cls_score.softmax(-1)[..., :-1]
for img_id in range(num_imgs):
cls_score_per_img = cls_score[img_id]
scores_per_img, topk_indices = cls_score_per_img.flatten(
0, 1).topk(
self.test_cfg.max_per_img, sorted=False)
labels_per_img = topk_indices % num_classes
bbox_pred_per_img = proposal_list[img_id][topk_indices //
num_classes]
if rescale:
scale_factor = img_metas[img_id]['scale_factor']
bbox_pred_per_img /= bbox_pred_per_img.new_tensor(scale_factor)
det_bboxes.append(
torch.cat([bbox_pred_per_img, scores_per_img[:, None]], dim=1))
det_labels.append(labels_per_img)
bbox_results = [
bbox2result(det_bboxes[i], det_labels[i], num_classes)
for i in range(num_imgs)
]
if self.with_mask:
if rescale and not isinstance(scale_factors[0], float):
scale_factors = [
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
segm_results = []
mask_pred = mask_results['mask_pred']
for img_id in range(num_imgs):
mask_pred_per_img = mask_pred[img_id].flatten(0,
1)[topk_indices]
mask_pred_per_img = mask_pred_per_img[:, None, ...].repeat(
1, num_classes, 1, 1)
segm_result = self.mask_head[-1].get_seg_masks(
mask_pred_per_img, _bboxes[img_id], det_labels[img_id],
self.test_cfg, ori_shapes[img_id], scale_factors[img_id],
rescale)
segm_results.append(segm_result)
if self.with_mask:
results = list(zip(bbox_results, segm_results))
else:
results = bbox_results
return results
def aug_test(self, features, proposal_list, img_metas, rescale=False):
raise NotImplementedError(
'Sparse R-CNN and QueryInst does not support `aug_test`')
def forward_dummy(self, x, proposal_boxes, proposal_features, img_metas):
"""Dummy forward function when do the flops computing."""
all_stage_bbox_results = []
proposal_list = [proposal_boxes[i] for i in range(len(proposal_boxes))]
object_feats = proposal_features
if self.with_bbox:
for stage in range(self.num_stages):
rois = bbox2roi(proposal_list)
bbox_results = self._bbox_forward(stage, x, rois, object_feats,
img_metas)
all_stage_bbox_results.append((bbox_results, ))
proposal_list = bbox_results['detach_proposal_list']
object_feats = bbox_results['object_feats']
if self.with_mask:
rois = bbox2roi(proposal_list)
mask_results = self._mask_forward(
stage, x, rois, bbox_results['attn_feats'])
all_stage_bbox_results[-1] += (mask_results, )
return all_stage_bbox_results
|