Spaces:
Sleeping
Sleeping
File size: 4,387 Bytes
41aae2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
from torch import nn
from utils.utils import MergeLayer
class TemporalAttentionLayer(torch.nn.Module):
"""
Temporal attention layer. Return the temporal embedding of a node given the node itself,
its neighbors and the edge timestamps.
"""
def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim,
output_dimension, n_head=2,
dropout=0.1):
super(TemporalAttentionLayer, self).__init__()
self.n_head = n_head
self.feat_dim = n_node_features
self.time_dim = time_dim
self.query_dim = n_node_features + time_dim
self.key_dim = n_neighbors_features + time_dim + n_edge_features
self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension)
self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim,
kdim=self.key_dim,
vdim=self.key_dim,
num_heads=n_head,
dropout=dropout)
def forward(self, src_node_features, src_time_features, neighbors_features,
neighbors_time_features, edge_features, neighbors_padding_mask):
"""
"Temporal attention model
:param src_node_features: float Tensor of shape [batch_size, n_node_features]
:param src_time_features: float Tensor of shape [batch_size, 1, time_dim]
:param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features]
:param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors,
time_dim]
:param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features]
:param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors]
:return:
attn_output: float Tensor of shape [1, batch_size, n_node_features]
attn_output_weights: [batch_size, 1, n_neighbors]
"""
src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)
query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)
key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)
# print(neighbors_features.shape, edge_features.shape, neighbors_time_features.shape)
# Reshape tensors so to expected shape by multi head attention
query = query.permute([1, 0, 2]) # [1, batch_size, num_of_features]
key = key.permute([1, 0, 2]) # [n_neighbors, batch_size, num_of_features]
# Compute mask of which source nodes have no valid neighbors
invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)
# If a source node has no valid neighbor, set it's first neighbor to be valid. This will
# force the attention to just 'attend' on this neighbor (which has the same features as all
# the others since they are fake neighbors) and will produce an equivalent result to the
# original tgat paper which was forcing fake neighbors to all have same attention of 1e-10
neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False
# print(query.shape, key.shape)
attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,
key_padding_mask=neighbors_padding_mask)
# mask = torch.unsqueeze(neighbors_padding_mask, dim=2) # mask [B, N, 1]
# mask = mask.permute([0, 2, 1])
# attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key,
# mask=mask)
attn_output = attn_output.squeeze()
attn_output_weights = attn_output_weights.squeeze()
# Source nodes with no neighbors have an all zero attention output. The attention output is
# then added or concatenated to the original source node features and then fed into an MLP.
# This means that an all zero vector is not used.
attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)
attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)
# Skip connection with temporal attention over neighborhood and the features of the node itself
attn_output = self.merger(attn_output, src_node_features)
return attn_output, attn_output_weights
|