|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import build_activation_layer
|
|
from mmengine.model import BaseModule, ModuleList, Sequential
|
|
|
|
from mmaction.models.utils import unit_tcn
|
|
from mmaction.models.utils.graph import k_adjacency, normalize_digraph
|
|
|
|
|
|
class MLP(BaseModule):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
act_cfg=dict(type='ReLU'),
|
|
dropout=0):
|
|
super().__init__()
|
|
channels = [in_channels] + out_channels
|
|
self.layers = ModuleList()
|
|
for i in range(1, len(channels)):
|
|
if dropout > 1e-3:
|
|
self.layers.append(nn.Dropout(p=dropout))
|
|
self.layers.append(
|
|
nn.Conv2d(channels[i - 1], channels[i], kernel_size=1))
|
|
self.layers.append(nn.BatchNorm2d(channels[i]))
|
|
if act_cfg:
|
|
self.layers.append(build_activation_layer(act_cfg))
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class MSGCN(BaseModule):
|
|
|
|
def __init__(self,
|
|
num_scales,
|
|
in_channels,
|
|
out_channels,
|
|
A,
|
|
dropout=0,
|
|
act_cfg=dict(type='ReLU')):
|
|
super().__init__()
|
|
self.num_scales = num_scales
|
|
|
|
A_powers = [
|
|
k_adjacency(A, k, with_self=True) for k in range(num_scales)
|
|
]
|
|
A_powers = np.stack([normalize_digraph(g) for g in A_powers])
|
|
|
|
|
|
self.register_buffer('A', torch.Tensor(A_powers))
|
|
self.PA = nn.Parameter(self.A.clone())
|
|
nn.init.uniform_(self.PA, -1e-6, 1e-6)
|
|
|
|
self.mlp = MLP(
|
|
in_channels * num_scales, [out_channels],
|
|
dropout=dropout,
|
|
act_cfg=act_cfg)
|
|
|
|
def forward(self, x):
|
|
N, C, T, V = x.shape
|
|
A = self.A
|
|
A = A + self.PA
|
|
|
|
support = torch.einsum('kvu,nctv->nkctu', A, x)
|
|
support = support.reshape(N, self.num_scales * C, T, V)
|
|
out = self.mlp(support)
|
|
return out
|
|
|
|
|
|
|
|
|
|
class MSTCN(BaseModule):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
dilations=[1, 2, 3, 4],
|
|
residual=True,
|
|
act_cfg=dict(type='ReLU'),
|
|
init_cfg=[
|
|
dict(type='Constant', layer='BatchNorm2d', val=1),
|
|
dict(type='Kaiming', layer='Conv2d', mode='fan_out')
|
|
],
|
|
tcn_dropout=0):
|
|
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
self.num_branches = len(dilations) + 2
|
|
branch_channels = out_channels // self.num_branches
|
|
branch_channels_rem = out_channels - branch_channels * (
|
|
self.num_branches - 1)
|
|
|
|
if type(kernel_size) == list:
|
|
assert len(kernel_size) == len(dilations)
|
|
else:
|
|
kernel_size = [kernel_size] * len(dilations)
|
|
|
|
self.branches = ModuleList([
|
|
Sequential(
|
|
nn.Conv2d(
|
|
in_channels, branch_channels, kernel_size=1, padding=0),
|
|
nn.BatchNorm2d(branch_channels),
|
|
build_activation_layer(act_cfg),
|
|
unit_tcn(
|
|
branch_channels,
|
|
branch_channels,
|
|
kernel_size=ks,
|
|
stride=stride,
|
|
dilation=dilation),
|
|
) for ks, dilation in zip(kernel_size, dilations)
|
|
])
|
|
|
|
|
|
self.branches.append(
|
|
Sequential(
|
|
nn.Conv2d(
|
|
in_channels, branch_channels, kernel_size=1, padding=0),
|
|
nn.BatchNorm2d(branch_channels),
|
|
build_activation_layer(act_cfg),
|
|
nn.MaxPool2d(
|
|
kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)),
|
|
nn.BatchNorm2d(branch_channels)))
|
|
|
|
self.branches.append(
|
|
Sequential(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
branch_channels_rem,
|
|
kernel_size=1,
|
|
padding=0,
|
|
stride=(stride, 1)), nn.BatchNorm2d(branch_channels_rem)))
|
|
|
|
|
|
if not residual:
|
|
self.residual = lambda x: 0
|
|
elif (in_channels == out_channels) and (stride == 1):
|
|
self.residual = lambda x: x
|
|
else:
|
|
self.residual = unit_tcn(
|
|
in_channels, out_channels, kernel_size=1, stride=stride)
|
|
|
|
self.act = build_activation_layer(act_cfg)
|
|
self.drop = nn.Dropout(tcn_dropout)
|
|
|
|
def forward(self, x):
|
|
|
|
res = self.residual(x)
|
|
branch_outs = []
|
|
for tempconv in self.branches:
|
|
out = tempconv(x)
|
|
branch_outs.append(out)
|
|
|
|
out = torch.cat(branch_outs, dim=1)
|
|
out += res
|
|
out = self.act(out)
|
|
out = self.drop(out)
|
|
return out
|
|
|
|
|
|
class UnfoldTemporalWindows(BaseModule):
|
|
|
|
def __init__(self, window_size, window_stride, window_dilation=1):
|
|
super().__init__()
|
|
self.window_size = window_size
|
|
self.window_stride = window_stride
|
|
self.window_dilation = window_dilation
|
|
|
|
self.padding = (window_size + (window_size - 1) *
|
|
(window_dilation - 1) - 1) // 2
|
|
self.unfold = nn.Unfold(
|
|
kernel_size=(self.window_size, 1),
|
|
dilation=(self.window_dilation, 1),
|
|
stride=(self.window_stride, 1),
|
|
padding=(self.padding, 0))
|
|
|
|
def forward(self, x):
|
|
|
|
N, C, T, V = x.shape
|
|
x = self.unfold(x)
|
|
|
|
|
|
x = x.reshape(N, C, self.window_size, -1, V).permute(0, 1, 3, 2,
|
|
4).contiguous()
|
|
x = x.reshape(N, C, -1, self.window_size * V)
|
|
return x
|
|
|
|
|
|
class ST_MSGCN(BaseModule):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
A,
|
|
num_scales,
|
|
window_size,
|
|
residual=False,
|
|
dropout=0,
|
|
act_cfg=dict(type='ReLU')):
|
|
|
|
super().__init__()
|
|
self.num_scales = num_scales
|
|
self.window_size = window_size
|
|
A = self.build_st_graph(A, window_size)
|
|
|
|
A_scales = [
|
|
k_adjacency(A, k, with_self=True) for k in range(num_scales)
|
|
]
|
|
A_scales = np.stack([normalize_digraph(g) for g in A_scales])
|
|
|
|
self.register_buffer('A', torch.Tensor(A_scales))
|
|
self.V = len(A)
|
|
|
|
self.PA = nn.Parameter(self.A.clone())
|
|
nn.init.uniform_(self.PA, -1e-6, 1e-6)
|
|
|
|
self.mlp = MLP(
|
|
in_channels * num_scales, [out_channels],
|
|
dropout=dropout,
|
|
act_cfg=act_cfg)
|
|
|
|
|
|
if not residual:
|
|
self.residual = lambda x: 0
|
|
elif (in_channels == out_channels):
|
|
self.residual = lambda x: x
|
|
else:
|
|
self.residual = MLP(in_channels, [out_channels], act_cfg=None)
|
|
|
|
self.act = build_activation_layer(act_cfg)
|
|
|
|
def build_st_graph(self, A, window_size):
|
|
if not isinstance(A, np.ndarray):
|
|
A = A.data.cpu().numpy()
|
|
|
|
assert len(A.shape) == 2 and A.shape[0] == A.shape[1]
|
|
V = len(A)
|
|
A_with_I = A + np.eye(V, dtype=A.dtype)
|
|
|
|
A_large = np.tile(A_with_I, (window_size, window_size)).copy()
|
|
return A_large
|
|
|
|
def forward(self, x):
|
|
N, C, T, V = x.shape
|
|
A = self.A + self.PA
|
|
|
|
|
|
res = self.residual(x)
|
|
agg = torch.einsum('kvu,nctv->nkctu', A, x)
|
|
agg = agg.reshape(N, self.num_scales * C, T, V)
|
|
out = self.mlp(agg)
|
|
if res == 0:
|
|
return self.act(out)
|
|
else:
|
|
return self.act(out + res)
|
|
|
|
|
|
class MSG3DBlock(BaseModule):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
A,
|
|
num_scales,
|
|
window_size,
|
|
window_stride,
|
|
window_dilation,
|
|
embed_factor=1,
|
|
activation='relu'):
|
|
|
|
super().__init__()
|
|
self.window_size = window_size
|
|
self.out_channels = out_channels
|
|
self.embed_channels_in = out_channels // embed_factor
|
|
self.embed_channels_out = out_channels // embed_factor
|
|
if embed_factor == 1:
|
|
self.in1x1 = nn.Identity()
|
|
self.embed_channels_in = self.embed_channels_out = in_channels
|
|
|
|
|
|
if in_channels == 3:
|
|
self.embed_channels_out = out_channels
|
|
else:
|
|
self.in1x1 = MLP(in_channels, [self.embed_channels_in])
|
|
|
|
self.gcn3d = Sequential(
|
|
UnfoldTemporalWindows(window_size, window_stride, window_dilation),
|
|
ST_MSGCN(
|
|
in_channels=self.embed_channels_in,
|
|
out_channels=self.embed_channels_out,
|
|
A=A,
|
|
num_scales=num_scales,
|
|
window_size=window_size))
|
|
|
|
self.out_conv = nn.Conv3d(
|
|
self.embed_channels_out,
|
|
out_channels,
|
|
kernel_size=(1, self.window_size, 1))
|
|
self.out_bn = nn.BatchNorm2d(out_channels)
|
|
|
|
def forward(self, x):
|
|
N, _, T, V = x.shape
|
|
x = self.in1x1(x)
|
|
|
|
x = self.gcn3d(x)
|
|
|
|
|
|
x = x.reshape(N, self.embed_channels_out, -1, self.window_size, V)
|
|
x = self.out_conv(x).squeeze(dim=3)
|
|
x = self.out_bn(x)
|
|
|
|
return x
|
|
|
|
|
|
class MW_MSG3DBlock(BaseModule):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
A,
|
|
num_scales,
|
|
window_sizes=[3, 5],
|
|
window_stride=1,
|
|
window_dilations=[1, 1]):
|
|
|
|
super().__init__()
|
|
self.gcn3d = ModuleList([
|
|
MSG3DBlock(in_channels, out_channels, A, num_scales, window_size,
|
|
window_stride, window_dilation) for window_size,
|
|
window_dilation in zip(window_sizes, window_dilations)
|
|
])
|
|
|
|
def forward(self, x):
|
|
out_sum = 0
|
|
for gcn3d in self.gcn3d:
|
|
out_sum += gcn3d(x)
|
|
return out_sum
|
|
|