mmaction2 / tests /models /heads /test_gcn_head.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmaction.models import GCNHead
def test_gcn_head():
"""Test GCNHead."""
with pytest.raises(AssertionError):
GCNHead(4, 5)(torch.rand((1, 2, 6, 75, 17)))
gcn_head = GCNHead(num_classes=60, in_channels=256)
gcn_head.init_weights()
feat = torch.rand(1, 2, 256, 75, 25)
cls_scores = gcn_head(feat)
assert gcn_head.num_classes == 60
assert gcn_head.in_channels == 256
assert cls_scores.shape == torch.Size([1, 60])
gcn_head = GCNHead(num_classes=60, in_channels=256, dropout=0.1)
gcn_head.init_weights()
feat = torch.rand(1, 2, 256, 75, 25)
cls_scores = gcn_head(feat)
assert gcn_head.num_classes == 60
assert gcn_head.in_channels == 256
assert cls_scores.shape == torch.Size([1, 60])