File size: 1,433 Bytes
c8ddb9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""LSTM-based textual encoder for tokenized input"""

from typing import Any

import torch
from torch import nn


class TextEncoder(nn.Module):
    """Simple text encoder based on RNN"""

    def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int) -> None:
        """
        Initialize embeddings lookup for tokens and main LSTM

        :param vocab_size:
            Size of created vocabulary for textual input. L from paper
        :param emb_dim: Length of embeddings for each word.
        :param hidden_dim:
            Length of hidden state of a LSTM cell. 2 x hidden_dim = C (from LWGAN paper)
        """
        super().__init__()
        self.embs = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, bidirectional=True, batch_first=True)

    def forward(self, tokens: torch.Tensor) -> Any:
        """
        Propagate the text token input through the LSTM and return
        two types of embeddings: word-level and sentence-level

        :param torch.Tensor tokens: Input text tokens from vocab
        :return: Word-level embeddings (BxCxL) and sentence-level embeddings (BxC)
        :rtype: Any
        """
        embs = self.embs(tokens)
        output, (hidden_states, _) = self.lstm(embs)
        word_embs = torch.transpose(output, 1, 2)
        sent_embs = torch.cat((hidden_states[-1, :, :], hidden_states[0, :, :]), dim=1)
        return word_embs, sent_embs