Prompt-Segment-Anything-Demo / mmdet /models /plugins /msdeformattn_pixel_decoder.py
RockeyCoss
add code files”
51f6859
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init,
normal_init, xavier_init)
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer_sequence)
from mmcv.runner import BaseModule, ModuleList
from mmdet.core.anchor import MlvlPointGenerator
from mmdet.models.utils.transformer import MultiScaleDeformableAttention
@PLUGIN_LAYERS.register_module()
class MSDeformAttnPixelDecoder(BaseModule):
"""Pixel decoder with multi-scale deformable attention.
Args:
in_channels (list[int] | tuple[int]): Number of channels in the
input feature maps.
strides (list[int] | tuple[int]): Output strides of feature from
backbone.
feat_channels (int): Number of channels for feature.
out_channels (int): Number of channels for output.
num_outs (int): Number of output scales.
norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
Defaults to dict(type='GN', num_groups=32).
act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
Defaults to dict(type='ReLU').
encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer
encoder. Defaults to `DetrTransformerEncoder`.
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
transformer encoder position encoding. Defaults to
dict(type='SinePositionalEncoding', num_feats=128,
normalize=True).
init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
"""
def __init__(self,
in_channels=[256, 512, 1024, 2048],
strides=[4, 8, 16, 32],
feat_channels=256,
out_channels=256,
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
feedforward_channels=1024,
ffn_dropout=0.0,
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.strides = strides
self.num_input_levels = len(in_channels)
self.num_encoder_levels = \
encoder.transformerlayers.attn_cfgs.num_levels
assert self.num_encoder_levels >= 1, \
'num_levels in attn_cfgs must be at least one'
input_conv_list = []
# from top to down (low to high resolution)
for i in range(self.num_input_levels - 1,
self.num_input_levels - self.num_encoder_levels - 1,
-1):
input_conv = ConvModule(
in_channels[i],
feat_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=None,
bias=True)
input_conv_list.append(input_conv)
self.input_convs = ModuleList(input_conv_list)
self.encoder = build_transformer_layer_sequence(encoder)
self.postional_encoding = build_positional_encoding(
positional_encoding)
# high resolution to low resolution
self.level_encoding = nn.Embedding(self.num_encoder_levels,
feat_channels)
# fpn-like structure
self.lateral_convs = ModuleList()
self.output_convs = ModuleList()
self.use_bias = norm_cfg is None
# from top to down (low to high resolution)
# fpn for the rest features that didn't pass in encoder
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
-1):
lateral_conv = ConvModule(
in_channels[i],
feat_channels,
kernel_size=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=None)
output_conv = ConvModule(
feat_channels,
feat_channels,
kernel_size=3,
stride=1,
padding=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.lateral_convs.append(lateral_conv)
self.output_convs.append(output_conv)
self.mask_feature = Conv2d(
feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.num_outs = num_outs
self.point_generator = MlvlPointGenerator(strides)
def init_weights(self):
"""Initialize weights."""
for i in range(0, self.num_encoder_levels):
xavier_init(
self.input_convs[i].conv,
gain=1,
bias=0,
distribution='uniform')
for i in range(0, self.num_input_levels - self.num_encoder_levels):
caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
caffe2_xavier_init(self.output_convs[i].conv, bias=0)
caffe2_xavier_init(self.mask_feature, bias=0)
normal_init(self.level_encoding, mean=0, std=1)
for p in self.encoder.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
# init_weights defined in MultiScaleDeformableAttention
for layer in self.encoder.layers:
for attn in layer.attentions:
if isinstance(attn, MultiScaleDeformableAttention):
attn.init_weights()
def forward(self, feats):
"""
Args:
feats (list[Tensor]): Feature maps of each level. Each has
shape of (batch_size, c, h, w).
Returns:
tuple: A tuple containing the following:
- mask_feature (Tensor): shape (batch_size, c, h, w).
- multi_scale_features (list[Tensor]): Multi scale \
features, each in shape (batch_size, c, h, w).
"""
# generate padding mask for each level, for each image
batch_size = feats[0].shape[0]
encoder_input_list = []
padding_mask_list = []
level_positional_encoding_list = []
spatial_shapes = []
reference_points_list = []
for i in range(self.num_encoder_levels):
level_idx = self.num_input_levels - i - 1
feat = feats[level_idx]
feat_projected = self.input_convs[i](feat)
h, w = feat.shape[-2:]
# no padding
padding_mask_resized = feat.new_zeros(
(batch_size, ) + feat.shape[-2:], dtype=torch.bool)
pos_embed = self.postional_encoding(padding_mask_resized)
level_embed = self.level_encoding.weight[i]
level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
# (h_i * w_i, 2)
reference_points = self.point_generator.single_level_grid_priors(
feat.shape[-2:], level_idx, device=feat.device)
# normalize
factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
reference_points = reference_points / factor
# shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
feat_projected = feat_projected.flatten(2).permute(2, 0, 1)
level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)
padding_mask_resized = padding_mask_resized.flatten(1)
encoder_input_list.append(feat_projected)
padding_mask_list.append(padding_mask_resized)
level_positional_encoding_list.append(level_pos_embed)
spatial_shapes.append(feat.shape[-2:])
reference_points_list.append(reference_points)
# shape (batch_size, total_num_query),
# total_num_query=sum([., h_i * w_i,.])
padding_masks = torch.cat(padding_mask_list, dim=1)
# shape (total_num_query, batch_size, c)
encoder_inputs = torch.cat(encoder_input_list, dim=0)
level_positional_encodings = torch.cat(
level_positional_encoding_list, dim=0)
device = encoder_inputs.device
# shape (num_encoder_levels, 2), from low
# resolution to high resolution
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=device)
# shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = torch.cat(reference_points_list, dim=0)
reference_points = reference_points[None, :, None].repeat(
batch_size, 1, self.num_encoder_levels, 1)
valid_radios = reference_points.new_ones(
(batch_size, self.num_encoder_levels, 2))
# shape (num_total_query, batch_size, c)
memory = self.encoder(
query=encoder_inputs,
key=None,
value=None,
query_pos=level_positional_encodings,
key_pos=None,
attn_masks=None,
key_padding_mask=None,
query_key_padding_mask=padding_masks,
spatial_shapes=spatial_shapes,
reference_points=reference_points,
level_start_index=level_start_index,
valid_radios=valid_radios)
# (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)
memory = memory.permute(1, 2, 0)
# from low resolution to high resolution
num_query_per_level = [e[0] * e[1] for e in spatial_shapes]
outs = torch.split(memory, num_query_per_level, dim=-1)
outs = [
x.reshape(batch_size, -1, spatial_shapes[i][0],
spatial_shapes[i][1]) for i, x in enumerate(outs)
]
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
-1):
x = feats[i]
cur_feat = self.lateral_convs[i](x)
y = cur_feat + F.interpolate(
outs[-1],
size=cur_feat.shape[-2:],
mode='bilinear',
align_corners=False)
y = self.output_convs[i](y)
outs.append(y)
multi_scale_features = outs[:self.num_outs]
mask_feature = self.mask_feature(outs[-1])
return mask_feature, multi_scale_features