File size: 813 Bytes
c8ddb9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import pickle
import re
from typing import List


class TAIMGANTokenizer:
    def __init__(self, captions_path):
        with open(captions_path, "rb") as ckpt_file:
            captions = pickle.load(ckpt_file)
            self.ix_to_word = captions[2]
            self.word_to_ix = captions[3]
        self.token_regex = r'\w+'
        self.pad_token_id = self.word_to_ix["<end>"]
        self.pad_repr = "[PAD]"

    def encode(self, text: str) -> List[int]:
        return [self.word_to_ix.get(word, self.pad_token_id)
                for word in re.findall(self.token_regex, text.lower())]

    def decode(self, tokens: List[int]) -> str:
        return ' '.join([self.ix_to_word[token]
                         if token != self.pad_token_id else self.pad_repr
                         for token in tokens])