Spaces:
Running
Running
Upload 14 files
Browse files- .gitattributes +1 -0
- data/ml_wikipedia.csv +3 -0
- data/ml_wikipedia.npy +3 -0
- data/ml_wikipedia_node.npy +3 -0
- model/temporal_attention.py +90 -0
- model/tgn.py +278 -0
- model/time_encoding.py +25 -0
- modules/embedding_module.py +291 -0
- modules/memory.py +75 -0
- modules/memory_updater.py +68 -0
- modules/message_aggregator.py +90 -0
- modules/message_function.py +40 -0
- modules/tgn.py +278 -0
- utils/data_processing.py +184 -0
- utils/utils.py +186 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/ml_wikipedia.csv filter=lfs diff=lfs merge=lfs -text
|
data/ml_wikipedia.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85fd8e50e5ffbb1348173b85d0b7b69ee270550f511c505daf062c9bf9db8027
|
3 |
+
size 347159369
|
data/ml_wikipedia.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f601ac36dfaafdd78759d174611204a0660be70c93aa16f30edb56a7bc642b53
|
3 |
+
size 216685728
|
data/ml_wikipedia_node.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85f2054b5fe9d76188a5bf014232c81a64b3ffedb5a146ff30e1daa71215278b
|
3 |
+
size 12697856
|
model/temporal_attention.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from utils.utils import MergeLayer
|
5 |
+
|
6 |
+
|
7 |
+
class TemporalAttentionLayer(torch.nn.Module):
|
8 |
+
"""
|
9 |
+
Temporal attention layer. Return the temporal embedding of a node given the node itself,
|
10 |
+
its neighbors and the edge timestamps.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim,
|
14 |
+
output_dimension, n_head=2,
|
15 |
+
dropout=0.1):
|
16 |
+
super(TemporalAttentionLayer, self).__init__()
|
17 |
+
|
18 |
+
self.n_head = n_head
|
19 |
+
|
20 |
+
self.feat_dim = n_node_features
|
21 |
+
self.time_dim = time_dim
|
22 |
+
|
23 |
+
self.query_dim = n_node_features + time_dim
|
24 |
+
self.key_dim = n_neighbors_features + time_dim + n_edge_features
|
25 |
+
|
26 |
+
self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension)
|
27 |
+
|
28 |
+
self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim,
|
29 |
+
kdim=self.key_dim,
|
30 |
+
vdim=self.key_dim,
|
31 |
+
num_heads=n_head,
|
32 |
+
dropout=dropout)
|
33 |
+
|
34 |
+
def forward(self, src_node_features, src_time_features, neighbors_features,
|
35 |
+
neighbors_time_features, edge_features, neighbors_padding_mask):
|
36 |
+
"""
|
37 |
+
"Temporal attention model
|
38 |
+
:param src_node_features: float Tensor of shape [batch_size, n_node_features]
|
39 |
+
:param src_time_features: float Tensor of shape [batch_size, 1, time_dim]
|
40 |
+
:param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features]
|
41 |
+
:param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors,
|
42 |
+
time_dim]
|
43 |
+
:param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features]
|
44 |
+
:param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors]
|
45 |
+
:return:
|
46 |
+
attn_output: float Tensor of shape [1, batch_size, n_node_features]
|
47 |
+
attn_output_weights: [batch_size, 1, n_neighbors]
|
48 |
+
"""
|
49 |
+
|
50 |
+
src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)
|
51 |
+
|
52 |
+
query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)
|
53 |
+
key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)
|
54 |
+
|
55 |
+
# print(neighbors_features.shape, edge_features.shape, neighbors_time_features.shape)
|
56 |
+
# Reshape tensors so to expected shape by multi head attention
|
57 |
+
query = query.permute([1, 0, 2]) # [1, batch_size, num_of_features]
|
58 |
+
key = key.permute([1, 0, 2]) # [n_neighbors, batch_size, num_of_features]
|
59 |
+
|
60 |
+
# Compute mask of which source nodes have no valid neighbors
|
61 |
+
invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)
|
62 |
+
# If a source node has no valid neighbor, set it's first neighbor to be valid. This will
|
63 |
+
# force the attention to just 'attend' on this neighbor (which has the same features as all
|
64 |
+
# the others since they are fake neighbors) and will produce an equivalent result to the
|
65 |
+
# original tgat paper which was forcing fake neighbors to all have same attention of 1e-10
|
66 |
+
neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False
|
67 |
+
|
68 |
+
# print(query.shape, key.shape)
|
69 |
+
|
70 |
+
attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,
|
71 |
+
key_padding_mask=neighbors_padding_mask)
|
72 |
+
|
73 |
+
# mask = torch.unsqueeze(neighbors_padding_mask, dim=2) # mask [B, N, 1]
|
74 |
+
# mask = mask.permute([0, 2, 1])
|
75 |
+
# attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key,
|
76 |
+
# mask=mask)
|
77 |
+
|
78 |
+
attn_output = attn_output.squeeze()
|
79 |
+
attn_output_weights = attn_output_weights.squeeze()
|
80 |
+
|
81 |
+
# Source nodes with no neighbors have an all zero attention output. The attention output is
|
82 |
+
# then added or concatenated to the original source node features and then fed into an MLP.
|
83 |
+
# This means that an all zero vector is not used.
|
84 |
+
attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)
|
85 |
+
attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)
|
86 |
+
|
87 |
+
# Skip connection with temporal attention over neighborhood and the features of the node itself
|
88 |
+
attn_output = self.merger(attn_output, src_node_features)
|
89 |
+
|
90 |
+
return attn_output, attn_output_weights
|
model/tgn.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from utils.utils import MergeLayer
|
7 |
+
from modules.memory import Memory
|
8 |
+
from modules.message_aggregator import get_message_aggregator
|
9 |
+
from modules.message_function import get_message_function
|
10 |
+
from modules.memory_updater import get_memory_updater
|
11 |
+
from modules.embedding_module import get_embedding_module
|
12 |
+
from model.time_encoding import TimeEncode
|
13 |
+
|
14 |
+
|
15 |
+
class TGN(torch.nn.Module):
|
16 |
+
def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2,
|
17 |
+
n_heads=2, dropout=0.1, use_memory=False,
|
18 |
+
memory_update_at_start=True, message_dimension=100,
|
19 |
+
memory_dimension=500, embedding_module_type="graph_attention",
|
20 |
+
message_function="mlp",
|
21 |
+
mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
|
22 |
+
std_time_shift_dst=1, n_neighbors=None, aggregator_type="last",
|
23 |
+
memory_updater_type="gru",
|
24 |
+
use_destination_embedding_in_message=False,
|
25 |
+
use_source_embedding_in_message=False,
|
26 |
+
dyrep=False):
|
27 |
+
super(TGN, self).__init__()
|
28 |
+
|
29 |
+
self.n_layers = n_layers
|
30 |
+
self.neighbor_finder = neighbor_finder
|
31 |
+
self.device = device
|
32 |
+
self.logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device)
|
35 |
+
self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device)
|
36 |
+
|
37 |
+
self.n_node_features = self.node_raw_features.shape[1]
|
38 |
+
self.n_nodes = self.node_raw_features.shape[0]
|
39 |
+
self.n_edge_features = self.edge_raw_features.shape[1]
|
40 |
+
self.embedding_dimension = self.n_node_features
|
41 |
+
self.n_neighbors = n_neighbors
|
42 |
+
self.embedding_module_type = embedding_module_type
|
43 |
+
self.use_destination_embedding_in_message = use_destination_embedding_in_message
|
44 |
+
self.use_source_embedding_in_message = use_source_embedding_in_message
|
45 |
+
self.dyrep = dyrep
|
46 |
+
|
47 |
+
self.use_memory = use_memory
|
48 |
+
self.time_encoder = TimeEncode(dimension=self.n_node_features)
|
49 |
+
self.memory = None
|
50 |
+
|
51 |
+
self.mean_time_shift_src = mean_time_shift_src
|
52 |
+
self.std_time_shift_src = std_time_shift_src
|
53 |
+
self.mean_time_shift_dst = mean_time_shift_dst
|
54 |
+
self.std_time_shift_dst = std_time_shift_dst
|
55 |
+
|
56 |
+
if self.use_memory:
|
57 |
+
self.memory_dimension = memory_dimension
|
58 |
+
self.memory_update_at_start = memory_update_at_start
|
59 |
+
raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \
|
60 |
+
self.time_encoder.dimension
|
61 |
+
message_dimension = message_dimension if message_function != "identity" else raw_message_dimension
|
62 |
+
self.memory = Memory(n_nodes=self.n_nodes,
|
63 |
+
memory_dimension=self.memory_dimension,
|
64 |
+
input_dimension=message_dimension,
|
65 |
+
message_dimension=message_dimension,
|
66 |
+
device=device)
|
67 |
+
self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type,
|
68 |
+
device=device)
|
69 |
+
self.message_function = get_message_function(module_type=message_function,
|
70 |
+
raw_message_dimension=raw_message_dimension,
|
71 |
+
message_dimension=message_dimension)
|
72 |
+
self.memory_updater = get_memory_updater(module_type=memory_updater_type,
|
73 |
+
memory=self.memory,
|
74 |
+
message_dimension=message_dimension,
|
75 |
+
memory_dimension=self.memory_dimension,
|
76 |
+
device=device)
|
77 |
+
|
78 |
+
self.embedding_module_type = embedding_module_type
|
79 |
+
|
80 |
+
self.embedding_module = get_embedding_module(module_type=embedding_module_type,
|
81 |
+
node_features=self.node_raw_features,
|
82 |
+
edge_features=self.edge_raw_features,
|
83 |
+
memory=self.memory,
|
84 |
+
neighbor_finder=self.neighbor_finder,
|
85 |
+
time_encoder=self.time_encoder,
|
86 |
+
n_layers=self.n_layers,
|
87 |
+
n_node_features=self.n_node_features,
|
88 |
+
n_edge_features=self.n_edge_features,
|
89 |
+
n_time_features=self.n_node_features,
|
90 |
+
embedding_dimension=self.embedding_dimension,
|
91 |
+
device=self.device,
|
92 |
+
n_heads=n_heads, dropout=dropout,
|
93 |
+
use_memory=use_memory,
|
94 |
+
n_neighbors=self.n_neighbors)
|
95 |
+
|
96 |
+
# MLP to compute probability on an edge given two node embeddings
|
97 |
+
self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features,
|
98 |
+
self.n_node_features,
|
99 |
+
1)
|
100 |
+
|
101 |
+
def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
|
102 |
+
edge_idxs, n_neighbors=20):
|
103 |
+
"""
|
104 |
+
Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
|
105 |
+
|
106 |
+
source_nodes [batch_size]: source ids.
|
107 |
+
:param destination_nodes [batch_size]: destination ids
|
108 |
+
:param negative_nodes [batch_size]: ids of negative sampled destination
|
109 |
+
:param edge_times [batch_size]: timestamp of interaction
|
110 |
+
:param edge_idxs [batch_size]: index of interaction
|
111 |
+
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
|
112 |
+
layer
|
113 |
+
:return: Temporal embeddings for sources, destinations and negatives
|
114 |
+
"""
|
115 |
+
|
116 |
+
n_samples = len(source_nodes)
|
117 |
+
nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
|
118 |
+
positives = np.concatenate([source_nodes, destination_nodes])
|
119 |
+
timestamps = np.concatenate([edge_times, edge_times, edge_times])
|
120 |
+
|
121 |
+
memory = None
|
122 |
+
time_diffs = None
|
123 |
+
if self.use_memory:
|
124 |
+
if self.memory_update_at_start:
|
125 |
+
# Update memory for all nodes with messages stored in previous batches
|
126 |
+
memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
|
127 |
+
self.memory.messages)
|
128 |
+
else:
|
129 |
+
memory = self.memory.get_memory(list(range(self.n_nodes)))
|
130 |
+
last_update = self.memory.last_update
|
131 |
+
|
132 |
+
### Compute differences between the time the memory of a node was last updated,
|
133 |
+
### and the time for which we want to compute the embedding of a node
|
134 |
+
source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
|
135 |
+
source_nodes].long()
|
136 |
+
source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
|
137 |
+
destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
|
138 |
+
destination_nodes].long()
|
139 |
+
destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
|
140 |
+
negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
|
141 |
+
negative_nodes].long()
|
142 |
+
negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
|
143 |
+
|
144 |
+
time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
|
145 |
+
dim=0)
|
146 |
+
|
147 |
+
# Compute the embeddings using the embedding module
|
148 |
+
node_embedding = self.embedding_module.compute_embedding(memory=memory,
|
149 |
+
source_nodes=nodes,
|
150 |
+
timestamps=timestamps,
|
151 |
+
n_layers=self.n_layers,
|
152 |
+
n_neighbors=n_neighbors,
|
153 |
+
time_diffs=time_diffs)
|
154 |
+
|
155 |
+
source_node_embedding = node_embedding[:n_samples]
|
156 |
+
destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
|
157 |
+
negative_node_embedding = node_embedding[2 * n_samples:]
|
158 |
+
|
159 |
+
if self.use_memory:
|
160 |
+
if self.memory_update_at_start:
|
161 |
+
# Persist the updates to the memory only for sources and destinations (since now we have
|
162 |
+
# new messages for them)
|
163 |
+
self.update_memory(positives, self.memory.messages)
|
164 |
+
|
165 |
+
assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
|
166 |
+
"Something wrong in how the memory was updated"
|
167 |
+
|
168 |
+
# Remove messages for the positives since we have already updated the memory using them
|
169 |
+
self.memory.clear_messages(positives)
|
170 |
+
|
171 |
+
unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
|
172 |
+
source_node_embedding,
|
173 |
+
destination_nodes,
|
174 |
+
destination_node_embedding,
|
175 |
+
edge_times, edge_idxs)
|
176 |
+
unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
|
177 |
+
destination_node_embedding,
|
178 |
+
source_nodes,
|
179 |
+
source_node_embedding,
|
180 |
+
edge_times, edge_idxs)
|
181 |
+
if self.memory_update_at_start:
|
182 |
+
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
|
183 |
+
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
|
184 |
+
else:
|
185 |
+
self.update_memory(unique_sources, source_id_to_messages)
|
186 |
+
self.update_memory(unique_destinations, destination_id_to_messages)
|
187 |
+
|
188 |
+
if self.dyrep:
|
189 |
+
source_node_embedding = memory[source_nodes]
|
190 |
+
destination_node_embedding = memory[destination_nodes]
|
191 |
+
negative_node_embedding = memory[negative_nodes]
|
192 |
+
|
193 |
+
return source_node_embedding, destination_node_embedding, negative_node_embedding
|
194 |
+
|
195 |
+
def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
|
196 |
+
edge_idxs, n_neighbors=20):
|
197 |
+
"""
|
198 |
+
Compute probabilities for edges between sources and destination and between sources and
|
199 |
+
negatives by first computing temporal embeddings using the TGN encoder and then feeding them
|
200 |
+
into the MLP decoder.
|
201 |
+
:param destination_nodes [batch_size]: destination ids
|
202 |
+
:param negative_nodes [batch_size]: ids of negative sampled destination
|
203 |
+
:param edge_times [batch_size]: timestamp of interaction
|
204 |
+
:param edge_idxs [batch_size]: index of interaction
|
205 |
+
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
|
206 |
+
layer
|
207 |
+
:return: Probabilities for both the positive and negative edges
|
208 |
+
"""
|
209 |
+
n_samples = len(source_nodes)
|
210 |
+
source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(
|
211 |
+
source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)
|
212 |
+
|
213 |
+
score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
|
214 |
+
torch.cat([destination_node_embedding,
|
215 |
+
negative_node_embedding])).squeeze(dim=0)
|
216 |
+
pos_score = score[:n_samples]
|
217 |
+
neg_score = score[n_samples:]
|
218 |
+
|
219 |
+
return pos_score.sigmoid(), neg_score.sigmoid()
|
220 |
+
|
221 |
+
def update_memory(self, nodes, messages):
|
222 |
+
# Aggregate messages for the same nodes
|
223 |
+
unique_nodes, unique_messages, unique_timestamps = \
|
224 |
+
self.message_aggregator.aggregate(
|
225 |
+
nodes,
|
226 |
+
messages)
|
227 |
+
|
228 |
+
if len(unique_nodes) > 0:
|
229 |
+
unique_messages = self.message_function.compute_message(unique_messages)
|
230 |
+
|
231 |
+
# Update the memory with the aggregated messages
|
232 |
+
self.memory_updater.update_memory(unique_nodes, unique_messages,
|
233 |
+
timestamps=unique_timestamps)
|
234 |
+
|
235 |
+
def get_updated_memory(self, nodes, messages):
|
236 |
+
# Aggregate messages for the same nodes
|
237 |
+
unique_nodes, unique_messages, unique_timestamps = \
|
238 |
+
self.message_aggregator.aggregate(
|
239 |
+
nodes,
|
240 |
+
messages)
|
241 |
+
|
242 |
+
if len(unique_nodes) > 0:
|
243 |
+
unique_messages = self.message_function.compute_message(unique_messages)
|
244 |
+
|
245 |
+
updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
|
246 |
+
unique_messages,
|
247 |
+
timestamps=unique_timestamps)
|
248 |
+
|
249 |
+
return updated_memory, updated_last_update
|
250 |
+
|
251 |
+
def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
|
252 |
+
destination_node_embedding, edge_times, edge_idxs):
|
253 |
+
edge_times = torch.from_numpy(edge_times).float().to(self.device)
|
254 |
+
edge_features = self.edge_raw_features[edge_idxs]
|
255 |
+
|
256 |
+
source_memory = self.memory.get_memory(source_nodes) if not \
|
257 |
+
self.use_source_embedding_in_message else source_node_embedding
|
258 |
+
destination_memory = self.memory.get_memory(destination_nodes) if \
|
259 |
+
not self.use_destination_embedding_in_message else destination_node_embedding
|
260 |
+
|
261 |
+
source_time_delta = edge_times - self.memory.last_update[source_nodes]
|
262 |
+
source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
|
263 |
+
source_nodes), -1)
|
264 |
+
|
265 |
+
source_message = torch.cat([source_memory, destination_memory, edge_features,
|
266 |
+
source_time_delta_encoding],
|
267 |
+
dim=1)
|
268 |
+
messages = defaultdict(list)
|
269 |
+
unique_sources = np.unique(source_nodes)
|
270 |
+
|
271 |
+
for i in range(len(source_nodes)):
|
272 |
+
messages[source_nodes[i]].append((source_message[i], edge_times[i]))
|
273 |
+
|
274 |
+
return unique_sources, messages
|
275 |
+
|
276 |
+
def set_neighbor_finder(self, neighbor_finder):
|
277 |
+
self.neighbor_finder = neighbor_finder
|
278 |
+
self.embedding_module.neighbor_finder = neighbor_finder
|
model/time_encoding.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class TimeEncode(torch.nn.Module):
|
6 |
+
# Time Encoding proposed by TGAT
|
7 |
+
def __init__(self, dimension):
|
8 |
+
super(TimeEncode, self).__init__()
|
9 |
+
|
10 |
+
self.dimension = dimension
|
11 |
+
self.w = torch.nn.Linear(1, dimension)
|
12 |
+
|
13 |
+
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
|
14 |
+
.float().reshape(dimension, -1))
|
15 |
+
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())
|
16 |
+
|
17 |
+
def forward(self, t):
|
18 |
+
# t has shape [batch_size, seq_len]
|
19 |
+
# Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
|
20 |
+
t = t.unsqueeze(dim=2)
|
21 |
+
|
22 |
+
# output has shape [batch_size, seq_len, dimension]
|
23 |
+
output = torch.cos(self.w(t))
|
24 |
+
|
25 |
+
return output
|
modules/embedding_module.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
|
6 |
+
from model.temporal_attention import TemporalAttentionLayer
|
7 |
+
|
8 |
+
|
9 |
+
class EmbeddingModule(nn.Module):
|
10 |
+
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
|
11 |
+
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
|
12 |
+
dropout):
|
13 |
+
super(EmbeddingModule, self).__init__()
|
14 |
+
self.node_features = node_features
|
15 |
+
self.edge_features = edge_features
|
16 |
+
# self.memory = memory
|
17 |
+
self.neighbor_finder = neighbor_finder
|
18 |
+
self.time_encoder = time_encoder
|
19 |
+
self.n_layers = n_layers
|
20 |
+
self.n_node_features = n_node_features
|
21 |
+
self.n_edge_features = n_edge_features
|
22 |
+
self.n_time_features = n_time_features
|
23 |
+
self.dropout = dropout
|
24 |
+
self.embedding_dimension = embedding_dimension
|
25 |
+
self.device = device
|
26 |
+
|
27 |
+
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
|
28 |
+
use_time_proj=True):
|
29 |
+
return NotImplemented
|
30 |
+
|
31 |
+
|
32 |
+
class IdentityEmbedding(EmbeddingModule):
|
33 |
+
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
|
34 |
+
use_time_proj=True):
|
35 |
+
return memory[source_nodes, :]
|
36 |
+
|
37 |
+
|
38 |
+
class TimeEmbedding(EmbeddingModule):
|
39 |
+
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
|
40 |
+
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
|
41 |
+
n_heads=2, dropout=0.1, use_memory=True, n_neighbors=1):
|
42 |
+
super(TimeEmbedding, self).__init__(node_features, edge_features, memory,
|
43 |
+
neighbor_finder, time_encoder, n_layers,
|
44 |
+
n_node_features, n_edge_features, n_time_features,
|
45 |
+
embedding_dimension, device, dropout)
|
46 |
+
|
47 |
+
class NormalLinear(nn.Linear):
|
48 |
+
# From Jodie code
|
49 |
+
def reset_parameters(self):
|
50 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
51 |
+
self.weight.data.normal_(0, stdv)
|
52 |
+
if self.bias is not None:
|
53 |
+
self.bias.data.normal_(0, stdv)
|
54 |
+
|
55 |
+
self.embedding_layer = NormalLinear(1, self.n_node_features)
|
56 |
+
|
57 |
+
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
|
58 |
+
use_time_proj=True):
|
59 |
+
source_embeddings = memory[source_nodes, :] * (1 + self.embedding_layer(time_diffs.unsqueeze(1)))
|
60 |
+
|
61 |
+
return source_embeddings
|
62 |
+
|
63 |
+
|
64 |
+
class GraphEmbedding(EmbeddingModule):
|
65 |
+
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
|
66 |
+
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
|
67 |
+
n_heads=2, dropout=0.1, use_memory=True):
|
68 |
+
super(GraphEmbedding, self).__init__(node_features, edge_features, memory,
|
69 |
+
neighbor_finder, time_encoder, n_layers,
|
70 |
+
n_node_features, n_edge_features, n_time_features,
|
71 |
+
embedding_dimension, device, dropout)
|
72 |
+
|
73 |
+
self.use_memory = use_memory
|
74 |
+
self.device = device
|
75 |
+
|
76 |
+
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
|
77 |
+
use_time_proj=True):
|
78 |
+
"""Recursive implementation of curr_layers temporal graph attention layers.
|
79 |
+
|
80 |
+
src_idx_l [batch_size]: users / items input ids.
|
81 |
+
cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
|
82 |
+
curr_layers [scalar]: number of temporal convolutional layers to stack.
|
83 |
+
num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
|
84 |
+
"""
|
85 |
+
|
86 |
+
assert (n_layers >= 0)
|
87 |
+
|
88 |
+
source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
|
89 |
+
timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)
|
90 |
+
|
91 |
+
# query node always has the start time -> time span == 0
|
92 |
+
source_nodes_time_embedding = self.time_encoder(torch.zeros_like(
|
93 |
+
timestamps_torch))
|
94 |
+
|
95 |
+
source_node_features = self.node_features[source_nodes_torch, :]
|
96 |
+
|
97 |
+
if self.use_memory:
|
98 |
+
source_node_features = memory[source_nodes, :] + source_node_features
|
99 |
+
|
100 |
+
if n_layers == 0:
|
101 |
+
return source_node_features
|
102 |
+
else:
|
103 |
+
|
104 |
+
source_node_conv_embeddings = self.compute_embedding(memory,
|
105 |
+
source_nodes,
|
106 |
+
timestamps,
|
107 |
+
n_layers=n_layers - 1,
|
108 |
+
n_neighbors=n_neighbors)
|
109 |
+
|
110 |
+
neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(
|
111 |
+
source_nodes,
|
112 |
+
timestamps,
|
113 |
+
n_neighbors=n_neighbors)
|
114 |
+
|
115 |
+
neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)
|
116 |
+
|
117 |
+
edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)
|
118 |
+
|
119 |
+
edge_deltas = timestamps[:, np.newaxis] - edge_times
|
120 |
+
|
121 |
+
edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)
|
122 |
+
|
123 |
+
neighbors = neighbors.flatten()
|
124 |
+
neighbor_embeddings = self.compute_embedding(memory,
|
125 |
+
neighbors,
|
126 |
+
np.repeat(timestamps, n_neighbors),
|
127 |
+
n_layers=n_layers - 1,
|
128 |
+
n_neighbors=n_neighbors)
|
129 |
+
|
130 |
+
effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
|
131 |
+
neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)
|
132 |
+
edge_time_embeddings = self.time_encoder(edge_deltas_torch)
|
133 |
+
|
134 |
+
edge_features = self.edge_features[edge_idxs, :]
|
135 |
+
|
136 |
+
mask = neighbors_torch == 0
|
137 |
+
|
138 |
+
source_embedding = self.aggregate(n_layers, source_node_conv_embeddings,
|
139 |
+
source_nodes_time_embedding,
|
140 |
+
neighbor_embeddings,
|
141 |
+
edge_time_embeddings,
|
142 |
+
edge_features,
|
143 |
+
mask)
|
144 |
+
|
145 |
+
return source_embedding
|
146 |
+
|
147 |
+
def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,
|
148 |
+
neighbor_embeddings,
|
149 |
+
edge_time_embeddings, edge_features, mask):
|
150 |
+
return NotImplemented
|
151 |
+
|
152 |
+
|
153 |
+
class GraphSumEmbedding(GraphEmbedding):
|
154 |
+
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
|
155 |
+
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
|
156 |
+
n_heads=2, dropout=0.1, use_memory=True):
|
157 |
+
super(GraphSumEmbedding, self).__init__(node_features=node_features,
|
158 |
+
edge_features=edge_features,
|
159 |
+
memory=memory,
|
160 |
+
neighbor_finder=neighbor_finder,
|
161 |
+
time_encoder=time_encoder, n_layers=n_layers,
|
162 |
+
n_node_features=n_node_features,
|
163 |
+
n_edge_features=n_edge_features,
|
164 |
+
n_time_features=n_time_features,
|
165 |
+
embedding_dimension=embedding_dimension,
|
166 |
+
device=device,
|
167 |
+
n_heads=n_heads, dropout=dropout,
|
168 |
+
use_memory=use_memory)
|
169 |
+
self.linear_1 = torch.nn.ModuleList([torch.nn.Linear(embedding_dimension + n_time_features +
|
170 |
+
n_edge_features, embedding_dimension)
|
171 |
+
for _ in range(n_layers)])
|
172 |
+
self.linear_2 = torch.nn.ModuleList(
|
173 |
+
[torch.nn.Linear(embedding_dimension + n_node_features + n_time_features,
|
174 |
+
embedding_dimension) for _ in range(n_layers)])
|
175 |
+
|
176 |
+
def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
|
177 |
+
neighbor_embeddings,
|
178 |
+
edge_time_embeddings, edge_features, mask):
|
179 |
+
neighbors_features = torch.cat([neighbor_embeddings, edge_time_embeddings, edge_features],
|
180 |
+
dim=2)
|
181 |
+
neighbor_embeddings = self.linear_1[n_layer - 1](neighbors_features)
|
182 |
+
neighbors_sum = torch.nn.functional.relu(torch.sum(neighbor_embeddings, dim=1))
|
183 |
+
|
184 |
+
source_features = torch.cat([source_node_features,
|
185 |
+
source_nodes_time_embedding.squeeze()], dim=1)
|
186 |
+
source_embedding = torch.cat([neighbors_sum, source_features], dim=1)
|
187 |
+
source_embedding = self.linear_2[n_layer - 1](source_embedding)
|
188 |
+
|
189 |
+
return source_embedding
|
190 |
+
|
191 |
+
|
192 |
+
class GraphAttentionEmbedding(GraphEmbedding):
|
193 |
+
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
|
194 |
+
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
|
195 |
+
n_heads=2, dropout=0.1, use_memory=True):
|
196 |
+
super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory,
|
197 |
+
neighbor_finder, time_encoder, n_layers,
|
198 |
+
n_node_features, n_edge_features,
|
199 |
+
n_time_features,
|
200 |
+
embedding_dimension, device,
|
201 |
+
n_heads, dropout,
|
202 |
+
use_memory)
|
203 |
+
|
204 |
+
self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(
|
205 |
+
n_node_features=n_node_features,
|
206 |
+
n_neighbors_features=n_node_features,
|
207 |
+
n_edge_features=n_edge_features,
|
208 |
+
time_dim=n_time_features,
|
209 |
+
n_head=n_heads,
|
210 |
+
dropout=dropout,
|
211 |
+
output_dimension=n_node_features)
|
212 |
+
for _ in range(n_layers)])
|
213 |
+
|
214 |
+
def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
|
215 |
+
neighbor_embeddings,
|
216 |
+
edge_time_embeddings, edge_features, mask):
|
217 |
+
attention_model = self.attention_models[n_layer - 1]
|
218 |
+
|
219 |
+
source_embedding, _ = attention_model(source_node_features,
|
220 |
+
source_nodes_time_embedding,
|
221 |
+
neighbor_embeddings,
|
222 |
+
edge_time_embeddings,
|
223 |
+
edge_features,
|
224 |
+
mask)
|
225 |
+
|
226 |
+
return source_embedding
|
227 |
+
|
228 |
+
|
229 |
+
def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,
|
230 |
+
time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,
|
231 |
+
embedding_dimension, device,
|
232 |
+
n_heads=2, dropout=0.1, n_neighbors=None,
|
233 |
+
use_memory=True):
|
234 |
+
if module_type == "graph_attention":
|
235 |
+
return GraphAttentionEmbedding(node_features=node_features,
|
236 |
+
edge_features=edge_features,
|
237 |
+
memory=memory,
|
238 |
+
neighbor_finder=neighbor_finder,
|
239 |
+
time_encoder=time_encoder,
|
240 |
+
n_layers=n_layers,
|
241 |
+
n_node_features=n_node_features,
|
242 |
+
n_edge_features=n_edge_features,
|
243 |
+
n_time_features=n_time_features,
|
244 |
+
embedding_dimension=embedding_dimension,
|
245 |
+
device=device,
|
246 |
+
n_heads=n_heads, dropout=dropout, use_memory=use_memory)
|
247 |
+
elif module_type == "graph_sum":
|
248 |
+
return GraphSumEmbedding(node_features=node_features,
|
249 |
+
edge_features=edge_features,
|
250 |
+
memory=memory,
|
251 |
+
neighbor_finder=neighbor_finder,
|
252 |
+
time_encoder=time_encoder,
|
253 |
+
n_layers=n_layers,
|
254 |
+
n_node_features=n_node_features,
|
255 |
+
n_edge_features=n_edge_features,
|
256 |
+
n_time_features=n_time_features,
|
257 |
+
embedding_dimension=embedding_dimension,
|
258 |
+
device=device,
|
259 |
+
n_heads=n_heads, dropout=dropout, use_memory=use_memory)
|
260 |
+
|
261 |
+
elif module_type == "identity":
|
262 |
+
return IdentityEmbedding(node_features=node_features,
|
263 |
+
edge_features=edge_features,
|
264 |
+
memory=memory,
|
265 |
+
neighbor_finder=neighbor_finder,
|
266 |
+
time_encoder=time_encoder,
|
267 |
+
n_layers=n_layers,
|
268 |
+
n_node_features=n_node_features,
|
269 |
+
n_edge_features=n_edge_features,
|
270 |
+
n_time_features=n_time_features,
|
271 |
+
embedding_dimension=embedding_dimension,
|
272 |
+
device=device,
|
273 |
+
dropout=dropout)
|
274 |
+
elif module_type == "time":
|
275 |
+
return TimeEmbedding(node_features=node_features,
|
276 |
+
edge_features=edge_features,
|
277 |
+
memory=memory,
|
278 |
+
neighbor_finder=neighbor_finder,
|
279 |
+
time_encoder=time_encoder,
|
280 |
+
n_layers=n_layers,
|
281 |
+
n_node_features=n_node_features,
|
282 |
+
n_edge_features=n_edge_features,
|
283 |
+
n_time_features=n_time_features,
|
284 |
+
embedding_dimension=embedding_dimension,
|
285 |
+
device=device,
|
286 |
+
dropout=dropout,
|
287 |
+
n_neighbors=n_neighbors)
|
288 |
+
else:
|
289 |
+
raise ValueError("Embedding Module {} not supported".format(module_type))
|
290 |
+
|
291 |
+
|
modules/memory.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from collections import defaultdict
|
5 |
+
from copy import deepcopy
|
6 |
+
|
7 |
+
|
8 |
+
class Memory(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,
|
11 |
+
device="cpu", combination_method='sum'):
|
12 |
+
super(Memory, self).__init__()
|
13 |
+
self.n_nodes = n_nodes
|
14 |
+
self.memory_dimension = memory_dimension
|
15 |
+
self.input_dimension = input_dimension
|
16 |
+
self.message_dimension = message_dimension
|
17 |
+
self.device = device
|
18 |
+
|
19 |
+
self.combination_method = combination_method
|
20 |
+
|
21 |
+
self.__init_memory__()
|
22 |
+
|
23 |
+
def __init_memory__(self):
|
24 |
+
"""
|
25 |
+
Initializes the memory to all zeros. It should be called at the start of each epoch.
|
26 |
+
"""
|
27 |
+
# Treat memory as parameter so that it is saved and loaded together with the model
|
28 |
+
self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
|
29 |
+
requires_grad=False)
|
30 |
+
self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
|
31 |
+
requires_grad=False)
|
32 |
+
|
33 |
+
self.messages = defaultdict(list)
|
34 |
+
|
35 |
+
def store_raw_messages(self, nodes, node_id_to_messages):
|
36 |
+
for node in nodes:
|
37 |
+
self.messages[node].extend(node_id_to_messages[node])
|
38 |
+
|
39 |
+
def get_memory(self, node_idxs):
|
40 |
+
return self.memory[node_idxs, :]
|
41 |
+
|
42 |
+
def set_memory(self, node_idxs, values):
|
43 |
+
self.memory[node_idxs, :] = values
|
44 |
+
|
45 |
+
def get_last_update(self, node_idxs):
|
46 |
+
return self.last_update[node_idxs]
|
47 |
+
|
48 |
+
def backup_memory(self):
|
49 |
+
messages_clone = {}
|
50 |
+
for k, v in self.messages.items():
|
51 |
+
messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]
|
52 |
+
|
53 |
+
return self.memory.data.clone(), self.last_update.data.clone(), messages_clone
|
54 |
+
|
55 |
+
def restore_memory(self, memory_backup):
|
56 |
+
self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()
|
57 |
+
|
58 |
+
self.messages = defaultdict(list)
|
59 |
+
for k, v in memory_backup[2].items():
|
60 |
+
self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]
|
61 |
+
|
62 |
+
def detach_memory(self):
|
63 |
+
self.memory.detach_()
|
64 |
+
|
65 |
+
# Detach all stored messages
|
66 |
+
for k, v in self.messages.items():
|
67 |
+
new_node_messages = []
|
68 |
+
for message in v:
|
69 |
+
new_node_messages.append((message[0].detach(), message[1]))
|
70 |
+
|
71 |
+
self.messages[k] = new_node_messages
|
72 |
+
|
73 |
+
def clear_messages(self, nodes):
|
74 |
+
for node in nodes:
|
75 |
+
self.messages[node] = []
|
modules/memory_updater.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class MemoryUpdater(nn.Module):
|
6 |
+
def update_memory(self, unique_node_ids, unique_messages, timestamps):
|
7 |
+
pass
|
8 |
+
|
9 |
+
|
10 |
+
class SequenceMemoryUpdater(MemoryUpdater):
|
11 |
+
def __init__(self, memory, message_dimension, memory_dimension, device):
|
12 |
+
super(SequenceMemoryUpdater, self).__init__()
|
13 |
+
self.memory = memory
|
14 |
+
self.layer_norm = torch.nn.LayerNorm(memory_dimension)
|
15 |
+
self.message_dimension = message_dimension
|
16 |
+
self.device = device
|
17 |
+
|
18 |
+
def update_memory(self, unique_node_ids, unique_messages, timestamps):
|
19 |
+
if len(unique_node_ids) <= 0:
|
20 |
+
return
|
21 |
+
|
22 |
+
assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
|
23 |
+
"update memory to time in the past"
|
24 |
+
|
25 |
+
memory = self.memory.get_memory(unique_node_ids)
|
26 |
+
self.memory.last_update[unique_node_ids] = timestamps
|
27 |
+
|
28 |
+
updated_memory = self.memory_updater(unique_messages, memory)
|
29 |
+
|
30 |
+
self.memory.set_memory(unique_node_ids, updated_memory)
|
31 |
+
|
32 |
+
def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):
|
33 |
+
if len(unique_node_ids) <= 0:
|
34 |
+
return self.memory.memory.data.clone(), self.memory.last_update.data.clone()
|
35 |
+
|
36 |
+
assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
|
37 |
+
"update memory to time in the past"
|
38 |
+
|
39 |
+
updated_memory = self.memory.memory.data.clone()
|
40 |
+
updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids])
|
41 |
+
|
42 |
+
updated_last_update = self.memory.last_update.data.clone()
|
43 |
+
updated_last_update[unique_node_ids] = timestamps
|
44 |
+
|
45 |
+
return updated_memory, updated_last_update
|
46 |
+
|
47 |
+
|
48 |
+
class GRUMemoryUpdater(SequenceMemoryUpdater):
|
49 |
+
def __init__(self, memory, message_dimension, memory_dimension, device):
|
50 |
+
super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)
|
51 |
+
|
52 |
+
self.memory_updater = nn.GRUCell(input_size=message_dimension,
|
53 |
+
hidden_size=memory_dimension)
|
54 |
+
|
55 |
+
|
56 |
+
class RNNMemoryUpdater(SequenceMemoryUpdater):
|
57 |
+
def __init__(self, memory, message_dimension, memory_dimension, device):
|
58 |
+
super(RNNMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)
|
59 |
+
|
60 |
+
self.memory_updater = nn.RNNCell(input_size=message_dimension,
|
61 |
+
hidden_size=memory_dimension)
|
62 |
+
|
63 |
+
|
64 |
+
def get_memory_updater(module_type, memory, message_dimension, memory_dimension, device):
|
65 |
+
if module_type == "gru":
|
66 |
+
return GRUMemoryUpdater(memory, message_dimension, memory_dimension, device)
|
67 |
+
elif module_type == "rnn":
|
68 |
+
return RNNMemoryUpdater(memory, message_dimension, memory_dimension, device)
|
modules/message_aggregator.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class MessageAggregator(torch.nn.Module):
|
7 |
+
"""
|
8 |
+
Abstract class for the message aggregator module, which given a batch of node ids and
|
9 |
+
corresponding messages, aggregates messages with the same node id.
|
10 |
+
"""
|
11 |
+
def __init__(self, device):
|
12 |
+
super(MessageAggregator, self).__init__()
|
13 |
+
self.device = device
|
14 |
+
|
15 |
+
def aggregate(self, node_ids, messages):
|
16 |
+
"""
|
17 |
+
Given a list of node ids, and a list of messages of the same length, aggregate different
|
18 |
+
messages for the same id using one of the possible strategies.
|
19 |
+
:param node_ids: A list of node ids of length batch_size
|
20 |
+
:param messages: A tensor of shape [batch_size, message_length]
|
21 |
+
:param timestamps A tensor of shape [batch_size]
|
22 |
+
:return: A tensor of shape [n_unique_node_ids, message_length] with the aggregated messages
|
23 |
+
"""
|
24 |
+
|
25 |
+
def group_by_id(self, node_ids, messages, timestamps):
|
26 |
+
node_id_to_messages = defaultdict(list)
|
27 |
+
|
28 |
+
for i, node_id in enumerate(node_ids):
|
29 |
+
node_id_to_messages[node_id].append((messages[i], timestamps[i]))
|
30 |
+
|
31 |
+
return node_id_to_messages
|
32 |
+
|
33 |
+
|
34 |
+
class LastMessageAggregator(MessageAggregator):
|
35 |
+
def __init__(self, device):
|
36 |
+
super(LastMessageAggregator, self).__init__(device)
|
37 |
+
|
38 |
+
def aggregate(self, node_ids, messages):
|
39 |
+
"""Only keep the last message for each node"""
|
40 |
+
unique_node_ids = np.unique(node_ids)
|
41 |
+
unique_messages = []
|
42 |
+
unique_timestamps = []
|
43 |
+
|
44 |
+
to_update_node_ids = []
|
45 |
+
|
46 |
+
for node_id in unique_node_ids:
|
47 |
+
if len(messages[node_id]) > 0:
|
48 |
+
to_update_node_ids.append(node_id)
|
49 |
+
unique_messages.append(messages[node_id][-1][0])
|
50 |
+
unique_timestamps.append(messages[node_id][-1][1])
|
51 |
+
|
52 |
+
unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
|
53 |
+
unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
|
54 |
+
|
55 |
+
return to_update_node_ids, unique_messages, unique_timestamps
|
56 |
+
|
57 |
+
|
58 |
+
class MeanMessageAggregator(MessageAggregator):
|
59 |
+
def __init__(self, device):
|
60 |
+
super(MeanMessageAggregator, self).__init__(device)
|
61 |
+
|
62 |
+
def aggregate(self, node_ids, messages):
|
63 |
+
"""Only keep the last message for each node"""
|
64 |
+
unique_node_ids = np.unique(node_ids)
|
65 |
+
unique_messages = []
|
66 |
+
unique_timestamps = []
|
67 |
+
|
68 |
+
to_update_node_ids = []
|
69 |
+
n_messages = 0
|
70 |
+
|
71 |
+
for node_id in unique_node_ids:
|
72 |
+
if len(messages[node_id]) > 0:
|
73 |
+
n_messages += len(messages[node_id])
|
74 |
+
to_update_node_ids.append(node_id)
|
75 |
+
unique_messages.append(torch.mean(torch.stack([m[0] for m in messages[node_id]]), dim=0))
|
76 |
+
unique_timestamps.append(messages[node_id][-1][1])
|
77 |
+
|
78 |
+
unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
|
79 |
+
unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
|
80 |
+
|
81 |
+
return to_update_node_ids, unique_messages, unique_timestamps
|
82 |
+
|
83 |
+
|
84 |
+
def get_message_aggregator(aggregator_type, device):
|
85 |
+
if aggregator_type == "last":
|
86 |
+
return LastMessageAggregator(device=device)
|
87 |
+
elif aggregator_type == "mean":
|
88 |
+
return MeanMessageAggregator(device=device)
|
89 |
+
else:
|
90 |
+
raise ValueError("Message aggregator {} not implemented".format(aggregator_type))
|
modules/message_function.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class MessageFunction(nn.Module):
|
5 |
+
"""
|
6 |
+
Module which computes the message for a given interaction.
|
7 |
+
"""
|
8 |
+
|
9 |
+
def compute_message(self, raw_messages):
|
10 |
+
return None
|
11 |
+
|
12 |
+
|
13 |
+
class MLPMessageFunction(MessageFunction):
|
14 |
+
def __init__(self, raw_message_dimension, message_dimension):
|
15 |
+
super(MLPMessageFunction, self).__init__()
|
16 |
+
|
17 |
+
self.mlp = self.layers = nn.Sequential(
|
18 |
+
nn.Linear(raw_message_dimension, raw_message_dimension // 2),
|
19 |
+
nn.ReLU(),
|
20 |
+
nn.Linear(raw_message_dimension // 2, message_dimension),
|
21 |
+
)
|
22 |
+
|
23 |
+
def compute_message(self, raw_messages):
|
24 |
+
messages = self.mlp(raw_messages)
|
25 |
+
|
26 |
+
return messages
|
27 |
+
|
28 |
+
|
29 |
+
class IdentityMessageFunction(MessageFunction):
|
30 |
+
|
31 |
+
def compute_message(self, raw_messages):
|
32 |
+
|
33 |
+
return raw_messages
|
34 |
+
|
35 |
+
|
36 |
+
def get_message_function(module_type, raw_message_dimension, message_dimension):
|
37 |
+
if module_type == "mlp":
|
38 |
+
return MLPMessageFunction(raw_message_dimension, message_dimension)
|
39 |
+
elif module_type == "identity":
|
40 |
+
return IdentityMessageFunction()
|
modules/tgn.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from utils.utils import MergeLayer
|
7 |
+
from modules.memory import Memory
|
8 |
+
from modules.message_aggregator import get_message_aggregator
|
9 |
+
from modules.message_function import get_message_function
|
10 |
+
from modules.memory_updater import get_memory_updater
|
11 |
+
from modules.embedding_module import get_embedding_module
|
12 |
+
from model.time_encoding import TimeEncode
|
13 |
+
|
14 |
+
|
15 |
+
class TGN(torch.nn.Module):
|
16 |
+
def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2,
|
17 |
+
n_heads=2, dropout=0.1, use_memory=False,
|
18 |
+
memory_update_at_start=True, message_dimension=100,
|
19 |
+
memory_dimension=500, embedding_module_type="graph_attention",
|
20 |
+
message_function="mlp",
|
21 |
+
mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
|
22 |
+
std_time_shift_dst=1, n_neighbors=None, aggregator_type="last",
|
23 |
+
memory_updater_type="gru",
|
24 |
+
use_destination_embedding_in_message=False,
|
25 |
+
use_source_embedding_in_message=False,
|
26 |
+
dyrep=False):
|
27 |
+
super(TGN, self).__init__()
|
28 |
+
|
29 |
+
self.n_layers = n_layers
|
30 |
+
self.neighbor_finder = neighbor_finder
|
31 |
+
self.device = device
|
32 |
+
self.logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device)
|
35 |
+
self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device)
|
36 |
+
|
37 |
+
self.n_node_features = self.node_raw_features.shape[1]
|
38 |
+
self.n_nodes = self.node_raw_features.shape[0]
|
39 |
+
self.n_edge_features = self.edge_raw_features.shape[1]
|
40 |
+
self.embedding_dimension = self.n_node_features
|
41 |
+
self.n_neighbors = n_neighbors
|
42 |
+
self.embedding_module_type = embedding_module_type
|
43 |
+
self.use_destination_embedding_in_message = use_destination_embedding_in_message
|
44 |
+
self.use_source_embedding_in_message = use_source_embedding_in_message
|
45 |
+
self.dyrep = dyrep
|
46 |
+
|
47 |
+
self.use_memory = use_memory
|
48 |
+
self.time_encoder = TimeEncode(dimension=self.n_node_features)
|
49 |
+
self.memory = None
|
50 |
+
|
51 |
+
self.mean_time_shift_src = mean_time_shift_src
|
52 |
+
self.std_time_shift_src = std_time_shift_src
|
53 |
+
self.mean_time_shift_dst = mean_time_shift_dst
|
54 |
+
self.std_time_shift_dst = std_time_shift_dst
|
55 |
+
|
56 |
+
if self.use_memory:
|
57 |
+
self.memory_dimension = memory_dimension
|
58 |
+
self.memory_update_at_start = memory_update_at_start
|
59 |
+
raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \
|
60 |
+
self.time_encoder.dimension
|
61 |
+
message_dimension = message_dimension if message_function != "identity" else raw_message_dimension
|
62 |
+
self.memory = Memory(n_nodes=self.n_nodes,
|
63 |
+
memory_dimension=self.memory_dimension,
|
64 |
+
input_dimension=message_dimension,
|
65 |
+
message_dimension=message_dimension,
|
66 |
+
device=device)
|
67 |
+
self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type,
|
68 |
+
device=device)
|
69 |
+
self.message_function = get_message_function(module_type=message_function,
|
70 |
+
raw_message_dimension=raw_message_dimension,
|
71 |
+
message_dimension=message_dimension)
|
72 |
+
self.memory_updater = get_memory_updater(module_type=memory_updater_type,
|
73 |
+
memory=self.memory,
|
74 |
+
message_dimension=message_dimension,
|
75 |
+
memory_dimension=self.memory_dimension,
|
76 |
+
device=device)
|
77 |
+
|
78 |
+
self.embedding_module_type = embedding_module_type
|
79 |
+
|
80 |
+
self.embedding_module = get_embedding_module(module_type=embedding_module_type,
|
81 |
+
node_features=self.node_raw_features,
|
82 |
+
edge_features=self.edge_raw_features,
|
83 |
+
memory=self.memory,
|
84 |
+
neighbor_finder=self.neighbor_finder,
|
85 |
+
time_encoder=self.time_encoder,
|
86 |
+
n_layers=self.n_layers,
|
87 |
+
n_node_features=self.n_node_features,
|
88 |
+
n_edge_features=self.n_edge_features,
|
89 |
+
n_time_features=self.n_node_features,
|
90 |
+
embedding_dimension=self.embedding_dimension,
|
91 |
+
device=self.device,
|
92 |
+
n_heads=n_heads, dropout=dropout,
|
93 |
+
use_memory=use_memory,
|
94 |
+
n_neighbors=self.n_neighbors)
|
95 |
+
|
96 |
+
# MLP to compute probability on an edge given two node embeddings
|
97 |
+
self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features,
|
98 |
+
self.n_node_features,
|
99 |
+
1)
|
100 |
+
|
101 |
+
def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
|
102 |
+
edge_idxs, n_neighbors=20):
|
103 |
+
"""
|
104 |
+
Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
|
105 |
+
|
106 |
+
source_nodes [batch_size]: source ids.
|
107 |
+
:param destination_nodes [batch_size]: destination ids
|
108 |
+
:param negative_nodes [batch_size]: ids of negative sampled destination
|
109 |
+
:param edge_times [batch_size]: timestamp of interaction
|
110 |
+
:param edge_idxs [batch_size]: index of interaction
|
111 |
+
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
|
112 |
+
layer
|
113 |
+
:return: Temporal embeddings for sources, destinations and negatives
|
114 |
+
"""
|
115 |
+
|
116 |
+
n_samples = len(source_nodes)
|
117 |
+
nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
|
118 |
+
positives = np.concatenate([source_nodes, destination_nodes])
|
119 |
+
timestamps = np.concatenate([edge_times, edge_times, edge_times])
|
120 |
+
|
121 |
+
memory = None
|
122 |
+
time_diffs = None
|
123 |
+
if self.use_memory:
|
124 |
+
if self.memory_update_at_start:
|
125 |
+
# Update memory for all nodes with messages stored in previous batches
|
126 |
+
memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
|
127 |
+
self.memory.messages)
|
128 |
+
else:
|
129 |
+
memory = self.memory.get_memory(list(range(self.n_nodes)))
|
130 |
+
last_update = self.memory.last_update
|
131 |
+
|
132 |
+
### Compute differences between the time the memory of a node was last updated,
|
133 |
+
### and the time for which we want to compute the embedding of a node
|
134 |
+
source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
|
135 |
+
source_nodes].long()
|
136 |
+
source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
|
137 |
+
destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
|
138 |
+
destination_nodes].long()
|
139 |
+
destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
|
140 |
+
negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
|
141 |
+
negative_nodes].long()
|
142 |
+
negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
|
143 |
+
|
144 |
+
time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
|
145 |
+
dim=0)
|
146 |
+
|
147 |
+
# Compute the embeddings using the embedding module
|
148 |
+
node_embedding = self.embedding_module.compute_embedding(memory=memory,
|
149 |
+
source_nodes=nodes,
|
150 |
+
timestamps=timestamps,
|
151 |
+
n_layers=self.n_layers,
|
152 |
+
n_neighbors=n_neighbors,
|
153 |
+
time_diffs=time_diffs)
|
154 |
+
|
155 |
+
source_node_embedding = node_embedding[:n_samples]
|
156 |
+
destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
|
157 |
+
negative_node_embedding = node_embedding[2 * n_samples:]
|
158 |
+
|
159 |
+
if self.use_memory:
|
160 |
+
if self.memory_update_at_start:
|
161 |
+
# Persist the updates to the memory only for sources and destinations (since now we have
|
162 |
+
# new messages for them)
|
163 |
+
self.update_memory(positives, self.memory.messages)
|
164 |
+
|
165 |
+
assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
|
166 |
+
"Something wrong in how the memory was updated"
|
167 |
+
|
168 |
+
# Remove messages for the positives since we have already updated the memory using them
|
169 |
+
self.memory.clear_messages(positives)
|
170 |
+
|
171 |
+
unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
|
172 |
+
source_node_embedding,
|
173 |
+
destination_nodes,
|
174 |
+
destination_node_embedding,
|
175 |
+
edge_times, edge_idxs)
|
176 |
+
unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
|
177 |
+
destination_node_embedding,
|
178 |
+
source_nodes,
|
179 |
+
source_node_embedding,
|
180 |
+
edge_times, edge_idxs)
|
181 |
+
if self.memory_update_at_start:
|
182 |
+
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
|
183 |
+
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
|
184 |
+
else:
|
185 |
+
self.update_memory(unique_sources, source_id_to_messages)
|
186 |
+
self.update_memory(unique_destinations, destination_id_to_messages)
|
187 |
+
|
188 |
+
if self.dyrep:
|
189 |
+
source_node_embedding = memory[source_nodes]
|
190 |
+
destination_node_embedding = memory[destination_nodes]
|
191 |
+
negative_node_embedding = memory[negative_nodes]
|
192 |
+
|
193 |
+
return source_node_embedding, destination_node_embedding, negative_node_embedding
|
194 |
+
|
195 |
+
def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
|
196 |
+
edge_idxs, n_neighbors=20):
|
197 |
+
"""
|
198 |
+
Compute probabilities for edges between sources and destination and between sources and
|
199 |
+
negatives by first computing temporal embeddings using the TGN encoder and then feeding them
|
200 |
+
into the MLP decoder.
|
201 |
+
:param destination_nodes [batch_size]: destination ids
|
202 |
+
:param negative_nodes [batch_size]: ids of negative sampled destination
|
203 |
+
:param edge_times [batch_size]: timestamp of interaction
|
204 |
+
:param edge_idxs [batch_size]: index of interaction
|
205 |
+
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
|
206 |
+
layer
|
207 |
+
:return: Probabilities for both the positive and negative edges
|
208 |
+
"""
|
209 |
+
n_samples = len(source_nodes)
|
210 |
+
source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(
|
211 |
+
source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)
|
212 |
+
|
213 |
+
score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
|
214 |
+
torch.cat([destination_node_embedding,
|
215 |
+
negative_node_embedding])).squeeze(dim=0)
|
216 |
+
pos_score = score[:n_samples]
|
217 |
+
neg_score = score[n_samples:]
|
218 |
+
|
219 |
+
return pos_score.sigmoid(), neg_score.sigmoid()
|
220 |
+
|
221 |
+
def update_memory(self, nodes, messages):
|
222 |
+
# Aggregate messages for the same nodes
|
223 |
+
unique_nodes, unique_messages, unique_timestamps = \
|
224 |
+
self.message_aggregator.aggregate(
|
225 |
+
nodes,
|
226 |
+
messages)
|
227 |
+
|
228 |
+
if len(unique_nodes) > 0:
|
229 |
+
unique_messages = self.message_function.compute_message(unique_messages)
|
230 |
+
|
231 |
+
# Update the memory with the aggregated messages
|
232 |
+
self.memory_updater.update_memory(unique_nodes, unique_messages,
|
233 |
+
timestamps=unique_timestamps)
|
234 |
+
|
235 |
+
def get_updated_memory(self, nodes, messages):
|
236 |
+
# Aggregate messages for the same nodes
|
237 |
+
unique_nodes, unique_messages, unique_timestamps = \
|
238 |
+
self.message_aggregator.aggregate(
|
239 |
+
nodes,
|
240 |
+
messages)
|
241 |
+
|
242 |
+
if len(unique_nodes) > 0:
|
243 |
+
unique_messages = self.message_function.compute_message(unique_messages)
|
244 |
+
|
245 |
+
updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
|
246 |
+
unique_messages,
|
247 |
+
timestamps=unique_timestamps)
|
248 |
+
|
249 |
+
return updated_memory, updated_last_update
|
250 |
+
|
251 |
+
def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
|
252 |
+
destination_node_embedding, edge_times, edge_idxs):
|
253 |
+
edge_times = torch.from_numpy(edge_times).float().to(self.device)
|
254 |
+
edge_features = self.edge_raw_features[edge_idxs]
|
255 |
+
|
256 |
+
source_memory = self.memory.get_memory(source_nodes) if not \
|
257 |
+
self.use_source_embedding_in_message else source_node_embedding
|
258 |
+
destination_memory = self.memory.get_memory(destination_nodes) if \
|
259 |
+
not self.use_destination_embedding_in_message else destination_node_embedding
|
260 |
+
|
261 |
+
source_time_delta = edge_times - self.memory.last_update[source_nodes]
|
262 |
+
source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
|
263 |
+
source_nodes), -1)
|
264 |
+
|
265 |
+
source_message = torch.cat([source_memory, destination_memory, edge_features,
|
266 |
+
source_time_delta_encoding],
|
267 |
+
dim=1)
|
268 |
+
messages = defaultdict(list)
|
269 |
+
unique_sources = np.unique(source_nodes)
|
270 |
+
|
271 |
+
for i in range(len(source_nodes)):
|
272 |
+
messages[source_nodes[i]].append((source_message[i], edge_times[i]))
|
273 |
+
|
274 |
+
return unique_sources, messages
|
275 |
+
|
276 |
+
def set_neighbor_finder(self, neighbor_finder):
|
277 |
+
self.neighbor_finder = neighbor_finder
|
278 |
+
self.embedding_module.neighbor_finder = neighbor_finder
|
utils/data_processing.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
|
6 |
+
class Data:
|
7 |
+
def __init__(self, sources, destinations, timestamps, edge_idxs, labels):
|
8 |
+
self.sources = sources
|
9 |
+
self.destinations = destinations
|
10 |
+
self.timestamps = timestamps
|
11 |
+
self.edge_idxs = edge_idxs
|
12 |
+
self.labels = labels
|
13 |
+
self.n_interactions = len(sources)
|
14 |
+
self.unique_nodes = set(sources) | set(destinations)
|
15 |
+
self.n_unique_nodes = len(self.unique_nodes)
|
16 |
+
|
17 |
+
|
18 |
+
def get_data_node_classification(dataset_name, use_validation=False):
|
19 |
+
### Load data and train val test split
|
20 |
+
graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name))
|
21 |
+
edge_features = np.load('./data/ml_{}.npy'.format(dataset_name))
|
22 |
+
node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name))
|
23 |
+
|
24 |
+
val_time, test_time = list(np.quantile(graph_df.ts, [0.70, 0.85]))
|
25 |
+
|
26 |
+
sources = graph_df.u.values
|
27 |
+
destinations = graph_df.i.values
|
28 |
+
edge_idxs = graph_df.idx.values
|
29 |
+
labels = graph_df.label.values
|
30 |
+
timestamps = graph_df.ts.values
|
31 |
+
|
32 |
+
random.seed(2020)
|
33 |
+
|
34 |
+
train_mask = timestamps <= val_time if use_validation else timestamps <= test_time
|
35 |
+
test_mask = timestamps > test_time
|
36 |
+
val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time) if use_validation else test_mask
|
37 |
+
|
38 |
+
full_data = Data(sources, destinations, timestamps, edge_idxs, labels)
|
39 |
+
|
40 |
+
train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask],
|
41 |
+
edge_idxs[train_mask], labels[train_mask])
|
42 |
+
|
43 |
+
val_data = Data(sources[val_mask], destinations[val_mask], timestamps[val_mask],
|
44 |
+
edge_idxs[val_mask], labels[val_mask])
|
45 |
+
|
46 |
+
test_data = Data(sources[test_mask], destinations[test_mask], timestamps[test_mask],
|
47 |
+
edge_idxs[test_mask], labels[test_mask])
|
48 |
+
|
49 |
+
return full_data, node_features, edge_features, train_data, val_data, test_data
|
50 |
+
|
51 |
+
|
52 |
+
def get_data(dataset_name, different_new_nodes_between_val_and_test=False, randomize_features=False):
|
53 |
+
### Load data and train val test split
|
54 |
+
graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name))
|
55 |
+
edge_features = np.load('./data/ml_{}.npy'.format(dataset_name))
|
56 |
+
node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name))
|
57 |
+
|
58 |
+
if randomize_features:
|
59 |
+
node_features = np.random.rand(node_features.shape[0], node_features.shape[1])
|
60 |
+
|
61 |
+
val_time, test_time = list(np.quantile(graph_df.ts, [0.70, 0.85]))
|
62 |
+
|
63 |
+
sources = graph_df.u.values
|
64 |
+
destinations = graph_df.i.values
|
65 |
+
edge_idxs = graph_df.idx.values
|
66 |
+
labels = graph_df.label.values
|
67 |
+
timestamps = graph_df.ts.values
|
68 |
+
|
69 |
+
full_data = Data(sources, destinations, timestamps, edge_idxs, labels)
|
70 |
+
|
71 |
+
random.seed(2020)
|
72 |
+
|
73 |
+
node_set = set(sources) | set(destinations)
|
74 |
+
n_total_unique_nodes = len(node_set)
|
75 |
+
|
76 |
+
# Compute nodes which appear at test time
|
77 |
+
test_node_set = set(sources[timestamps > val_time]).union(
|
78 |
+
set(destinations[timestamps > val_time]))
|
79 |
+
# Sample nodes which we keep as new nodes (to test inductiveness), so than we have to remove all
|
80 |
+
# their edges from training
|
81 |
+
new_test_node_set = set(random.sample(test_node_set, int(0.1 * n_total_unique_nodes)))
|
82 |
+
|
83 |
+
# Mask saying for each source and destination whether they are new test nodes
|
84 |
+
new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
|
85 |
+
new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values
|
86 |
+
|
87 |
+
# Mask which is true for edges with both destination and source not being new test nodes (because
|
88 |
+
# we want to remove all edges involving any new test node)
|
89 |
+
observed_edges_mask = np.logical_and(~new_test_source_mask, ~new_test_destination_mask)
|
90 |
+
|
91 |
+
# For train we keep edges happening before the validation time which do not involve any new node
|
92 |
+
# used for inductiveness
|
93 |
+
train_mask = np.logical_and(timestamps <= val_time, observed_edges_mask)
|
94 |
+
|
95 |
+
train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask],
|
96 |
+
edge_idxs[train_mask], labels[train_mask])
|
97 |
+
|
98 |
+
# define the new nodes sets for testing inductiveness of the model
|
99 |
+
train_node_set = set(train_data.sources).union(train_data.destinations)
|
100 |
+
assert len(train_node_set & new_test_node_set) == 0
|
101 |
+
new_node_set = node_set - train_node_set
|
102 |
+
|
103 |
+
val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time)
|
104 |
+
test_mask = timestamps > test_time
|
105 |
+
|
106 |
+
if different_new_nodes_between_val_and_test:
|
107 |
+
n_new_nodes = len(new_test_node_set) // 2
|
108 |
+
val_new_node_set = set(list(new_test_node_set)[:n_new_nodes])
|
109 |
+
test_new_node_set = set(list(new_test_node_set)[n_new_nodes:])
|
110 |
+
|
111 |
+
edge_contains_new_val_node_mask = np.array(
|
112 |
+
[(a in val_new_node_set or b in val_new_node_set) for a, b in zip(sources, destinations)])
|
113 |
+
edge_contains_new_test_node_mask = np.array(
|
114 |
+
[(a in test_new_node_set or b in test_new_node_set) for a, b in zip(sources, destinations)])
|
115 |
+
new_node_val_mask = np.logical_and(val_mask, edge_contains_new_val_node_mask)
|
116 |
+
new_node_test_mask = np.logical_and(test_mask, edge_contains_new_test_node_mask)
|
117 |
+
|
118 |
+
|
119 |
+
else:
|
120 |
+
edge_contains_new_node_mask = np.array(
|
121 |
+
[(a in new_node_set or b in new_node_set) for a, b in zip(sources, destinations)])
|
122 |
+
new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask)
|
123 |
+
new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask)
|
124 |
+
|
125 |
+
# validation and test with all edges
|
126 |
+
val_data = Data(sources[val_mask], destinations[val_mask], timestamps[val_mask],
|
127 |
+
edge_idxs[val_mask], labels[val_mask])
|
128 |
+
|
129 |
+
test_data = Data(sources[test_mask], destinations[test_mask], timestamps[test_mask],
|
130 |
+
edge_idxs[test_mask], labels[test_mask])
|
131 |
+
|
132 |
+
# validation and test with edges that at least has one new node (not in training set)
|
133 |
+
new_node_val_data = Data(sources[new_node_val_mask], destinations[new_node_val_mask],
|
134 |
+
timestamps[new_node_val_mask],
|
135 |
+
edge_idxs[new_node_val_mask], labels[new_node_val_mask])
|
136 |
+
|
137 |
+
new_node_test_data = Data(sources[new_node_test_mask], destinations[new_node_test_mask],
|
138 |
+
timestamps[new_node_test_mask], edge_idxs[new_node_test_mask],
|
139 |
+
labels[new_node_test_mask])
|
140 |
+
|
141 |
+
print("The dataset has {} interactions, involving {} different nodes".format(full_data.n_interactions,
|
142 |
+
full_data.n_unique_nodes))
|
143 |
+
print("The training dataset has {} interactions, involving {} different nodes".format(
|
144 |
+
train_data.n_interactions, train_data.n_unique_nodes))
|
145 |
+
print("The validation dataset has {} interactions, involving {} different nodes".format(
|
146 |
+
val_data.n_interactions, val_data.n_unique_nodes))
|
147 |
+
print("The test dataset has {} interactions, involving {} different nodes".format(
|
148 |
+
test_data.n_interactions, test_data.n_unique_nodes))
|
149 |
+
print("The new node validation dataset has {} interactions, involving {} different nodes".format(
|
150 |
+
new_node_val_data.n_interactions, new_node_val_data.n_unique_nodes))
|
151 |
+
print("The new node test dataset has {} interactions, involving {} different nodes".format(
|
152 |
+
new_node_test_data.n_interactions, new_node_test_data.n_unique_nodes))
|
153 |
+
print("{} nodes were used for the inductive testing, i.e. are never seen during training".format(
|
154 |
+
len(new_test_node_set)))
|
155 |
+
|
156 |
+
return node_features, edge_features, full_data, train_data, val_data, test_data, \
|
157 |
+
new_node_val_data, new_node_test_data
|
158 |
+
|
159 |
+
|
160 |
+
def compute_time_statistics(sources, destinations, timestamps):
|
161 |
+
last_timestamp_sources = dict()
|
162 |
+
last_timestamp_dst = dict()
|
163 |
+
all_timediffs_src = []
|
164 |
+
all_timediffs_dst = []
|
165 |
+
for k in range(len(sources)):
|
166 |
+
source_id = sources[k]
|
167 |
+
dest_id = destinations[k]
|
168 |
+
c_timestamp = timestamps[k]
|
169 |
+
if source_id not in last_timestamp_sources.keys():
|
170 |
+
last_timestamp_sources[source_id] = 0
|
171 |
+
if dest_id not in last_timestamp_dst.keys():
|
172 |
+
last_timestamp_dst[dest_id] = 0
|
173 |
+
all_timediffs_src.append(c_timestamp - last_timestamp_sources[source_id])
|
174 |
+
all_timediffs_dst.append(c_timestamp - last_timestamp_dst[dest_id])
|
175 |
+
last_timestamp_sources[source_id] = c_timestamp
|
176 |
+
last_timestamp_dst[dest_id] = c_timestamp
|
177 |
+
assert len(all_timediffs_src) == len(sources)
|
178 |
+
assert len(all_timediffs_dst) == len(sources)
|
179 |
+
mean_time_shift_src = np.mean(all_timediffs_src)
|
180 |
+
std_time_shift_src = np.std(all_timediffs_src)
|
181 |
+
mean_time_shift_dst = np.mean(all_timediffs_dst)
|
182 |
+
std_time_shift_dst = np.std(all_timediffs_dst)
|
183 |
+
|
184 |
+
return mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst
|
utils/utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class MergeLayer(torch.nn.Module):
|
6 |
+
def __init__(self, dim1, dim2, dim3, dim4):
|
7 |
+
super().__init__()
|
8 |
+
self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
|
9 |
+
self.fc2 = torch.nn.Linear(dim3, dim4)
|
10 |
+
self.act = torch.nn.ReLU()
|
11 |
+
|
12 |
+
torch.nn.init.xavier_normal_(self.fc1.weight)
|
13 |
+
torch.nn.init.xavier_normal_(self.fc2.weight)
|
14 |
+
|
15 |
+
def forward(self, x1, x2):
|
16 |
+
x = torch.cat([x1, x2], dim=1)
|
17 |
+
h = self.act(self.fc1(x))
|
18 |
+
return self.fc2(h)
|
19 |
+
|
20 |
+
|
21 |
+
class MLP(torch.nn.Module):
|
22 |
+
def __init__(self, dim, drop=0.3):
|
23 |
+
super().__init__()
|
24 |
+
self.fc_1 = torch.nn.Linear(dim, 80)
|
25 |
+
self.fc_2 = torch.nn.Linear(80, 10)
|
26 |
+
self.fc_3 = torch.nn.Linear(10, 1)
|
27 |
+
self.act = torch.nn.ReLU()
|
28 |
+
self.dropout = torch.nn.Dropout(p=drop, inplace=False)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x = self.act(self.fc_1(x))
|
32 |
+
x = self.dropout(x)
|
33 |
+
x = self.act(self.fc_2(x))
|
34 |
+
x = self.dropout(x)
|
35 |
+
return self.fc_3(x).squeeze(dim=1)
|
36 |
+
|
37 |
+
|
38 |
+
class EarlyStopMonitor(object):
|
39 |
+
def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
|
40 |
+
self.max_round = max_round
|
41 |
+
self.num_round = 0
|
42 |
+
|
43 |
+
self.epoch_count = 0
|
44 |
+
self.best_epoch = 0
|
45 |
+
|
46 |
+
self.last_best = None
|
47 |
+
self.higher_better = higher_better
|
48 |
+
self.tolerance = tolerance
|
49 |
+
|
50 |
+
def early_stop_check(self, curr_val):
|
51 |
+
if not self.higher_better:
|
52 |
+
curr_val *= -1
|
53 |
+
if self.last_best is None:
|
54 |
+
self.last_best = curr_val
|
55 |
+
elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
|
56 |
+
self.last_best = curr_val
|
57 |
+
self.num_round = 0
|
58 |
+
self.best_epoch = self.epoch_count
|
59 |
+
else:
|
60 |
+
self.num_round += 1
|
61 |
+
|
62 |
+
self.epoch_count += 1
|
63 |
+
|
64 |
+
return self.num_round >= self.max_round
|
65 |
+
|
66 |
+
|
67 |
+
class RandEdgeSampler(object):
|
68 |
+
def __init__(self, src_list, dst_list, seed=None):
|
69 |
+
self.seed = None
|
70 |
+
self.src_list = np.unique(src_list)
|
71 |
+
self.dst_list = np.unique(dst_list)
|
72 |
+
|
73 |
+
if seed is not None:
|
74 |
+
self.seed = seed
|
75 |
+
self.random_state = np.random.RandomState(self.seed)
|
76 |
+
|
77 |
+
def sample(self, size):
|
78 |
+
if self.seed is None:
|
79 |
+
src_index = np.random.randint(0, len(self.src_list), size)
|
80 |
+
dst_index = np.random.randint(0, len(self.dst_list), size)
|
81 |
+
else:
|
82 |
+
|
83 |
+
src_index = self.random_state.randint(0, len(self.src_list), size)
|
84 |
+
dst_index = self.random_state.randint(0, len(self.dst_list), size)
|
85 |
+
return self.src_list[src_index], self.dst_list[dst_index]
|
86 |
+
|
87 |
+
def reset_random_state(self):
|
88 |
+
self.random_state = np.random.RandomState(self.seed)
|
89 |
+
|
90 |
+
|
91 |
+
def get_neighbor_finder(data, uniform, max_node_idx=None):
|
92 |
+
max_node_idx = max(data.sources.max(), data.destinations.max()) if max_node_idx is None else max_node_idx
|
93 |
+
adj_list = [[] for _ in range(max_node_idx + 1)]
|
94 |
+
for source, destination, edge_idx, timestamp in zip(data.sources, data.destinations,
|
95 |
+
data.edge_idxs,
|
96 |
+
data.timestamps):
|
97 |
+
adj_list[source].append((destination, edge_idx, timestamp))
|
98 |
+
adj_list[destination].append((source, edge_idx, timestamp))
|
99 |
+
|
100 |
+
return NeighborFinder(adj_list, uniform=uniform)
|
101 |
+
|
102 |
+
|
103 |
+
class NeighborFinder:
|
104 |
+
def __init__(self, adj_list, uniform=False, seed=None):
|
105 |
+
self.node_to_neighbors = []
|
106 |
+
self.node_to_edge_idxs = []
|
107 |
+
self.node_to_edge_timestamps = []
|
108 |
+
|
109 |
+
for neighbors in adj_list:
|
110 |
+
# Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
|
111 |
+
# We sort the list based on timestamp
|
112 |
+
sorted_neighhbors = sorted(neighbors, key=lambda x: x[2])
|
113 |
+
self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors]))
|
114 |
+
self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))
|
115 |
+
self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))
|
116 |
+
|
117 |
+
self.uniform = uniform
|
118 |
+
|
119 |
+
if seed is not None:
|
120 |
+
self.seed = seed
|
121 |
+
self.random_state = np.random.RandomState(self.seed)
|
122 |
+
|
123 |
+
def find_before(self, src_idx, cut_time):
|
124 |
+
"""
|
125 |
+
Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time.
|
126 |
+
|
127 |
+
Returns 3 lists: neighbors, edge_idxs, timestamps
|
128 |
+
|
129 |
+
"""
|
130 |
+
i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)
|
131 |
+
|
132 |
+
return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[src_idx][:i]
|
133 |
+
|
134 |
+
def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):
|
135 |
+
"""
|
136 |
+
Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list.
|
137 |
+
|
138 |
+
Params
|
139 |
+
------
|
140 |
+
src_idx_l: List[int]
|
141 |
+
cut_time_l: List[float],
|
142 |
+
num_neighbors: int
|
143 |
+
"""
|
144 |
+
assert (len(source_nodes) == len(timestamps))
|
145 |
+
|
146 |
+
tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1
|
147 |
+
# NB! All interactions described in these matrices are sorted in each row by time
|
148 |
+
neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
|
149 |
+
np.int32) # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i]
|
150 |
+
edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
|
151 |
+
np.float32) # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
|
152 |
+
edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
|
153 |
+
np.int32) # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
|
154 |
+
|
155 |
+
for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):
|
156 |
+
source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node,
|
157 |
+
timestamp) # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_time
|
158 |
+
|
159 |
+
if len(source_neighbors) > 0 and n_neighbors > 0:
|
160 |
+
if self.uniform: # if we are applying uniform sampling, shuffles the data above before sampling
|
161 |
+
sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)
|
162 |
+
|
163 |
+
neighbors[i, :] = source_neighbors[sampled_idx]
|
164 |
+
edge_times[i, :] = source_edge_times[sampled_idx]
|
165 |
+
edge_idxs[i, :] = source_edge_idxs[sampled_idx]
|
166 |
+
|
167 |
+
# re-sort based on time
|
168 |
+
pos = edge_times[i, :].argsort()
|
169 |
+
neighbors[i, :] = neighbors[i, :][pos]
|
170 |
+
edge_times[i, :] = edge_times[i, :][pos]
|
171 |
+
edge_idxs[i, :] = edge_idxs[i, :][pos]
|
172 |
+
else:
|
173 |
+
# Take most recent interactions
|
174 |
+
source_edge_times = source_edge_times[-n_neighbors:]
|
175 |
+
source_neighbors = source_neighbors[-n_neighbors:]
|
176 |
+
source_edge_idxs = source_edge_idxs[-n_neighbors:]
|
177 |
+
|
178 |
+
assert (len(source_neighbors) <= n_neighbors)
|
179 |
+
assert (len(source_edge_times) <= n_neighbors)
|
180 |
+
assert (len(source_edge_idxs) <= n_neighbors)
|
181 |
+
|
182 |
+
neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors
|
183 |
+
edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times
|
184 |
+
edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs
|
185 |
+
|
186 |
+
return neighbors, edge_idxs, edge_times
|