import numpy as np from typing import Dict, List, NoReturn, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torchlibrosa.stft import STFT, ISTFT, magphase from models.base import Base, init_layer, init_bn, act class FiLM(nn.Module): def __init__(self, film_meta, condition_size): super(FiLM, self).__init__() self.condition_size = condition_size self.modules, _ = self.create_film_modules( film_meta=film_meta, ancestor_names=[], ) def create_film_modules(self, film_meta, ancestor_names): modules = {} # Pre-order traversal of modules for module_name, value in film_meta.items(): if isinstance(value, int): ancestor_names.append(module_name) unique_module_name = '->'.join(ancestor_names) modules[module_name] = self.add_film_layer_to_module( num_features=value, unique_module_name=unique_module_name, ) elif isinstance(value, dict): ancestor_names.append(module_name) modules[module_name], _ = self.create_film_modules( film_meta=value, ancestor_names=ancestor_names, ) ancestor_names.pop() return modules, ancestor_names def add_film_layer_to_module(self, num_features, unique_module_name): layer = nn.Linear(self.condition_size, num_features) init_layer(layer) self.add_module(name=unique_module_name, module=layer) return layer def forward(self, conditions): film_dict = self.calculate_film_data( conditions=conditions, modules=self.modules, ) return film_dict def calculate_film_data(self, conditions, modules): film_data = {} # Pre-order traversal of modules for module_name, module in modules.items(): if isinstance(module, nn.Module): film_data[module_name] = module(conditions)[:, :, None, None] elif isinstance(module, dict): film_data[module_name] = self.calculate_film_data(conditions, module) return film_data class ConvBlockRes(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: Tuple, momentum: float, has_film, ): r"""Residual block.""" super(ConvBlockRes, self).__init__() padding = [kernel_size[0] // 2, kernel_size[1] // 2] self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) self.conv1 = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=(1, 1), dilation=(1, 1), padding=padding, bias=False, ) self.conv2 = nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=(1, 1), dilation=(1, 1), padding=padding, bias=False, ) if in_channels != out_channels: self.shortcut = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), ) self.is_shortcut = True else: self.is_shortcut = False self.has_film = has_film self.init_weights() def init_weights(self) -> NoReturn: r"""Initialize weights.""" init_bn(self.bn1) init_bn(self.bn2) init_layer(self.conv1) init_layer(self.conv2) if self.is_shortcut: init_layer(self.shortcut) def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor: r"""Forward data into the module. Args: input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins) Returns: output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins) """ b1 = film_dict['beta1'] b2 = film_dict['beta2'] x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1, negative_slope=0.01)) x = self.conv2(F.leaky_relu_(self.bn2(x) + b2, negative_slope=0.01)) if self.is_shortcut: return self.shortcut(input_tensor) + x else: return input_tensor + x class EncoderBlockRes1B(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: Tuple, downsample: Tuple, momentum: float, has_film, ): r"""Encoder block, contains 8 convolutional layers.""" super(EncoderBlockRes1B, self).__init__() self.conv_block1 = ConvBlockRes( in_channels, out_channels, kernel_size, momentum, has_film, ) self.downsample = downsample def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor: r"""Forward data into the module. Args: input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins) Returns: encoder_pool: (batch_size, output_feature_maps, downsampled_time_steps, downsampled_freq_bins) encoder: (batch_size, output_feature_maps, time_steps, freq_bins) """ encoder = self.conv_block1(input_tensor, film_dict['conv_block1']) encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample) return encoder_pool, encoder class DecoderBlockRes1B(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: Tuple, upsample: Tuple, momentum: float, has_film, ): r"""Decoder block, contains 1 transposed convolutional and 8 convolutional layers.""" super(DecoderBlockRes1B, self).__init__() self.kernel_size = kernel_size self.stride = upsample self.conv1 = torch.nn.ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=self.stride, stride=self.stride, padding=(0, 0), bias=False, dilation=(1, 1), ) self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) self.conv_block2 = ConvBlockRes( out_channels * 2, out_channels, kernel_size, momentum, has_film, ) self.bn2 = nn.BatchNorm2d(in_channels, momentum=momentum) self.has_film = has_film self.init_weights() def init_weights(self): r"""Initialize weights.""" init_bn(self.bn1) init_layer(self.conv1) def forward( self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor, film_dict: Dict, ) -> torch.Tensor: r"""Forward data into the module. Args: input_tensor: (batch_size, input_feature_maps, downsampled_time_steps, downsampled_freq_bins) concat_tensor: (batch_size, input_feature_maps, time_steps, freq_bins) Returns: output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins) """ # b1 = film_dict['beta1'] b1 = film_dict['beta1'] x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1)) # (batch_size, input_feature_maps, time_steps, freq_bins) x = torch.cat((x, concat_tensor), dim=1) # (batch_size, input_feature_maps * 2, time_steps, freq_bins) x = self.conv_block2(x, film_dict['conv_block2']) # output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins) return x class ResUNet30_Base(nn.Module, Base): def __init__(self, input_channels, output_channels): super(ResUNet30_Base, self).__init__() window_size = 2048 hop_size = 320 center = True pad_mode = "reflect" window = "hann" momentum = 0.01 self.output_channels = output_channels self.target_sources_num = 1 self.K = 3 self.time_downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks} self.stft = STFT( n_fft=window_size, hop_length=hop_size, win_length=window_size, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True, ) self.istft = ISTFT( n_fft=window_size, hop_length=hop_size, win_length=window_size, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True, ) self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum) self.pre_conv = nn.Conv2d( in_channels=input_channels, out_channels=32, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True, ) self.encoder_block1 = EncoderBlockRes1B( in_channels=32, out_channels=32, kernel_size=(3, 3), downsample=(2, 2), momentum=momentum, has_film=True, ) self.encoder_block2 = EncoderBlockRes1B( in_channels=32, out_channels=64, kernel_size=(3, 3), downsample=(2, 2), momentum=momentum, has_film=True, ) self.encoder_block3 = EncoderBlockRes1B( in_channels=64, out_channels=128, kernel_size=(3, 3), downsample=(2, 2), momentum=momentum, has_film=True, ) self.encoder_block4 = EncoderBlockRes1B( in_channels=128, out_channels=256, kernel_size=(3, 3), downsample=(2, 2), momentum=momentum, has_film=True, ) self.encoder_block5 = EncoderBlockRes1B( in_channels=256, out_channels=384, kernel_size=(3, 3), downsample=(2, 2), momentum=momentum, has_film=True, ) self.encoder_block6 = EncoderBlockRes1B( in_channels=384, out_channels=384, kernel_size=(3, 3), downsample=(1, 2), momentum=momentum, has_film=True, ) self.conv_block7a = EncoderBlockRes1B( in_channels=384, out_channels=384, kernel_size=(3, 3), downsample=(1, 1), momentum=momentum, has_film=True, ) self.decoder_block1 = DecoderBlockRes1B( in_channels=384, out_channels=384, kernel_size=(3, 3), upsample=(1, 2), momentum=momentum, has_film=True, ) self.decoder_block2 = DecoderBlockRes1B( in_channels=384, out_channels=384, kernel_size=(3, 3), upsample=(2, 2), momentum=momentum, has_film=True, ) self.decoder_block3 = DecoderBlockRes1B( in_channels=384, out_channels=256, kernel_size=(3, 3), upsample=(2, 2), momentum=momentum, has_film=True, ) self.decoder_block4 = DecoderBlockRes1B( in_channels=256, out_channels=128, kernel_size=(3, 3), upsample=(2, 2), momentum=momentum, has_film=True, ) self.decoder_block5 = DecoderBlockRes1B( in_channels=128, out_channels=64, kernel_size=(3, 3), upsample=(2, 2), momentum=momentum, has_film=True, ) self.decoder_block6 = DecoderBlockRes1B( in_channels=64, out_channels=32, kernel_size=(3, 3), upsample=(2, 2), momentum=momentum, has_film=True, ) self.after_conv = nn.Conv2d( in_channels=32, out_channels=output_channels * self.K, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True, ) self.init_weights() def init_weights(self): init_bn(self.bn0) init_layer(self.pre_conv) init_layer(self.after_conv) def feature_maps_to_wav( self, input_tensor: torch.Tensor, sp: torch.Tensor, sin_in: torch.Tensor, cos_in: torch.Tensor, audio_length: int, ) -> torch.Tensor: r"""Convert feature maps to waveform. Args: input_tensor: (batch_size, target_sources_num * output_channels * self.K, time_steps, freq_bins) sp: (batch_size, input_channels, time_steps, freq_bins) sin_in: (batch_size, input_channels, time_steps, freq_bins) cos_in: (batch_size, input_channels, time_steps, freq_bins) (There is input_channels == output_channels for the source separation task.) Outputs: waveform: (batch_size, target_sources_num * output_channels, segment_samples) """ batch_size, _, time_steps, freq_bins = input_tensor.shape x = input_tensor.reshape( batch_size, self.target_sources_num, self.output_channels, self.K, time_steps, freq_bins, ) # x: (batch_size, target_sources_num, output_channels, self.K, time_steps, freq_bins) mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) _mask_real = torch.tanh(x[:, :, :, 1, :, :]) _mask_imag = torch.tanh(x[:, :, :, 2, :, :]) # linear_mag = torch.tanh(x[:, :, :, 3, :, :]) _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) # mask_cos, mask_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) # Y = |Y|cos∠Y + j|Y|sin∠Y # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) out_cos = ( cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin ) out_sin = ( sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin ) # out_cos: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) # out_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) # Calculate |Y|. out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag) # out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag) # out_mag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) # Calculate Y_{real} and Y_{imag} for ISTFT. out_real = out_mag * out_cos out_imag = out_mag * out_sin # out_real, out_imag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins) # Reformat shape to (N, 1, time_steps, freq_bins) for ISTFT where # N = batch_size * target_sources_num * output_channels shape = ( batch_size * self.target_sources_num * self.output_channels, 1, time_steps, freq_bins, ) out_real = out_real.reshape(shape) out_imag = out_imag.reshape(shape) # ISTFT. x = self.istft(out_real, out_imag, audio_length) # (batch_size * target_sources_num * output_channels, segments_num) # Reshape. waveform = x.reshape( batch_size, self.target_sources_num * self.output_channels, audio_length ) # (batch_size, target_sources_num * output_channels, segments_num) return waveform def forward(self, mixtures, film_dict): """ Args: input: (batch_size, segment_samples, channels_num) Outputs: output_dict: { 'wav': (batch_size, segment_samples, channels_num), 'sp': (batch_size, channels_num, time_steps, freq_bins)} """ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures) x = mag # Batch normalization x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) """(batch_size, chanenls, time_steps, freq_bins)""" # Pad spectrogram to be evenly divided by downsample ratio. origin_len = x.shape[2] pad_len = ( int(np.ceil(x.shape[2] / self.time_downsample_ratio)) * self.time_downsample_ratio - origin_len ) x = F.pad(x, pad=(0, 0, 0, pad_len)) """(batch_size, channels, padded_time_steps, freq_bins)""" # Let frequency bins be evenly divided by 2, e.g., 513 -> 512 x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F) # UNet x = self.pre_conv(x) x1_pool, x1 = self.encoder_block1(x, film_dict['encoder_block1']) # x1_pool: (bs, 32, T / 2, F / 2) x2_pool, x2 = self.encoder_block2(x1_pool, film_dict['encoder_block2']) # x2_pool: (bs, 64, T / 4, F / 4) x3_pool, x3 = self.encoder_block3(x2_pool, film_dict['encoder_block3']) # x3_pool: (bs, 128, T / 8, F / 8) x4_pool, x4 = self.encoder_block4(x3_pool, film_dict['encoder_block4']) # x4_pool: (bs, 256, T / 16, F / 16) x5_pool, x5 = self.encoder_block5(x4_pool, film_dict['encoder_block5']) # x5_pool: (bs, 384, T / 32, F / 32) x6_pool, x6 = self.encoder_block6(x5_pool, film_dict['encoder_block6']) # x6_pool: (bs, 384, T / 32, F / 64) x_center, _ = self.conv_block7a(x6_pool, film_dict['conv_block7a']) # (bs, 384, T / 32, F / 64) x7 = self.decoder_block1(x_center, x6, film_dict['decoder_block1']) # (bs, 384, T / 32, F / 32) x8 = self.decoder_block2(x7, x5, film_dict['decoder_block2']) # (bs, 384, T / 16, F / 16) x9 = self.decoder_block3(x8, x4, film_dict['decoder_block3']) # (bs, 256, T / 8, F / 8) x10 = self.decoder_block4(x9, x3, film_dict['decoder_block4']) # (bs, 128, T / 4, F / 4) x11 = self.decoder_block5(x10, x2, film_dict['decoder_block5']) # (bs, 64, T / 2, F / 2) x12 = self.decoder_block6(x11, x1, film_dict['decoder_block6']) # (bs, 32, T, F) x = self.after_conv(x12) # Recover shape x = F.pad(x, pad=(0, 1)) x = x[:, :, 0:origin_len, :] audio_length = mixtures.shape[2] # Recover each subband spectrograms to subband waveforms. Then synthesis # the subband waveforms to a waveform. separated_audio = self.feature_maps_to_wav( input_tensor=x, # input_tensor: (batch_size, target_sources_num * output_channels * self.K, T, F') sp=mag, # sp: (batch_size, input_channels, T, F') sin_in=sin_in, # sin_in: (batch_size, input_channels, T, F') cos_in=cos_in, # cos_in: (batch_size, input_channels, T, F') audio_length=audio_length, ) # (batch_size, target_sources_num * output_channels, subbands_num, segment_samples) output_dict = {'waveform': separated_audio} return output_dict def get_film_meta(module): film_meta = {} if hasattr(module, 'has_film'):\ if module.has_film: film_meta['beta1'] = module.bn1.num_features film_meta['beta2'] = module.bn2.num_features else: film_meta['beta1'] = 0 film_meta['beta2'] = 0 for child_name, child_module in module.named_children(): child_meta = get_film_meta(child_module) if len(child_meta) > 0: film_meta[child_name] = child_meta return film_meta class ResUNet30(nn.Module): def __init__(self, input_channels, output_channels, condition_size): super(ResUNet30, self).__init__() self.base = ResUNet30_Base( input_channels=input_channels, output_channels=output_channels, ) self.film_meta = get_film_meta( module=self.base, ) self.film = FiLM( film_meta=self.film_meta, condition_size=condition_size ) def forward(self, input_dict): mixtures = input_dict['mixture'] conditions = input_dict['condition'] film_dict = self.film( conditions=conditions, ) output_dict = self.base( mixtures=mixtures, film_dict=film_dict, ) return output_dict @torch.no_grad() def chunk_inference(self, input_dict): chunk_config = { 'NL': 1.0, 'NC': 3.0, 'NR': 1.0, 'RATE': self.sampling_rate } mixtures = input_dict['mixture'] conditions = input_dict['condition'] film_dict = self.film( conditions=conditions, ) NL = int(chunk_config['NL'] * chunk_config['RATE']) NC = int(chunk_config['NC'] * chunk_config['RATE']) NR = int(chunk_config['NR'] * chunk_config['RATE']) L = mixtures.shape[2] out_np = np.zeros([1, L]) WINDOW = NL + NC + NR current_idx = 0 while current_idx + WINDOW < L: chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW] chunk_out = self.base( mixtures=chunk_in, film_dict=film_dict, )['waveform'] chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy() if current_idx == 0: out_np[:, current_idx:current_idx+WINDOW-NR] = \ chunk_out_np[:, :-NR] if NR != 0 else chunk_out_np else: out_np[:, current_idx+NL:current_idx+WINDOW-NR] = \ chunk_out_np[:, NL:-NR] if NR != 0 else chunk_out_np[:, NL:] current_idx += NC if current_idx < L: chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW] chunk_out = self.base( mixtures=chunk_in, film_dict=film_dict, )['waveform'] chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy() seg_len = chunk_out_np.shape[1] out_np[:, current_idx + NL:current_idx + seg_len] = \ chunk_out_np[:, NL:] return out_np