# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch from mmaction.models import SwinTransformer3D from mmaction.testing import generate_backbone_demo_inputs def test_swin_backbone(): """Test swin backbone.""" with pytest.raises(AssertionError): SwinTransformer3D(arch='-t') with pytest.raises(AssertionError): SwinTransformer3D(arch={'embed_dims': 96}) with pytest.raises(AssertionError): SwinTransformer3D(arch={ 'embed_dims': 96, 'depths': [2, 2, 6], 'num_heads': [3, 6, 12, 24] }) with pytest.raises(AssertionError): SwinTransformer3D( arch={ 'embed_dims': 96, 'depths': [2, 2, 6, 2, 2], 'num_heads': [3, 6, 12, 24, 48] }) with pytest.raises(AssertionError): SwinTransformer3D(arch='t', out_indices=(4, )) with pytest.raises(TypeError): swin_t = SwinTransformer3D(arch='t', pretrained=[0, 1, 1]) swin_t.init_weights() with pytest.raises(TypeError): swin_t = SwinTransformer3D(arch='t') swin_t.init_weights(pretrained=[0, 1, 1]) swin_b = SwinTransformer3D(arch='b', pretrained=None, pretrained2d=False) swin_b.init_weights() swin_b.train() pretrained_url = 'https://download.openmmlab.com/mmaction/v1.0/' \ 'recognition/swin/swin_tiny_patch4_window7_224.pth' swin_t_pre = SwinTransformer3D( arch='t', pretrained=pretrained_url, pretrained2d=True) swin_t_pre.init_weights() swin_t_pre.train() from mmengine.runner.checkpoint import _load_checkpoint ckpt_2d = _load_checkpoint(pretrained_url, map_location='cpu') state_dict = ckpt_2d['model'] patch_embed_weight2d = state_dict['patch_embed.proj.weight'].data patch_embed_weight3d = swin_t_pre.patch_embed.proj.weight.data assert torch.equal( patch_embed_weight3d, patch_embed_weight2d.unsqueeze(2).expand_as(patch_embed_weight3d) / patch_embed_weight3d.shape[2]) norm = swin_t_pre.norm3 assert torch.equal(norm.weight.data, state_dict['norm.weight']) assert torch.equal(norm.bias.data, state_dict['norm.bias']) for name, param in swin_t_pre.named_parameters(): if 'relative_position_bias_table' in name: bias2d = state_dict[name] assert torch.equal( param.data, bias2d.repeat(2 * swin_t_pre.window_size[0] - 1, 1)) frozen_stages = 1 swin_t_frozen = SwinTransformer3D( arch='t', pretrained=None, pretrained2d=False, frozen_stages=frozen_stages) swin_t_frozen.init_weights() swin_t_frozen.train() for param in swin_t_frozen.patch_embed.parameters(): assert param.requires_grad is False for i in range(frozen_stages): layer = swin_t_frozen.layers[i] for param in layer.parameters(): assert param.requires_grad is False input_shape = (1, 3, 6, 64, 64) imgs = generate_backbone_demo_inputs(input_shape) feat = swin_t_frozen(imgs) assert feat.shape == torch.Size([1, 768, 3, 2, 2]) input_shape = (1, 3, 5, 63, 63) imgs = generate_backbone_demo_inputs(input_shape) feat = swin_t_frozen(imgs) assert feat.shape == torch.Size([1, 768, 3, 2, 2]) swin_t_all_stages = SwinTransformer3D(arch='t', out_indices=(0, 1, 2, 3)) feats = swin_t_all_stages(imgs) assert feats[0].shape == torch.Size([1, 96, 3, 16, 16]) assert feats[1].shape == torch.Size([1, 192, 3, 8, 8]) assert feats[2].shape == torch.Size([1, 384, 3, 4, 4]) assert feats[3].shape == torch.Size([1, 768, 3, 2, 2]) swin_t_all_stages_after_ds = SwinTransformer3D( arch='t', out_indices=(0, 1, 2, 3), out_after_downsample=True) feats = swin_t_all_stages_after_ds(imgs) assert feats[0].shape == torch.Size([1, 192, 3, 8, 8]) assert feats[1].shape == torch.Size([1, 384, 3, 4, 4]) assert feats[2].shape == torch.Size([1, 768, 3, 2, 2]) assert feats[3].shape == torch.Size([1, 768, 3, 2, 2])