Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from utils.utils import MergeLayer | |
class TemporalAttentionLayer(torch.nn.Module): | |
""" | |
Temporal attention layer. Return the temporal embedding of a node given the node itself, | |
its neighbors and the edge timestamps. | |
""" | |
def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim, | |
output_dimension, n_head=2, | |
dropout=0.1): | |
super(TemporalAttentionLayer, self).__init__() | |
self.n_head = n_head | |
self.feat_dim = n_node_features | |
self.time_dim = time_dim | |
self.query_dim = n_node_features + time_dim | |
self.key_dim = n_neighbors_features + time_dim + n_edge_features | |
self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension) | |
self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim, | |
kdim=self.key_dim, | |
vdim=self.key_dim, | |
num_heads=n_head, | |
dropout=dropout) | |
def forward(self, src_node_features, src_time_features, neighbors_features, | |
neighbors_time_features, edge_features, neighbors_padding_mask): | |
""" | |
"Temporal attention model | |
:param src_node_features: float Tensor of shape [batch_size, n_node_features] | |
:param src_time_features: float Tensor of shape [batch_size, 1, time_dim] | |
:param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features] | |
:param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors, | |
time_dim] | |
:param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features] | |
:param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors] | |
:return: | |
attn_output: float Tensor of shape [1, batch_size, n_node_features] | |
attn_output_weights: [batch_size, 1, n_neighbors] | |
""" | |
src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1) | |
query = torch.cat([src_node_features_unrolled, src_time_features], dim=2) | |
key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2) | |
# print(neighbors_features.shape, edge_features.shape, neighbors_time_features.shape) | |
# Reshape tensors so to expected shape by multi head attention | |
query = query.permute([1, 0, 2]) # [1, batch_size, num_of_features] | |
key = key.permute([1, 0, 2]) # [n_neighbors, batch_size, num_of_features] | |
# Compute mask of which source nodes have no valid neighbors | |
invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True) | |
# If a source node has no valid neighbor, set it's first neighbor to be valid. This will | |
# force the attention to just 'attend' on this neighbor (which has the same features as all | |
# the others since they are fake neighbors) and will produce an equivalent result to the | |
# original tgat paper which was forcing fake neighbors to all have same attention of 1e-10 | |
neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False | |
# print(query.shape, key.shape) | |
attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key, | |
key_padding_mask=neighbors_padding_mask) | |
# mask = torch.unsqueeze(neighbors_padding_mask, dim=2) # mask [B, N, 1] | |
# mask = mask.permute([0, 2, 1]) | |
# attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key, | |
# mask=mask) | |
attn_output = attn_output.squeeze() | |
attn_output_weights = attn_output_weights.squeeze() | |
# Source nodes with no neighbors have an all zero attention output. The attention output is | |
# then added or concatenated to the original source node features and then fed into an MLP. | |
# This means that an all zero vector is not used. | |
attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0) | |
attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0) | |
# Skip connection with temporal attention over neighborhood and the features of the node itself | |
attn_output = self.merger(attn_output, src_node_features) | |
return attn_output, attn_output_weights | |