File size: 4,991 Bytes
bcc12b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

def read_corpus(corpus_path:str):
    with open(corpus_path, 'r', encoding='utf-8') as f:
        text = f.read()
    return text



class BPEGujaratiTokenizer:
    def __init__(self, corpus_path:str, max_vocab_size:int=5000, sample_size:int=20000):
        self.corpus = read_corpus(corpus_path)
        self.max_vocab_size = max_vocab_size
        self.corpus_vocab = sorted(list(set(self.corpus)))
        self.corpus_vocab_size = len(self.corpus_vocab)
        self.stoi = { ch:i for i,ch in enumerate(self.corpus_vocab) }
        self.itos = { i:ch for i,ch in enumerate(self.corpus_vocab) }
        self.sample_size = sample_size

        self.vocab, self.merges = self.train_bpe(self.corpus, self.max_vocab_size, self.sample_size)


    def get_stats(self, ids):
        counts = {}
        for pair in zip(ids, ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts


    def merge(self,ids, pair, idx):
        newids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                newids.append(idx)
                i += 2
            else:
                newids.append(ids[i])
                i += 1
        return newids



    def train_bpe(self, corpus, max_vocab_size, sample_size=None):
        self.vocab = {idx: bytes([idx]) for idx in range(256)}
        if sample_size :
            corpus = corpus[:sample_size]
        num_merges = max_vocab_size - len(self.vocab)
        tokens = corpus.encode('utf-8')
        tokens= list(map(int, tokens))
        ids = list(tokens)
        self.merges = {} # (int, int) -> int
        print(f"Before training: ids length: {len(ids)}")
        print(f"Before training: tokens length: {len(tokens)}")
        print("Before training: merges length: ", len(self.merges))

        for i in range(num_merges):
            stats = self.get_stats(ids)
            pair = max(stats, key=stats.get)
            idx = len(self.vocab)+i
            ids = self.merge(ids, pair, idx)
            self.merges[pair] = idx
        # merge the vocab
        for (p0, p1), idx in self.merges.items():
            self.vocab[idx] = self.vocab[p0] + self.vocab[p1]
        print(f"After training: ids length: {len(ids)}")
        print(f"After training: tokens length: {len(tokens)}")
        print("After training: merges length: ", len(self.merges))
        print(f"compression ratio: {len(tokens) / len(ids):.2f}X")
        return self.vocab, self.merges

    def encode(self, text):
        tokens = list(text.encode("utf-8"))
        while len(tokens) >= 2:
            stats = self.get_stats(tokens)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # nothing else can be merged
            idx = self.merges[pair]
            tokens = self.merge(tokens, pair, idx)
        return tokens

    
    def decode(self, tokens):
        tokens = b"".join(self.vocab[idx] for idx in tokens)
        text = tokens.decode("utf-8", errors="replace")
        return text
    
import time
if __name__ == "__main__":
    start_time = time.time()
    tokenizer = BPEGujaratiTokenizer(corpus_path="gu_corpus.txt", max_vocab_size=5000, sample_size=20000)
    end_time = time.time()
    print(f"Time taken to train: {end_time - start_time} seconds")
    print("--------------------------------")
    start_time = time.time()
    print(tokenizer.encode("હું તને પ્રેમ કરું છું"))
    end_time = time.time()
    print(f"Time taken to encode: {end_time - start_time} seconds")
    print("--------------------------------")
    start_time = time.time()
    print(tokenizer.decode(tokenizer.encode("હું તને પ્રેમ કરું છું")))
    end_time = time.time()
    print(f"Time taken to decode: {end_time - start_time} seconds")
    print("--------------------------------")
    start_time = time.time()
    sentences = ["હું આજે ખૂબ ખુશ છું.","તું શું કરે છે? ","મને ચા પીવી છે. ","એ બધું સરસ છે. ","આ પુસ્તક ખૂબ રસપ્રદ છે. ","તારે ક્યારે આવવું છે? ","આ મારો મિત્ર છે. ","હું શાકભાજી લઈ આવ્યો છું. ","આકાશ માં વાદળ છે. ","શાળા ક્યારે શરૂ થશે? ",'આ પુસ્તક ખૂબ રસપ્રદ છે.']
    for sentence in sentences:
        print("original: ", sentence)
        print("encoded: ", tokenizer.encode(sentence))
        print("decoded: ", tokenizer.decode(tokenizer.encode(sentence)))
        print(tokenizer.decode(tokenizer.encode(sentence)) == sentence)
    end_time = time.time()
    print(f"Time taken to decode: {end_time - start_time} seconds")
    print("--------------------------------")