PriyePrabhakar commited on
Commit
232c9b6
·
1 Parent(s): 9bbe09e

Added files for sanskritBPE tokenizer

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.tokenizer import SanskritBPETokenizer
3
+ import os
4
+ import random
5
+
6
+ # Initialize tokenizer
7
+ tokenizer = SanskritBPETokenizer(
8
+ merges_path='data/vocab',
9
+ token_path='data/vocab'
10
+ )
11
+
12
+ def generate_color(token_id: int) -> str:
13
+ """Generate a consistent color for a token ID"""
14
+ random.seed(token_id) # Make color consistent for same token
15
+ hue = random.randint(0, 360)
16
+ return f"hsl({hue}, 80%, 80%)"
17
+
18
+ def colorize_tokens(text: str) -> str:
19
+ """Convert text to HTML with colored token spans"""
20
+ if not text.strip():
21
+ return ""
22
+
23
+ tokens = tokenizer.encode(text)
24
+ decoded_pieces = []
25
+
26
+ for i, token_id in enumerate(tokens):
27
+ decoded_text = tokenizer.decode([token_id])
28
+ color = generate_color(token_id)
29
+ span = f'<span style="background-color: {color}; color: black; padding: 0 2px; border-radius: 3px; margin: 0 1px;" title="Token {token_id}">{decoded_text}</span>'
30
+ decoded_pieces.append(span)
31
+
32
+ return "".join(decoded_pieces)
33
+
34
+ def count_tokens(text: str, show_tokens: bool = False) -> tuple:
35
+ """Count tokens and return token visualization"""
36
+ if not text.strip():
37
+ return "0 tokens", ""
38
+
39
+ tokens = tokenizer.encode(text)
40
+ token_count = len(tokens)
41
+
42
+ if show_tokens:
43
+ decoded = tokenizer.decode(tokens)
44
+ token_info = f"{token_count} tokens\nTokens: {tokens}\nDecoded: {decoded}"
45
+ else:
46
+ token_info = f"{token_count} tokens"
47
+
48
+ colored_text = colorize_tokens(text)
49
+ return token_info, colored_text
50
+
51
+ # Custom CSS for better visualization
52
+ custom_css = """
53
+ footer {visibility: hidden}
54
+ .token-text {
55
+ font-family: monospace;
56
+ line-height: 1.8;
57
+ padding: 10px;
58
+ border-radius: 5px;
59
+ background: white;
60
+ margin: 10px 0;
61
+ color: black;
62
+ }
63
+ .gradio-container {
64
+ max-width: 1000px !important;
65
+ }
66
+ """
67
+
68
+ # Create the Gradio interface
69
+ with gr.Blocks(css=custom_css) as demo:
70
+ gr.Markdown(
71
+ """
72
+ # Sanskrit BPE Tokenizer
73
+
74
+ Test how the Sanskrit BPE tokenizer processes text. Enter Sanskrit text below to see how many tokens it uses.
75
+ Each colored span represents one token.
76
+ """
77
+ )
78
+
79
+ with gr.Row():
80
+ with gr.Column():
81
+ text_input = gr.Textbox(
82
+ label="Content",
83
+ placeholder="Enter Sanskrit text here...",
84
+ lines=5
85
+ )
86
+ show_tokens = gr.Checkbox(
87
+ label="Show token IDs and decoded text",
88
+ value=False
89
+ )
90
+
91
+ with gr.Column():
92
+ token_count = gr.Textbox(
93
+ label="Token count",
94
+ lines=2,
95
+ interactive=False
96
+ )
97
+ token_viz = gr.HTML(
98
+ label="Token visualization",
99
+ elem_classes=["token-text"]
100
+ )
101
+
102
+ # Update token count and visualization when text changes or checkbox is toggled
103
+ text_input.change(
104
+ fn=count_tokens,
105
+ inputs=[text_input, show_tokens],
106
+ outputs=[token_count, token_viz]
107
+ )
108
+ show_tokens.change(
109
+ fn=count_tokens,
110
+ inputs=[text_input, show_tokens],
111
+ outputs=[token_count, token_viz]
112
+ )
113
+
114
+ gr.Markdown(
115
+ """
116
+ ### Examples
117
+ Try these Sanskrit text samples:
118
+ """
119
+ )
120
+
121
+ gr.Examples(
122
+ examples=[
123
+ ["विश्वामित्रवचः श्रुत्वा राघवः सहलक्ष्मणः।"],
124
+ ["धर्मक्षेत्रे कुरुक्षेत्रे समवेता युयुत्सवः।"],
125
+ ["यदा यदा हि धर्मस्य ग्लानिर्भवति भारत।"],
126
+ ],
127
+ inputs=text_input
128
+ )
129
+
130
+ gr.Markdown(
131
+ """
132
+ ---
133
+ Built with [Gradio](https://gradio.app) | [GitHub Repository](https://github.com/PRIYE/SanskritBPETokenizer)
134
+ """
135
+ )
136
+
137
+ # Launch the app
138
+ if __name__ == "__main__":
139
+ demo.launch()
data/vocab/merges_saved.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:831479c60a34b724bae34fa1d16571748b7f5d008d6e71cb1cc6e8489d596f48
3
+ size 54143
data/vocab/saved.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfd750f97d79eaba6b14d575f7d71c723365311b22658f850699a5bed75397b8
3
+ size 41003929
merges_saved.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:831479c60a34b724bae34fa1d16571748b7f5d008d6e71cb1cc6e8489d596f48
3
+ size 54143
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy>=1.21.0
2
+ tqdm>=4.65.0
3
+ gradio>=4.11.0
4
+ datasets>=2.0.0
src/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .tokenizer import SanskritBPETokenizer
2
+ from .utils import save_merges, load_merges
3
+
4
+ __all__ = ['SanskritBPETokenizer', 'save_merges', 'load_merges']
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (320 Bytes). View file
 
