File size: 6,675 Bytes
51f6859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
from mmcv.runner import ModuleList

from ..builder import HEADS
from ..utils import ConvUpsample
from .base_semantic_head import BaseSemanticHead


@HEADS.register_module()
class PanopticFPNHead(BaseSemanticHead):
    """PanopticFPNHead used in Panoptic FPN.

    In this head, the number of output channels is ``num_stuff_classes
    + 1``, including all stuff classes and one thing class. The stuff
    classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
    thing classes will be merged to ``num_stuff_classes``-th channel.

    Arg:
        num_things_classes (int): Number of thing classes. Default: 80.
        num_stuff_classes (int): Number of stuff classes. Default: 53.
        num_classes (int): Number of classes, including all stuff
            classes and one thing class. This argument is deprecated,
            please use ``num_things_classes`` and ``num_stuff_classes``.
            The module will automatically infer the num_classes by
            ``num_stuff_classes + 1``.
        in_channels (int): Number of channels in the input feature
            map.
        inner_channels (int): Number of channels in inner features.
        start_level (int): The start level of the input features
            used in PanopticFPN.
        end_level (int): The end level of the used features, the
            ``end_level``-th layer will not be used.
        fg_range (tuple): Range of the foreground classes. It starts
            from ``0`` to ``num_things_classes-1``. Deprecated, please use
             ``num_things_classes`` directly.
        bg_range (tuple): Range of the background classes. It starts
            from ``num_things_classes`` to ``num_things_classes +
            num_stuff_classes - 1``. Deprecated, please use
            ``num_stuff_classes`` and ``num_things_classes`` directly.
        conv_cfg (dict): Dictionary to construct and config
            conv layer. Default: None.
        norm_cfg (dict): Dictionary to construct and config norm layer.
            Use ``GN`` by default.
        init_cfg (dict or list[dict], optional): Initialization config dict.
        loss_seg (dict): the loss of the semantic head.
    """

    def __init__(self,
                 num_things_classes=80,
                 num_stuff_classes=53,
                 num_classes=None,
                 in_channels=256,
                 inner_channels=128,
                 start_level=0,
                 end_level=4,
                 fg_range=None,
                 bg_range=None,
                 conv_cfg=None,
                 norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
                 init_cfg=None,
                 loss_seg=dict(
                     type='CrossEntropyLoss', ignore_index=-1,
                     loss_weight=1.0)):
        if num_classes is not None:
            warnings.warn(
                '`num_classes` is deprecated now, please set '
                '`num_stuff_classes` directly, the `num_classes` will be '
                'set to `num_stuff_classes + 1`')
            # num_classes = num_stuff_classes + 1 for PanopticFPN.
            assert num_classes == num_stuff_classes + 1
        super(PanopticFPNHead, self).__init__(num_stuff_classes + 1, init_cfg,
                                              loss_seg)
        self.num_things_classes = num_things_classes
        self.num_stuff_classes = num_stuff_classes
        if fg_range is not None and bg_range is not None:
            self.fg_range = fg_range
            self.bg_range = bg_range
            self.num_things_classes = fg_range[1] - fg_range[0] + 1
            self.num_stuff_classes = bg_range[1] - bg_range[0] + 1
            warnings.warn(
                '`fg_range` and `bg_range` are deprecated now, '
                f'please use `num_things_classes`={self.num_things_classes} '
                f'and `num_stuff_classes`={self.num_stuff_classes} instead.')

        # Used feature layers are [start_level, end_level)
        self.start_level = start_level
        self.end_level = end_level
        self.num_stages = end_level - start_level
        self.inner_channels = inner_channels

        self.conv_upsample_layers = ModuleList()
        for i in range(start_level, end_level):
            self.conv_upsample_layers.append(
                ConvUpsample(
                    in_channels,
                    inner_channels,
                    num_layers=i if i > 0 else 1,
                    num_upsample=i if i > 0 else 0,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                ))
        self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)

    def _set_things_to_void(self, gt_semantic_seg):
        """Merge thing classes to one class.

        In PanopticFPN, the background labels will be reset from `0` to
        `self.num_stuff_classes-1`, the foreground labels will be merged to
        `self.num_stuff_classes`-th channel.
        """
        gt_semantic_seg = gt_semantic_seg.int()
        fg_mask = gt_semantic_seg < self.num_things_classes
        bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
            gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)

        new_gt_seg = torch.clone(gt_semantic_seg)
        new_gt_seg = torch.where(bg_mask,
                                 gt_semantic_seg - self.num_things_classes,
                                 new_gt_seg)
        new_gt_seg = torch.where(fg_mask,
                                 fg_mask.int() * self.num_stuff_classes,
                                 new_gt_seg)
        return new_gt_seg

    def loss(self, seg_preds, gt_semantic_seg):
        """The loss of PanopticFPN head.

        Things classes will be merged to one class in PanopticFPN.
        """
        gt_semantic_seg = self._set_things_to_void(gt_semantic_seg)
        return super().loss(seg_preds, gt_semantic_seg)

    def init_weights(self):
        super().init_weights()
        nn.init.normal_(self.conv_logits.weight.data, 0, 0.01)
        self.conv_logits.bias.data.zero_()

    def forward(self, x):
        # the number of subnets must be not more than
        # the length of features.
        assert self.num_stages <= len(x)

        feats = []
        for i, layer in enumerate(self.conv_upsample_layers):
            f = layer(x[self.start_level + i])
            feats.append(f)

        feats = torch.sum(torch.stack(feats, dim=0), dim=0)
        seg_preds = self.conv_logits(feats)
        out = dict(seg_preds=seg_preds, feats=feats)
        return out