Spaces:
Sleeping
Sleeping
from collections import defaultdict | |
import torch | |
import numpy as np | |
class MessageAggregator(torch.nn.Module): | |
""" | |
Abstract class for the message aggregator module, which given a batch of node ids and | |
corresponding messages, aggregates messages with the same node id. | |
""" | |
def __init__(self, device): | |
super(MessageAggregator, self).__init__() | |
self.device = device | |
def aggregate(self, node_ids, messages): | |
""" | |
Given a list of node ids, and a list of messages of the same length, aggregate different | |
messages for the same id using one of the possible strategies. | |
:param node_ids: A list of node ids of length batch_size | |
:param messages: A tensor of shape [batch_size, message_length] | |
:param timestamps A tensor of shape [batch_size] | |
:return: A tensor of shape [n_unique_node_ids, message_length] with the aggregated messages | |
""" | |
def group_by_id(self, node_ids, messages, timestamps): | |
node_id_to_messages = defaultdict(list) | |
for i, node_id in enumerate(node_ids): | |
node_id_to_messages[node_id].append((messages[i], timestamps[i])) | |
return node_id_to_messages | |
class LastMessageAggregator(MessageAggregator): | |
def __init__(self, device): | |
super(LastMessageAggregator, self).__init__(device) | |
def aggregate(self, node_ids, messages): | |
"""Only keep the last message for each node""" | |
unique_node_ids = np.unique(node_ids) | |
unique_messages = [] | |
unique_timestamps = [] | |
to_update_node_ids = [] | |
for node_id in unique_node_ids: | |
if len(messages[node_id]) > 0: | |
to_update_node_ids.append(node_id) | |
unique_messages.append(messages[node_id][-1][0]) | |
unique_timestamps.append(messages[node_id][-1][1]) | |
unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else [] | |
unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else [] | |
return to_update_node_ids, unique_messages, unique_timestamps | |
class MeanMessageAggregator(MessageAggregator): | |
def __init__(self, device): | |
super(MeanMessageAggregator, self).__init__(device) | |
def aggregate(self, node_ids, messages): | |
"""Only keep the last message for each node""" | |
unique_node_ids = np.unique(node_ids) | |
unique_messages = [] | |
unique_timestamps = [] | |
to_update_node_ids = [] | |
n_messages = 0 | |
for node_id in unique_node_ids: | |
if len(messages[node_id]) > 0: | |
n_messages += len(messages[node_id]) | |
to_update_node_ids.append(node_id) | |
unique_messages.append(torch.mean(torch.stack([m[0] for m in messages[node_id]]), dim=0)) | |
unique_timestamps.append(messages[node_id][-1][1]) | |
unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else [] | |
unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else [] | |
return to_update_node_ids, unique_messages, unique_timestamps | |
def get_message_aggregator(aggregator_type, device): | |
if aggregator_type == "last": | |
return LastMessageAggregator(device=device) | |
elif aggregator_type == "mean": | |
return MeanMessageAggregator(device=device) | |
else: | |
raise ValueError("Message aggregator {} not implemented".format(aggregator_type)) | |