|
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmaction.models import MViTHead
|
|
|
|
|
|
class TestMViTHead(TestCase):
|
|
DEFAULT_ARGS = dict(in_channels=768, num_classes=5)
|
|
fake_feats = ([torch.rand(4, 768, 3, 2, 2), torch.rand(4, 768)], )
|
|
|
|
def test_init(self):
|
|
head = MViTHead(**self.DEFAULT_ARGS)
|
|
head.init_weights()
|
|
self.assertEqual(head.dropout.p, head.dropout_ratio)
|
|
self.assertIsInstance(head.fc_cls, nn.Linear)
|
|
self.assertEqual(head.num_classes, 5)
|
|
self.assertEqual(head.dropout_ratio, 0.5)
|
|
self.assertEqual(head.in_channels, 768)
|
|
self.assertEqual(head.init_std, 0.02)
|
|
|
|
def test_pre_logits(self):
|
|
head = MViTHead(**self.DEFAULT_ARGS)
|
|
pre_logits = head.pre_logits(self.fake_feats)
|
|
self.assertIs(pre_logits, self.fake_feats[-1][1])
|
|
|
|
def test_forward(self):
|
|
head = MViTHead(**self.DEFAULT_ARGS)
|
|
cls_score = head(self.fake_feats)
|
|
self.assertEqual(cls_score.shape, (4, 5))
|
|
|