# coding: utf-8 # Author:WangTianRui # Date :2020/11/3 16:49 import torch.nn as nn import torch from utils.conv_stft import * from utils.complexnn import * class DCCRN(nn.Module): def __init__(self, rnn_layer=2, rnn_hidden=256, win_len=400, hop_len=100, fft_len=512, win_type='hann', use_clstm=True, use_cbn=False, masking_mode='E', kernel_size=5, kernel_num=(32, 64, 128, 256, 256, 256) ): super(DCCRN, self).__init__() self.rnn_layer = rnn_layer self.rnn_hidden = rnn_hidden self.win_len = win_len self.hop_len = hop_len self.fft_len = fft_len self.win_type = win_type self.use_clstm = use_clstm self.use_cbn = use_cbn self.masking_mode = masking_mode self.kernel_size = kernel_size self.kernel_num = (2,) + kernel_num self.stft = ConvSTFT(self.win_len, self.hop_len, self.fft_len, self.win_type, 'complex', fix=True) self.istft = ConviSTFT(self.win_len, self.hop_len, self.fft_len, self.win_type, 'complex', fix=True) self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for idx in range(len(self.kernel_num) - 1): self.encoder.append( nn.Sequential( ComplexConv2d( self.kernel_num[idx], self.kernel_num[idx + 1], kernel_size=(self.kernel_size, 2), stride=(2, 1), padding=(2, 1) ), nn.BatchNorm2d(self.kernel_num[idx + 1]) if not use_cbn else ComplexBatchNorm( self.kernel_num[idx + 1]), nn.PReLU() ) ) hidden_dim = self.fft_len // (2 ** (len(self.kernel_num))) if self.use_clstm: rnns = [] for idx in range(rnn_layer): rnns.append( NavieComplexLSTM( input_size=hidden_dim * self.kernel_num[-1] if idx == 0 else self.rnn_hidden, hidden_size=self.rnn_hidden, batch_first=False, projection_dim=hidden_dim * self.kernel_num[-1] if idx == rnn_layer - 1 else None ) ) self.enhance = nn.Sequential(*rnns) else: self.enhance = nn.LSTM( input_size=hidden_dim * self.kernel_num[-1], hidden_size=self.rnn_hidden, num_layers=2, dropout=0.0, batch_first=False ) self.transform = nn.Linear(self.rnn_hidden, hidden_dim * self.kernel_num[-1]) for idx in range(len(self.kernel_num) - 1, 0, -1): if idx != 1: self.decoder.append( nn.Sequential( ComplexConvTranspose2d( self.kernel_num[idx] * 2, self.kernel_num[idx - 1], kernel_size=(self.kernel_size, 2), stride=(2, 1), padding=(2, 0), output_padding=(1, 0) ), nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm( self.kernel_num[idx - 1]), nn.PReLU() ) ) else: self.decoder.append( nn.Sequential( ComplexConvTranspose2d( self.kernel_num[idx] * 2, self.kernel_num[idx - 1], kernel_size=(self.kernel_size, 2), stride=(2, 1), padding=(2, 0), output_padding=(1, 0) ) ) ) if isinstance(self.enhance, nn.LSTM): self.enhance.flatten_parameters() def forward(self, x): stft = self.stft(x) # print("stft:", stft.size()) real = stft[:, :self.fft_len // 2 + 1] imag = stft[:, self.fft_len // 2 + 1:] # print("real imag:", real.size(), imag.size()) spec_mags = torch.sqrt(real ** 2 + imag ** 2 + 1e-8) spec_phase = torch.atan2(imag, real) spec_complex = torch.stack([real, imag], dim=1)[:, :, 1:] # B,2,256 # print("spec", spec_mags.size(), spec_phase.size(), spec_complex.size()) out = spec_complex encoder_out = [] for idx, encoder in enumerate(self.encoder): out = encoder(out) # print("encoder out:", out.size()) encoder_out.append(out) B, C, D, T = out.size() out = out.permute(3, 0, 1, 2) if self.use_clstm: r_rnn_in = out[:, :, :C // 2] i_rnn_in = out[:, :, C // 2:] r_rnn_in = torch.reshape(r_rnn_in, [T, B, C // 2 * D]) i_rnn_in = torch.reshape(i_rnn_in, [T, B, C // 2 * D]) r_rnn_in, i_rnn_in = self.enhance([r_rnn_in, i_rnn_in]) r_rnn_in = torch.reshape(r_rnn_in, [T, B, C // 2, D]) i_rnn_in = torch.reshape(i_rnn_in, [T, B, C // 2, D]) out = torch.cat([r_rnn_in, i_rnn_in], 2) else: out = torch.reshape(out, [T, B, C * D]) out, _ = self.enhance(out) out = self.transform(out) out = torch.reshape(out, [T, B, C, D]) out = out.permute(1, 2, 3, 0) for idx in range(len(self.decoder)): out = complex_cat([out, encoder_out[-1 - idx]], 1) out = self.decoder[idx](out) out = out[..., 1:] mask_real = out[:, 0] mask_imag = out[:, 1] mask_real = F.pad(mask_real, [0, 0, 1, 0]) mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) if self.masking_mode == 'E': mask_mags = (mask_real ** 2 + mask_imag ** 2) ** 0.5 real_phase = mask_real / (mask_mags + 1e-8) imag_phase = mask_imag / (mask_mags + 1e-8) mask_phase = torch.atan2( imag_phase, real_phase ) mask_mags = torch.tanh(mask_mags) est_mags = mask_mags * spec_mags est_phase = spec_phase + mask_phase real = est_mags * torch.cos(est_phase) imag = est_mags * torch.sin(est_phase) elif self.masking_mode == 'C': real = real * mask_real - imag * mask_imag imag = real * mask_imag + imag * mask_real elif self.masking_mode == 'R': real = real * mask_real imag = imag * mask_imag out_spec = torch.cat([real, imag], 1) out_wav = self.istft(out_spec) out_wav = torch.squeeze(out_wav, 1) out_wav = out_wav.clamp_(-1, 1) return out_wav def l2_norm(s1, s2): norm = torch.sum(s1 * s2, -1, keepdim=True) return norm def si_snr(s1, s2, eps=1e-8): s1_s2_norm = l2_norm(s1, s2) s2_s2_norm = l2_norm(s2, s2) s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 e_nosie = s1 - s_target target_norm = l2_norm(s_target, s_target) noise_norm = l2_norm(e_nosie, e_nosie) snr = 10 * torch.log10(target_norm / (noise_norm + eps) + eps) return torch.mean(snr) def loss(inputs, label): return -(si_snr(inputs, label)) if __name__ == '__main__': test_model = DCCRN(rnn_hidden=256, masking_mode='E', use_clstm=True, kernel_num=(32, 64, 128, 256, 256, 256)) model_test_timer(test_model, (1, 16000 * 30))