dccrn-demo / utils /complexnn.py
Ada312's picture
Upload 8 files
310317c verified
raw
history blame
15.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def get_casual_padding1d():
pass
def get_casual_padding2d():
pass
class cPReLU(nn.Module):
def __init__(self, complex_axis=1):
super(cPReLU, self).__init__()
self.r_prelu = nn.PReLU()
self.i_prelu = nn.PReLU()
self.complex_axis = complex_axis
def forward(self, inputs):
real, imag = torch.chunk(inputs, 2, self.complex_axis)
real = self.r_prelu(real)
imag = self.i_prelu(imag)
return torch.cat([real, imag], self.complex_axis)
class NavieComplexLSTM(nn.Module):
def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False):
super(NavieComplexLSTM, self).__init__()
self.input_dim = input_size // 2
self.rnn_units = hidden_size // 2
self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
batch_first=False)
self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
batch_first=False)
if bidirectional:
bidirectional = 2
else:
bidirectional = 1
if projection_dim is not None:
self.projection_dim = projection_dim // 2
self.r_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
self.i_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
else:
self.projection_dim = None
def forward(self, inputs):
if isinstance(inputs, list):
real, imag = inputs
elif isinstance(inputs, torch.Tensor):
real, imag = torch.chunk(inputs, -1)
r2r_out = self.real_lstm(real)[0]
r2i_out = self.imag_lstm(real)[0]
i2r_out = self.real_lstm(imag)[0]
i2i_out = self.imag_lstm(imag)[0]
real_out = r2r_out - i2i_out
imag_out = i2r_out + r2i_out
if self.projection_dim is not None:
real_out = self.r_trans(real_out)
imag_out = self.i_trans(imag_out)
# print(real_out.shape,imag_out.shape)
return [real_out, imag_out]
def flatten_parameters(self):
self.imag_lstm.flatten_parameters()
self.real_lstm.flatten_parameters()
def complex_cat(inputs, axis):
real, imag = [], []
for idx, data in enumerate(inputs):
r, i = torch.chunk(data, 2, axis)
real.append(r)
imag.append(i)
real = torch.cat(real, axis)
imag = torch.cat(imag, axis)
outputs = torch.cat([real, imag], axis)
return outputs
class ComplexConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
dilation=1,
groups=1,
causal=True,
complex_axis=1,
):
'''
in_channels: real+imag
out_channels: real+imag
kernel_size : input [B,C,D,T] kernel size in [D,T]
padding : input [B,C,D,T] padding in [D,T]
causal: if causal, will padding time dimension's left side,
otherwise both
'''
super(ComplexConv2d, self).__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.causal = causal
self.groups = groups
self.dilation = dilation
self.complex_axis = complex_axis
self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
nn.init.normal_(self.real_conv.weight.data, std=0.05)
nn.init.normal_(self.imag_conv.weight.data, std=0.05)
nn.init.constant_(self.real_conv.bias, 0.)
nn.init.constant_(self.imag_conv.bias, 0.)
def forward(self, inputs):
if self.padding[1] != 0 and self.causal:
inputs = F.pad(inputs, [self.padding[1], 0, 0, 0])
else:
inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0])
if self.complex_axis == 0:
real = self.real_conv(inputs)
imag = self.imag_conv(inputs)
real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
else:
if isinstance(inputs, torch.Tensor):
real, imag = torch.chunk(inputs, 2, self.complex_axis)
real2real = self.real_conv(real, )
imag2imag = self.imag_conv(imag, )
real2imag = self.imag_conv(real)
imag2real = self.real_conv(imag)
real = real2real - imag2imag
imag = real2imag + imag2real
out = torch.cat([real, imag], self.complex_axis)
return out
class ComplexConvTranspose2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
output_padding=(0, 0),
causal=False,
complex_axis=1,
groups=1
):
'''
in_channels: real+imag
out_channels: real+imag
'''
super(ComplexConvTranspose2d, self).__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.output_padding = output_padding
self.groups = groups
self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
padding=self.padding, output_padding=output_padding, groups=self.groups)
self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
padding=self.padding, output_padding=output_padding, groups=self.groups)
self.complex_axis = complex_axis
nn.init.normal_(self.real_conv.weight, std=0.05)
nn.init.normal_(self.imag_conv.weight, std=0.05)
nn.init.constant_(self.real_conv.bias, 0.)
nn.init.constant_(self.imag_conv.bias, 0.)
def forward(self, inputs):
if isinstance(inputs, torch.Tensor):
real, imag = torch.chunk(inputs, 2, self.complex_axis)
elif isinstance(inputs, tuple) or isinstance(inputs, list):
real = inputs[0]
imag = inputs[1]
if self.complex_axis == 0:
real = self.real_conv(inputs)
imag = self.imag_conv(inputs)
real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
else:
if isinstance(inputs, torch.Tensor):
real, imag = torch.chunk(inputs, 2, self.complex_axis)
real2real = self.real_conv(real, )
imag2imag = self.imag_conv(imag, )
real2imag = self.imag_conv(real)
imag2real = self.real_conv(imag)
real = real2real - imag2imag
imag = real2imag + imag2real
out = torch.cat([real, imag], self.complex_axis)
return out
# Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch
# from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55
class ComplexBatchNorm(torch.nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, complex_axis=1):
super(ComplexBatchNorm, self).__init__()
self.num_features = num_features // 2
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.complex_axis = complex_axis
if self.affine:
self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Br = torch.nn.Parameter(torch.Tensor(self.num_features))
self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features))
else:
self.register_parameter('Wrr', None)
self.register_parameter('Wri', None)
self.register_parameter('Wii', None)
self.register_parameter('Br', None)
self.register_parameter('Bi', None)
if self.track_running_stats:
self.register_buffer('RMr', torch.zeros(self.num_features))
self.register_buffer('RMi', torch.zeros(self.num_features))
self.register_buffer('RVrr', torch.ones(self.num_features))
self.register_buffer('RVri', torch.zeros(self.num_features))
self.register_buffer('RVii', torch.ones(self.num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_parameter('RMr', None)
self.register_parameter('RMi', None)
self.register_parameter('RVrr', None)
self.register_parameter('RVri', None)
self.register_parameter('RVii', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.RMr.zero_()
self.RMi.zero_()
self.RVrr.fill_(1)
self.RVri.zero_()
self.RVii.fill_(1)
self.num_batches_tracked.zero_()
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
self.Br.data.zero_()
self.Bi.data.zero_()
self.Wrr.data.fill_(1)
self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
self.Wii.data.fill_(1)
def _check_input_dim(self, xr, xi):
assert (xr.shape == xi.shape)
assert (xr.size(1) == self.num_features)
def forward(self, inputs):
# self._check_input_dim(xr, xi)
xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
#
# NOTE: The precise meaning of the "training flag" is:
# True: Normalize using batch statistics, update running statistics
# if they are being collected.
# False: Normalize using running statistics, ignore batch statistics.
#
training = self.training or not self.track_running_stats
redux = [i for i in reversed(range(xr.dim())) if i != 1]
vdim = [1] * xr.dim()
vdim[1] = xr.size(1)
#
# Mean M Computation and Centering
#
# Includes running mean update if training and running.
#
if training:
Mr, Mi = xr, xi
for d in redux:
Mr = Mr.mean(d, keepdim=True)
Mi = Mi.mean(d, keepdim=True)
if self.track_running_stats:
self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
else:
Mr = self.RMr.view(vdim)
Mi = self.RMi.view(vdim)
xr, xi = xr - Mr, xi - Mi
#
# Variance Matrix V Computation
#
# Includes epsilon numerical stabilizer/Tikhonov regularizer.
# Includes running variance update if training and running.
#
if training:
Vrr = xr * xr
Vri = xr * xi
Vii = xi * xi
for d in redux:
Vrr = Vrr.mean(d, keepdim=True)
Vri = Vri.mean(d, keepdim=True)
Vii = Vii.mean(d, keepdim=True)
if self.track_running_stats:
self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
else:
Vrr = self.RVrr.view(vdim)
Vri = self.RVri.view(vdim)
Vii = self.RVii.view(vdim)
Vrr = Vrr + self.eps
Vri = Vri
Vii = Vii + self.eps
#
# Matrix Inverse Square Root U = V^-0.5
#
# sqrt of a 2x2 matrix,
# - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
tau = Vrr + Vii
delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri)
s = delta.sqrt()
t = (tau + 2 * s).sqrt()
# matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
rst = (s * t).reciprocal()
Urr = (s + Vii) * rst
Uii = (s + Vrr) * rst
Uri = (- Vri) * rst
#
# Optionally left-multiply U by affine weights W to produce combined
# weights Z, left-multiply the inputs by Z, then optionally bias them.
#
# y = Zx + B
# y = WUx + B
# y = [Wrr Wri][Urr Uri] [xr] + [Br]
# [Wir Wii][Uir Uii] [xi] [Bi]
#
if self.affine:
Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
Zrr = (Wrr * Urr) + (Wri * Uri)
Zri = (Wrr * Uri) + (Wri * Uii)
Zir = (Wri * Urr) + (Wii * Uri)
Zii = (Wri * Uri) + (Wii * Uii)
else:
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
yr = (Zrr * xr) + (Zri * xi)
yi = (Zir * xr) + (Zii * xi)
if self.affine:
yr = yr + self.Br.view(vdim)
yi = yi + self.Bi.view(vdim)
outputs = torch.cat([yr, yi], self.complex_axis)
return outputs
def extra_repr(self):
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
'track_running_stats={track_running_stats}'.format(**self.__dict__)
def complex_cat(inputs, axis):
real, imag = [], []
for idx, data in enumerate(inputs):
r, i = torch.chunk(data, 2, axis)
real.append(r)
imag.append(i)
real = torch.cat(real, axis)
imag = torch.cat(imag, axis)
outputs = torch.cat([real, imag], axis)
return outputs
if __name__ == '__main__':
import dc_crn7
torch.manual_seed(20)
onet1 = dc_crn7.ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1))
onet2 = dc_crn7.ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1))
inputs = torch.randn([1, 12, 12, 10])
# print(onet1.real_kernel[0,0,0,0])
nnet1 = ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1), causal=True)
# print(nnet1.real_conv.weight[0,0,0,0])
nnet2 = ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1))
print(torch.mean(nnet1(inputs) - onet1(inputs)))