Spaces:
Runtime error
Runtime error
File size: 7,923 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
def _expand_onehot_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(
(labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds]] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module()
class GHMC(nn.Module):
"""GHM Classification Loss.
Details of the theorem can be viewed in the paper
`Gradient Harmonized Single-stage Detector
<https://arxiv.org/abs/1811.05181>`_.
Args:
bins (int): Number of the unit regions for distribution calculation.
momentum (float): The parameter for moving average.
use_sigmoid (bool): Can only be true for BCE based loss now.
loss_weight (float): The weight of the total GHM-C loss.
reduction (str): Options are "none", "mean" and "sum".
Defaults to "mean"
"""
def __init__(self,
bins=10,
momentum=0,
use_sigmoid=True,
loss_weight=1.0,
reduction='mean'):
super(GHMC, self).__init__()
self.bins = bins
self.momentum = momentum
edges = torch.arange(bins + 1).float() / bins
self.register_buffer('edges', edges)
self.edges[-1] += 1e-6
if momentum > 0:
acc_sum = torch.zeros(bins)
self.register_buffer('acc_sum', acc_sum)
self.use_sigmoid = use_sigmoid
if not self.use_sigmoid:
raise NotImplementedError
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self,
pred,
target,
label_weight,
reduction_override=None,
**kwargs):
"""Calculate the GHM-C loss.
Args:
pred (float tensor of size [batch_num, class_num]):
The direct prediction of classification fc layer.
target (float tensor of size [batch_num, class_num]):
Binary class target for each sample.
label_weight (float tensor of size [batch_num, class_num]):
the value is 1 if the sample is valid and 0 if ignored.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
The gradient harmonized loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
# the target should be binary class label
if pred.dim() != target.dim():
target, label_weight = _expand_onehot_labels(
target, label_weight, pred.size(-1))
target, label_weight = target.float(), label_weight.float()
edges = self.edges
mmt = self.momentum
weights = torch.zeros_like(pred)
# gradient length
g = torch.abs(pred.sigmoid().detach() - target)
valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0 # n valid bins
for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
num_in_bin = inds.sum().item()
if num_in_bin > 0:
if mmt > 0:
self.acc_sum[i] = mmt * self.acc_sum[i] \
+ (1 - mmt) * num_in_bin
weights[inds] = tot / self.acc_sum[i]
else:
weights[inds] = tot / num_in_bin
n += 1
if n > 0:
weights = weights / n
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none')
loss = weight_reduce_loss(
loss, weights, reduction=reduction, avg_factor=tot)
return loss * self.loss_weight
# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module()
class GHMR(nn.Module):
"""GHM Regression Loss.
Details of the theorem can be viewed in the paper
`Gradient Harmonized Single-stage Detector
<https://arxiv.org/abs/1811.05181>`_.
Args:
mu (float): The parameter for the Authentic Smooth L1 loss.
bins (int): Number of the unit regions for distribution calculation.
momentum (float): The parameter for moving average.
loss_weight (float): The weight of the total GHM-R loss.
reduction (str): Options are "none", "mean" and "sum".
Defaults to "mean"
"""
def __init__(self,
mu=0.02,
bins=10,
momentum=0,
loss_weight=1.0,
reduction='mean'):
super(GHMR, self).__init__()
self.mu = mu
self.bins = bins
edges = torch.arange(bins + 1).float() / bins
self.register_buffer('edges', edges)
self.edges[-1] = 1e3
self.momentum = momentum
if momentum > 0:
acc_sum = torch.zeros(bins)
self.register_buffer('acc_sum', acc_sum)
self.loss_weight = loss_weight
self.reduction = reduction
# TODO: support reduction parameter
def forward(self,
pred,
target,
label_weight,
avg_factor=None,
reduction_override=None):
"""Calculate the GHM-R loss.
Args:
pred (float tensor of size [batch_num, 4 (* class_num)]):
The prediction of box regression layer. Channel number can be 4
or 4 * class_num depending on whether it is class-agnostic.
target (float tensor of size [batch_num, 4 (* class_num)]):
The target regression values with the same size of pred.
label_weight (float tensor of size [batch_num, 4 (* class_num)]):
The weight of each sample, 0 if ignored.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
Returns:
The gradient harmonized loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
mu = self.mu
edges = self.edges
mmt = self.momentum
# ASL1 loss
diff = pred - target
loss = torch.sqrt(diff * diff + mu * mu) - mu
# gradient length
g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()
weights = torch.zeros_like(g)
valid = label_weight > 0
tot = max(label_weight.float().sum().item(), 1.0)
n = 0 # n: valid bins
for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
num_in_bin = inds.sum().item()
if num_in_bin > 0:
n += 1
if mmt > 0:
self.acc_sum[i] = mmt * self.acc_sum[i] \
+ (1 - mmt) * num_in_bin
weights[inds] = tot / self.acc_sum[i]
else:
weights[inds] = tot / num_in_bin
if n > 0:
weights /= n
loss = weight_reduce_loss(
loss, weights, reduction=reduction, avg_factor=tot)
return loss * self.loss_weight
|