# Copyright (c) OpenMMLab. All rights reserved. import torch from ...core.bbox.assigners import AscendMaxIoUAssigner from ...core.bbox.samplers import PseudoSampler from ...utils import (batch_images_to_levels, get_max_num_gt_division_factor, masked_fill) from ..builder import HEADS from .anchor_head import AnchorHead @HEADS.register_module() class AscendAnchorHead(AnchorHead): """Ascend Anchor-based head (RetinaNet, SSD, etc.). Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. feat_channels (int): Number of hidden channels. Used in child classes. anchor_generator (dict): Config dict for anchor generator bbox_coder (dict): Config of bounding box coder. reg_decoded_bbox (bool): If true, the regression loss would be applied directly on decoded bounding boxes, converting both the predicted boxes and regression targets to absolute coordinates format. Default False. It should be `True` when using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. loss_cls (dict): Config of classification loss. loss_bbox (dict): Config of localization loss. train_cfg (dict): Training config of anchor head. test_cfg (dict): Testing config of anchor head. init_cfg (dict or list[dict], optional): Initialization config dict. """ # noqa: W605 def __init__(self, num_classes, in_channels, feat_channels=256, anchor_generator=dict( type='AnchorGenerator', scales=[8, 16, 32], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]), bbox_coder=dict( type='DeltaXYWHBBoxCoder', clip_border=True, target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0)), reg_decoded_bbox=False, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), train_cfg=None, test_cfg=None, init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)): super(AscendAnchorHead, self).__init__( num_classes=num_classes, in_channels=in_channels, feat_channels=feat_channels, anchor_generator=anchor_generator, bbox_coder=bbox_coder, reg_decoded_bbox=reg_decoded_bbox, loss_cls=loss_cls, loss_bbox=loss_bbox, train_cfg=train_cfg, test_cfg=test_cfg, init_cfg=init_cfg) def get_batch_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, device, max_gt_labels): """Get ground truth bboxes of all image. Args: gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. num_images (int): The num of images. gt_nums(list[int]): The ground truth bboxes num of each image. device (torch.device | str): Device for returned tensors max_gt_labels(int): The max ground truth bboxes num of all image. Returns: batch_gt_bboxes: (Tensor): Ground truth bboxes of all image. """ # a static ground truth boxes. # Save static gt. Related to Ascend. Helps improve performance if not hasattr(self, 'batch_gt_bboxes'): self.batch_gt_bboxes = {} # a min anchor filled the excess anchor if not hasattr(self, 'min_anchor'): self.min_anchor = (-1354, -1344) if gt_bboxes_list is None: batch_gt_bboxes = None else: if self.batch_gt_bboxes.get(max_gt_labels) is None: batch_gt_bboxes = torch.zeros((num_images, max_gt_labels, 4), dtype=gt_bboxes_list[0].dtype, device=device) batch_gt_bboxes[:, :, :2] = self.min_anchor[0] batch_gt_bboxes[:, :, 2:] = self.min_anchor[1] self.batch_gt_bboxes[max_gt_labels] = batch_gt_bboxes.clone() else: batch_gt_bboxes = self.batch_gt_bboxes.get( max_gt_labels).clone() for index_imgs, gt_bboxes in enumerate(gt_bboxes_list): batch_gt_bboxes[index_imgs, :gt_nums[index_imgs]] = gt_bboxes return batch_gt_bboxes def get_batch_gt_bboxes_ignore(self, gt_bboxes_ignore_list, num_images, gt_nums, device): """Ground truth bboxes to be ignored of all image. Args: gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be ignored. num_images (int): The num of images. gt_nums(list[int]): The ground truth bboxes num of each image. device (torch.device | str): Device for returned tensors Returns: batch_gt_bboxes_ignore: (Tensor): Ground truth bboxes to be ignored of all image. """ # TODO: support gt_bboxes_ignore_list if gt_bboxes_ignore_list is None: batch_gt_bboxes_ignore = None else: raise RuntimeError('gt_bboxes_ignore not support yet') return batch_gt_bboxes_ignore def get_batch_gt_labels(self, gt_labels_list, num_images, gt_nums, device, max_gt_labels): """Ground truth bboxes to be ignored of all image. Args: gt_labels_list (list[Tensor]): Ground truth labels. num_images (int): The num of images. gt_nums(list[int]): The ground truth bboxes num of each image. device (torch.device | str): Device for returned tensors Returns: batch_gt_labels: (Tensor): Ground truth labels of all image. """ if gt_labels_list is None: batch_gt_labels = None else: batch_gt_labels = torch.zeros((num_images, max_gt_labels), dtype=gt_labels_list[0].dtype, device=device) for index_imgs, gt_labels in enumerate(gt_labels_list): batch_gt_labels[index_imgs, :gt_nums[index_imgs]] = gt_labels return batch_gt_labels def _get_targets_concat(self, batch_anchors, batch_valid_flags, batch_gt_bboxes, batch_gt_bboxes_ignore, batch_gt_labels, img_metas, label_channels=1, unmap_outputs=True): """Compute regression and classification targets for anchors in all images. Args: batch_anchors (Tensor): anchors of all image, which are concatenated into a single tensor of shape (num_imgs, num_anchors ,4). batch_valid_flags (Tensor): valid flags of all image, which are concatenated into a single tensor of shape (num_imgs, num_anchors,). batch_gt_bboxes (Tensor): Ground truth bboxes of all image, shape (num_imgs, max_gt_nums, 4). batch_gt_bboxes_ignore (Tensor): Ground truth bboxes to be ignored, shape (num_imgs, num_ignored_gts, 4). batch_gt_labels (Tensor): Ground truth labels of each box, shape (num_imgs, max_gt_nums,). img_metas (list[dict]): Meta info of each image. label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Returns: tuple: batch_labels (Tensor): Labels of all level batch_label_weights (Tensor): Label weights of all level batch_bbox_targets (Tensor): BBox targets of all level batch_bbox_weights (Tensor): BBox weights of all level batch_pos_mask (Tensor): Positive samples mask in all images batch_neg_mask (Tensor): Negative samples mask in all images sampling_result (Sampling): The result of sampling, default: None. """ num_imgs, num_anchors, _ = batch_anchors.size() # assign gt and sample batch_anchors assign_result = self.assigner.assign( batch_anchors, batch_gt_bboxes, batch_gt_bboxes_ignore, None if self.sampling else batch_gt_labels, batch_bboxes_ignore_mask=batch_valid_flags) # TODO: support sampling_result sampling_result = None batch_pos_mask = assign_result.batch_pos_mask batch_neg_mask = assign_result.batch_neg_mask batch_anchor_gt_indes = assign_result.batch_anchor_gt_indes batch_anchor_gt_labels = assign_result.batch_anchor_gt_labels batch_anchor_gt_bboxes = torch.zeros( batch_anchors.size(), dtype=batch_anchors.dtype, device=batch_anchors.device) for index_imgs in range(num_imgs): batch_anchor_gt_bboxes[index_imgs] = torch.index_select( batch_gt_bboxes[index_imgs], 0, batch_anchor_gt_indes[index_imgs]) batch_bbox_targets = torch.zeros_like(batch_anchors) batch_bbox_weights = torch.zeros_like(batch_anchors) batch_labels = batch_anchors.new_full((num_imgs, num_anchors), self.num_classes, dtype=torch.int) batch_label_weights = batch_anchors.new_zeros((num_imgs, num_anchors), dtype=torch.float) if not self.reg_decoded_bbox: batch_pos_bbox_targets = self.bbox_coder.encode( batch_anchors, batch_anchor_gt_bboxes) else: batch_pos_bbox_targets = batch_anchor_gt_bboxes batch_bbox_targets = masked_fill(batch_bbox_targets, batch_pos_mask.unsqueeze(2), batch_pos_bbox_targets) batch_bbox_weights = masked_fill(batch_bbox_weights, batch_pos_mask.unsqueeze(2), 1.0) if batch_gt_labels is None: batch_labels = masked_fill(batch_labels, batch_pos_mask, 0.0) else: batch_labels = masked_fill(batch_labels, batch_pos_mask, batch_anchor_gt_labels) if self.train_cfg.pos_weight <= 0: batch_label_weights = masked_fill(batch_label_weights, batch_pos_mask, 1.0) else: batch_label_weights = masked_fill(batch_label_weights, batch_pos_mask, self.train_cfg.pos_weight) batch_label_weights = masked_fill(batch_label_weights, batch_neg_mask, 1.0) return (batch_labels, batch_label_weights, batch_bbox_targets, batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result) def get_targets(self, anchor_list, valid_flag_list, gt_bboxes_list, img_metas, gt_bboxes_ignore_list=None, gt_labels_list=None, label_channels=1, unmap_outputs=True, return_sampling_results=False, return_level=True): """Compute regression and classification targets for anchors in multiple images. Args: anchor_list (list[list[Tensor]]): Multi level anchors of each image. The outer list indicates images, and the inner list corresponds to feature levels of the image. Each element of the inner list is a tensor of shape (num_anchors, 4). valid_flag_list (list[list[Tensor]]): Multi level valid flags of each image. The outer list indicates images, and the inner list corresponds to feature levels of the image. Each element of the inner list is a tensor of shape (num_anchors, ) gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. img_metas (list[dict]): Meta info of each image. gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be ignored. gt_labels_list (list[Tensor]): Ground truth labels of each box. label_channels (int): Channel of label. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. return_sampling_results (bool): Whether to return the result of sample. return_level (bool): Whether to map outputs back to the levels of feature map sizes. Returns: tuple: Usually returns a tuple containing learning targets. - labels_list (list[Tensor]): Labels of each level. - label_weights_list (list[Tensor]): Label weights of each level. - bbox_targets_list (list[Tensor]): BBox targets of each level. - bbox_weights_list (list[Tensor]): BBox weights of each level. - num_total_pos (int): Number of positive samples in all images. - num_total_neg (int): Number of negative samples in all images. additional_returns: This function enables user-defined returns from `self._get_targets_single`. These returns are currently refined to properties at each feature map (i.e. having HxW dimension). The results will be concatenated after the end """ assert gt_bboxes_ignore_list is None assert unmap_outputs is True assert return_sampling_results is False assert self.train_cfg.allowed_border < 0 assert isinstance(self.assigner, AscendMaxIoUAssigner) assert isinstance(self.sampler, PseudoSampler) num_imgs = len(img_metas) assert len(anchor_list) == len(valid_flag_list) == num_imgs device = anchor_list[0][0].device num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] batch_anchor_list = [] batch_valid_flag_list = [] for i in range(num_imgs): assert len(anchor_list[i]) == len(valid_flag_list[i]) batch_anchor_list.append(torch.cat(anchor_list[i])) batch_valid_flag_list.append(torch.cat(valid_flag_list[i])) batch_anchors = torch.cat( [torch.unsqueeze(anchor, 0) for anchor in batch_anchor_list], 0) batch_valid_flags = torch.cat([ torch.unsqueeze(batch_valid_flag, 0) for batch_valid_flag in batch_valid_flag_list ], 0) gt_nums = [len(gt_bbox) for gt_bbox in gt_bboxes_list] max_gt_nums = get_max_num_gt_division_factor(gt_nums) batch_gt_bboxes = self.get_batch_gt_bboxes(gt_bboxes_list, num_imgs, gt_nums, device, max_gt_nums) batch_gt_bboxes_ignore = self.get_batch_gt_bboxes_ignore( gt_bboxes_ignore_list, num_imgs, gt_nums, device) batch_gt_labels = self.get_batch_gt_labels(gt_labels_list, num_imgs, gt_nums, device, max_gt_nums) results = self._get_targets_concat( batch_anchors, batch_valid_flags, batch_gt_bboxes, batch_gt_bboxes_ignore, batch_gt_labels, img_metas, label_channels=label_channels, unmap_outputs=unmap_outputs) (batch_labels, batch_label_weights, batch_bbox_targets, batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result) = results[:7] rest_results = list(results[7:]) # user-added return values # sampled anchors of all images min_num = torch.ones((num_imgs, ), dtype=torch.long, device=batch_pos_mask.device) num_total_pos = torch.sum( torch.max(torch.sum(batch_pos_mask, dim=1), min_num)) num_total_neg = torch.sum( torch.max(torch.sum(batch_neg_mask, dim=1), min_num)) if return_level is True: labels_list = batch_images_to_levels(batch_labels, num_level_anchors) label_weights_list = batch_images_to_levels( batch_label_weights, num_level_anchors) bbox_targets_list = batch_images_to_levels(batch_bbox_targets, num_level_anchors) bbox_weights_list = batch_images_to_levels(batch_bbox_weights, num_level_anchors) res = (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) if return_sampling_results: res = res + (sampling_result, ) for i, r in enumerate(rest_results): # user-added return values rest_results[i] = batch_images_to_levels(r, num_level_anchors) return res + tuple(rest_results) else: res = (batch_labels, batch_label_weights, batch_bbox_targets, batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result, num_total_pos, num_total_neg, batch_anchors) return res