import torch import torch.nn.functional as F from collections import defaultdict def make_positions(tensor, padding_idx): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. """ # The series of casts and type-conversions here are carefully # balanced to both work with ONNX export and XLA. In particular XLA # prefers ints, cumsum defaults to output longs, and ONNX doesn't know # how to handle the dtype kwarg in cumsum. mask = tensor.ne(padding_idx).int() return ( torch.cumsum(mask, dim=1).type_as(mask) * mask ).long() + padding_idx def fill_with_neg_inf2(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(-1e8).type_as(t) def softmax(x, dim): return F.softmax(x, dim=dim, dtype=torch.float32)