Spaces:
Build error
Build error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmpose.models.backbones import SEResNeXt | |
| from mmpose.models.backbones.seresnext import SEBottleneck as SEBottleneckX | |
| def test_bottleneck(): | |
| with pytest.raises(AssertionError): | |
| # Style must be in ['pytorch', 'caffe'] | |
| SEBottleneckX(64, 64, groups=32, width_per_group=4, style='tensorflow') | |
| # Test SEResNeXt Bottleneck structure | |
| block = SEBottleneckX( | |
| 64, 256, groups=32, width_per_group=4, stride=2, style='pytorch') | |
| assert block.width_per_group == 4 | |
| assert block.conv2.stride == (2, 2) | |
| assert block.conv2.groups == 32 | |
| assert block.conv2.out_channels == 128 | |
| assert block.conv2.out_channels == block.mid_channels | |
| # Test SEResNeXt Bottleneck structure (groups=1) | |
| block = SEBottleneckX( | |
| 64, 256, groups=1, width_per_group=4, stride=2, style='pytorch') | |
| assert block.conv2.stride == (2, 2) | |
| assert block.conv2.groups == 1 | |
| assert block.conv2.out_channels == 64 | |
| assert block.mid_channels == 64 | |
| assert block.conv2.out_channels == block.mid_channels | |
| # Test SEResNeXt Bottleneck forward | |
| block = SEBottleneckX( | |
| 64, 64, base_channels=16, groups=32, width_per_group=4) | |
| x = torch.randn(1, 64, 56, 56) | |
| x_out = block(x) | |
| assert x_out.shape == torch.Size([1, 64, 56, 56]) | |
| def test_seresnext(): | |
| with pytest.raises(KeyError): | |
| # SEResNeXt depth should be in [50, 101, 152] | |
| SEResNeXt(depth=18) | |
| # Test SEResNeXt with group 32, width_per_group 4 | |
| model = SEResNeXt( | |
| depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3)) | |
| for m in model.modules(): | |
| if isinstance(m, SEBottleneckX): | |
| assert m.conv2.groups == 32 | |
| model.init_weights() | |
| model.train() | |
| imgs = torch.randn(1, 3, 224, 224) | |
| feat = model(imgs) | |
| assert len(feat) == 4 | |
| assert feat[0].shape == torch.Size([1, 256, 56, 56]) | |
| assert feat[1].shape == torch.Size([1, 512, 28, 28]) | |
| assert feat[2].shape == torch.Size([1, 1024, 14, 14]) | |
| assert feat[3].shape == torch.Size([1, 2048, 7, 7]) | |
| # Test SEResNeXt with group 32, width_per_group 4 and layers 3 out forward | |
| model = SEResNeXt( | |
| depth=50, groups=32, width_per_group=4, out_indices=(3, )) | |
| for m in model.modules(): | |
| if isinstance(m, SEBottleneckX): | |
| assert m.conv2.groups == 32 | |
| model.init_weights() | |
| model.train() | |
| imgs = torch.randn(1, 3, 224, 224) | |
| feat = model(imgs) | |
| assert feat.shape == torch.Size([1, 2048, 7, 7]) | |