Spaces:
Sleeping
Sleeping
from torch import nn | |
class MessageFunction(nn.Module): | |
""" | |
Module which computes the message for a given interaction. | |
""" | |
def compute_message(self, raw_messages): | |
return None | |
class MLPMessageFunction(MessageFunction): | |
def __init__(self, raw_message_dimension, message_dimension): | |
super(MLPMessageFunction, self).__init__() | |
self.mlp = self.layers = nn.Sequential( | |
nn.Linear(raw_message_dimension, raw_message_dimension // 2), | |
nn.ReLU(), | |
nn.Linear(raw_message_dimension // 2, message_dimension), | |
) | |
def compute_message(self, raw_messages): | |
messages = self.mlp(raw_messages) | |
return messages | |
class IdentityMessageFunction(MessageFunction): | |
def compute_message(self, raw_messages): | |
return raw_messages | |
def get_message_function(module_type, raw_message_dimension, message_dimension): | |
if module_type == "mlp": | |
return MLPMessageFunction(raw_message_dimension, message_dimension) | |
elif module_type == "identity": | |
return IdentityMessageFunction() | |