mmaction2 / projects /msg3d /models /msg3d_utils.py
niobures's picture
mmaction2
d3dbf03 verified
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer
from mmengine.model import BaseModule, ModuleList, Sequential
from mmaction.models.utils import unit_tcn
from mmaction.models.utils.graph import k_adjacency, normalize_digraph
class MLP(BaseModule):
def __init__(self,
in_channels,
out_channels,
act_cfg=dict(type='ReLU'),
dropout=0):
super().__init__()
channels = [in_channels] + out_channels
self.layers = ModuleList()
for i in range(1, len(channels)):
if dropout > 1e-3:
self.layers.append(nn.Dropout(p=dropout))
self.layers.append(
nn.Conv2d(channels[i - 1], channels[i], kernel_size=1))
self.layers.append(nn.BatchNorm2d(channels[i]))
if act_cfg:
self.layers.append(build_activation_layer(act_cfg))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class MSGCN(BaseModule):
def __init__(self,
num_scales,
in_channels,
out_channels,
A,
dropout=0,
act_cfg=dict(type='ReLU')):
super().__init__()
self.num_scales = num_scales
A_powers = [
k_adjacency(A, k, with_self=True) for k in range(num_scales)
]
A_powers = np.stack([normalize_digraph(g) for g in A_powers])
# K, V, V
self.register_buffer('A', torch.Tensor(A_powers))
self.PA = nn.Parameter(self.A.clone())
nn.init.uniform_(self.PA, -1e-6, 1e-6)
self.mlp = MLP(
in_channels * num_scales, [out_channels],
dropout=dropout,
act_cfg=act_cfg)
def forward(self, x):
N, C, T, V = x.shape
A = self.A
A = A + self.PA
support = torch.einsum('kvu,nctv->nkctu', A, x)
support = support.reshape(N, self.num_scales * C, T, V)
out = self.mlp(support)
return out
# ! Notice: The implementation of MSTCN in
# MS-G3D is not the same as our implementation.
class MSTCN(BaseModule):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilations=[1, 2, 3, 4],
residual=True,
act_cfg=dict(type='ReLU'),
init_cfg=[
dict(type='Constant', layer='BatchNorm2d', val=1),
dict(type='Kaiming', layer='Conv2d', mode='fan_out')
],
tcn_dropout=0):
super().__init__(init_cfg=init_cfg)
# Multiple branches of temporal convolution
self.num_branches = len(dilations) + 2
branch_channels = out_channels // self.num_branches
branch_channels_rem = out_channels - branch_channels * (
self.num_branches - 1)
if type(kernel_size) == list:
assert len(kernel_size) == len(dilations)
else:
kernel_size = [kernel_size] * len(dilations)
self.branches = ModuleList([
Sequential(
nn.Conv2d(
in_channels, branch_channels, kernel_size=1, padding=0),
nn.BatchNorm2d(branch_channels),
build_activation_layer(act_cfg),
unit_tcn(
branch_channels,
branch_channels,
kernel_size=ks,
stride=stride,
dilation=dilation),
) for ks, dilation in zip(kernel_size, dilations)
])
# Additional Max & 1x1 branch
self.branches.append(
Sequential(
nn.Conv2d(
in_channels, branch_channels, kernel_size=1, padding=0),
nn.BatchNorm2d(branch_channels),
build_activation_layer(act_cfg),
nn.MaxPool2d(
kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)),
nn.BatchNorm2d(branch_channels)))
self.branches.append(
Sequential(
nn.Conv2d(
in_channels,
branch_channels_rem,
kernel_size=1,
padding=0,
stride=(stride, 1)), nn.BatchNorm2d(branch_channels_rem)))
# Residual connection
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = unit_tcn(
in_channels, out_channels, kernel_size=1, stride=stride)
self.act = build_activation_layer(act_cfg)
self.drop = nn.Dropout(tcn_dropout)
def forward(self, x):
# Input dim: (N,C,T,V)
res = self.residual(x)
branch_outs = []
for tempconv in self.branches:
out = tempconv(x)
branch_outs.append(out)
out = torch.cat(branch_outs, dim=1)
out += res
out = self.act(out)
out = self.drop(out)
return out
class UnfoldTemporalWindows(BaseModule):
def __init__(self, window_size, window_stride, window_dilation=1):
super().__init__()
self.window_size = window_size
self.window_stride = window_stride
self.window_dilation = window_dilation
self.padding = (window_size + (window_size - 1) *
(window_dilation - 1) - 1) // 2
self.unfold = nn.Unfold(
kernel_size=(self.window_size, 1),
dilation=(self.window_dilation, 1),
stride=(self.window_stride, 1),
padding=(self.padding, 0))
def forward(self, x):
# Input shape: (N,C,T,V), out: (N,C,T,V*window_size)
N, C, T, V = x.shape
x = self.unfold(x)
# Permute extra channels from window size to the graph dimension;
# -1 for number of windows
x = x.reshape(N, C, self.window_size, -1, V).permute(0, 1, 3, 2,
4).contiguous()
x = x.reshape(N, C, -1, self.window_size * V)
return x
class ST_MSGCN(BaseModule):
def __init__(self,
in_channels,
out_channels,
A,
num_scales,
window_size,
residual=False,
dropout=0,
act_cfg=dict(type='ReLU')):
super().__init__()
self.num_scales = num_scales
self.window_size = window_size
A = self.build_st_graph(A, window_size)
A_scales = [
k_adjacency(A, k, with_self=True) for k in range(num_scales)
]
A_scales = np.stack([normalize_digraph(g) for g in A_scales])
self.register_buffer('A', torch.Tensor(A_scales))
self.V = len(A)
self.PA = nn.Parameter(self.A.clone())
nn.init.uniform_(self.PA, -1e-6, 1e-6)
self.mlp = MLP(
in_channels * num_scales, [out_channels],
dropout=dropout,
act_cfg=act_cfg)
# Residual connection
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels):
self.residual = lambda x: x
else:
self.residual = MLP(in_channels, [out_channels], act_cfg=None)
self.act = build_activation_layer(act_cfg)
def build_st_graph(self, A, window_size):
if not isinstance(A, np.ndarray):
A = A.data.cpu().numpy()
assert len(A.shape) == 2 and A.shape[0] == A.shape[1]
V = len(A)
A_with_I = A + np.eye(V, dtype=A.dtype)
A_large = np.tile(A_with_I, (window_size, window_size)).copy()
return A_large
def forward(self, x):
N, C, T, V = x.shape # T = number of windows, V = self.V * window_size
A = self.A + self.PA
# Perform Graph Convolution
res = self.residual(x)
agg = torch.einsum('kvu,nctv->nkctu', A, x)
agg = agg.reshape(N, self.num_scales * C, T, V)
out = self.mlp(agg)
if res == 0:
return self.act(out)
else:
return self.act(out + res)
class MSG3DBlock(BaseModule):
def __init__(self,
in_channels,
out_channels,
A,
num_scales,
window_size,
window_stride,
window_dilation,
embed_factor=1,
activation='relu'):
super().__init__()
self.window_size = window_size
self.out_channels = out_channels
self.embed_channels_in = out_channels // embed_factor
self.embed_channels_out = out_channels // embed_factor
if embed_factor == 1:
self.in1x1 = nn.Identity()
self.embed_channels_in = self.embed_channels_out = in_channels
# The first STGC block changes channels right away;
# others change at collapse
if in_channels == 3:
self.embed_channels_out = out_channels
else:
self.in1x1 = MLP(in_channels, [self.embed_channels_in])
self.gcn3d = Sequential(
UnfoldTemporalWindows(window_size, window_stride, window_dilation),
ST_MSGCN(
in_channels=self.embed_channels_in,
out_channels=self.embed_channels_out,
A=A,
num_scales=num_scales,
window_size=window_size))
self.out_conv = nn.Conv3d(
self.embed_channels_out,
out_channels,
kernel_size=(1, self.window_size, 1))
self.out_bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
N, _, T, V = x.shape
x = self.in1x1(x)
# Construct temporal windows and apply MS-GCN
x = self.gcn3d(x)
# Collapse the window dimension
x = x.reshape(N, self.embed_channels_out, -1, self.window_size, V)
x = self.out_conv(x).squeeze(dim=3)
x = self.out_bn(x)
# no activation
return x
class MW_MSG3DBlock(BaseModule):
def __init__(self,
in_channels,
out_channels,
A,
num_scales,
window_sizes=[3, 5],
window_stride=1,
window_dilations=[1, 1]):
super().__init__()
self.gcn3d = ModuleList([
MSG3DBlock(in_channels, out_channels, A, num_scales, window_size,
window_stride, window_dilation) for window_size,
window_dilation in zip(window_sizes, window_dilations)
])
def forward(self, x):
out_sum = 0
for gcn3d in self.gcn3d:
out_sum += gcn3d(x)
return out_sum