Spaces:
Build error
Build error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from ..builder import HEADS | |
| from .deconv_head import DeconvHead | |
| class AESimpleHead(DeconvHead): | |
| """Associative embedding simple head. | |
| paper ref: Alejandro Newell et al. "Associative | |
| Embedding: End-to-end Learning for Joint Detection | |
| and Grouping" | |
| Args: | |
| in_channels (int): Number of input channels. | |
| num_joints (int): Number of joints. | |
| num_deconv_layers (int): Number of deconv layers. | |
| num_deconv_layers should >= 0. Note that 0 means | |
| no deconv layers. | |
| num_deconv_filters (list|tuple): Number of filters. | |
| If num_deconv_layers > 0, the length of | |
| num_deconv_kernels (list|tuple): Kernel sizes. | |
| tag_per_joint (bool): If tag_per_joint is True, | |
| the dimension of tags equals to num_joints, | |
| else the dimension of tags is 1. Default: True | |
| with_ae_loss (list[bool]): Option to use ae loss or not. | |
| loss_keypoint (dict): Config for loss. Default: None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| num_joints, | |
| num_deconv_layers=3, | |
| num_deconv_filters=(256, 256, 256), | |
| num_deconv_kernels=(4, 4, 4), | |
| tag_per_joint=True, | |
| with_ae_loss=None, | |
| extra=None, | |
| loss_keypoint=None): | |
| dim_tag = num_joints if tag_per_joint else 1 | |
| if with_ae_loss[0]: | |
| out_channels = num_joints + dim_tag | |
| else: | |
| out_channels = num_joints | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| num_deconv_layers=num_deconv_layers, | |
| num_deconv_filters=num_deconv_filters, | |
| num_deconv_kernels=num_deconv_kernels, | |
| extra=extra, | |
| loss_keypoint=loss_keypoint) | |
| def get_loss(self, outputs, targets, masks, joints): | |
| """Calculate bottom-up keypoint loss. | |
| Note: | |
| - batch_size: N | |
| - num_keypoints: K | |
| - num_outputs: O | |
| - heatmaps height: H | |
| - heatmaps weight: W | |
| Args: | |
| outputs (list(torch.Tensor[N,K,H,W])): Multi-scale output heatmaps. | |
| targets (List(torch.Tensor[N,K,H,W])): Multi-scale target heatmaps. | |
| masks (List(torch.Tensor[N,H,W])): Masks of multi-scale target | |
| heatmaps | |
| joints(List(torch.Tensor[N,M,K,2])): Joints of multi-scale target | |
| heatmaps for ae loss | |
| """ | |
| losses = dict() | |
| heatmaps_losses, push_losses, pull_losses = self.loss( | |
| outputs, targets, masks, joints) | |
| for idx in range(len(targets)): | |
| if heatmaps_losses[idx] is not None: | |
| heatmaps_loss = heatmaps_losses[idx].mean(dim=0) | |
| if 'heatmap_loss' not in losses: | |
| losses['heatmap_loss'] = heatmaps_loss | |
| else: | |
| losses['heatmap_loss'] += heatmaps_loss | |
| if push_losses[idx] is not None: | |
| push_loss = push_losses[idx].mean(dim=0) | |
| if 'push_loss' not in losses: | |
| losses['push_loss'] = push_loss | |
| else: | |
| losses['push_loss'] += push_loss | |
| if pull_losses[idx] is not None: | |
| pull_loss = pull_losses[idx].mean(dim=0) | |
| if 'pull_loss' not in losses: | |
| losses['pull_loss'] = pull_loss | |
| else: | |
| losses['pull_loss'] += pull_loss | |
| return losses | |