xlm-roberta-large / src /models /two_layer_nn.py
shayekh's picture
Upload 61 files
cc9c7ee
raw
history blame
1.44 kB
"""Implements a two layer Neural Network."""
from torch.nn import Module, Linear, ReLU
from src.utils.mapper import configmapper
@configmapper.map("models", "two_layer_nn")
class TwoLayerNN(Module):
"""Implements two layer neural network.
Methods:
forward(x_input): Returns the output of the neural network.
"""
def __init__(self, embedding, dims):
"""Construct the two layer Neural Network.
This method is used to initialize the two layer neural network,
with a given embedding type and corresponding arguments.
Args:
embedding (torch.nn.Module): The embedding layer for the model.
dims (list): List of dimensions for the neural network, input to output.
"""
super(TwoLayerNN, self).__init__()
self.embedding = embedding
self.linear1 = Linear(dims[0], dims[1])
self.relu = ReLU()
self.linear2 = Linear(dims[1], dims[2])
def forward(self, x_input):
"""
Return the output of the neural network for an input.
Args:
x_input (torch.Tensor): The input tensor to the neural network.
Returns:
x_output (torch.Tensor): The output tensor for the neural network.
"""
output = self.embedding(x_input)
output = self.linear1(output)
output = self.relu(output)
x_output = self.linear2(output)
return x_output