# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32

from ..builder import HEADS
from ..losses import smooth_l1_loss
from .ascend_anchor_head import AscendAnchorHead
from .ssd_head import SSDHead


@HEADS.register_module()
class AscendSSDHead(SSDHead, AscendAnchorHead):
    """Ascend SSD head used in https://arxiv.org/abs/1512.02325.

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
        stacked_convs (int): Number of conv layers in cls and reg tower.
            Default: 0.
        feat_channels (int): Number of hidden channels when stacked_convs
            > 0. Default: 256.
        use_depthwise (bool): Whether to use DepthwiseSeparableConv.
            Default: False.
        conv_cfg (dict): Dictionary to construct and config conv layer.
            Default: None.
        norm_cfg (dict): Dictionary to construct and config norm layer.
            Default: None.
        act_cfg (dict): Dictionary to construct and config activation layer.
            Default: None.
        anchor_generator (dict): Config dict for anchor generator
        bbox_coder (dict): Config of bounding box coder.
        reg_decoded_bbox (bool): If true, the regression loss would be
            applied directly on decoded bounding boxes, converting both
            the predicted boxes and regression targets to absolute
            coordinates format. Default False. It should be `True` when
            using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
        train_cfg (dict): Training config of anchor head.
        test_cfg (dict): Testing config of anchor head.
        init_cfg (dict or list[dict], optional): Initialization config dict.
    """  # noqa: W605

    def __init__(self,
                 num_classes=80,
                 in_channels=(512, 1024, 512, 256, 256, 256),
                 stacked_convs=0,
                 feat_channels=256,
                 use_depthwise=False,
                 conv_cfg=None,
                 norm_cfg=None,
                 act_cfg=None,
                 anchor_generator=dict(
                     type='SSDAnchorGenerator',
                     scale_major=False,
                     input_size=300,
                     strides=[8, 16, 32, 64, 100, 300],
                     ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
                     basesize_ratio_range=(0.1, 0.9)),
                 bbox_coder=dict(
                     type='DeltaXYWHBBoxCoder',
                     clip_border=True,
                     target_means=[.0, .0, .0, .0],
                     target_stds=[1.0, 1.0, 1.0, 1.0],
                 ),
                 reg_decoded_bbox=False,
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=dict(
                     type='Xavier',
                     layer='Conv2d',
                     distribution='uniform',
                     bias=0)):
        super(AscendSSDHead, self).__init__(
            num_classes=num_classes,
            in_channels=in_channels,
            stacked_convs=stacked_convs,
            feat_channels=feat_channels,
            use_depthwise=use_depthwise,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg,
            anchor_generator=anchor_generator,
            bbox_coder=bbox_coder,
            reg_decoded_bbox=reg_decoded_bbox,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg)
        assert self.reg_decoded_bbox is False, \
            'reg_decoded_bbox only support False now.'

    def get_static_anchors(self, featmap_sizes, img_metas, device='cuda'):
        """Get static anchors according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            img_metas (list[dict]): Image meta info.
            device (torch.device | str): Device for returned tensors

        Returns:
            tuple:
                anchor_list (list[Tensor]): Anchors of each image.
                valid_flag_list (list[Tensor]): Valid flags of each image.
        """
        if not hasattr(self, 'static_anchors') or \
                not hasattr(self, 'static_valid_flags'):
            static_anchors, static_valid_flags = self.get_anchors(
                featmap_sizes, img_metas, device)
            self.static_anchors = static_anchors
            self.static_valid_flags = static_valid_flags
        return self.static_anchors, self.static_valid_flags

    def get_targets(self,
                    anchor_list,
                    valid_flag_list,
                    gt_bboxes_list,
                    img_metas,
                    gt_bboxes_ignore_list=None,
                    gt_labels_list=None,
                    label_channels=1,
                    unmap_outputs=True,
                    return_sampling_results=False,
                    return_level=True):
        """Compute regression and classification targets for anchors in
        multiple images.

        Args:
            anchor_list (list[list[Tensor]]): Multi level anchors of each
                image. The outer list indicates images, and the inner list
                corresponds to feature levels of the image. Each element of
                the inner list is a tensor of shape (num_anchors, 4).
            valid_flag_list (list[list[Tensor]]): Multi level valid flags of
                each image. The outer list indicates images, and the inner list
                corresponds to feature levels of the image. Each element of
                the inner list is a tensor of shape (num_anchors, )
            gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
            img_metas (list[dict]): Meta info of each image.
            gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
                ignored.
            gt_labels_list (list[Tensor]): Ground truth labels of each box.
            label_channels (int): Channel of label.
            unmap_outputs (bool): Whether to map outputs back to the original
                set of anchors.
            return_sampling_results (bool): Whether to return the result of
                sample.
            return_level (bool): Whether to map outputs back to the levels
                of feature map sizes.
        Returns:
            tuple: Usually returns a tuple containing learning targets.

                - labels_list (list[Tensor]): Labels of each level.
                - label_weights_list (list[Tensor]): Label weights of each
                  level.
                - bbox_targets_list (list[Tensor]): BBox targets of each level.
                - bbox_weights_list (list[Tensor]): BBox weights of each level.
                - num_total_pos (int): Number of positive samples in all
                  images.
                - num_total_neg (int): Number of negative samples in all
                  images.

            additional_returns: This function enables user-defined returns from
                `self._get_targets_single`. These returns are currently refined
                to properties at each feature map (i.e. having HxW dimension).
                The results will be concatenated after the end
        """
        return AscendAnchorHead.get_targets(
            self,
            anchor_list,
            valid_flag_list,
            gt_bboxes_list,
            img_metas,
            gt_bboxes_ignore_list,
            gt_labels_list,
            label_channels,
            unmap_outputs,
            return_sampling_results,
            return_level,
        )

    def batch_loss(self, batch_cls_score, batch_bbox_pred, batch_anchor,
                   batch_labels, batch_label_weights, batch_bbox_targets,
                   batch_bbox_weights, batch_pos_mask, batch_neg_mask,
                   num_total_samples):
        """Compute loss of all images.

        Args:
            batch_cls_score (Tensor): Box scores for all image
                Has shape (num_imgs, num_total_anchors, num_classes).
            batch_bbox_pred (Tensor): Box energies / deltas for all image
                level with shape (num_imgs, num_total_anchors, 4).
            batch_anchor (Tensor): Box reference for all image with shape
                (num_imgs, num_total_anchors, 4).
            batch_labels (Tensor): Labels of all anchors with shape
                (num_imgs, num_total_anchors,).
            batch_label_weights (Tensor): Label weights of all anchor with
                shape (num_imgs, num_total_anchors,)
            batch_bbox_targets (Tensor): BBox regression targets of all anchor
                weight shape (num_imgs, num_total_anchors, 4).
            batch_bbox_weights (Tensor): BBox regression loss weights of
                all anchor with shape (num_imgs, num_total_anchors, 4).
            batch_pos_mask (Tensor): Positive samples mask in all images.
            batch_neg_mask (Tensor): negative samples mask in all images.
            num_total_samples (int): If sampling, num total samples equal to
                the number of total anchors; Otherwise, it is the number of
                positive anchors.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        num_images, num_anchors, _ = batch_anchor.size()

        batch_loss_cls_all = F.cross_entropy(
            batch_cls_score.view((-1, self.cls_out_channels)),
            batch_labels.view(-1),
            reduction='none').view(
                batch_label_weights.size()) * batch_label_weights
        # # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        batch_num_pos_samples = torch.sum(batch_pos_mask, dim=1)
        batch_num_neg_samples = \
            self.train_cfg.neg_pos_ratio * batch_num_pos_samples

        batch_num_neg_samples_max = torch.sum(batch_neg_mask, dim=1)
        batch_num_neg_samples = torch.min(batch_num_neg_samples,
                                          batch_num_neg_samples_max)

        batch_topk_loss_cls_neg, _ = torch.topk(
            batch_loss_cls_all * batch_neg_mask, k=num_anchors, dim=1)
        batch_loss_cls_pos = torch.sum(
            batch_loss_cls_all * batch_pos_mask, dim=1)

        anchor_index = torch.arange(
            end=num_anchors, dtype=torch.float,
            device=batch_anchor.device).view((1, -1))
        topk_loss_neg_mask = (anchor_index < batch_num_neg_samples.view(
            -1, 1)).float()

        batch_loss_cls_neg = torch.sum(
            batch_topk_loss_cls_neg * topk_loss_neg_mask, dim=1)
        loss_cls = \
            (batch_loss_cls_pos + batch_loss_cls_neg) / num_total_samples

        if self.reg_decoded_bbox:
            # TODO: support self.reg_decoded_bbox is True
            raise RuntimeError

        loss_bbox_all = smooth_l1_loss(
            batch_bbox_pred,
            batch_bbox_targets,
            batch_bbox_weights,
            reduction='none',
            beta=self.train_cfg.smoothl1_beta,
            avg_factor=num_total_samples)
        eps = torch.finfo(torch.float32).eps

        sum_dim = (i for i in range(1, len(loss_bbox_all.size())))
        loss_bbox = loss_bbox_all.sum(tuple(sum_dim)) / (
            num_total_samples + eps)
        return loss_cls[None], loss_bbox

    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             gt_bboxes_ignore=None):
        """Compute losses of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level
                Has shape (N, num_anchors * num_classes, H, W)
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level with shape (N, num_anchors * 4, H, W)
            gt_bboxes (list[Tensor]): each item are the truth boxes for each
                image in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == self.prior_generator.num_levels

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas, device=device)
        cls_reg_targets = self.get_targets(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=1,
            unmap_outputs=True,
            return_level=False)
        if cls_reg_targets is None:
            return None

        (batch_labels, batch_label_weights, batch_bbox_targets,
         batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result,
         num_total_pos, num_total_neg, batch_anchors) = cls_reg_targets

        num_imgs = len(img_metas)
        batch_cls_score = torch.cat([
            s.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels)
            for s in cls_scores
        ], 1)

        batch_bbox_pred = torch.cat([
            b.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for b in bbox_preds
        ], -2)

        batch_losses_cls, batch_losses_bbox = self.batch_loss(
            batch_cls_score, batch_bbox_pred, batch_anchors, batch_labels,
            batch_label_weights, batch_bbox_targets, batch_bbox_weights,
            batch_pos_mask, batch_neg_mask, num_total_pos)
        losses_cls = [
            batch_losses_cls[:, index_imgs] for index_imgs in range(num_imgs)
        ]
        losses_bbox = [losses_bbox for losses_bbox in batch_losses_bbox]
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)