Spaces:
Sleeping
Sleeping
File size: 775 Bytes
41aae2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
import numpy as np
class TimeEncode(torch.nn.Module):
# Time Encoding proposed by TGAT
def __init__(self, dimension):
super(TimeEncode, self).__init__()
self.dimension = dimension
self.w = torch.nn.Linear(1, dimension)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.float().reshape(dimension, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())
def forward(self, t):
# t has shape [batch_size, seq_len]
# Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
t = t.unsqueeze(dim=2)
# output has shape [batch_size, seq_len, dimension]
output = torch.cos(self.w(t))
return output
|