import torch from torch import nn from utils.utils import MergeLayer class TemporalAttentionLayer(torch.nn.Module): """ Temporal attention layer. Return the temporal embedding of a node given the node itself, its neighbors and the edge timestamps. """ def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim, output_dimension, n_head=2, dropout=0.1): super(TemporalAttentionLayer, self).__init__() self.n_head = n_head self.feat_dim = n_node_features self.time_dim = time_dim self.query_dim = n_node_features + time_dim self.key_dim = n_neighbors_features + time_dim + n_edge_features self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension) self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim, kdim=self.key_dim, vdim=self.key_dim, num_heads=n_head, dropout=dropout) def forward(self, src_node_features, src_time_features, neighbors_features, neighbors_time_features, edge_features, neighbors_padding_mask): """ "Temporal attention model :param src_node_features: float Tensor of shape [batch_size, n_node_features] :param src_time_features: float Tensor of shape [batch_size, 1, time_dim] :param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features] :param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors, time_dim] :param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features] :param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors] :return: attn_output: float Tensor of shape [1, batch_size, n_node_features] attn_output_weights: [batch_size, 1, n_neighbors] """ src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1) query = torch.cat([src_node_features_unrolled, src_time_features], dim=2) key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2) # print(neighbors_features.shape, edge_features.shape, neighbors_time_features.shape) # Reshape tensors so to expected shape by multi head attention query = query.permute([1, 0, 2]) # [1, batch_size, num_of_features] key = key.permute([1, 0, 2]) # [n_neighbors, batch_size, num_of_features] # Compute mask of which source nodes have no valid neighbors invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True) # If a source node has no valid neighbor, set it's first neighbor to be valid. This will # force the attention to just 'attend' on this neighbor (which has the same features as all # the others since they are fake neighbors) and will produce an equivalent result to the # original tgat paper which was forcing fake neighbors to all have same attention of 1e-10 neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False # print(query.shape, key.shape) attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key, key_padding_mask=neighbors_padding_mask) # mask = torch.unsqueeze(neighbors_padding_mask, dim=2) # mask [B, N, 1] # mask = mask.permute([0, 2, 1]) # attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key, # mask=mask) attn_output = attn_output.squeeze() attn_output_weights = attn_output_weights.squeeze() # Source nodes with no neighbors have an all zero attention output. The attention output is # then added or concatenated to the original source node features and then fed into an MLP. # This means that an all zero vector is not used. attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0) attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0) # Skip connection with temporal attention over neighborhood and the features of the node itself attn_output = self.merger(attn_output, src_node_features) return attn_output, attn_output_weights