Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from typing import Any, Optional | |
| import torch | |
| import torch.onnx.operators | |
| from fairseq import utils | |
| from torch import Tensor, nn | |
| class SinusoidalPositionalEmbedding(nn.Module): | |
| """This module produces sinusoidal positional embeddings of any length. | |
| Padding symbols are ignored. | |
| """ | |
| def __init__(self, embedding_dim, padding_idx, init_size=1024): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.padding_idx = padding_idx if padding_idx is not None else 0 | |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
| init_size, embedding_dim, padding_idx | |
| ) | |
| self.onnx_trace = False | |
| self.register_buffer("_float_tensor", torch.FloatTensor(1)) | |
| self.max_positions = int(1e5) | |
| def prepare_for_onnx_export_(self): | |
| self.onnx_trace = True | |
| def get_embedding( | |
| num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None | |
| ): | |
| """Build sinusoidal embeddings. | |
| This matches the implementation in tensor2tensor, but differs slightly | |
| from the description in Section 3.5 of "Attention Is All You Need". | |
| """ | |
| half_dim = embedding_dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) | |
| emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( | |
| 1 | |
| ) * emb.unsqueeze(0) | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( | |
| num_embeddings, -1 | |
| ) | |
| if embedding_dim % 2 == 1: | |
| # zero pad | |
| emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) | |
| if padding_idx is not None: | |
| emb[padding_idx, :] = 0 | |
| return emb | |
| def forward( | |
| self, | |
| input, | |
| incremental_state: Optional[Any] = None, | |
| timestep: Optional[Tensor] = None, | |
| positions: Optional[Any] = None, | |
| ): | |
| """Input is expected to be of size [bsz x seqlen].""" | |
| bspair = torch.onnx.operators.shape_as_tensor(input) | |
| bsz, seq_len = bspair[0], bspair[1] | |
| max_pos = self.padding_idx + 1 + seq_len | |
| if self.weights is None or max_pos > self.weights.size(0): | |
| # recompute/expand embeddings if needed | |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
| max_pos, self.embedding_dim, self.padding_idx | |
| ) | |
| self.weights = self.weights.to(self._float_tensor) | |
| if incremental_state is not None: | |
| # positions is the same for every token when decoding a single step | |
| pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len | |
| if self.onnx_trace: | |
| return ( | |
| self.weights.index_select(index=self.padding_idx + pos, dim=0) | |
| .unsqueeze(1) | |
| .repeat(bsz, 1, 1) | |
| ) | |
| return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) | |
| positions = utils.make_positions( | |
| input, self.padding_idx, onnx_trace=self.onnx_trace | |
| ) | |
| if self.onnx_trace: | |
| flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) | |
| embedding_shape = torch.cat( | |
| (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) | |
| ) | |
| embeddings = torch.onnx.operators.reshape_from_tensor_shape( | |
| flat_embeddings, embedding_shape | |
| ) | |
| return embeddings | |
| return ( | |
| self.weights.index_select(0, positions.view(-1)) | |
| .view(bsz, seq_len, -1) | |
| .detach() | |
| ) | |