"""Contains various kinds of embeddings like Glove, BERT, etc.""" from torch.nn import Module, Embedding, Flatten from src.utils.mapper import configmapper @configmapper.map("embeddings", "glove") class GloveEmbedding(Module): """Implement Glove based Word Embedding.""" def __init__(self, embedding_matrix, padding_idx, static=True): """Construct GloveEmbedding. Args: embedding_matrix (torch.Tensor): The matrix contrainining the embedding weights padding_idx (int): The padding index in the tokenizer. static (bool): Whether or not to freeze embeddings. """ super(GloveEmbedding, self).__init__() self.embedding = Embedding.from_pretrained(embedding_matrix) self.embedding.padding_idx = padding_idx if static: self.embedding.weight.required_grad = False self.flatten = Flatten(start_dim=1) def forward(self, x_input): """Pass the input through the embedding. Args: x_input (torch.Tensor): The numericalized tokenized input Returns: x_output (torch.Tensor): The output from the embedding """ x_output = self.embedding(x_input) x_output = self.flatten(x_output) return x_output