mmaction2 / tests /models /losses /test_ssn_loss.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine import ConfigDict
from mmaction.models import OHEMHingeLoss, SSNLoss
def test_ssn_loss():
ssn_loss = SSNLoss()
# test activity_loss
activity_score = torch.rand((8, 21))
labels = torch.LongTensor([8] * 8).squeeze()
activity_indexer = torch.tensor([0, 7])
output_activity_loss = ssn_loss.activity_loss(activity_score, labels,
activity_indexer)
assert torch.equal(
output_activity_loss,
F.cross_entropy(activity_score[activity_indexer, :],
labels[activity_indexer]))
# test completeness_loss
completeness_score = torch.rand((8, 20), requires_grad=True)
labels = torch.LongTensor([8] * 8).squeeze()
completeness_indexer = torch.tensor([0, 1, 2, 3, 4, 5, 6])
positive_per_video = 1
incomplete_per_video = 6
output_completeness_loss = ssn_loss.completeness_loss(
completeness_score, labels, completeness_indexer, positive_per_video,
incomplete_per_video)
pred = completeness_score[completeness_indexer, :]
gt = labels[completeness_indexer]
pred_dim = pred.size(1)
pred = pred.view(-1, positive_per_video + incomplete_per_video, pred_dim)
gt = gt.view(-1, positive_per_video + incomplete_per_video)
# yapf:disable
positive_pred = pred[:, :positive_per_video, :].contiguous().view(-1, pred_dim) # noqa:E501
incomplete_pred = pred[:, positive_per_video:, :].contiguous().view(-1, pred_dim) # noqa:E501
# yapf:enable
ohem_ratio = 0.17
positive_loss = OHEMHingeLoss.apply(
positive_pred, gt[:, :positive_per_video].contiguous().view(-1), 1,
1.0, positive_per_video)
incomplete_loss = OHEMHingeLoss.apply(
incomplete_pred, gt[:, positive_per_video:].contiguous().view(-1), -1,
ohem_ratio, incomplete_per_video)
num_positives = positive_pred.size(0)
num_incompletes = int(incomplete_pred.size(0) * ohem_ratio)
assert_loss = ((positive_loss + incomplete_loss) /
float(num_positives + num_incompletes))
assert torch.equal(output_completeness_loss, assert_loss)
# test reg_loss
bbox_pred = torch.rand((8, 20, 2))
labels = torch.LongTensor([8] * 8).squeeze()
bbox_targets = torch.rand((8, 2))
regression_indexer = torch.tensor([0])
output_reg_loss = ssn_loss.classwise_regression_loss(
bbox_pred, labels, bbox_targets, regression_indexer)
pred = bbox_pred[regression_indexer, :, :]
gt = labels[regression_indexer]
reg_target = bbox_targets[regression_indexer, :]
class_idx = gt.data - 1
classwise_pred = pred[:, class_idx, :]
classwise_reg_pred = torch.cat((torch.diag(classwise_pred[:, :, 0]).view(
-1, 1), torch.diag(classwise_pred[:, :, 1]).view(-1, 1)),
dim=1)
assert torch.equal(
output_reg_loss,
F.smooth_l1_loss(classwise_reg_pred.view(-1), reg_target.view(-1)) * 2)
# test ssn_loss
proposal_type = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 2]])
train_cfg = ConfigDict(
dict(
ssn=dict(
sampler=dict(
num_per_video=8,
positive_ratio=1,
background_ratio=1,
incomplete_ratio=6,
add_gt_as_proposals=True),
loss_weight=dict(comp_loss_weight=0.1, reg_loss_weight=0.1))))
output_loss = ssn_loss(activity_score, completeness_score, bbox_pred,
proposal_type, labels, bbox_targets, train_cfg)
assert torch.equal(output_loss['loss_activity'], output_activity_loss)
assert torch.equal(output_loss['loss_completeness'],
output_completeness_loss * 0.1)
assert torch.equal(output_loss['loss_reg'], output_reg_loss * 0.1)