Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule | |
| from ..builder import BACKBONES | |
| from ..utils import ResLayer | |
| from .resnet import BasicBlock | |
| class HourglassModule(nn.Module): | |
| """Hourglass Module for HourglassNet backbone. | |
| Generate module recursively and use BasicBlock as the base unit. | |
| Args: | |
| depth (int): Depth of current HourglassModule. | |
| stage_channels (list[int]): Feature channels of sub-modules in current | |
| and follow-up HourglassModule. | |
| stage_blocks (list[int]): Number of sub-modules stacked in current and | |
| follow-up HourglassModule. | |
| norm_cfg (dict): Dictionary to construct and config norm layer. | |
| """ | |
| def __init__(self, | |
| depth, | |
| stage_channels, | |
| stage_blocks, | |
| norm_cfg=dict(type='BN', requires_grad=True)): | |
| super(HourglassModule, self).__init__() | |
| self.depth = depth | |
| cur_block = stage_blocks[0] | |
| next_block = stage_blocks[1] | |
| cur_channel = stage_channels[0] | |
| next_channel = stage_channels[1] | |
| self.up1 = ResLayer( | |
| BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg) | |
| self.low1 = ResLayer( | |
| BasicBlock, | |
| cur_channel, | |
| next_channel, | |
| cur_block, | |
| stride=2, | |
| norm_cfg=norm_cfg) | |
| if self.depth > 1: | |
| self.low2 = HourglassModule(depth - 1, stage_channels[1:], | |
| stage_blocks[1:]) | |
| else: | |
| self.low2 = ResLayer( | |
| BasicBlock, | |
| next_channel, | |
| next_channel, | |
| next_block, | |
| norm_cfg=norm_cfg) | |
| self.low3 = ResLayer( | |
| BasicBlock, | |
| next_channel, | |
| cur_channel, | |
| cur_block, | |
| norm_cfg=norm_cfg, | |
| downsample_first=False) | |
| self.up2 = nn.Upsample(scale_factor=2) | |
| def forward(self, x): | |
| """Forward function.""" | |
| up1 = self.up1(x) | |
| low1 = self.low1(x) | |
| low2 = self.low2(low1) | |
| low3 = self.low3(low2) | |
| up2 = self.up2(low3) | |
| return up1 + up2 | |
| class HourglassNet(nn.Module): | |
| """HourglassNet backbone. | |
| Stacked Hourglass Networks for Human Pose Estimation. | |
| More details can be found in the `paper | |
| <https://arxiv.org/abs/1603.06937>`_ . | |
| Args: | |
| downsample_times (int): Downsample times in a HourglassModule. | |
| num_stacks (int): Number of HourglassModule modules stacked, | |
| 1 for Hourglass-52, 2 for Hourglass-104. | |
| stage_channels (list[int]): Feature channel of each sub-module in a | |
| HourglassModule. | |
| stage_blocks (list[int]): Number of sub-modules stacked in a | |
| HourglassModule. | |
| feat_channel (int): Feature channel of conv after a HourglassModule. | |
| norm_cfg (dict): Dictionary to construct and config norm layer. | |
| Example: | |
| >>> from mmdet.models import HourglassNet | |
| >>> import torch | |
| >>> self = HourglassNet() | |
| >>> self.eval() | |
| >>> inputs = torch.rand(1, 3, 511, 511) | |
| >>> level_outputs = self.forward(inputs) | |
| >>> for level_output in level_outputs: | |
| ... print(tuple(level_output.shape)) | |
| (1, 256, 128, 128) | |
| (1, 256, 128, 128) | |
| """ | |
| def __init__(self, | |
| downsample_times=5, | |
| num_stacks=2, | |
| stage_channels=(256, 256, 384, 384, 384, 512), | |
| stage_blocks=(2, 2, 2, 2, 2, 4), | |
| feat_channel=256, | |
| norm_cfg=dict(type='BN', requires_grad=True)): | |
| super(HourglassNet, self).__init__() | |
| self.num_stacks = num_stacks | |
| assert self.num_stacks >= 1 | |
| assert len(stage_channels) == len(stage_blocks) | |
| assert len(stage_channels) > downsample_times | |
| cur_channel = stage_channels[0] | |
| self.stem = nn.Sequential( | |
| ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg), | |
| ResLayer(BasicBlock, 128, 256, 1, stride=2, norm_cfg=norm_cfg)) | |
| self.hourglass_modules = nn.ModuleList([ | |
| HourglassModule(downsample_times, stage_channels, stage_blocks) | |
| for _ in range(num_stacks) | |
| ]) | |
| self.inters = ResLayer( | |
| BasicBlock, | |
| cur_channel, | |
| cur_channel, | |
| num_stacks - 1, | |
| norm_cfg=norm_cfg) | |
| self.conv1x1s = nn.ModuleList([ | |
| ConvModule( | |
| cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) | |
| for _ in range(num_stacks - 1) | |
| ]) | |
| self.out_convs = nn.ModuleList([ | |
| ConvModule( | |
| cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg) | |
| for _ in range(num_stacks) | |
| ]) | |
| self.remap_convs = nn.ModuleList([ | |
| ConvModule( | |
| feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) | |
| for _ in range(num_stacks - 1) | |
| ]) | |
| self.relu = nn.ReLU(inplace=True) | |
| def init_weights(self, pretrained=None): | |
| """Init module weights. | |
| We do nothing in this function because all modules we used | |
| (ConvModule, BasicBlock and etc.) have default initialization, and | |
| currently we don't provide pretrained model of HourglassNet. | |
| Detector's __init__() will call backbone's init_weights() with | |
| pretrained as input, so we keep this function. | |
| """ | |
| # Training Centripetal Model needs to reset parameters for Conv2d | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| m.reset_parameters() | |
| def forward(self, x): | |
| """Forward function.""" | |
| inter_feat = self.stem(x) | |
| out_feats = [] | |
| for ind in range(self.num_stacks): | |
| single_hourglass = self.hourglass_modules[ind] | |
| out_conv = self.out_convs[ind] | |
| hourglass_feat = single_hourglass(inter_feat) | |
| out_feat = out_conv(hourglass_feat) | |
| out_feats.append(out_feat) | |
| if ind < self.num_stacks - 1: | |
| inter_feat = self.conv1x1s[ind]( | |
| inter_feat) + self.remap_convs[ind]( | |
| out_feat) | |
| inter_feat = self.inters[ind](self.relu(inter_feat)) | |
| return out_feats | |