mramazan's picture
Upload 60 files
426ffb5 verified
raw
history blame
395 Bytes
import torch.nn as nn
import torch
import math
class PositionalEmbedding(nn.Module):
def __init__(self, max_len, d_model):
super().__init__()
# Compute the positional encodings once in log space.
self.pe = nn.Embedding(max_len, d_model)
def forward(self, x):
batch_size = x.size(0)
return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)