tgn-playground / modules /message_function.py
ashu316's picture
Upload 14 files
41aae2b verified
raw
history blame
1.02 kB
from torch import nn
class MessageFunction(nn.Module):
"""
Module which computes the message for a given interaction.
"""
def compute_message(self, raw_messages):
return None
class MLPMessageFunction(MessageFunction):
def __init__(self, raw_message_dimension, message_dimension):
super(MLPMessageFunction, self).__init__()
self.mlp = self.layers = nn.Sequential(
nn.Linear(raw_message_dimension, raw_message_dimension // 2),
nn.ReLU(),
nn.Linear(raw_message_dimension // 2, message_dimension),
)
def compute_message(self, raw_messages):
messages = self.mlp(raw_messages)
return messages
class IdentityMessageFunction(MessageFunction):
def compute_message(self, raw_messages):
return raw_messages
def get_message_function(module_type, raw_message_dimension, message_dimension):
if module_type == "mlp":
return MLPMessageFunction(raw_message_dimension, message_dimension)
elif module_type == "identity":
return IdentityMessageFunction()