Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class PreNet(nn.Module): | |
| def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5): | |
| super().__init__() | |
| self.fc1 = nn.Linear(in_dims, fc1_dims) | |
| self.fc2 = nn.Linear(fc1_dims, fc2_dims) | |
| self.p = dropout | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = F.relu(x) | |
| x = F.dropout(x, self.p, training=self.training) | |
| x = self.fc2(x) | |
| x = F.relu(x) | |
| x = F.dropout(x, self.p, training=self.training) | |
| return x | |
| class HighwayNetwork(nn.Module): | |
| def __init__(self, size): | |
| super().__init__() | |
| self.W1 = nn.Linear(size, size) | |
| self.W2 = nn.Linear(size, size) | |
| self.W1.bias.data.fill_(0.) | |
| def forward(self, x): | |
| x1 = self.W1(x) | |
| x2 = self.W2(x) | |
| g = torch.sigmoid(x2) | |
| y = g * F.relu(x1) + (1. - g) * x | |
| return y | |
| class BatchNormConv(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel, relu=True): | |
| super().__init__() | |
| self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False) | |
| self.bnorm = nn.BatchNorm1d(out_channels) | |
| self.relu = relu | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = F.relu(x) if self.relu is True else x | |
| return self.bnorm(x) | |
| class ConvNorm(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, | |
| padding=None, dilation=1, bias=True, w_init_gain='linear'): | |
| super(ConvNorm, self).__init__() | |
| if padding is None: | |
| assert (kernel_size % 2 == 1) | |
| padding = int(dilation * (kernel_size - 1) / 2) | |
| self.conv = torch.nn.Conv1d(in_channels, out_channels, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, dilation=dilation, | |
| bias=bias) | |
| torch.nn.init.xavier_uniform_( | |
| self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) | |
| def forward(self, signal): | |
| conv_signal = self.conv(signal) | |
| return conv_signal | |
| class CBHG(nn.Module): | |
| def __init__(self, K, in_channels, channels, proj_channels, num_highways): | |
| super().__init__() | |
| # List of all rnns to call `flatten_parameters()` on | |
| self._to_flatten = [] | |
| self.bank_kernels = [i for i in range(1, K + 1)] | |
| self.conv1d_bank = nn.ModuleList() | |
| for k in self.bank_kernels: | |
| conv = BatchNormConv(in_channels, channels, k) | |
| self.conv1d_bank.append(conv) | |
| self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) | |
| self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) | |
| self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) | |
| # Fix the highway input if necessary | |
| if proj_channels[-1] != channels: | |
| self.highway_mismatch = True | |
| self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) | |
| else: | |
| self.highway_mismatch = False | |
| self.highways = nn.ModuleList() | |
| for i in range(num_highways): | |
| hn = HighwayNetwork(channels) | |
| self.highways.append(hn) | |
| self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True) | |
| self._to_flatten.append(self.rnn) | |
| # Avoid fragmentation of RNN parameters and associated warning | |
| self._flatten_parameters() | |
| def forward(self, x): | |
| # Although we `_flatten_parameters()` on init, when using DataParallel | |
| # the model gets replicated, making it no longer guaranteed that the | |
| # weights are contiguous in GPU memory. Hence, we must call it again | |
| self._flatten_parameters() | |
| # Save these for later | |
| residual = x | |
| seq_len = x.size(-1) | |
| conv_bank = [] | |
| # Convolution Bank | |
| for conv in self.conv1d_bank: | |
| c = conv(x) # Convolution | |
| conv_bank.append(c[:, :, :seq_len]) | |
| # Stack along the channel axis | |
| conv_bank = torch.cat(conv_bank, dim=1) | |
| # dump the last padding to fit residual | |
| x = self.maxpool(conv_bank)[:, :, :seq_len] | |
| # Conv1d projections | |
| x = self.conv_project1(x) | |
| x = self.conv_project2(x) | |
| # Residual Connect | |
| x = x + residual | |
| # Through the highways | |
| x = x.transpose(1, 2) | |
| if self.highway_mismatch is True: | |
| x = self.pre_highway(x) | |
| for h in self.highways: | |
| x = h(x) | |
| # And then the RNN | |
| x, _ = self.rnn(x) | |
| return x | |
| def _flatten_parameters(self): | |
| """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used | |
| to improve efficiency and avoid PyTorch yelling at us.""" | |
| [m.flatten_parameters() for m in self._to_flatten] | |
| class TacotronEncoder(nn.Module): | |
| def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout): | |
| super().__init__() | |
| self.embedding = nn.Embedding(num_chars, embed_dims) | |
| self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout) | |
| self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels, | |
| proj_channels=[cbhg_channels, cbhg_channels], | |
| num_highways=num_highways) | |
| self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| x = self.pre_net(x) | |
| x.transpose_(1, 2) | |
| x = self.cbhg(x) | |
| x = self.proj_out(x) | |
| return x | |
| class RNNEncoder(nn.Module): | |
| def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5): | |
| super(RNNEncoder, self).__init__() | |
| self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0) | |
| convolutions = [] | |
| for _ in range(n_convolutions): | |
| conv_layer = nn.Sequential( | |
| ConvNorm(embedding_dim, | |
| embedding_dim, | |
| kernel_size=kernel_size, stride=1, | |
| padding=int((kernel_size - 1) / 2), | |
| dilation=1, w_init_gain='relu'), | |
| nn.BatchNorm1d(embedding_dim)) | |
| convolutions.append(conv_layer) | |
| self.convolutions = nn.ModuleList(convolutions) | |
| self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1, | |
| batch_first=True, bidirectional=True) | |
| def forward(self, x): | |
| input_lengths = (x > 0).sum(-1) | |
| input_lengths = input_lengths.cpu().numpy() | |
| x = self.embedding(x) | |
| x = x.transpose(1, 2) # [B, H, T] | |
| for conv in self.convolutions: | |
| x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x | |
| x = x.transpose(1, 2) # [B, T, H] | |
| # pytorch tensor are not reversible, hence the conversion | |
| x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False) | |
| self.lstm.flatten_parameters() | |
| outputs, _ = self.lstm(x) | |
| outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) | |
| return outputs | |
| class DecoderRNN(torch.nn.Module): | |
| def __init__(self, hidden_size, decoder_rnn_dim, dropout): | |
| super(DecoderRNN, self).__init__() | |
| self.in_conv1d = nn.Sequential( | |
| torch.nn.Conv1d( | |
| in_channels=hidden_size, | |
| out_channels=hidden_size, | |
| kernel_size=9, padding=4, | |
| ), | |
| torch.nn.ReLU(), | |
| torch.nn.Conv1d( | |
| in_channels=hidden_size, | |
| out_channels=hidden_size, | |
| kernel_size=9, padding=4, | |
| ), | |
| ) | |
| self.ln = nn.LayerNorm(hidden_size) | |
| if decoder_rnn_dim == 0: | |
| decoder_rnn_dim = hidden_size * 2 | |
| self.rnn = torch.nn.LSTM( | |
| input_size=hidden_size, | |
| hidden_size=decoder_rnn_dim, | |
| num_layers=1, | |
| batch_first=True, | |
| bidirectional=True, | |
| dropout=dropout | |
| ) | |
| self.rnn.flatten_parameters() | |
| self.conv1d = torch.nn.Conv1d( | |
| in_channels=decoder_rnn_dim * 2, | |
| out_channels=hidden_size, | |
| kernel_size=3, | |
| padding=1, | |
| ) | |
| def forward(self, x): | |
| input_masks = x.abs().sum(-1).ne(0).data[:, :, None] | |
| input_lengths = input_masks.sum([-1, -2]) | |
| input_lengths = input_lengths.cpu().numpy() | |
| x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2) | |
| x = self.ln(x) | |
| x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False) | |
| self.rnn.flatten_parameters() | |
| x, _ = self.rnn(x) # [B, T, C] | |
| x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) | |
| x = x * input_masks | |
| pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) # [B, T, C] | |
| pre_mel = pre_mel * input_masks | |
| return pre_mel | |