Spaces:
Build error
Build error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, | |
| constant_init, normal_init) | |
| from mmpose.models.builder import HEADS, build_loss | |
| from mmpose.models.utils.ops import resize | |
| class DeconvHead(nn.Module): | |
| """Simple deconv head. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| num_deconv_layers (int): Number of deconv layers. | |
| num_deconv_layers should >= 0. Note that 0 means | |
| no deconv layers. | |
| num_deconv_filters (list|tuple): Number of filters. | |
| If num_deconv_layers > 0, the length of | |
| num_deconv_kernels (list|tuple): Kernel sizes. | |
| in_index (int|Sequence[int]): Input feature index. Default: 0 | |
| input_transform (str|None): Transformation type of input features. | |
| Options: 'resize_concat', 'multiple_select', None. | |
| Default: None. | |
| - 'resize_concat': Multiple feature maps will be resized to the | |
| same size as the first one and then concat together. | |
| Usually used in FCN head of HRNet. | |
| - 'multiple_select': Multiple feature maps will be bundle into | |
| a list and passed into decode head. | |
| - None: Only one select feature map is allowed. | |
| align_corners (bool): align_corners argument of F.interpolate. | |
| Default: False. | |
| loss_keypoint (dict): Config for loss. Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels=3, | |
| out_channels=17, | |
| num_deconv_layers=3, | |
| num_deconv_filters=(256, 256, 256), | |
| num_deconv_kernels=(4, 4, 4), | |
| extra=None, | |
| in_index=0, | |
| input_transform=None, | |
| align_corners=False, | |
| loss_keypoint=None): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.loss = build_loss(loss_keypoint) | |
| self._init_inputs(in_channels, in_index, input_transform) | |
| self.in_index = in_index | |
| self.align_corners = align_corners | |
| if extra is not None and not isinstance(extra, dict): | |
| raise TypeError('extra should be dict or None.') | |
| if num_deconv_layers > 0: | |
| self.deconv_layers = self._make_deconv_layer( | |
| num_deconv_layers, | |
| num_deconv_filters, | |
| num_deconv_kernels, | |
| ) | |
| elif num_deconv_layers == 0: | |
| self.deconv_layers = nn.Identity() | |
| else: | |
| raise ValueError( | |
| f'num_deconv_layers ({num_deconv_layers}) should >= 0.') | |
| identity_final_layer = False | |
| if extra is not None and 'final_conv_kernel' in extra: | |
| assert extra['final_conv_kernel'] in [0, 1, 3] | |
| if extra['final_conv_kernel'] == 3: | |
| padding = 1 | |
| elif extra['final_conv_kernel'] == 1: | |
| padding = 0 | |
| else: | |
| # 0 for Identity mapping. | |
| identity_final_layer = True | |
| kernel_size = extra['final_conv_kernel'] | |
| else: | |
| kernel_size = 1 | |
| padding = 0 | |
| if identity_final_layer: | |
| self.final_layer = nn.Identity() | |
| else: | |
| conv_channels = num_deconv_filters[ | |
| -1] if num_deconv_layers > 0 else self.in_channels | |
| layers = [] | |
| if extra is not None: | |
| num_conv_layers = extra.get('num_conv_layers', 0) | |
| num_conv_kernels = extra.get('num_conv_kernels', | |
| [1] * num_conv_layers) | |
| for i in range(num_conv_layers): | |
| layers.append( | |
| build_conv_layer( | |
| dict(type='Conv2d'), | |
| in_channels=conv_channels, | |
| out_channels=conv_channels, | |
| kernel_size=num_conv_kernels[i], | |
| stride=1, | |
| padding=(num_conv_kernels[i] - 1) // 2)) | |
| layers.append( | |
| build_norm_layer(dict(type='BN'), conv_channels)[1]) | |
| layers.append(nn.ReLU(inplace=True)) | |
| layers.append( | |
| build_conv_layer( | |
| cfg=dict(type='Conv2d'), | |
| in_channels=conv_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=padding)) | |
| if len(layers) > 1: | |
| self.final_layer = nn.Sequential(*layers) | |
| else: | |
| self.final_layer = layers[0] | |
| def _init_inputs(self, in_channels, in_index, input_transform): | |
| """Check and initialize input transforms. | |
| The in_channels, in_index and input_transform must match. | |
| Specifically, when input_transform is None, only single feature map | |
| will be selected. So in_channels and in_index must be of type int. | |
| When input_transform is not None, in_channels and in_index must be | |
| list or tuple, with the same length. | |
| Args: | |
| in_channels (int|Sequence[int]): Input channels. | |
| in_index (int|Sequence[int]): Input feature index. | |
| input_transform (str|None): Transformation type of input features. | |
| Options: 'resize_concat', 'multiple_select', None. | |
| - 'resize_concat': Multiple feature maps will be resize to the | |
| same size as first one and than concat together. | |
| Usually used in FCN head of HRNet. | |
| - 'multiple_select': Multiple feature maps will be bundle into | |
| a list and passed into decode head. | |
| - None: Only one select feature map is allowed. | |
| """ | |
| if input_transform is not None: | |
| assert input_transform in ['resize_concat', 'multiple_select'] | |
| self.input_transform = input_transform | |
| self.in_index = in_index | |
| if input_transform is not None: | |
| assert isinstance(in_channels, (list, tuple)) | |
| assert isinstance(in_index, (list, tuple)) | |
| assert len(in_channels) == len(in_index) | |
| if input_transform == 'resize_concat': | |
| self.in_channels = sum(in_channels) | |
| else: | |
| self.in_channels = in_channels | |
| else: | |
| assert isinstance(in_channels, int) | |
| assert isinstance(in_index, int) | |
| self.in_channels = in_channels | |
| def _transform_inputs(self, inputs): | |
| """Transform inputs for decoder. | |
| Args: | |
| inputs (list[Tensor] | Tensor): multi-level img features. | |
| Returns: | |
| Tensor: The transformed inputs | |
| """ | |
| if not isinstance(inputs, list): | |
| return inputs | |
| if self.input_transform == 'resize_concat': | |
| inputs = [inputs[i] for i in self.in_index] | |
| upsampled_inputs = [ | |
| resize( | |
| input=x, | |
| size=inputs[0].shape[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) for x in inputs | |
| ] | |
| inputs = torch.cat(upsampled_inputs, dim=1) | |
| elif self.input_transform == 'multiple_select': | |
| inputs = [inputs[i] for i in self.in_index] | |
| else: | |
| inputs = inputs[self.in_index] | |
| return inputs | |
| def _make_deconv_layer(self, num_layers, num_filters, num_kernels): | |
| """Make deconv layers.""" | |
| if num_layers != len(num_filters): | |
| error_msg = f'num_layers({num_layers}) ' \ | |
| f'!= length of num_filters({len(num_filters)})' | |
| raise ValueError(error_msg) | |
| if num_layers != len(num_kernels): | |
| error_msg = f'num_layers({num_layers}) ' \ | |
| f'!= length of num_kernels({len(num_kernels)})' | |
| raise ValueError(error_msg) | |
| layers = [] | |
| for i in range(num_layers): | |
| kernel, padding, output_padding = \ | |
| self._get_deconv_cfg(num_kernels[i]) | |
| planes = num_filters[i] | |
| layers.append( | |
| build_upsample_layer( | |
| dict(type='deconv'), | |
| in_channels=self.in_channels, | |
| out_channels=planes, | |
| kernel_size=kernel, | |
| stride=2, | |
| padding=padding, | |
| output_padding=output_padding, | |
| bias=False)) | |
| layers.append(nn.BatchNorm2d(planes)) | |
| layers.append(nn.ReLU(inplace=True)) | |
| self.in_channels = planes | |
| return nn.Sequential(*layers) | |
| def _get_deconv_cfg(deconv_kernel): | |
| """Get configurations for deconv layers.""" | |
| if deconv_kernel == 4: | |
| padding = 1 | |
| output_padding = 0 | |
| elif deconv_kernel == 3: | |
| padding = 1 | |
| output_padding = 1 | |
| elif deconv_kernel == 2: | |
| padding = 0 | |
| output_padding = 0 | |
| else: | |
| raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') | |
| return deconv_kernel, padding, output_padding | |
| def get_loss(self, outputs, targets, masks): | |
| """Calculate bottom-up masked mse loss. | |
| Note: | |
| - batch_size: N | |
| - num_channels: C | |
| - heatmaps height: H | |
| - heatmaps weight: W | |
| Args: | |
| outputs (List(torch.Tensor[N,C,H,W])): Multi-scale outputs. | |
| targets (List(torch.Tensor[N,C,H,W])): Multi-scale targets. | |
| masks (List(torch.Tensor[N,H,W])): Masks of multi-scale targets. | |
| """ | |
| losses = dict() | |
| for idx in range(len(targets)): | |
| if 'loss' not in losses: | |
| losses['loss'] = self.loss(outputs[idx], targets[idx], | |
| masks[idx]) | |
| else: | |
| losses['loss'] += self.loss(outputs[idx], targets[idx], | |
| masks[idx]) | |
| return losses | |
| def forward(self, x): | |
| """Forward function.""" | |
| x = self._transform_inputs(x) | |
| final_outputs = [] | |
| x = self.deconv_layers(x) | |
| y = self.final_layer(x) | |
| final_outputs.append(y) | |
| return final_outputs | |
| def init_weights(self): | |
| """Initialize model weights.""" | |
| for _, m in self.deconv_layers.named_modules(): | |
| if isinstance(m, nn.ConvTranspose2d): | |
| normal_init(m, std=0.001) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| constant_init(m, 1) | |
| for m in self.final_layer.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| normal_init(m, std=0.001, bias=0) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| constant_init(m, 1) | |