|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from mmengine import ConfigDict
|
|
|
|
from mmaction.models import OHEMHingeLoss, SSNLoss
|
|
|
|
|
|
def test_ssn_loss():
|
|
ssn_loss = SSNLoss()
|
|
|
|
|
|
activity_score = torch.rand((8, 21))
|
|
labels = torch.LongTensor([8] * 8).squeeze()
|
|
activity_indexer = torch.tensor([0, 7])
|
|
output_activity_loss = ssn_loss.activity_loss(activity_score, labels,
|
|
activity_indexer)
|
|
assert torch.equal(
|
|
output_activity_loss,
|
|
F.cross_entropy(activity_score[activity_indexer, :],
|
|
labels[activity_indexer]))
|
|
|
|
|
|
completeness_score = torch.rand((8, 20), requires_grad=True)
|
|
labels = torch.LongTensor([8] * 8).squeeze()
|
|
completeness_indexer = torch.tensor([0, 1, 2, 3, 4, 5, 6])
|
|
positive_per_video = 1
|
|
incomplete_per_video = 6
|
|
output_completeness_loss = ssn_loss.completeness_loss(
|
|
completeness_score, labels, completeness_indexer, positive_per_video,
|
|
incomplete_per_video)
|
|
|
|
pred = completeness_score[completeness_indexer, :]
|
|
gt = labels[completeness_indexer]
|
|
pred_dim = pred.size(1)
|
|
pred = pred.view(-1, positive_per_video + incomplete_per_video, pred_dim)
|
|
gt = gt.view(-1, positive_per_video + incomplete_per_video)
|
|
|
|
positive_pred = pred[:, :positive_per_video, :].contiguous().view(-1, pred_dim)
|
|
incomplete_pred = pred[:, positive_per_video:, :].contiguous().view(-1, pred_dim)
|
|
|
|
ohem_ratio = 0.17
|
|
positive_loss = OHEMHingeLoss.apply(
|
|
positive_pred, gt[:, :positive_per_video].contiguous().view(-1), 1,
|
|
1.0, positive_per_video)
|
|
incomplete_loss = OHEMHingeLoss.apply(
|
|
incomplete_pred, gt[:, positive_per_video:].contiguous().view(-1), -1,
|
|
ohem_ratio, incomplete_per_video)
|
|
num_positives = positive_pred.size(0)
|
|
num_incompletes = int(incomplete_pred.size(0) * ohem_ratio)
|
|
assert_loss = ((positive_loss + incomplete_loss) /
|
|
float(num_positives + num_incompletes))
|
|
assert torch.equal(output_completeness_loss, assert_loss)
|
|
|
|
|
|
bbox_pred = torch.rand((8, 20, 2))
|
|
labels = torch.LongTensor([8] * 8).squeeze()
|
|
bbox_targets = torch.rand((8, 2))
|
|
regression_indexer = torch.tensor([0])
|
|
output_reg_loss = ssn_loss.classwise_regression_loss(
|
|
bbox_pred, labels, bbox_targets, regression_indexer)
|
|
|
|
pred = bbox_pred[regression_indexer, :, :]
|
|
gt = labels[regression_indexer]
|
|
reg_target = bbox_targets[regression_indexer, :]
|
|
class_idx = gt.data - 1
|
|
classwise_pred = pred[:, class_idx, :]
|
|
classwise_reg_pred = torch.cat((torch.diag(classwise_pred[:, :, 0]).view(
|
|
-1, 1), torch.diag(classwise_pred[:, :, 1]).view(-1, 1)),
|
|
dim=1)
|
|
assert torch.equal(
|
|
output_reg_loss,
|
|
F.smooth_l1_loss(classwise_reg_pred.view(-1), reg_target.view(-1)) * 2)
|
|
|
|
|
|
proposal_type = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 2]])
|
|
train_cfg = ConfigDict(
|
|
dict(
|
|
ssn=dict(
|
|
sampler=dict(
|
|
num_per_video=8,
|
|
positive_ratio=1,
|
|
background_ratio=1,
|
|
incomplete_ratio=6,
|
|
add_gt_as_proposals=True),
|
|
loss_weight=dict(comp_loss_weight=0.1, reg_loss_weight=0.1))))
|
|
output_loss = ssn_loss(activity_score, completeness_score, bbox_pred,
|
|
proposal_type, labels, bbox_targets, train_cfg)
|
|
assert torch.equal(output_loss['loss_activity'], output_activity_loss)
|
|
assert torch.equal(output_loss['loss_completeness'],
|
|
output_completeness_loss * 0.1)
|
|
assert torch.equal(output_loss['loss_reg'], output_reg_loss * 0.1)
|
|
|