src/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (9.35 kB). View file
 
src/__pycache__/utils.cpython-312.pyc ADDED
Binary file (1.92 kB). View file
 
src/tokenizer.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ import pickle
4
+ from typing import List, Dict, Tuple, Optional
5
+ from datasets import load_dataset
6
+ import re
7
+
8
+ class SanskritBPETokenizer:
9
+ def __init__(self, vocab_path:Optional[str] = None , merges_path: Optional[str] = None, token_path: Optional[str] = None):
10
+ """Initialize the tokenizer with vocabulary and merges"""
11
+ self.vocab = []
12
+ self.merges = {}
13
+ if merges_path:
14
+ self.load_vocab(merges_path)
15
+ if token_path:
16
+ self.load_tokens(token_path)
17
+ if vocab_path:
18
+ self.create_tokens(vocab_path, token_path, merges_path)
19
+
20
+ def create_tokens(self, vocab_path, token_path, merges_path):
21
+ dataset = load_dataset(vocab_path)
22
+ text = ''.join([i['translation']['sn'] for i in dataset['train']])
23
+ tokens = self.regex_sanskrit_tokenize(text)
24
+ tokens = text.encode("utf-8") # raw bytes
25
+ tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience
26
+ with open(token_path + '/saved.pkl', 'wb') as f:
27
+ pickle.dump(tokens, f, pickle.HIGHEST_PROTOCOL)
28
+ vocab_size = 5250 # the desired final vocabulary size
29
+ num_merges = vocab_size - 256
30
+ ids = list(tokens) # copy so we don't destroy the original list
31
+ merges = {} # (int, int) -> int
32
+ for i in range(num_merges):
33
+ stats = self.get_stats(ids)
34
+ pair = max(stats, key=stats.get)
35
+ idx = 256 + i
36
+ print(f"merging {pair} into a new token {idx}")
37
+ ids = self.merge(ids, pair, idx)
38
+ merges[pair] = idx
39
+ with open(merges_path + '/merges_saved.pkl', 'wb') as f:
40
+ pickle.dump(merges, f, pickle.HIGHEST_PROTOCOL)
41
+ print("tokens length:", len(tokens))
42
+ print("ids length:", len(ids))
43
+ print(f"compression ratio: {len(tokens) / len(ids):.2f}X")
44
+
45
+
46
+ def regex_sanskrit_tokenize(self, text):
47
+ # Basic sandhi patterns
48
+ sandhi_patterns = [
49
+ # # Visarga sandhi
50
+ # r'ः\s*([कखगघङचछजझञटठडढणतथदधनपफबभम])',
51
+
52
+ # # Vowel sandhi
53
+ # r'([अआइईउऊऋॠऌॡएऐओऔ])्?\s*([अआइईउऊऋॠऌॡएऐओऔ])',
54
+
55
+ # # Consonant sandhi
56
+ # r'([क-ह])्\s*([क-ह])',
57
+
58
+ # # Common contractions and combinations
59
+ # r'([क-ह])्([यरलवहमनञणन])',
60
+
61
+ # # Anusvara and chandrabindu combinations
62
+ # r'[ंँ]([क-ह])',
63
+
64
+ # # Handle special cases like ज्ञ, क्ष
65
+ # r'(ज्ञ|क्ष)',
66
+
67
+ # # Handle numbers and punctuation
68
+ # r'([०-९])|([।॥,])',
69
+ # # Handle specific compound formations
70
+ # r'([क-ह])्य', # -ya formations
71
+ # r'([क-ह])्र', # -ra formations
72
+
73
+ # # Handle specific prefixes
74
+ # r'(प्र|उप|अभि|नि|वि|आ|उद्|परि)',
75
+
76
+ # # Handle specific suffixes
77
+ # r'(तया|त्वम्|त्वात्)',
78
+
79
+ ##################
80
+ # Anusvara and visarga combinations
81
+ r'ं|ः',
82
+
83
+ # Common vowel sandhis
84
+ r'ा|ि|ी|ु|ू|ृ|ॄ|ॢ|ॣ|े|ै|ो|ौ',
85
+
86
+ # Virama (halant) combinations
87
+ r'्',
88
+
89
+ # Common consonant combinations
90
+ r'त्त|त्र|त्व|न्त|न्द|न्ध|श्च|श्व|ष्ट|स्त|स्थ|ह्म|ह्य',
91
+
92
+ # Basic word boundaries
93
+ r'\s+',
94
+
95
+ # Punctuation and numbers
96
+ r'[।॥॰,!?०-९]+',
97
+ ]
98
+
99
+ # Combine all patterns
100
+ pattern = '|'.join(sandhi_patterns)
101
+
102
+ # Function to process each match
103
+ def split_token(match):
104
+ token = match.group(0)
105
+ # Add spaces around the matched token
106
+ return f' {token} '
107
+
108
+ # Apply the regex
109
+ tokenized_text = re.sub(pattern, split_token, text)
110
+ print('tokenized_text',tokenized_text)
111
+
112
+ # Clean up extra spaces and split
113
+ tokens = [token.strip() for token in tokenized_text.split() if token.strip()]
114
+
115
+ return ' '.join(tokens)
116
+
117
+ def load_tokens(self, token_path: str):
118
+ """Load vocabulary and merges from file"""
119
+ with open(token_path + "/saved.pkl", "rb") as f:
120
+ self.tokens = pickle.load(f)
121
+ print("tokens length:", len(self.tokens))
122
+ chars = sorted(list(set(self.tokens)))
123
+
124
+
125
+ def load_vocab(self, vocab_path: str):
126
+ """Load vocabulary and merges from file"""
127
+ with open(vocab_path + "/merges_saved.pkl", "rb") as f:
128
+ self.merges = pickle.load(f)
129
+ #print(self.merges)
130
+ # Create reverse vocab from merges
131
+ self.vocab = {idx: bytes([idx]) for idx in range(256)}
132
+ for (p0, p1), idx in self.merges.items():
133
+ self.vocab[idx] = self.vocab[p0] + self.vocab[p1]
134
+ #print(self.vocab)
135
+
136
+ def get_stats(self, tokens: List[int]) -> Dict[Tuple[int, int], int]:
137
+ """Count frequency of token pairs"""
138
+ stats = {}
139
+ for pair in zip(tokens, tokens[1:]): # Pythonic way to iterate consecutive elements
140
+ stats[pair] = stats.get(pair, 0) + 1
141
+ return stats
142
+
143
+ def merge(self, tokens: List[int], pair: Tuple[int, int], idx: int) -> List[int]:
144
+ """Merge all occurrences of a token pair"""
145
+ new_tokens = []
146
+ i = 0
147
+ while i < len(tokens):
148
+ if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
149
+ new_tokens.append(idx)
150
+ i += 2
151
+ else:
152
+ new_tokens.append(tokens[i])
153
+ i += 1
154
+ return new_tokens
155
+
156
+ def encode(self, text: str) -> List[int]:
157
+ """Encode text to token IDs"""
158
+ tokens = list(text.encode("utf-8"))
159
+ while len(tokens) >= 2:
160
+ stats = self.get_stats(tokens)
161
+ pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
162
+ if pair not in self.merges:
163
+ break # nothing else can be merged
164
+ idx = self.merges[pair]
165
+ tokens = self.merge(tokens, pair, idx)
166
+ return tokens
167
+
168
+ def decode(self, ids: List[int]) -> str:
169
+ """Decode token IDs back to text"""
170
+ tokens = b"".join(self.vocab[idx] for idx in ids)
171
+ text = tokens.decode("utf-8", errors="replace")
172
+ return text
173
+
174
+ if __name__ == "__main__":
175
+ # Create tokens from text
176
+ vocab_path = 'rahular/itihasa' # loading sansakrit text from huggingface
177
+ #SanskritBPETokenizer(vocab_path = vocab_path, merges_path='/Users/priye/Desktop/ERAV3/SanskritBPETokenizer' , token_path='/Users/priye/Desktop/ERAV3/SanskritBPETokenizer' )
178
+
179
+ # Example usage
180
+ tokenizer = SanskritBPETokenizer(merges_path='/Users/priye/Desktop/ERAV3/SanskritBPETokenizer/data/vocab' , token_path='/Users/priye/Desktop/ERAV3/SanskritBPETokenizer/data/vocab' )
181
+
182
+ sample_text = "विश्वामित्रवचः श्रुत्वा राघवः सहलक्ष्मणः। विस्मयं परमं गत्वा विश्वामित्रमथाब्रवीत्॥"
183
+ encoded = tokenizer.encode(sample_text)
184
+ decoded = tokenizer.decode(encoded)
185
+
186
+ print(f"Original text: {sample_text}")
187
+ print(f"Encoded tokens: {encoded}")
188
+ print(f"Decoded text: {decoded}")
189
+ print(tokenizer.decode(tokenizer.encode(sample_text)))
190
+ assert sample_text == tokenizer.decode(tokenizer.encode(sample_text))
src/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, Tuple
4
+
5
+ def save_merges(merges: Dict[Tuple[int, int], int], save_path: str):
6
+ """Save merges dictionary to JSON file"""
7
+ # Convert tuple keys to strings for JSON serialization
8
+ serializable_merges = {f"{k[0]},{k[1]}": v for k, v in merges.items()}
9
+
10
+ save_dir = Path(save_path)
11
+ save_dir.mkdir(parents=True, exist_ok=True)
12
+
13
+ with open(save_dir / "merges.json", "w", encoding="utf-8") as f:
14
+ json.dump(serializable_merges, f, ensure_ascii=False, indent=2)
15
+
16
+ def load_merges(load_path: str) -> Dict[Tuple[int, int], int]:
17
+ """Load merges dictionary from JSON file"""
18
+ with open(Path(load_path) / "merges.json", "r", encoding="utf-8") as f:
19
+ serialized_merges = json.load(f)
20
+
21
+ # Convert string keys back to tuples
22
+ merges = {tuple(map(int, k.split(","))): v for k, v in serialized_merges.items()}
23
+ return merges