|  | import math | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | from torch.nn.utils.parametrize import remove_parametrizations | 
					
						
						|  |  | 
					
						
						|  | from TTS.vocoder.layers.parallel_wavegan import ResidualBlock | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ParallelWaveganDiscriminator(nn.Module): | 
					
						
						|  | """PWGAN discriminator as in https://arxiv.org/abs/1910.11480. | 
					
						
						|  | It classifies each audio window real/fake and returns a sequence | 
					
						
						|  | of predictions. | 
					
						
						|  | It is a stack of convolutional blocks with dilation. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels=1, | 
					
						
						|  | out_channels=1, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | num_layers=10, | 
					
						
						|  | conv_channels=64, | 
					
						
						|  | dilation_factor=1, | 
					
						
						|  | nonlinear_activation="LeakyReLU", | 
					
						
						|  | nonlinear_activation_params={"negative_slope": 0.2}, | 
					
						
						|  | bias=True, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size." | 
					
						
						|  | assert dilation_factor > 0, " [!] dilation factor must be > 0." | 
					
						
						|  | self.conv_layers = nn.ModuleList() | 
					
						
						|  | conv_in_channels = in_channels | 
					
						
						|  | for i in range(num_layers - 1): | 
					
						
						|  | if i == 0: | 
					
						
						|  | dilation = 1 | 
					
						
						|  | else: | 
					
						
						|  | dilation = i if dilation_factor == 1 else dilation_factor**i | 
					
						
						|  | conv_in_channels = conv_channels | 
					
						
						|  | padding = (kernel_size - 1) // 2 * dilation | 
					
						
						|  | conv_layer = [ | 
					
						
						|  | nn.Conv1d( | 
					
						
						|  | conv_in_channels, | 
					
						
						|  | conv_channels, | 
					
						
						|  | kernel_size=kernel_size, | 
					
						
						|  | padding=padding, | 
					
						
						|  | dilation=dilation, | 
					
						
						|  | bias=bias, | 
					
						
						|  | ), | 
					
						
						|  | getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), | 
					
						
						|  | ] | 
					
						
						|  | self.conv_layers += conv_layer | 
					
						
						|  | padding = (kernel_size - 1) // 2 | 
					
						
						|  | last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) | 
					
						
						|  | self.conv_layers += [last_conv_layer] | 
					
						
						|  | self.apply_weight_norm() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | x : (B, 1, T). | 
					
						
						|  | Returns: | 
					
						
						|  | Tensor: (B, 1, T) | 
					
						
						|  | """ | 
					
						
						|  | for f in self.conv_layers: | 
					
						
						|  | x = f(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | def apply_weight_norm(self): | 
					
						
						|  | def _apply_weight_norm(m): | 
					
						
						|  | if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): | 
					
						
						|  | torch.nn.utils.parametrizations.weight_norm(m) | 
					
						
						|  |  | 
					
						
						|  | self.apply(_apply_weight_norm) | 
					
						
						|  |  | 
					
						
						|  | def remove_weight_norm(self): | 
					
						
						|  | def _remove_weight_norm(m): | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | remove_parametrizations(m, "weight") | 
					
						
						|  | except ValueError: | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | self.apply(_remove_weight_norm) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ResidualParallelWaveganDiscriminator(nn.Module): | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels=1, | 
					
						
						|  | out_channels=1, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | num_layers=30, | 
					
						
						|  | stacks=3, | 
					
						
						|  | res_channels=64, | 
					
						
						|  | gate_channels=128, | 
					
						
						|  | skip_channels=64, | 
					
						
						|  | dropout=0.0, | 
					
						
						|  | bias=True, | 
					
						
						|  | nonlinear_activation="LeakyReLU", | 
					
						
						|  | nonlinear_activation_params={"negative_slope": 0.2}, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." | 
					
						
						|  |  | 
					
						
						|  | self.in_channels = in_channels | 
					
						
						|  | self.out_channels = out_channels | 
					
						
						|  | self.num_layers = num_layers | 
					
						
						|  | self.stacks = stacks | 
					
						
						|  | self.kernel_size = kernel_size | 
					
						
						|  | self.res_factor = math.sqrt(1.0 / num_layers) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | assert num_layers % stacks == 0 | 
					
						
						|  | layers_per_stack = num_layers // stacks | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.first_conv = nn.Sequential( | 
					
						
						|  | nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True), | 
					
						
						|  | getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.conv_layers = nn.ModuleList() | 
					
						
						|  | for layer in range(num_layers): | 
					
						
						|  | dilation = 2 ** (layer % layers_per_stack) | 
					
						
						|  | conv = ResidualBlock( | 
					
						
						|  | kernel_size=kernel_size, | 
					
						
						|  | res_channels=res_channels, | 
					
						
						|  | gate_channels=gate_channels, | 
					
						
						|  | skip_channels=skip_channels, | 
					
						
						|  | aux_channels=-1, | 
					
						
						|  | dilation=dilation, | 
					
						
						|  | dropout=dropout, | 
					
						
						|  | bias=bias, | 
					
						
						|  | use_causal_conv=False, | 
					
						
						|  | ) | 
					
						
						|  | self.conv_layers += [conv] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.last_conv_layers = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), | 
					
						
						|  | nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True), | 
					
						
						|  | getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), | 
					
						
						|  | nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.apply_weight_norm() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | x: (B, 1, T). | 
					
						
						|  | """ | 
					
						
						|  | x = self.first_conv(x) | 
					
						
						|  |  | 
					
						
						|  | skips = 0 | 
					
						
						|  | for f in self.conv_layers: | 
					
						
						|  | x, h = f(x, None) | 
					
						
						|  | skips += h | 
					
						
						|  | skips *= self.res_factor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = skips | 
					
						
						|  | for f in self.last_conv_layers: | 
					
						
						|  | x = f(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | def apply_weight_norm(self): | 
					
						
						|  | def _apply_weight_norm(m): | 
					
						
						|  | if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): | 
					
						
						|  | torch.nn.utils.parametrizations.weight_norm(m) | 
					
						
						|  |  | 
					
						
						|  | self.apply(_apply_weight_norm) | 
					
						
						|  |  | 
					
						
						|  | def remove_weight_norm(self): | 
					
						
						|  | def _remove_weight_norm(m): | 
					
						
						|  | try: | 
					
						
						|  | print(f"Weight norm is removed from {m}.") | 
					
						
						|  | remove_parametrizations(m, "weight") | 
					
						
						|  | except ValueError: | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | self.apply(_remove_weight_norm) | 
					
						
						|  |  |