|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from mmaction.models import HVULoss
|
|
|
|
|
|
def test_hvu_loss():
|
|
pred = torch.tensor([[-1.0525, -0.7085, 0.1819, -0.8011],
|
|
[0.1555, -1.5550, 0.5586, 1.9746]])
|
|
gt = torch.tensor([[1., 0., 0., 0.], [0., 0., 1., 1.]])
|
|
mask = torch.tensor([[1., 1., 0., 0.], [0., 0., 1., 1.]])
|
|
category_mask = torch.tensor([[1., 0.], [0., 1.]])
|
|
categories = ['action', 'scene']
|
|
category_nums = (2, 2)
|
|
category_loss_weights = (1, 1)
|
|
loss_all_nomask_sum = HVULoss(
|
|
categories=categories,
|
|
category_nums=category_nums,
|
|
category_loss_weights=category_loss_weights,
|
|
loss_type='all',
|
|
with_mask=False,
|
|
reduction='sum')
|
|
loss = loss_all_nomask_sum(pred, gt, mask, category_mask)
|
|
loss1 = F.binary_cross_entropy_with_logits(pred, gt, reduction='none')
|
|
loss1 = torch.sum(loss1, dim=1)
|
|
assert torch.eq(loss['loss_cls'], torch.mean(loss1))
|
|
|
|
loss_all_mask = HVULoss(
|
|
categories=categories,
|
|
category_nums=category_nums,
|
|
category_loss_weights=category_loss_weights,
|
|
loss_type='all',
|
|
with_mask=True)
|
|
loss = loss_all_mask(pred, gt, mask, category_mask)
|
|
loss1 = F.binary_cross_entropy_with_logits(pred, gt, reduction='none')
|
|
loss1 = torch.sum(loss1 * mask, dim=1) / torch.sum(mask, dim=1)
|
|
loss1 = torch.mean(loss1)
|
|
assert torch.eq(loss['loss_cls'], loss1)
|
|
|
|
loss_ind_mask = HVULoss(
|
|
categories=categories,
|
|
category_nums=category_nums,
|
|
category_loss_weights=category_loss_weights,
|
|
loss_type='individual',
|
|
with_mask=True)
|
|
loss = loss_ind_mask(pred, gt, mask, category_mask)
|
|
action_loss = F.binary_cross_entropy_with_logits(pred[:1, :2], gt[:1, :2])
|
|
scene_loss = F.binary_cross_entropy_with_logits(pred[1:, 2:], gt[1:, 2:])
|
|
loss1 = (action_loss + scene_loss) / 2
|
|
assert torch.eq(loss['loss_cls'], loss1)
|
|
|
|
loss_ind_nomask_sum = HVULoss(
|
|
categories=categories,
|
|
category_nums=category_nums,
|
|
category_loss_weights=category_loss_weights,
|
|
loss_type='individual',
|
|
with_mask=False,
|
|
reduction='sum')
|
|
loss = loss_ind_nomask_sum(pred, gt, mask, category_mask)
|
|
action_loss = F.binary_cross_entropy_with_logits(
|
|
pred[:, :2], gt[:, :2], reduction='none')
|
|
action_loss = torch.sum(action_loss, dim=1)
|
|
action_loss = torch.mean(action_loss)
|
|
|
|
scene_loss = F.binary_cross_entropy_with_logits(
|
|
pred[:, 2:], gt[:, 2:], reduction='none')
|
|
scene_loss = torch.sum(scene_loss, dim=1)
|
|
scene_loss = torch.mean(scene_loss)
|
|
|
|
loss1 = (action_loss + scene_loss) / 2
|
|
assert torch.eq(loss['loss_cls'], loss1)
|
|
|