Spaces:
Runtime error
Runtime error
File size: 5,598 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.core import multi_apply
from ..builder import HEADS
from ..losses import CrossEntropyLoss, SmoothL1Loss, carl_loss, isr_p
from .ssd_head import SSDHead
# TODO: add loss evaluator for SSD
@HEADS.register_module()
class PISASSDHead(SSDHead):
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
gt_bboxes (list[Tensor]): Ground truth bboxes of each image
with shape (num_obj, 4).
gt_labels (list[Tensor]): Ground truth labels of each image
with shape (num_obj, 4).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor]): Ignored gt bboxes of each image.
Default: None.
Returns:
dict: Loss dict, comprise classification loss regression loss and
carl loss.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=1,
unmap_outputs=False,
return_sampling_results=True)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg, sampling_results_list) = cls_reg_targets
num_images = len(img_metas)
all_cls_scores = torch.cat([
s.permute(0, 2, 3, 1).reshape(
num_images, -1, self.cls_out_channels) for s in cls_scores
], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list,
-1).view(num_images, -1)
all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds
], -2)
all_bbox_targets = torch.cat(bbox_targets_list,
-2).view(num_images, -1, 4)
all_bbox_weights = torch.cat(bbox_weights_list,
-2).view(num_images, -1, 4)
# concat all level anchors to a single tensor
all_anchors = []
for i in range(num_images):
all_anchors.append(torch.cat(anchor_list[i]))
isr_cfg = self.train_cfg.get('isr', None)
all_targets = (all_labels.view(-1), all_label_weights.view(-1),
all_bbox_targets.view(-1,
4), all_bbox_weights.view(-1, 4))
# apply ISR-P
if isr_cfg is not None:
all_targets = isr_p(
all_cls_scores.view(-1, all_cls_scores.size(-1)),
all_bbox_preds.view(-1, 4),
all_targets,
torch.cat(all_anchors),
sampling_results_list,
loss_cls=CrossEntropyLoss(),
bbox_coder=self.bbox_coder,
**self.train_cfg.isr,
num_class=self.num_classes)
(new_labels, new_label_weights, new_bbox_targets,
new_bbox_weights) = all_targets
all_labels = new_labels.view(all_labels.shape)
all_label_weights = new_label_weights.view(all_label_weights.shape)
all_bbox_targets = new_bbox_targets.view(all_bbox_targets.shape)
all_bbox_weights = new_bbox_weights.view(all_bbox_weights.shape)
# add CARL loss
carl_loss_cfg = self.train_cfg.get('carl', None)
if carl_loss_cfg is not None:
loss_carl = carl_loss(
all_cls_scores.view(-1, all_cls_scores.size(-1)),
all_targets[0],
all_bbox_preds.view(-1, 4),
all_targets[2],
SmoothL1Loss(beta=1.),
**self.train_cfg.carl,
avg_factor=num_total_pos,
num_class=self.num_classes)
# check NaN and Inf
assert torch.isfinite(all_cls_scores).all().item(), \
'classification scores become infinite or NaN!'
assert torch.isfinite(all_bbox_preds).all().item(), \
'bbox predications become infinite or NaN!'
losses_cls, losses_bbox = multi_apply(
self.loss_single,
all_cls_scores,
all_bbox_preds,
all_anchors,
all_labels,
all_label_weights,
all_bbox_targets,
all_bbox_weights,
num_total_samples=num_total_pos)
loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
if carl_loss_cfg is not None:
loss_dict.update(loss_carl)
return loss_dict
|