ashu316 commited on
Commit
41aae2b
·
verified ·
1 Parent(s): 1249292

Upload 14 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/ml_wikipedia.csv filter=lfs diff=lfs merge=lfs -text
data/ml_wikipedia.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85fd8e50e5ffbb1348173b85d0b7b69ee270550f511c505daf062c9bf9db8027
3
+ size 347159369
data/ml_wikipedia.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f601ac36dfaafdd78759d174611204a0660be70c93aa16f30edb56a7bc642b53
3
+ size 216685728
data/ml_wikipedia_node.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85f2054b5fe9d76188a5bf014232c81a64b3ffedb5a146ff30e1daa71215278b
3
+ size 12697856
model/temporal_attention.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from utils.utils import MergeLayer
5
+
6
+
7
+ class TemporalAttentionLayer(torch.nn.Module):
8
+ """
9
+ Temporal attention layer. Return the temporal embedding of a node given the node itself,
10
+ its neighbors and the edge timestamps.
11
+ """
12
+
13
+ def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim,
14
+ output_dimension, n_head=2,
15
+ dropout=0.1):
16
+ super(TemporalAttentionLayer, self).__init__()
17
+
18
+ self.n_head = n_head
19
+
20
+ self.feat_dim = n_node_features
21
+ self.time_dim = time_dim
22
+
23
+ self.query_dim = n_node_features + time_dim
24
+ self.key_dim = n_neighbors_features + time_dim + n_edge_features
25
+
26
+ self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension)
27
+
28
+ self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim,
29
+ kdim=self.key_dim,
30
+ vdim=self.key_dim,
31
+ num_heads=n_head,
32
+ dropout=dropout)
33
+
34
+ def forward(self, src_node_features, src_time_features, neighbors_features,
35
+ neighbors_time_features, edge_features, neighbors_padding_mask):
36
+ """
37
+ "Temporal attention model
38
+ :param src_node_features: float Tensor of shape [batch_size, n_node_features]
39
+ :param src_time_features: float Tensor of shape [batch_size, 1, time_dim]
40
+ :param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features]
41
+ :param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors,
42
+ time_dim]
43
+ :param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features]
44
+ :param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors]
45
+ :return:
46
+ attn_output: float Tensor of shape [1, batch_size, n_node_features]
47
+ attn_output_weights: [batch_size, 1, n_neighbors]
48
+ """
49
+
50
+ src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)
51
+
52
+ query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)
53
+ key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)
54
+
55
+ # print(neighbors_features.shape, edge_features.shape, neighbors_time_features.shape)
56
+ # Reshape tensors so to expected shape by multi head attention
57
+ query = query.permute([1, 0, 2]) # [1, batch_size, num_of_features]
58
+ key = key.permute([1, 0, 2]) # [n_neighbors, batch_size, num_of_features]
59
+
60
+ # Compute mask of which source nodes have no valid neighbors
61
+ invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)
62
+ # If a source node has no valid neighbor, set it's first neighbor to be valid. This will
63
+ # force the attention to just 'attend' on this neighbor (which has the same features as all
64
+ # the others since they are fake neighbors) and will produce an equivalent result to the
65
+ # original tgat paper which was forcing fake neighbors to all have same attention of 1e-10
66
+ neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False
67
+
68
+ # print(query.shape, key.shape)
69
+
70
+ attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,
71
+ key_padding_mask=neighbors_padding_mask)
72
+
73
+ # mask = torch.unsqueeze(neighbors_padding_mask, dim=2) # mask [B, N, 1]
74
+ # mask = mask.permute([0, 2, 1])
75
+ # attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key,
76
+ # mask=mask)
77
+
78
+ attn_output = attn_output.squeeze()
79
+ attn_output_weights = attn_output_weights.squeeze()
80
+
81
+ # Source nodes with no neighbors have an all zero attention output. The attention output is
82
+ # then added or concatenated to the original source node features and then fed into an MLP.
83
+ # This means that an all zero vector is not used.
84
+ attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)
85
+ attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)
86
+
87
+ # Skip connection with temporal attention over neighborhood and the features of the node itself
88
+ attn_output = self.merger(attn_output, src_node_features)
89
+
90
+ return attn_output, attn_output_weights
model/tgn.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import torch
4
+ from collections import defaultdict
5
+
6
+ from utils.utils import MergeLayer
7
+ from modules.memory import Memory
8
+ from modules.message_aggregator import get_message_aggregator
9
+ from modules.message_function import get_message_function
10
+ from modules.memory_updater import get_memory_updater
11
+ from modules.embedding_module import get_embedding_module
12
+ from model.time_encoding import TimeEncode
13
+
14
+
15
+ class TGN(torch.nn.Module):
16
+ def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2,
17
+ n_heads=2, dropout=0.1, use_memory=False,
18
+ memory_update_at_start=True, message_dimension=100,
19
+ memory_dimension=500, embedding_module_type="graph_attention",
20
+ message_function="mlp",
21
+ mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
22
+ std_time_shift_dst=1, n_neighbors=None, aggregator_type="last",
23
+ memory_updater_type="gru",
24
+ use_destination_embedding_in_message=False,
25
+ use_source_embedding_in_message=False,
26
+ dyrep=False):
27
+ super(TGN, self).__init__()
28
+
29
+ self.n_layers = n_layers
30
+ self.neighbor_finder = neighbor_finder
31
+ self.device = device
32
+ self.logger = logging.getLogger(__name__)
33
+
34
+ self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device)
35
+ self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device)
36
+
37
+ self.n_node_features = self.node_raw_features.shape[1]
38
+ self.n_nodes = self.node_raw_features.shape[0]
39
+ self.n_edge_features = self.edge_raw_features.shape[1]
40
+ self.embedding_dimension = self.n_node_features
41
+ self.n_neighbors = n_neighbors
42
+ self.embedding_module_type = embedding_module_type
43
+ self.use_destination_embedding_in_message = use_destination_embedding_in_message
44
+ self.use_source_embedding_in_message = use_source_embedding_in_message
45
+ self.dyrep = dyrep
46
+
47
+ self.use_memory = use_memory
48
+ self.time_encoder = TimeEncode(dimension=self.n_node_features)
49
+ self.memory = None
50
+
51
+ self.mean_time_shift_src = mean_time_shift_src
52
+ self.std_time_shift_src = std_time_shift_src
53
+ self.mean_time_shift_dst = mean_time_shift_dst
54
+ self.std_time_shift_dst = std_time_shift_dst
55
+
56
+ if self.use_memory:
57
+ self.memory_dimension = memory_dimension
58
+ self.memory_update_at_start = memory_update_at_start
59
+ raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \
60
+ self.time_encoder.dimension
61
+ message_dimension = message_dimension if message_function != "identity" else raw_message_dimension
62
+ self.memory = Memory(n_nodes=self.n_nodes,
63
+ memory_dimension=self.memory_dimension,
64
+ input_dimension=message_dimension,
65
+ message_dimension=message_dimension,
66
+ device=device)
67
+ self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type,
68
+ device=device)
69
+ self.message_function = get_message_function(module_type=message_function,
70
+ raw_message_dimension=raw_message_dimension,
71
+ message_dimension=message_dimension)
72
+ self.memory_updater = get_memory_updater(module_type=memory_updater_type,
73
+ memory=self.memory,
74
+ message_dimension=message_dimension,
75
+ memory_dimension=self.memory_dimension,
76
+ device=device)
77
+
78
+ self.embedding_module_type = embedding_module_type
79
+
80
+ self.embedding_module = get_embedding_module(module_type=embedding_module_type,
81
+ node_features=self.node_raw_features,
82
+ edge_features=self.edge_raw_features,
83
+ memory=self.memory,
84
+ neighbor_finder=self.neighbor_finder,
85
+ time_encoder=self.time_encoder,
86
+ n_layers=self.n_layers,
87
+ n_node_features=self.n_node_features,
88
+ n_edge_features=self.n_edge_features,
89
+ n_time_features=self.n_node_features,
90
+ embedding_dimension=self.embedding_dimension,
91
+ device=self.device,
92
+ n_heads=n_heads, dropout=dropout,
93
+ use_memory=use_memory,
94
+ n_neighbors=self.n_neighbors)
95
+
96
+ # MLP to compute probability on an edge given two node embeddings
97
+ self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features,
98
+ self.n_node_features,
99
+ 1)
100
+
101
+ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
102
+ edge_idxs, n_neighbors=20):
103
+ """
104
+ Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
105
+
106
+ source_nodes [batch_size]: source ids.
107
+ :param destination_nodes [batch_size]: destination ids
108
+ :param negative_nodes [batch_size]: ids of negative sampled destination
109
+ :param edge_times [batch_size]: timestamp of interaction
110
+ :param edge_idxs [batch_size]: index of interaction
111
+ :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
112
+ layer
113
+ :return: Temporal embeddings for sources, destinations and negatives
114
+ """
115
+
116
+ n_samples = len(source_nodes)
117
+ nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
118
+ positives = np.concatenate([source_nodes, destination_nodes])
119
+ timestamps = np.concatenate([edge_times, edge_times, edge_times])
120
+
121
+ memory = None
122
+ time_diffs = None
123
+ if self.use_memory:
124
+ if self.memory_update_at_start:
125
+ # Update memory for all nodes with messages stored in previous batches
126
+ memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
127
+ self.memory.messages)
128
+ else:
129
+ memory = self.memory.get_memory(list(range(self.n_nodes)))
130
+ last_update = self.memory.last_update
131
+
132
+ ### Compute differences between the time the memory of a node was last updated,
133
+ ### and the time for which we want to compute the embedding of a node
134
+ source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
135
+ source_nodes].long()
136
+ source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
137
+ destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
138
+ destination_nodes].long()
139
+ destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
140
+ negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
141
+ negative_nodes].long()
142
+ negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
143
+
144
+ time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
145
+ dim=0)
146
+
147
+ # Compute the embeddings using the embedding module
148
+ node_embedding = self.embedding_module.compute_embedding(memory=memory,
149
+ source_nodes=nodes,
150
+ timestamps=timestamps,
151
+ n_layers=self.n_layers,
152
+ n_neighbors=n_neighbors,
153
+ time_diffs=time_diffs)
154
+
155
+ source_node_embedding = node_embedding[:n_samples]
156
+ destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
157
+ negative_node_embedding = node_embedding[2 * n_samples:]
158
+
159
+ if self.use_memory:
160
+ if self.memory_update_at_start:
161
+ # Persist the updates to the memory only for sources and destinations (since now we have
162
+ # new messages for them)
163
+ self.update_memory(positives, self.memory.messages)
164
+
165
+ assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
166
+ "Something wrong in how the memory was updated"
167
+
168
+ # Remove messages for the positives since we have already updated the memory using them
169
+ self.memory.clear_messages(positives)
170
+
171
+ unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
172
+ source_node_embedding,
173
+ destination_nodes,
174
+ destination_node_embedding,
175
+ edge_times, edge_idxs)
176
+ unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
177
+ destination_node_embedding,
178
+ source_nodes,
179
+ source_node_embedding,
180
+ edge_times, edge_idxs)
181
+ if self.memory_update_at_start:
182
+ self.memory.store_raw_messages(unique_sources, source_id_to_messages)
183
+ self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
184
+ else:
185
+ self.update_memory(unique_sources, source_id_to_messages)
186
+ self.update_memory(unique_destinations, destination_id_to_messages)
187
+
188
+ if self.dyrep:
189
+ source_node_embedding = memory[source_nodes]
190
+ destination_node_embedding = memory[destination_nodes]
191
+ negative_node_embedding = memory[negative_nodes]
192
+
193
+ return source_node_embedding, destination_node_embedding, negative_node_embedding
194
+
195
+ def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
196
+ edge_idxs, n_neighbors=20):
197
+ """
198
+ Compute probabilities for edges between sources and destination and between sources and
199
+ negatives by first computing temporal embeddings using the TGN encoder and then feeding them
200
+ into the MLP decoder.
201
+ :param destination_nodes [batch_size]: destination ids
202
+ :param negative_nodes [batch_size]: ids of negative sampled destination
203
+ :param edge_times [batch_size]: timestamp of interaction
204
+ :param edge_idxs [batch_size]: index of interaction
205
+ :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
206
+ layer
207
+ :return: Probabilities for both the positive and negative edges
208
+ """
209
+ n_samples = len(source_nodes)
210
+ source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(
211
+ source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)
212
+
213
+ score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
214
+ torch.cat([destination_node_embedding,
215
+ negative_node_embedding])).squeeze(dim=0)
216
+ pos_score = score[:n_samples]
217
+ neg_score = score[n_samples:]
218
+
219
+ return pos_score.sigmoid(), neg_score.sigmoid()
220
+
221
+ def update_memory(self, nodes, messages):
222
+ # Aggregate messages for the same nodes
223
+ unique_nodes, unique_messages, unique_timestamps = \
224
+ self.message_aggregator.aggregate(
225
+ nodes,
226
+ messages)
227
+
228
+ if len(unique_nodes) > 0:
229
+ unique_messages = self.message_function.compute_message(unique_messages)
230
+
231
+ # Update the memory with the aggregated messages
232
+ self.memory_updater.update_memory(unique_nodes, unique_messages,
233
+ timestamps=unique_timestamps)
234
+
235
+ def get_updated_memory(self, nodes, messages):
236
+ # Aggregate messages for the same nodes
237
+ unique_nodes, unique_messages, unique_timestamps = \
238
+ self.message_aggregator.aggregate(
239
+ nodes,
240
+ messages)
241
+
242
+ if len(unique_nodes) > 0:
243
+ unique_messages = self.message_function.compute_message(unique_messages)
244
+
245
+ updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
246
+ unique_messages,
247
+ timestamps=unique_timestamps)
248
+
249
+ return updated_memory, updated_last_update
250
+
251
+ def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
252
+ destination_node_embedding, edge_times, edge_idxs):
253
+ edge_times = torch.from_numpy(edge_times).float().to(self.device)
254
+ edge_features = self.edge_raw_features[edge_idxs]
255
+
256
+ source_memory = self.memory.get_memory(source_nodes) if not \
257
+ self.use_source_embedding_in_message else source_node_embedding
258
+ destination_memory = self.memory.get_memory(destination_nodes) if \
259
+ not self.use_destination_embedding_in_message else destination_node_embedding
260
+
261
+ source_time_delta = edge_times - self.memory.last_update[source_nodes]
262
+ source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
263
+ source_nodes), -1)
264
+
265
+ source_message = torch.cat([source_memory, destination_memory, edge_features,
266
+ source_time_delta_encoding],
267
+ dim=1)
268
+ messages = defaultdict(list)
269
+ unique_sources = np.unique(source_nodes)
270
+
271
+ for i in range(len(source_nodes)):
272
+ messages[source_nodes[i]].append((source_message[i], edge_times[i]))
273
+
274
+ return unique_sources, messages
275
+
276
+ def set_neighbor_finder(self, neighbor_finder):
277
+ self.neighbor_finder = neighbor_finder
278
+ self.embedding_module.neighbor_finder = neighbor_finder
model/time_encoding.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class TimeEncode(torch.nn.Module):
6
+ # Time Encoding proposed by TGAT
7
+ def __init__(self, dimension):
8
+ super(TimeEncode, self).__init__()
9
+
10
+ self.dimension = dimension
11
+ self.w = torch.nn.Linear(1, dimension)
12
+
13
+ self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
14
+ .float().reshape(dimension, -1))
15
+ self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())
16
+
17
+ def forward(self, t):
18
+ # t has shape [batch_size, seq_len]
19
+ # Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
20
+ t = t.unsqueeze(dim=2)
21
+
22
+ # output has shape [batch_size, seq_len, dimension]
23
+ output = torch.cos(self.w(t))
24
+
25
+ return output
modules/embedding_module.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ import math
5
+
6
+ from model.temporal_attention import TemporalAttentionLayer
7
+
8
+
9
+ class EmbeddingModule(nn.Module):
10
+ def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
11
+ n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
12
+ dropout):
13
+ super(EmbeddingModule, self).__init__()
14
+ self.node_features = node_features
15
+ self.edge_features = edge_features
16
+ # self.memory = memory
17
+ self.neighbor_finder = neighbor_finder
18
+ self.time_encoder = time_encoder
19
+ self.n_layers = n_layers
20
+ self.n_node_features = n_node_features
21
+ self.n_edge_features = n_edge_features
22
+ self.n_time_features = n_time_features
23
+ self.dropout = dropout
24
+ self.embedding_dimension = embedding_dimension
25
+ self.device = device
26
+
27
+ def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
28
+ use_time_proj=True):
29
+ return NotImplemented
30
+
31
+
32
+ class IdentityEmbedding(EmbeddingModule):
33
+ def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
34
+ use_time_proj=True):
35
+ return memory[source_nodes, :]
36
+
37
+
38
+ class TimeEmbedding(EmbeddingModule):
39
+ def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
40
+ n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
41
+ n_heads=2, dropout=0.1, use_memory=True, n_neighbors=1):
42
+ super(TimeEmbedding, self).__init__(node_features, edge_features, memory,
43
+ neighbor_finder, time_encoder, n_layers,
44
+ n_node_features, n_edge_features, n_time_features,
45
+ embedding_dimension, device, dropout)
46
+
47
+ class NormalLinear(nn.Linear):
48
+ # From Jodie code
49
+ def reset_parameters(self):
50
+ stdv = 1. / math.sqrt(self.weight.size(1))
51
+ self.weight.data.normal_(0, stdv)
52
+ if self.bias is not None:
53
+ self.bias.data.normal_(0, stdv)
54
+
55
+ self.embedding_layer = NormalLinear(1, self.n_node_features)
56
+
57
+ def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
58
+ use_time_proj=True):
59
+ source_embeddings = memory[source_nodes, :] * (1 + self.embedding_layer(time_diffs.unsqueeze(1)))
60
+
61
+ return source_embeddings
62
+
63
+
64
+ class GraphEmbedding(EmbeddingModule):
65
+ def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
66
+ n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
67
+ n_heads=2, dropout=0.1, use_memory=True):
68
+ super(GraphEmbedding, self).__init__(node_features, edge_features, memory,
69
+ neighbor_finder, time_encoder, n_layers,
70
+ n_node_features, n_edge_features, n_time_features,
71
+ embedding_dimension, device, dropout)
72
+
73
+ self.use_memory = use_memory
74
+ self.device = device
75
+
76
+ def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
77
+ use_time_proj=True):
78
+ """Recursive implementation of curr_layers temporal graph attention layers.
79
+
80
+ src_idx_l [batch_size]: users / items input ids.
81
+ cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
82
+ curr_layers [scalar]: number of temporal convolutional layers to stack.
83
+ num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
84
+ """
85
+
86
+ assert (n_layers >= 0)
87
+
88
+ source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
89
+ timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)
90
+
91
+ # query node always has the start time -> time span == 0
92
+ source_nodes_time_embedding = self.time_encoder(torch.zeros_like(
93
+ timestamps_torch))
94
+
95
+ source_node_features = self.node_features[source_nodes_torch, :]
96
+
97
+ if self.use_memory:
98
+ source_node_features = memory[source_nodes, :] + source_node_features
99
+
100
+ if n_layers == 0:
101
+ return source_node_features
102
+ else:
103
+
104
+ source_node_conv_embeddings = self.compute_embedding(memory,
105
+ source_nodes,
106
+ timestamps,
107
+ n_layers=n_layers - 1,
108
+ n_neighbors=n_neighbors)
109
+
110
+ neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(
111
+ source_nodes,
112
+ timestamps,
113
+ n_neighbors=n_neighbors)
114
+
115
+ neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)
116
+
117
+ edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)
118
+
119
+ edge_deltas = timestamps[:, np.newaxis] - edge_times
120
+
121
+ edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)
122
+
123
+ neighbors = neighbors.flatten()
124
+ neighbor_embeddings = self.compute_embedding(memory,
125
+ neighbors,
126
+ np.repeat(timestamps, n_neighbors),
127
+ n_layers=n_layers - 1,
128
+ n_neighbors=n_neighbors)
129
+
130
+ effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
131
+ neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)
132
+ edge_time_embeddings = self.time_encoder(edge_deltas_torch)
133
+
134
+ edge_features = self.edge_features[edge_idxs, :]
135
+
136
+ mask = neighbors_torch == 0
137
+
138
+ source_embedding = self.aggregate(n_layers, source_node_conv_embeddings,
139
+ source_nodes_time_embedding,
140
+ neighbor_embeddings,
141
+ edge_time_embeddings,
142
+ edge_features,
143
+ mask)
144
+
145
+ return source_embedding
146
+
147
+ def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,
148
+ neighbor_embeddings,
149
+ edge_time_embeddings, edge_features, mask):
150
+ return NotImplemented
151
+
152
+
153
+ class GraphSumEmbedding(GraphEmbedding):
154
+ def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
155
+ n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
156
+ n_heads=2, dropout=0.1, use_memory=True):
157
+ super(GraphSumEmbedding, self).__init__(node_features=node_features,
158
+ edge_features=edge_features,
159
+ memory=memory,
160
+ neighbor_finder=neighbor_finder,
161
+ time_encoder=time_encoder, n_layers=n_layers,
162
+ n_node_features=n_node_features,
163
+ n_edge_features=n_edge_features,
164
+ n_time_features=n_time_features,
165
+ embedding_dimension=embedding_dimension,
166
+ device=device,
167
+ n_heads=n_heads, dropout=dropout,
168
+ use_memory=use_memory)
169
+ self.linear_1 = torch.nn.ModuleList([torch.nn.Linear(embedding_dimension + n_time_features +
170
+ n_edge_features, embedding_dimension)
171
+ for _ in range(n_layers)])
172
+ self.linear_2 = torch.nn.ModuleList(
173
+ [torch.nn.Linear(embedding_dimension + n_node_features + n_time_features,
174
+ embedding_dimension) for _ in range(n_layers)])
175
+
176
+ def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
177
+ neighbor_embeddings,
178
+ edge_time_embeddings, edge_features, mask):
179
+ neighbors_features = torch.cat([neighbor_embeddings, edge_time_embeddings, edge_features],
180
+ dim=2)
181
+ neighbor_embeddings = self.linear_1[n_layer - 1](neighbors_features)
182
+ neighbors_sum = torch.nn.functional.relu(torch.sum(neighbor_embeddings, dim=1))
183
+
184
+ source_features = torch.cat([source_node_features,
185
+ source_nodes_time_embedding.squeeze()], dim=1)
186
+ source_embedding = torch.cat([neighbors_sum, source_features], dim=1)
187
+ source_embedding = self.linear_2[n_layer - 1](source_embedding)
188
+
189
+ return source_embedding
190
+
191
+
192
+ class GraphAttentionEmbedding(GraphEmbedding):
193
+ def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
194
+ n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
195
+ n_heads=2, dropout=0.1, use_memory=True):
196
+ super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory,
197
+ neighbor_finder, time_encoder, n_layers,
198
+ n_node_features, n_edge_features,
199
+ n_time_features,
200
+ embedding_dimension, device,
201
+ n_heads, dropout,
202
+ use_memory)
203
+
204
+ self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(
205
+ n_node_features=n_node_features,
206
+ n_neighbors_features=n_node_features,
207
+ n_edge_features=n_edge_features,
208
+ time_dim=n_time_features,
209
+ n_head=n_heads,
210
+ dropout=dropout,
211
+ output_dimension=n_node_features)
212
+ for _ in range(n_layers)])
213
+
214
+ def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
215
+ neighbor_embeddings,
216
+ edge_time_embeddings, edge_features, mask):
217
+ attention_model = self.attention_models[n_layer - 1]
218
+
219
+ source_embedding, _ = attention_model(source_node_features,
220
+ source_nodes_time_embedding,
221
+ neighbor_embeddings,
222
+ edge_time_embeddings,
223
+ edge_features,
224
+ mask)
225
+
226
+ return source_embedding
227
+
228
+
229
+ def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,
230
+ time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,
231
+ embedding_dimension, device,
232
+ n_heads=2, dropout=0.1, n_neighbors=None,
233
+ use_memory=True):
234
+ if module_type == "graph_attention":
235
+ return GraphAttentionEmbedding(node_features=node_features,
236
+ edge_features=edge_features,
237
+ memory=memory,
238
+ neighbor_finder=neighbor_finder,
239
+ time_encoder=time_encoder,
240
+ n_layers=n_layers,
241
+ n_node_features=n_node_features,
242
+ n_edge_features=n_edge_features,
243
+ n_time_features=n_time_features,
244
+ embedding_dimension=embedding_dimension,
245
+ device=device,
246
+ n_heads=n_heads, dropout=dropout, use_memory=use_memory)
247
+ elif module_type == "graph_sum":
248
+ return GraphSumEmbedding(node_features=node_features,
249
+ edge_features=edge_features,
250
+ memory=memory,
251
+ neighbor_finder=neighbor_finder,
252
+ time_encoder=time_encoder,
253
+ n_layers=n_layers,
254
+ n_node_features=n_node_features,
255
+ n_edge_features=n_edge_features,
256
+ n_time_features=n_time_features,
257
+ embedding_dimension=embedding_dimension,
258
+ device=device,
259
+ n_heads=n_heads, dropout=dropout, use_memory=use_memory)
260
+
261
+ elif module_type == "identity":
262
+ return IdentityEmbedding(node_features=node_features,
263
+ edge_features=edge_features,
264
+ memory=memory,
265
+ neighbor_finder=neighbor_finder,
266
+ time_encoder=time_encoder,
267
+ n_layers=n_layers,
268
+ n_node_features=n_node_features,
269
+ n_edge_features=n_edge_features,
270
+ n_time_features=n_time_features,
271
+ embedding_dimension=embedding_dimension,
272
+ device=device,
273
+ dropout=dropout)
274
+ elif module_type == "time":
275
+ return TimeEmbedding(node_features=node_features,
276
+ edge_features=edge_features,
277
+ memory=memory,
278
+ neighbor_finder=neighbor_finder,
279
+ time_encoder=time_encoder,
280
+ n_layers=n_layers,
281
+ n_node_features=n_node_features,
282
+ n_edge_features=n_edge_features,
283
+ n_time_features=n_time_features,
284
+ embedding_dimension=embedding_dimension,
285
+ device=device,
286
+ dropout=dropout,
287
+ n_neighbors=n_neighbors)
288
+ else:
289
+ raise ValueError("Embedding Module {} not supported".format(module_type))
290
+
291
+
modules/memory.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from collections import defaultdict
5
+ from copy import deepcopy
6
+
7
+
8
+ class Memory(nn.Module):
9
+
10
+ def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,
11
+ device="cpu", combination_method='sum'):
12
+ super(Memory, self).__init__()
13
+ self.n_nodes = n_nodes
14
+ self.memory_dimension = memory_dimension
15
+ self.input_dimension = input_dimension
16
+ self.message_dimension = message_dimension
17
+ self.device = device
18
+
19
+ self.combination_method = combination_method
20
+
21
+ self.__init_memory__()
22
+
23
+ def __init_memory__(self):
24
+ """
25
+ Initializes the memory to all zeros. It should be called at the start of each epoch.
26
+ """
27
+ # Treat memory as parameter so that it is saved and loaded together with the model
28
+ self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
29
+ requires_grad=False)
30
+ self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
31
+ requires_grad=False)
32
+
33
+ self.messages = defaultdict(list)
34
+
35
+ def store_raw_messages(self, nodes, node_id_to_messages):
36
+ for node in nodes:
37
+ self.messages[node].extend(node_id_to_messages[node])
38
+
39
+ def get_memory(self, node_idxs):
40
+ return self.memory[node_idxs, :]
41
+
42
+ def set_memory(self, node_idxs, values):
43
+ self.memory[node_idxs, :] = values
44
+
45
+ def get_last_update(self, node_idxs):
46
+ return self.last_update[node_idxs]
47
+
48
+ def backup_memory(self):
49
+ messages_clone = {}
50
+ for k, v in self.messages.items():
51
+ messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]
52
+
53
+ return self.memory.data.clone(), self.last_update.data.clone(), messages_clone
54
+
55
+ def restore_memory(self, memory_backup):
56
+ self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()
57
+
58
+ self.messages = defaultdict(list)
59
+ for k, v in memory_backup[2].items():
60
+ self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]
61
+
62
+ def detach_memory(self):
63
+ self.memory.detach_()
64
+
65
+ # Detach all stored messages
66
+ for k, v in self.messages.items():
67
+ new_node_messages = []
68
+ for message in v:
69
+ new_node_messages.append((message[0].detach(), message[1]))
70
+
71
+ self.messages[k] = new_node_messages
72
+
73
+ def clear_messages(self, nodes):
74
+ for node in nodes:
75
+ self.messages[node] = []
modules/memory_updater.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+
5
+ class MemoryUpdater(nn.Module):
6
+ def update_memory(self, unique_node_ids, unique_messages, timestamps):
7
+ pass
8
+
9
+
10
+ class SequenceMemoryUpdater(MemoryUpdater):
11
+ def __init__(self, memory, message_dimension, memory_dimension, device):
12
+ super(SequenceMemoryUpdater, self).__init__()
13
+ self.memory = memory
14
+ self.layer_norm = torch.nn.LayerNorm(memory_dimension)
15
+ self.message_dimension = message_dimension
16
+ self.device = device
17
+
18
+ def update_memory(self, unique_node_ids, unique_messages, timestamps):
19
+ if len(unique_node_ids) <= 0:
20
+ return
21
+
22
+ assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
23
+ "update memory to time in the past"
24
+
25
+ memory = self.memory.get_memory(unique_node_ids)
26
+ self.memory.last_update[unique_node_ids] = timestamps
27
+
28
+ updated_memory = self.memory_updater(unique_messages, memory)
29
+
30
+ self.memory.set_memory(unique_node_ids, updated_memory)
31
+
32
+ def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):
33
+ if len(unique_node_ids) <= 0:
34
+ return self.memory.memory.data.clone(), self.memory.last_update.data.clone()
35
+
36
+ assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
37
+ "update memory to time in the past"
38
+
39
+ updated_memory = self.memory.memory.data.clone()
40
+ updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids])
41
+
42
+ updated_last_update = self.memory.last_update.data.clone()
43
+ updated_last_update[unique_node_ids] = timestamps
44
+
45
+ return updated_memory, updated_last_update
46
+
47
+
48
+ class GRUMemoryUpdater(SequenceMemoryUpdater):
49
+ def __init__(self, memory, message_dimension, memory_dimension, device):
50
+ super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)
51
+
52
+ self.memory_updater = nn.GRUCell(input_size=message_dimension,
53
+ hidden_size=memory_dimension)
54
+
55
+
56
+ class RNNMemoryUpdater(SequenceMemoryUpdater):
57
+ def __init__(self, memory, message_dimension, memory_dimension, device):
58
+ super(RNNMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)
59
+
60
+ self.memory_updater = nn.RNNCell(input_size=message_dimension,
61
+ hidden_size=memory_dimension)
62
+
63
+
64
+ def get_memory_updater(module_type, memory, message_dimension, memory_dimension, device):
65
+ if module_type == "gru":
66
+ return GRUMemoryUpdater(memory, message_dimension, memory_dimension, device)
67
+ elif module_type == "rnn":
68
+ return RNNMemoryUpdater(memory, message_dimension, memory_dimension, device)
modules/message_aggregator.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ class MessageAggregator(torch.nn.Module):
7
+ """
8
+ Abstract class for the message aggregator module, which given a batch of node ids and
9
+ corresponding messages, aggregates messages with the same node id.
10
+ """
11
+ def __init__(self, device):
12
+ super(MessageAggregator, self).__init__()
13
+ self.device = device
14
+
15
+ def aggregate(self, node_ids, messages):
16
+ """
17
+ Given a list of node ids, and a list of messages of the same length, aggregate different
18
+ messages for the same id using one of the possible strategies.
19
+ :param node_ids: A list of node ids of length batch_size
20
+ :param messages: A tensor of shape [batch_size, message_length]
21
+ :param timestamps A tensor of shape [batch_size]
22
+ :return: A tensor of shape [n_unique_node_ids, message_length] with the aggregated messages
23
+ """
24
+
25
+ def group_by_id(self, node_ids, messages, timestamps):
26
+ node_id_to_messages = defaultdict(list)
27
+
28
+ for i, node_id in enumerate(node_ids):
29
+ node_id_to_messages[node_id].append((messages[i], timestamps[i]))
30
+
31
+ return node_id_to_messages
32
+
33
+
34
+ class LastMessageAggregator(MessageAggregator):
35
+ def __init__(self, device):
36
+ super(LastMessageAggregator, self).__init__(device)
37
+
38
+ def aggregate(self, node_ids, messages):
39
+ """Only keep the last message for each node"""
40
+ unique_node_ids = np.unique(node_ids)
41
+ unique_messages = []
42
+ unique_timestamps = []
43
+
44
+ to_update_node_ids = []
45
+
46
+ for node_id in unique_node_ids:
47
+ if len(messages[node_id]) > 0:
48
+ to_update_node_ids.append(node_id)
49
+ unique_messages.append(messages[node_id][-1][0])
50
+ unique_timestamps.append(messages[node_id][-1][1])
51
+
52
+ unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
53
+ unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
54
+
55
+ return to_update_node_ids, unique_messages, unique_timestamps
56
+
57
+
58
+ class MeanMessageAggregator(MessageAggregator):
59
+ def __init__(self, device):
60
+ super(MeanMessageAggregator, self).__init__(device)
61
+
62
+ def aggregate(self, node_ids, messages):
63
+ """Only keep the last message for each node"""
64
+ unique_node_ids = np.unique(node_ids)
65
+ unique_messages = []
66
+ unique_timestamps = []
67
+
68
+ to_update_node_ids = []
69
+ n_messages = 0
70
+
71
+ for node_id in unique_node_ids:
72
+ if len(messages[node_id]) > 0:
73
+ n_messages += len(messages[node_id])
74
+ to_update_node_ids.append(node_id)
75
+ unique_messages.append(torch.mean(torch.stack([m[0] for m in messages[node_id]]), dim=0))
76
+ unique_timestamps.append(messages[node_id][-1][1])
77
+
78
+ unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
79
+ unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
80
+
81
+ return to_update_node_ids, unique_messages, unique_timestamps
82
+
83
+
84
+ def get_message_aggregator(aggregator_type, device):
85
+ if aggregator_type == "last":
86
+ return LastMessageAggregator(device=device)
87
+ elif aggregator_type == "mean":
88
+ return MeanMessageAggregator(device=device)
89
+ else:
90
+ raise ValueError("Message aggregator {} not implemented".format(aggregator_type))
modules/message_function.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class MessageFunction(nn.Module):
5
+ """
6
+ Module which computes the message for a given interaction.
7
+ """
8
+
9
+ def compute_message(self, raw_messages):
10
+ return None
11
+
12
+
13
+ class MLPMessageFunction(MessageFunction):
14
+ def __init__(self, raw_message_dimension, message_dimension):
15
+ super(MLPMessageFunction, self).__init__()
16
+
17
+ self.mlp = self.layers = nn.Sequential(
18
+ nn.Linear(raw_message_dimension, raw_message_dimension // 2),
19
+ nn.ReLU(),
20
+ nn.Linear(raw_message_dimension // 2, message_dimension),
21
+ )
22
+
23
+ def compute_message(self, raw_messages):
24
+ messages = self.mlp(raw_messages)
25
+
26
+ return messages
27
+
28
+
29
+ class IdentityMessageFunction(MessageFunction):
30
+
31
+ def compute_message(self, raw_messages):
32
+
33
+ return raw_messages
34
+
35
+
36
+ def get_message_function(module_type, raw_message_dimension, message_dimension):
37
+ if module_type == "mlp":
38
+ return MLPMessageFunction(raw_message_dimension, message_dimension)
39
+ elif module_type == "identity":
40
+ return IdentityMessageFunction()
modules/tgn.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import torch
4
+ from collections import defaultdict
5
+
6
+ from utils.utils import MergeLayer
7
+ from modules.memory import Memory
8
+ from modules.message_aggregator import get_message_aggregator
9
+ from modules.message_function import get_message_function
10
+ from modules.memory_updater import get_memory_updater
11
+ from modules.embedding_module import get_embedding_module
12
+ from model.time_encoding import TimeEncode
13
+
14
+
15
+ class TGN(torch.nn.Module):
16
+ def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2,
17
+ n_heads=2, dropout=0.1, use_memory=False,
18
+ memory_update_at_start=True, message_dimension=100,
19
+ memory_dimension=500, embedding_module_type="graph_attention",
20
+ message_function="mlp",
21
+ mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
22
+ std_time_shift_dst=1, n_neighbors=None, aggregator_type="last",
23
+ memory_updater_type="gru",
24
+ use_destination_embedding_in_message=False,
25
+ use_source_embedding_in_message=False,
26
+ dyrep=False):
27
+ super(TGN, self).__init__()
28
+
29
+ self.n_layers = n_layers
30
+ self.neighbor_finder = neighbor_finder
31
+ self.device = device
32
+ self.logger = logging.getLogger(__name__)
33
+
34
+ self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device)
35
+ self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device)
36
+
37
+ self.n_node_features = self.node_raw_features.shape[1]
38
+ self.n_nodes = self.node_raw_features.shape[0]
39
+ self.n_edge_features = self.edge_raw_features.shape[1]
40
+ self.embedding_dimension = self.n_node_features
41
+ self.n_neighbors = n_neighbors
42
+ self.embedding_module_type = embedding_module_type
43
+ self.use_destination_embedding_in_message = use_destination_embedding_in_message
44
+ self.use_source_embedding_in_message = use_source_embedding_in_message
45
+ self.dyrep = dyrep
46
+
47
+ self.use_memory = use_memory
48
+ self.time_encoder = TimeEncode(dimension=self.n_node_features)
49
+ self.memory = None
50
+
51
+ self.mean_time_shift_src = mean_time_shift_src
52
+ self.std_time_shift_src = std_time_shift_src
53
+ self.mean_time_shift_dst = mean_time_shift_dst
54
+ self.std_time_shift_dst = std_time_shift_dst
55
+
56
+ if self.use_memory:
57
+ self.memory_dimension = memory_dimension
58
+ self.memory_update_at_start = memory_update_at_start
59
+ raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \
60
+ self.time_encoder.dimension
61
+ message_dimension = message_dimension if message_function != "identity" else raw_message_dimension
62
+ self.memory = Memory(n_nodes=self.n_nodes,
63
+ memory_dimension=self.memory_dimension,
64
+ input_dimension=message_dimension,
65
+ message_dimension=message_dimension,
66
+ device=device)
67
+ self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type,
68
+ device=device)
69
+ self.message_function = get_message_function(module_type=message_function,
70
+ raw_message_dimension=raw_message_dimension,
71
+ message_dimension=message_dimension)
72
+ self.memory_updater = get_memory_updater(module_type=memory_updater_type,
73
+ memory=self.memory,
74
+ message_dimension=message_dimension,
75
+ memory_dimension=self.memory_dimension,
76
+ device=device)
77
+
78
+ self.embedding_module_type = embedding_module_type
79
+
80
+ self.embedding_module = get_embedding_module(module_type=embedding_module_type,
81
+ node_features=self.node_raw_features,
82
+ edge_features=self.edge_raw_features,
83
+ memory=self.memory,
84
+ neighbor_finder=self.neighbor_finder,
85
+ time_encoder=self.time_encoder,
86
+ n_layers=self.n_layers,
87
+ n_node_features=self.n_node_features,
88
+ n_edge_features=self.n_edge_features,
89
+ n_time_features=self.n_node_features,
90
+ embedding_dimension=self.embedding_dimension,
91
+ device=self.device,
92
+ n_heads=n_heads, dropout=dropout,
93
+ use_memory=use_memory,
94
+ n_neighbors=self.n_neighbors)
95
+
96
+ # MLP to compute probability on an edge given two node embeddings
97
+ self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features,
98
+ self.n_node_features,
99
+ 1)
100
+
101
+ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
102
+ edge_idxs, n_neighbors=20):
103
+ """
104
+ Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
105
+
106
+ source_nodes [batch_size]: source ids.
107
+ :param destination_nodes [batch_size]: destination ids
108
+ :param negative_nodes [batch_size]: ids of negative sampled destination
109
+ :param edge_times [batch_size]: timestamp of interaction
110
+ :param edge_idxs [batch_size]: index of interaction
111
+ :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
112
+ layer
113
+ :return: Temporal embeddings for sources, destinations and negatives
114
+ """
115
+
116
+ n_samples = len(source_nodes)
117
+ nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
118
+ positives = np.concatenate([source_nodes, destination_nodes])
119
+ timestamps = np.concatenate([edge_times, edge_times, edge_times])
120
+
121
+ memory = None
122
+ time_diffs = None
123
+ if self.use_memory:
124
+ if self.memory_update_at_start:
125
+ # Update memory for all nodes with messages stored in previous batches
126
+ memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
127
+ self.memory.messages)
128
+ else:
129
+ memory = self.memory.get_memory(list(range(self.n_nodes)))
130
+ last_update = self.memory.last_update
131
+
132
+ ### Compute differences between the time the memory of a node was last updated,
133
+ ### and the time for which we want to compute the embedding of a node
134
+ source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
135
+ source_nodes].long()
136
+ source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
137
+ destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
138
+ destination_nodes].long()
139
+ destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
140
+ negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
141
+ negative_nodes].long()
142
+ negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
143
+
144
+ time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
145
+ dim=0)
146
+
147
+ # Compute the embeddings using the embedding module
148
+ node_embedding = self.embedding_module.compute_embedding(memory=memory,
149
+ source_nodes=nodes,
150
+ timestamps=timestamps,
151
+ n_layers=self.n_layers,
152
+ n_neighbors=n_neighbors,
153
+ time_diffs=time_diffs)
154
+
155
+ source_node_embedding = node_embedding[:n_samples]
156
+ destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
157
+ negative_node_embedding = node_embedding[2 * n_samples:]
158
+
159
+ if self.use_memory:
160
+ if self.memory_update_at_start:
161
+ # Persist the updates to the memory only for sources and destinations (since now we have
162
+ # new messages for them)
163
+ self.update_memory(positives, self.memory.messages)
164
+
165
+ assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
166
+ "Something wrong in how the memory was updated"
167
+
168
+ # Remove messages for the positives since we have already updated the memory using them
169
+ self.memory.clear_messages(positives)
170
+
171
+ unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
172
+ source_node_embedding,
173
+ destination_nodes,
174
+ destination_node_embedding,
175
+ edge_times, edge_idxs)
176
+ unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
177
+ destination_node_embedding,
178
+ source_nodes,
179
+ source_node_embedding,
180
+ edge_times, edge_idxs)
181
+ if self.memory_update_at_start:
182
+ self.memory.store_raw_messages(unique_sources, source_id_to_messages)
183
+ self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
184
+ else:
185
+ self.update_memory(unique_sources, source_id_to_messages)
186
+ self.update_memory(unique_destinations, destination_id_to_messages)
187
+
188
+ if self.dyrep:
189
+ source_node_embedding = memory[source_nodes]
190
+ destination_node_embedding = memory[destination_nodes]
191
+ negative_node_embedding = memory[negative_nodes]
192
+
193
+ return source_node_embedding, destination_node_embedding, negative_node_embedding
194
+
195
+ def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
196
+ edge_idxs, n_neighbors=20):
197
+ """
198
+ Compute probabilities for edges between sources and destination and between sources and
199
+ negatives by first computing temporal embeddings using the TGN encoder and then feeding them
200
+ into the MLP decoder.
201
+ :param destination_nodes [batch_size]: destination ids
202
+ :param negative_nodes [batch_size]: ids of negative sampled destination
203
+ :param edge_times [batch_size]: timestamp of interaction
204
+ :param edge_idxs [batch_size]: index of interaction
205
+ :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
206
+ layer
207
+ :return: Probabilities for both the positive and negative edges
208
+ """
209
+ n_samples = len(source_nodes)
210
+ source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(
211
+ source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)
212
+
213
+ score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
214
+ torch.cat([destination_node_embedding,
215
+ negative_node_embedding])).squeeze(dim=0)
216
+ pos_score = score[:n_samples]
217
+ neg_score = score[n_samples:]
218
+
219
+ return pos_score.sigmoid(), neg_score.sigmoid()
220
+
221
+ def update_memory(self, nodes, messages):
222
+ # Aggregate messages for the same nodes
223
+ unique_nodes, unique_messages, unique_timestamps = \
224
+ self.message_aggregator.aggregate(
225
+ nodes,
226
+ messages)
227
+
228
+ if len(unique_nodes) > 0:
229
+ unique_messages = self.message_function.compute_message(unique_messages)
230
+
231
+ # Update the memory with the aggregated messages
232
+ self.memory_updater.update_memory(unique_nodes, unique_messages,
233
+ timestamps=unique_timestamps)
234
+
235
+ def get_updated_memory(self, nodes, messages):
236
+ # Aggregate messages for the same nodes
237
+ unique_nodes, unique_messages, unique_timestamps = \
238
+ self.message_aggregator.aggregate(
239
+ nodes,
240
+ messages)
241
+
242
+ if len(unique_nodes) > 0:
243
+ unique_messages = self.message_function.compute_message(unique_messages)
244
+
245
+ updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
246
+ unique_messages,
247
+ timestamps=unique_timestamps)
248
+
249
+ return updated_memory, updated_last_update
250
+
251
+ def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
252
+ destination_node_embedding, edge_times, edge_idxs):
253
+ edge_times = torch.from_numpy(edge_times).float().to(self.device)
254
+ edge_features = self.edge_raw_features[edge_idxs]
255
+
256
+ source_memory = self.memory.get_memory(source_nodes) if not \
257
+ self.use_source_embedding_in_message else source_node_embedding
258
+ destination_memory = self.memory.get_memory(destination_nodes) if \
259
+ not self.use_destination_embedding_in_message else destination_node_embedding
260
+
261
+ source_time_delta = edge_times - self.memory.last_update[source_nodes]
262
+ source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
263
+ source_nodes), -1)
264
+
265
+ source_message = torch.cat([source_memory, destination_memory, edge_features,
266
+ source_time_delta_encoding],
267
+ dim=1)
268
+ messages = defaultdict(list)
269
+ unique_sources = np.unique(source_nodes)
270
+
271
+ for i in range(len(source_nodes)):
272
+ messages[source_nodes[i]].append((source_message[i], edge_times[i]))
273
+
274
+ return unique_sources, messages
275
+
276
+ def set_neighbor_finder(self, neighbor_finder):
277
+ self.neighbor_finder = neighbor_finder
278
+ self.embedding_module.neighbor_finder = neighbor_finder
utils/data_processing.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import pandas as pd
4
+
5
+
6
+ class Data:
7
+ def __init__(self, sources, destinations, timestamps, edge_idxs, labels):
8
+ self.sources = sources
9
+ self.destinations = destinations
10
+ self.timestamps = timestamps
11
+ self.edge_idxs = edge_idxs
12
+ self.labels = labels
13
+ self.n_interactions = len(sources)
14
+ self.unique_nodes = set(sources) | set(destinations)
15
+ self.n_unique_nodes = len(self.unique_nodes)
16
+
17
+
18
+ def get_data_node_classification(dataset_name, use_validation=False):
19
+ ### Load data and train val test split
20
+ graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name))
21
+ edge_features = np.load('./data/ml_{}.npy'.format(dataset_name))
22
+ node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name))
23
+
24
+ val_time, test_time = list(np.quantile(graph_df.ts, [0.70, 0.85]))
25
+
26
+ sources = graph_df.u.values
27
+ destinations = graph_df.i.values
28
+ edge_idxs = graph_df.idx.values
29
+ labels = graph_df.label.values
30
+ timestamps = graph_df.ts.values
31
+
32
+ random.seed(2020)
33
+
34
+ train_mask = timestamps <= val_time if use_validation else timestamps <= test_time
35
+ test_mask = timestamps > test_time
36
+ val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time) if use_validation else test_mask
37
+
38
+ full_data = Data(sources, destinations, timestamps, edge_idxs, labels)
39
+
40
+ train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask],
41
+ edge_idxs[train_mask], labels[train_mask])
42
+
43
+ val_data = Data(sources[val_mask], destinations[val_mask], timestamps[val_mask],
44
+ edge_idxs[val_mask], labels[val_mask])
45
+
46
+ test_data = Data(sources[test_mask], destinations[test_mask], timestamps[test_mask],
47
+ edge_idxs[test_mask], labels[test_mask])
48
+
49
+ return full_data, node_features, edge_features, train_data, val_data, test_data
50
+
51
+
52
+ def get_data(dataset_name, different_new_nodes_between_val_and_test=False, randomize_features=False):
53
+ ### Load data and train val test split
54
+ graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name))
55
+ edge_features = np.load('./data/ml_{}.npy'.format(dataset_name))
56
+ node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name))
57
+
58
+ if randomize_features:
59
+ node_features = np.random.rand(node_features.shape[0], node_features.shape[1])
60
+
61
+ val_time, test_time = list(np.quantile(graph_df.ts, [0.70, 0.85]))
62
+
63
+ sources = graph_df.u.values
64
+ destinations = graph_df.i.values
65
+ edge_idxs = graph_df.idx.values
66
+ labels = graph_df.label.values
67
+ timestamps = graph_df.ts.values
68
+
69
+ full_data = Data(sources, destinations, timestamps, edge_idxs, labels)
70
+
71
+ random.seed(2020)
72
+
73
+ node_set = set(sources) | set(destinations)
74
+ n_total_unique_nodes = len(node_set)
75
+
76
+ # Compute nodes which appear at test time
77
+ test_node_set = set(sources[timestamps > val_time]).union(
78
+ set(destinations[timestamps > val_time]))
79
+ # Sample nodes which we keep as new nodes (to test inductiveness), so than we have to remove all
80
+ # their edges from training
81
+ new_test_node_set = set(random.sample(test_node_set, int(0.1 * n_total_unique_nodes)))
82
+
83
+ # Mask saying for each source and destination whether they are new test nodes
84
+ new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
85
+ new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values
86
+
87
+ # Mask which is true for edges with both destination and source not being new test nodes (because
88
+ # we want to remove all edges involving any new test node)
89
+ observed_edges_mask = np.logical_and(~new_test_source_mask, ~new_test_destination_mask)
90
+
91
+ # For train we keep edges happening before the validation time which do not involve any new node
92
+ # used for inductiveness
93
+ train_mask = np.logical_and(timestamps <= val_time, observed_edges_mask)
94
+
95
+ train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask],
96
+ edge_idxs[train_mask], labels[train_mask])
97
+
98
+ # define the new nodes sets for testing inductiveness of the model
99
+ train_node_set = set(train_data.sources).union(train_data.destinations)
100
+ assert len(train_node_set & new_test_node_set) == 0
101
+ new_node_set = node_set - train_node_set
102
+
103
+ val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time)
104
+ test_mask = timestamps > test_time
105
+
106
+ if different_new_nodes_between_val_and_test:
107
+ n_new_nodes = len(new_test_node_set) // 2
108
+ val_new_node_set = set(list(new_test_node_set)[:n_new_nodes])
109
+ test_new_node_set = set(list(new_test_node_set)[n_new_nodes:])
110
+
111
+ edge_contains_new_val_node_mask = np.array(
112
+ [(a in val_new_node_set or b in val_new_node_set) for a, b in zip(sources, destinations)])
113
+ edge_contains_new_test_node_mask = np.array(
114
+ [(a in test_new_node_set or b in test_new_node_set) for a, b in zip(sources, destinations)])
115
+ new_node_val_mask = np.logical_and(val_mask, edge_contains_new_val_node_mask)
116
+ new_node_test_mask = np.logical_and(test_mask, edge_contains_new_test_node_mask)
117
+
118
+
119
+ else:
120
+ edge_contains_new_node_mask = np.array(
121
+ [(a in new_node_set or b in new_node_set) for a, b in zip(sources, destinations)])
122
+ new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask)
123
+ new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask)
124
+
125
+ # validation and test with all edges
126
+ val_data = Data(sources[val_mask], destinations[val_mask], timestamps[val_mask],
127
+ edge_idxs[val_mask], labels[val_mask])
128
+
129
+ test_data = Data(sources[test_mask], destinations[test_mask], timestamps[test_mask],
130
+ edge_idxs[test_mask], labels[test_mask])
131
+
132
+ # validation and test with edges that at least has one new node (not in training set)
133
+ new_node_val_data = Data(sources[new_node_val_mask], destinations[new_node_val_mask],
134
+ timestamps[new_node_val_mask],
135
+ edge_idxs[new_node_val_mask], labels[new_node_val_mask])
136
+
137
+ new_node_test_data = Data(sources[new_node_test_mask], destinations[new_node_test_mask],
138
+ timestamps[new_node_test_mask], edge_idxs[new_node_test_mask],
139
+ labels[new_node_test_mask])
140
+
141
+ print("The dataset has {} interactions, involving {} different nodes".format(full_data.n_interactions,
142
+ full_data.n_unique_nodes))
143
+ print("The training dataset has {} interactions, involving {} different nodes".format(
144
+ train_data.n_interactions, train_data.n_unique_nodes))
145
+ print("The validation dataset has {} interactions, involving {} different nodes".format(
146
+ val_data.n_interactions, val_data.n_unique_nodes))
147
+ print("The test dataset has {} interactions, involving {} different nodes".format(
148
+ test_data.n_interactions, test_data.n_unique_nodes))
149
+ print("The new node validation dataset has {} interactions, involving {} different nodes".format(
150
+ new_node_val_data.n_interactions, new_node_val_data.n_unique_nodes))
151
+ print("The new node test dataset has {} interactions, involving {} different nodes".format(
152
+ new_node_test_data.n_interactions, new_node_test_data.n_unique_nodes))
153
+ print("{} nodes were used for the inductive testing, i.e. are never seen during training".format(
154
+ len(new_test_node_set)))
155
+
156
+ return node_features, edge_features, full_data, train_data, val_data, test_data, \
157
+ new_node_val_data, new_node_test_data
158
+
159
+
160
+ def compute_time_statistics(sources, destinations, timestamps):
161
+ last_timestamp_sources = dict()
162
+ last_timestamp_dst = dict()
163
+ all_timediffs_src = []
164
+ all_timediffs_dst = []
165
+ for k in range(len(sources)):
166
+ source_id = sources[k]
167
+ dest_id = destinations[k]
168
+ c_timestamp = timestamps[k]
169
+ if source_id not in last_timestamp_sources.keys():
170
+ last_timestamp_sources[source_id] = 0
171
+ if dest_id not in last_timestamp_dst.keys():
172
+ last_timestamp_dst[dest_id] = 0
173
+ all_timediffs_src.append(c_timestamp - last_timestamp_sources[source_id])
174
+ all_timediffs_dst.append(c_timestamp - last_timestamp_dst[dest_id])
175
+ last_timestamp_sources[source_id] = c_timestamp
176
+ last_timestamp_dst[dest_id] = c_timestamp
177
+ assert len(all_timediffs_src) == len(sources)
178
+ assert len(all_timediffs_dst) == len(sources)
179
+ mean_time_shift_src = np.mean(all_timediffs_src)
180
+ std_time_shift_src = np.std(all_timediffs_src)
181
+ mean_time_shift_dst = np.mean(all_timediffs_dst)
182
+ std_time_shift_dst = np.std(all_timediffs_dst)
183
+
184
+ return mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst
utils/utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class MergeLayer(torch.nn.Module):
6
+ def __init__(self, dim1, dim2, dim3, dim4):
7
+ super().__init__()
8
+ self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
9
+ self.fc2 = torch.nn.Linear(dim3, dim4)
10
+ self.act = torch.nn.ReLU()
11
+
12
+ torch.nn.init.xavier_normal_(self.fc1.weight)
13
+ torch.nn.init.xavier_normal_(self.fc2.weight)
14
+
15
+ def forward(self, x1, x2):
16
+ x = torch.cat([x1, x2], dim=1)
17
+ h = self.act(self.fc1(x))
18
+ return self.fc2(h)
19
+
20
+
21
+ class MLP(torch.nn.Module):
22
+ def __init__(self, dim, drop=0.3):
23
+ super().__init__()
24
+ self.fc_1 = torch.nn.Linear(dim, 80)
25
+ self.fc_2 = torch.nn.Linear(80, 10)
26
+ self.fc_3 = torch.nn.Linear(10, 1)
27
+ self.act = torch.nn.ReLU()
28
+ self.dropout = torch.nn.Dropout(p=drop, inplace=False)
29
+
30
+ def forward(self, x):
31
+ x = self.act(self.fc_1(x))
32
+ x = self.dropout(x)
33
+ x = self.act(self.fc_2(x))
34
+ x = self.dropout(x)
35
+ return self.fc_3(x).squeeze(dim=1)
36
+
37
+
38
+ class EarlyStopMonitor(object):
39
+ def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
40
+ self.max_round = max_round
41
+ self.num_round = 0
42
+
43
+ self.epoch_count = 0
44
+ self.best_epoch = 0
45
+
46
+ self.last_best = None
47
+ self.higher_better = higher_better
48
+ self.tolerance = tolerance
49
+
50
+ def early_stop_check(self, curr_val):
51
+ if not self.higher_better:
52
+ curr_val *= -1
53
+ if self.last_best is None:
54
+ self.last_best = curr_val
55
+ elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
56
+ self.last_best = curr_val
57
+ self.num_round = 0
58
+ self.best_epoch = self.epoch_count
59
+ else:
60
+ self.num_round += 1
61
+
62
+ self.epoch_count += 1
63
+
64
+ return self.num_round >= self.max_round
65
+
66
+
67
+ class RandEdgeSampler(object):
68
+ def __init__(self, src_list, dst_list, seed=None):
69
+ self.seed = None
70
+ self.src_list = np.unique(src_list)
71
+ self.dst_list = np.unique(dst_list)
72
+
73
+ if seed is not None:
74
+ self.seed = seed
75
+ self.random_state = np.random.RandomState(self.seed)
76
+
77
+ def sample(self, size):
78
+ if self.seed is None:
79
+ src_index = np.random.randint(0, len(self.src_list), size)
80
+ dst_index = np.random.randint(0, len(self.dst_list), size)
81
+ else:
82
+
83
+ src_index = self.random_state.randint(0, len(self.src_list), size)
84
+ dst_index = self.random_state.randint(0, len(self.dst_list), size)
85
+ return self.src_list[src_index], self.dst_list[dst_index]
86
+
87
+ def reset_random_state(self):
88
+ self.random_state = np.random.RandomState(self.seed)
89
+
90
+
91
+ def get_neighbor_finder(data, uniform, max_node_idx=None):
92
+ max_node_idx = max(data.sources.max(), data.destinations.max()) if max_node_idx is None else max_node_idx
93
+ adj_list = [[] for _ in range(max_node_idx + 1)]
94
+ for source, destination, edge_idx, timestamp in zip(data.sources, data.destinations,
95
+ data.edge_idxs,
96
+ data.timestamps):
97
+ adj_list[source].append((destination, edge_idx, timestamp))
98
+ adj_list[destination].append((source, edge_idx, timestamp))
99
+
100
+ return NeighborFinder(adj_list, uniform=uniform)
101
+
102
+
103
+ class NeighborFinder:
104
+ def __init__(self, adj_list, uniform=False, seed=None):
105
+ self.node_to_neighbors = []
106
+ self.node_to_edge_idxs = []
107
+ self.node_to_edge_timestamps = []
108
+
109
+ for neighbors in adj_list:
110
+ # Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
111
+ # We sort the list based on timestamp
112
+ sorted_neighhbors = sorted(neighbors, key=lambda x: x[2])
113
+ self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors]))
114
+ self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))
115
+ self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))
116
+
117
+ self.uniform = uniform
118
+
119
+ if seed is not None:
120
+ self.seed = seed
121
+ self.random_state = np.random.RandomState(self.seed)
122
+
123
+ def find_before(self, src_idx, cut_time):
124
+ """
125
+ Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time.
126
+
127
+ Returns 3 lists: neighbors, edge_idxs, timestamps
128
+
129
+ """
130
+ i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)
131
+
132
+ return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[src_idx][:i]
133
+
134
+ def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):
135
+ """
136
+ Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list.
137
+
138
+ Params
139
+ ------
140
+ src_idx_l: List[int]
141
+ cut_time_l: List[float],
142
+ num_neighbors: int
143
+ """
144
+ assert (len(source_nodes) == len(timestamps))
145
+
146
+ tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1
147
+ # NB! All interactions described in these matrices are sorted in each row by time
148
+ neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
149
+ np.int32) # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i]
150
+ edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
151
+ np.float32) # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
152
+ edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
153
+ np.int32) # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
154
+
155
+ for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):
156
+ source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node,
157
+ timestamp) # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_time
158
+
159
+ if len(source_neighbors) > 0 and n_neighbors > 0:
160
+ if self.uniform: # if we are applying uniform sampling, shuffles the data above before sampling
161
+ sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)
162
+
163
+ neighbors[i, :] = source_neighbors[sampled_idx]
164
+ edge_times[i, :] = source_edge_times[sampled_idx]
165
+ edge_idxs[i, :] = source_edge_idxs[sampled_idx]
166
+
167
+ # re-sort based on time
168
+ pos = edge_times[i, :].argsort()
169
+ neighbors[i, :] = neighbors[i, :][pos]
170
+ edge_times[i, :] = edge_times[i, :][pos]
171
+ edge_idxs[i, :] = edge_idxs[i, :][pos]
172
+ else:
173
+ # Take most recent interactions
174
+ source_edge_times = source_edge_times[-n_neighbors:]
175
+ source_neighbors = source_neighbors[-n_neighbors:]
176
+ source_edge_idxs = source_edge_idxs[-n_neighbors:]
177
+
178
+ assert (len(source_neighbors) <= n_neighbors)
179
+ assert (len(source_edge_times) <= n_neighbors)
180
+ assert (len(source_edge_idxs) <= n_neighbors)
181
+
182
+ neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors
183
+ edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times
184
+ edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs
185
+
186
+ return neighbors, edge_idxs, edge_times