Spaces:
Build error
Build error
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |
| def sort_pack_padded_sequence(input, lengths): | |
| sorted_lengths, indices = torch.sort(lengths, descending=True) | |
| tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) | |
| inv_ix = indices.clone() | |
| inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) | |
| return tmp, inv_ix | |
| def pad_unsort_packed_sequence(input, inv_ix): | |
| tmp, _ = pad_packed_sequence(input, batch_first=True) | |
| tmp = tmp[inv_ix] | |
| return tmp | |
| def pack_wrapper(module, attn_feats, attn_feat_lens): | |
| packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens) | |
| if isinstance(module, torch.nn.RNNBase): | |
| return pad_unsort_packed_sequence(module(packed)[0], inv_ix) | |
| else: | |
| return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) | |
| def generate_length_mask(lens, max_length=None): | |
| lens = torch.as_tensor(lens) | |
| N = lens.size(0) | |
| if max_length is None: | |
| max_length = max(lens) | |
| idxs = torch.arange(max_length).repeat(N).view(N, max_length) | |
| idxs = idxs.to(lens.device) | |
| mask = (idxs < lens.view(-1, 1)) | |
| return mask | |
| def mean_with_lens(features, lens): | |
| """ | |
| features: [N, T, ...] (assume the second dimension represents length) | |
| lens: [N,] | |
| """ | |
| lens = torch.as_tensor(lens) | |
| if max(lens) != features.size(1): | |
| max_length = features.size(1) | |
| mask = generate_length_mask(lens, max_length) | |
| else: | |
| mask = generate_length_mask(lens) | |
| mask = mask.to(features.device) # [N, T] | |
| while mask.ndim < features.ndim: | |
| mask = mask.unsqueeze(-1) | |
| feature_mean = features * mask | |
| feature_mean = feature_mean.sum(1) | |
| while lens.ndim < feature_mean.ndim: | |
| lens = lens.unsqueeze(1) | |
| feature_mean = feature_mean / lens.to(features.device) | |
| # feature_mean = features * mask.unsqueeze(-1) | |
| # feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device) | |
| return feature_mean | |
| def max_with_lens(features, lens): | |
| """ | |
| features: [N, T, ...] (assume the second dimension represents length) | |
| lens: [N,] | |
| """ | |
| lens = torch.as_tensor(lens) | |
| mask = generate_length_mask(lens).to(features.device) # [N, T] | |
| feature_max = features.clone() | |
| feature_max[~mask] = float("-inf") | |
| feature_max, _ = feature_max.max(1) | |
| return feature_max | |
| def repeat_tensor(x, n): | |
| return x.unsqueeze(0).repeat(n, *([1] * len(x.shape))) | |
| def init(m, method="kaiming"): | |
| if isinstance(m, (nn.Conv2d, nn.Conv1d)): | |
| if method == "kaiming": | |
| nn.init.kaiming_uniform_(m.weight) | |
| elif method == "xavier": | |
| nn.init.xavier_uniform_(m.weight) | |
| else: | |
| raise Exception(f"initialization method {method} not supported") | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): | |
| nn.init.constant_(m.weight, 1) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| if method == "kaiming": | |
| nn.init.kaiming_uniform_(m.weight) | |
| elif method == "xavier": | |
| nn.init.xavier_uniform_(m.weight) | |
| else: | |
| raise Exception(f"initialization method {method} not supported") | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Embedding): | |
| if method == "kaiming": | |
| nn.init.kaiming_uniform_(m.weight) | |
| elif method == "xavier": | |
| nn.init.xavier_uniform_(m.weight) | |
| else: | |
| raise Exception(f"initialization method {method} not supported") | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, dropout=0.1, max_len=100): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * \ | |
| (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| # self.register_buffer("pe", pe) | |
| self.register_parameter("pe", nn.Parameter(pe, requires_grad=False)) | |
| def forward(self, x): | |
| # x: [T, N, E] | |
| x = x + self.pe[:x.size(0), :] | |
| return self.dropout(x) | |