|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  | from .SubLayers import MultiHeadAttention, PositionwiseFeedForward | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FFTBlock(torch.nn.Module): | 
					
						
						|  | """FFT Block""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): | 
					
						
						|  | super(FFTBlock, self).__init__() | 
					
						
						|  | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) | 
					
						
						|  | self.pos_ffn = PositionwiseFeedForward( | 
					
						
						|  | d_model, d_inner, kernel_size, dropout=dropout | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, enc_input, mask=None, slf_attn_mask=None): | 
					
						
						|  | enc_output, enc_slf_attn = self.slf_attn( | 
					
						
						|  | enc_input, enc_input, enc_input, mask=slf_attn_mask | 
					
						
						|  | ) | 
					
						
						|  | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) | 
					
						
						|  |  | 
					
						
						|  | enc_output = self.pos_ffn(enc_output) | 
					
						
						|  | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) | 
					
						
						|  |  | 
					
						
						|  | return enc_output, enc_slf_attn | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConvNorm(torch.nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=1, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=None, | 
					
						
						|  | dilation=1, | 
					
						
						|  | bias=True, | 
					
						
						|  | w_init_gain="linear", | 
					
						
						|  | ): | 
					
						
						|  | super(ConvNorm, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | if padding is None: | 
					
						
						|  | assert kernel_size % 2 == 1 | 
					
						
						|  | padding = int(dilation * (kernel_size - 1) / 2) | 
					
						
						|  |  | 
					
						
						|  | self.conv = torch.nn.Conv1d( | 
					
						
						|  | in_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size=kernel_size, | 
					
						
						|  | stride=stride, | 
					
						
						|  | padding=padding, | 
					
						
						|  | dilation=dilation, | 
					
						
						|  | bias=bias, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, signal): | 
					
						
						|  | conv_signal = self.conv(signal) | 
					
						
						|  |  | 
					
						
						|  | return conv_signal | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PostNet(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | PostNet: Five 1-d convolution with 512 channels and kernel size 5 | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | n_mel_channels=80, | 
					
						
						|  | postnet_embedding_dim=512, | 
					
						
						|  | postnet_kernel_size=5, | 
					
						
						|  | postnet_n_convolutions=5, | 
					
						
						|  | ): | 
					
						
						|  | super(PostNet, self).__init__() | 
					
						
						|  | self.convolutions = nn.ModuleList() | 
					
						
						|  |  | 
					
						
						|  | self.convolutions.append( | 
					
						
						|  | nn.Sequential( | 
					
						
						|  | ConvNorm( | 
					
						
						|  | n_mel_channels, | 
					
						
						|  | postnet_embedding_dim, | 
					
						
						|  | kernel_size=postnet_kernel_size, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=int((postnet_kernel_size - 1) / 2), | 
					
						
						|  | dilation=1, | 
					
						
						|  | w_init_gain="tanh", | 
					
						
						|  | ), | 
					
						
						|  | nn.BatchNorm1d(postnet_embedding_dim), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | for i in range(1, postnet_n_convolutions - 1): | 
					
						
						|  | self.convolutions.append( | 
					
						
						|  | nn.Sequential( | 
					
						
						|  | ConvNorm( | 
					
						
						|  | postnet_embedding_dim, | 
					
						
						|  | postnet_embedding_dim, | 
					
						
						|  | kernel_size=postnet_kernel_size, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=int((postnet_kernel_size - 1) / 2), | 
					
						
						|  | dilation=1, | 
					
						
						|  | w_init_gain="tanh", | 
					
						
						|  | ), | 
					
						
						|  | nn.BatchNorm1d(postnet_embedding_dim), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.convolutions.append( | 
					
						
						|  | nn.Sequential( | 
					
						
						|  | ConvNorm( | 
					
						
						|  | postnet_embedding_dim, | 
					
						
						|  | n_mel_channels, | 
					
						
						|  | kernel_size=postnet_kernel_size, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=int((postnet_kernel_size - 1) / 2), | 
					
						
						|  | dilation=1, | 
					
						
						|  | w_init_gain="linear", | 
					
						
						|  | ), | 
					
						
						|  | nn.BatchNorm1d(n_mel_channels), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | x = x.contiguous().transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | for i in range(len(self.convolutions) - 1): | 
					
						
						|  | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) | 
					
						
						|  | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) | 
					
						
						|  |  | 
					
						
						|  | x = x.contiguous().transpose(1, 2) | 
					
						
						|  | return x | 
					
						
						|  |  |