tgn-playground / modules /message_aggregator.py
ashu316's picture
Upload 14 files
41aae2b verified
raw
history blame
3.25 kB
from collections import defaultdict
import torch
import numpy as np
class MessageAggregator(torch.nn.Module):
"""
Abstract class for the message aggregator module, which given a batch of node ids and
corresponding messages, aggregates messages with the same node id.
"""
def __init__(self, device):
super(MessageAggregator, self).__init__()
self.device = device
def aggregate(self, node_ids, messages):
"""
Given a list of node ids, and a list of messages of the same length, aggregate different
messages for the same id using one of the possible strategies.
:param node_ids: A list of node ids of length batch_size
:param messages: A tensor of shape [batch_size, message_length]
:param timestamps A tensor of shape [batch_size]
:return: A tensor of shape [n_unique_node_ids, message_length] with the aggregated messages
"""
def group_by_id(self, node_ids, messages, timestamps):
node_id_to_messages = defaultdict(list)
for i, node_id in enumerate(node_ids):
node_id_to_messages[node_id].append((messages[i], timestamps[i]))
return node_id_to_messages
class LastMessageAggregator(MessageAggregator):
def __init__(self, device):
super(LastMessageAggregator, self).__init__(device)
def aggregate(self, node_ids, messages):
"""Only keep the last message for each node"""
unique_node_ids = np.unique(node_ids)
unique_messages = []
unique_timestamps = []
to_update_node_ids = []
for node_id in unique_node_ids:
if len(messages[node_id]) > 0:
to_update_node_ids.append(node_id)
unique_messages.append(messages[node_id][-1][0])
unique_timestamps.append(messages[node_id][-1][1])
unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
return to_update_node_ids, unique_messages, unique_timestamps
class MeanMessageAggregator(MessageAggregator):
def __init__(self, device):
super(MeanMessageAggregator, self).__init__(device)
def aggregate(self, node_ids, messages):
"""Only keep the last message for each node"""
unique_node_ids = np.unique(node_ids)
unique_messages = []
unique_timestamps = []
to_update_node_ids = []
n_messages = 0
for node_id in unique_node_ids:
if len(messages[node_id]) > 0:
n_messages += len(messages[node_id])
to_update_node_ids.append(node_id)
unique_messages.append(torch.mean(torch.stack([m[0] for m in messages[node_id]]), dim=0))
unique_timestamps.append(messages[node_id][-1][1])
unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
return to_update_node_ids, unique_messages, unique_timestamps
def get_message_aggregator(aggregator_type, device):
if aggregator_type == "last":
return LastMessageAggregator(device=device)
elif aggregator_type == "mean":
return MeanMessageAggregator(device=device)
else:
raise ValueError("Message aggregator {} not implemented".format(aggregator_type))