Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| def get_casual_padding1d(): | |
| pass | |
| def get_casual_padding2d(): | |
| pass | |
| class cPReLU(nn.Module): | |
| def __init__(self, complex_axis=1): | |
| super(cPReLU, self).__init__() | |
| self.r_prelu = nn.PReLU() | |
| self.i_prelu = nn.PReLU() | |
| self.complex_axis = complex_axis | |
| def forward(self, inputs): | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| real = self.r_prelu(real) | |
| imag = self.i_prelu(imag) | |
| return torch.cat([real, imag], self.complex_axis) | |
| class NavieComplexLSTM(nn.Module): | |
| def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False): | |
| super(NavieComplexLSTM, self).__init__() | |
| self.input_dim = input_size // 2 | |
| self.rnn_units = hidden_size // 2 | |
| self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, | |
| batch_first=False) | |
| self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, | |
| batch_first=False) | |
| if bidirectional: | |
| bidirectional = 2 | |
| else: | |
| bidirectional = 1 | |
| if projection_dim is not None: | |
| self.projection_dim = projection_dim // 2 | |
| self.r_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim) | |
| self.i_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim) | |
| else: | |
| self.projection_dim = None | |
| def forward(self, inputs): | |
| if isinstance(inputs, list): | |
| real, imag = inputs | |
| elif isinstance(inputs, torch.Tensor): | |
| real, imag = torch.chunk(inputs, -1) | |
| r2r_out = self.real_lstm(real)[0] | |
| r2i_out = self.imag_lstm(real)[0] | |
| i2r_out = self.real_lstm(imag)[0] | |
| i2i_out = self.imag_lstm(imag)[0] | |
| real_out = r2r_out - i2i_out | |
| imag_out = i2r_out + r2i_out | |
| if self.projection_dim is not None: | |
| real_out = self.r_trans(real_out) | |
| imag_out = self.i_trans(imag_out) | |
| # print(real_out.shape,imag_out.shape) | |
| return [real_out, imag_out] | |
| def flatten_parameters(self): | |
| self.imag_lstm.flatten_parameters() | |
| self.real_lstm.flatten_parameters() | |
| def complex_cat(inputs, axis): | |
| real, imag = [], [] | |
| for idx, data in enumerate(inputs): | |
| r, i = torch.chunk(data, 2, axis) | |
| real.append(r) | |
| imag.append(i) | |
| real = torch.cat(real, axis) | |
| imag = torch.cat(imag, axis) | |
| outputs = torch.cat([real, imag], axis) | |
| return outputs | |
| class ComplexConv2d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=(1, 1), | |
| stride=(1, 1), | |
| padding=(0, 0), | |
| dilation=1, | |
| groups=1, | |
| causal=True, | |
| complex_axis=1, | |
| ): | |
| ''' | |
| in_channels: real+imag | |
| out_channels: real+imag | |
| kernel_size : input [B,C,D,T] kernel size in [D,T] | |
| padding : input [B,C,D,T] padding in [D,T] | |
| causal: if causal, will padding time dimension's left side, | |
| otherwise both | |
| ''' | |
| super(ComplexConv2d, self).__init__() | |
| self.in_channels = in_channels // 2 | |
| self.out_channels = out_channels // 2 | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.padding = padding | |
| self.causal = causal | |
| self.groups = groups | |
| self.dilation = dilation | |
| self.complex_axis = complex_axis | |
| self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, | |
| padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) | |
| self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, | |
| padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) | |
| nn.init.normal_(self.real_conv.weight.data, std=0.05) | |
| nn.init.normal_(self.imag_conv.weight.data, std=0.05) | |
| nn.init.constant_(self.real_conv.bias, 0.) | |
| nn.init.constant_(self.imag_conv.bias, 0.) | |
| def forward(self, inputs): | |
| if self.padding[1] != 0 and self.causal: | |
| inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) | |
| else: | |
| inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0]) | |
| if self.complex_axis == 0: | |
| real = self.real_conv(inputs) | |
| imag = self.imag_conv(inputs) | |
| real2real, imag2real = torch.chunk(real, 2, self.complex_axis) | |
| real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) | |
| else: | |
| if isinstance(inputs, torch.Tensor): | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| real2real = self.real_conv(real, ) | |
| imag2imag = self.imag_conv(imag, ) | |
| real2imag = self.imag_conv(real) | |
| imag2real = self.real_conv(imag) | |
| real = real2real - imag2imag | |
| imag = real2imag + imag2real | |
| out = torch.cat([real, imag], self.complex_axis) | |
| return out | |
| class ComplexConvTranspose2d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=(1, 1), | |
| stride=(1, 1), | |
| padding=(0, 0), | |
| output_padding=(0, 0), | |
| causal=False, | |
| complex_axis=1, | |
| groups=1 | |
| ): | |
| ''' | |
| in_channels: real+imag | |
| out_channels: real+imag | |
| ''' | |
| super(ComplexConvTranspose2d, self).__init__() | |
| self.in_channels = in_channels // 2 | |
| self.out_channels = out_channels // 2 | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.padding = padding | |
| self.output_padding = output_padding | |
| self.groups = groups | |
| self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, | |
| padding=self.padding, output_padding=output_padding, groups=self.groups) | |
| self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, | |
| padding=self.padding, output_padding=output_padding, groups=self.groups) | |
| self.complex_axis = complex_axis | |
| nn.init.normal_(self.real_conv.weight, std=0.05) | |
| nn.init.normal_(self.imag_conv.weight, std=0.05) | |
| nn.init.constant_(self.real_conv.bias, 0.) | |
| nn.init.constant_(self.imag_conv.bias, 0.) | |
| def forward(self, inputs): | |
| if isinstance(inputs, torch.Tensor): | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| elif isinstance(inputs, tuple) or isinstance(inputs, list): | |
| real = inputs[0] | |
| imag = inputs[1] | |
| if self.complex_axis == 0: | |
| real = self.real_conv(inputs) | |
| imag = self.imag_conv(inputs) | |
| real2real, imag2real = torch.chunk(real, 2, self.complex_axis) | |
| real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) | |
| else: | |
| if isinstance(inputs, torch.Tensor): | |
| real, imag = torch.chunk(inputs, 2, self.complex_axis) | |
| real2real = self.real_conv(real, ) | |
| imag2imag = self.imag_conv(imag, ) | |
| real2imag = self.imag_conv(real) | |
| imag2real = self.real_conv(imag) | |
| real = real2real - imag2imag | |
| imag = real2imag + imag2real | |
| out = torch.cat([real, imag], self.complex_axis) | |
| return out | |
| # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch | |
| # from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55 | |
| class ComplexBatchNorm(torch.nn.Module): | |
| def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, | |
| track_running_stats=True, complex_axis=1): | |
| super(ComplexBatchNorm, self).__init__() | |
| self.num_features = num_features // 2 | |
| self.eps = eps | |
| self.momentum = momentum | |
| self.affine = affine | |
| self.track_running_stats = track_running_stats | |
| self.complex_axis = complex_axis | |
| if self.affine: | |
| self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| self.Br = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features)) | |
| else: | |
| self.register_parameter('Wrr', None) | |
| self.register_parameter('Wri', None) | |
| self.register_parameter('Wii', None) | |
| self.register_parameter('Br', None) | |
| self.register_parameter('Bi', None) | |
| if self.track_running_stats: | |
| self.register_buffer('RMr', torch.zeros(self.num_features)) | |
| self.register_buffer('RMi', torch.zeros(self.num_features)) | |
| self.register_buffer('RVrr', torch.ones(self.num_features)) | |
| self.register_buffer('RVri', torch.zeros(self.num_features)) | |
| self.register_buffer('RVii', torch.ones(self.num_features)) | |
| self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) | |
| else: | |
| self.register_parameter('RMr', None) | |
| self.register_parameter('RMi', None) | |
| self.register_parameter('RVrr', None) | |
| self.register_parameter('RVri', None) | |
| self.register_parameter('RVii', None) | |
| self.register_parameter('num_batches_tracked', None) | |
| self.reset_parameters() | |
| def reset_running_stats(self): | |
| if self.track_running_stats: | |
| self.RMr.zero_() | |
| self.RMi.zero_() | |
| self.RVrr.fill_(1) | |
| self.RVri.zero_() | |
| self.RVii.fill_(1) | |
| self.num_batches_tracked.zero_() | |
| def reset_parameters(self): | |
| self.reset_running_stats() | |
| if self.affine: | |
| self.Br.data.zero_() | |
| self.Bi.data.zero_() | |
| self.Wrr.data.fill_(1) | |
| self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite | |
| self.Wii.data.fill_(1) | |
| def _check_input_dim(self, xr, xi): | |
| assert (xr.shape == xi.shape) | |
| assert (xr.size(1) == self.num_features) | |
| def forward(self, inputs): | |
| # self._check_input_dim(xr, xi) | |
| xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis) | |
| exponential_average_factor = 0.0 | |
| if self.training and self.track_running_stats: | |
| self.num_batches_tracked += 1 | |
| if self.momentum is None: # use cumulative moving average | |
| exponential_average_factor = 1.0 / self.num_batches_tracked.item() | |
| else: # use exponential moving average | |
| exponential_average_factor = self.momentum | |
| # | |
| # NOTE: The precise meaning of the "training flag" is: | |
| # True: Normalize using batch statistics, update running statistics | |
| # if they are being collected. | |
| # False: Normalize using running statistics, ignore batch statistics. | |
| # | |
| training = self.training or not self.track_running_stats | |
| redux = [i for i in reversed(range(xr.dim())) if i != 1] | |
| vdim = [1] * xr.dim() | |
| vdim[1] = xr.size(1) | |
| # | |
| # Mean M Computation and Centering | |
| # | |
| # Includes running mean update if training and running. | |
| # | |
| if training: | |
| Mr, Mi = xr, xi | |
| for d in redux: | |
| Mr = Mr.mean(d, keepdim=True) | |
| Mi = Mi.mean(d, keepdim=True) | |
| if self.track_running_stats: | |
| self.RMr.lerp_(Mr.squeeze(), exponential_average_factor) | |
| self.RMi.lerp_(Mi.squeeze(), exponential_average_factor) | |
| else: | |
| Mr = self.RMr.view(vdim) | |
| Mi = self.RMi.view(vdim) | |
| xr, xi = xr - Mr, xi - Mi | |
| # | |
| # Variance Matrix V Computation | |
| # | |
| # Includes epsilon numerical stabilizer/Tikhonov regularizer. | |
| # Includes running variance update if training and running. | |
| # | |
| if training: | |
| Vrr = xr * xr | |
| Vri = xr * xi | |
| Vii = xi * xi | |
| for d in redux: | |
| Vrr = Vrr.mean(d, keepdim=True) | |
| Vri = Vri.mean(d, keepdim=True) | |
| Vii = Vii.mean(d, keepdim=True) | |
| if self.track_running_stats: | |
| self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor) | |
| self.RVri.lerp_(Vri.squeeze(), exponential_average_factor) | |
| self.RVii.lerp_(Vii.squeeze(), exponential_average_factor) | |
| else: | |
| Vrr = self.RVrr.view(vdim) | |
| Vri = self.RVri.view(vdim) | |
| Vii = self.RVii.view(vdim) | |
| Vrr = Vrr + self.eps | |
| Vri = Vri | |
| Vii = Vii + self.eps | |
| # | |
| # Matrix Inverse Square Root U = V^-0.5 | |
| # | |
| # sqrt of a 2x2 matrix, | |
| # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix | |
| tau = Vrr + Vii | |
| delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri) | |
| s = delta.sqrt() | |
| t = (tau + 2 * s).sqrt() | |
| # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html | |
| rst = (s * t).reciprocal() | |
| Urr = (s + Vii) * rst | |
| Uii = (s + Vrr) * rst | |
| Uri = (- Vri) * rst | |
| # | |
| # Optionally left-multiply U by affine weights W to produce combined | |
| # weights Z, left-multiply the inputs by Z, then optionally bias them. | |
| # | |
| # y = Zx + B | |
| # y = WUx + B | |
| # y = [Wrr Wri][Urr Uri] [xr] + [Br] | |
| # [Wir Wii][Uir Uii] [xi] [Bi] | |
| # | |
| if self.affine: | |
| Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim) | |
| Zrr = (Wrr * Urr) + (Wri * Uri) | |
| Zri = (Wrr * Uri) + (Wri * Uii) | |
| Zir = (Wri * Urr) + (Wii * Uri) | |
| Zii = (Wri * Uri) + (Wii * Uii) | |
| else: | |
| Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii | |
| yr = (Zrr * xr) + (Zri * xi) | |
| yi = (Zir * xr) + (Zii * xi) | |
| if self.affine: | |
| yr = yr + self.Br.view(vdim) | |
| yi = yi + self.Bi.view(vdim) | |
| outputs = torch.cat([yr, yi], self.complex_axis) | |
| return outputs | |
| def extra_repr(self): | |
| return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ | |
| 'track_running_stats={track_running_stats}'.format(**self.__dict__) | |
| def complex_cat(inputs, axis): | |
| real, imag = [], [] | |
| for idx, data in enumerate(inputs): | |
| r, i = torch.chunk(data, 2, axis) | |
| real.append(r) | |
| imag.append(i) | |
| real = torch.cat(real, axis) | |
| imag = torch.cat(imag, axis) | |
| outputs = torch.cat([real, imag], axis) | |
| return outputs | |
| if __name__ == '__main__': | |
| import dc_crn7 | |
| torch.manual_seed(20) | |
| onet1 = dc_crn7.ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1)) | |
| onet2 = dc_crn7.ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1)) | |
| inputs = torch.randn([1, 12, 12, 10]) | |
| # print(onet1.real_kernel[0,0,0,0]) | |
| nnet1 = ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1), causal=True) | |
| # print(nnet1.real_conv.weight[0,0,0,0]) | |
| nnet2 = ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1)) | |
| print(torch.mean(nnet1(inputs) - onet1(inputs))) | |