mmaction2 / tests /models /heads /test_mvit_head.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
import torch.nn as nn
from mmaction.models import MViTHead
class TestMViTHead(TestCase):
DEFAULT_ARGS = dict(in_channels=768, num_classes=5)
fake_feats = ([torch.rand(4, 768, 3, 2, 2), torch.rand(4, 768)], )
def test_init(self):
head = MViTHead(**self.DEFAULT_ARGS)
head.init_weights()
self.assertEqual(head.dropout.p, head.dropout_ratio)
self.assertIsInstance(head.fc_cls, nn.Linear)
self.assertEqual(head.num_classes, 5)
self.assertEqual(head.dropout_ratio, 0.5)
self.assertEqual(head.in_channels, 768)
self.assertEqual(head.init_std, 0.02)
def test_pre_logits(self):
head = MViTHead(**self.DEFAULT_ARGS)
pre_logits = head.pre_logits(self.fake_feats)
self.assertIs(pre_logits, self.fake_feats[-1][1])
def test_forward(self):
head = MViTHead(**self.DEFAULT_ARGS)
cls_score = head(self.fake_feats)
self.assertEqual(cls_score.shape, (4, 5))