import json from pathlib import Path import pickle from typing import List, Dict, Tuple, Optional from datasets import load_dataset import re class SanskritBPETokenizer: def __init__(self, vocab_path:Optional[str] = None , merges_path: Optional[str] = None, token_path: Optional[str] = None): """Initialize the tokenizer with vocabulary and merges""" self.vocab = [] self.merges = {} if merges_path: self.load_vocab(merges_path) if token_path: self.load_tokens(token_path) if vocab_path: self.create_tokens(vocab_path, token_path, merges_path) def create_tokens(self, vocab_path, token_path, merges_path): dataset = load_dataset(vocab_path) text = ''.join([i['translation']['sn'] for i in dataset['train']]) tokens = self.regex_sanskrit_tokenize(text) tokens = text.encode("utf-8") # raw bytes tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience with open(token_path + '/saved.pkl', 'wb') as f: pickle.dump(tokens, f, pickle.HIGHEST_PROTOCOL) vocab_size = 5250 # the desired final vocabulary size num_merges = vocab_size - 256 ids = list(tokens) # copy so we don't destroy the original list merges = {} # (int, int) -> int for i in range(num_merges): stats = self.get_stats(ids) pair = max(stats, key=stats.get) idx = 256 + i print(f"merging {pair} into a new token {idx}") ids = self.merge(ids, pair, idx) merges[pair] = idx with open(merges_path + '/merges_saved.pkl', 'wb') as f: pickle.dump(merges, f, pickle.HIGHEST_PROTOCOL) print("tokens length:", len(tokens)) print("ids length:", len(ids)) print(f"compression ratio: {len(tokens) / len(ids):.2f}X") def regex_sanskrit_tokenize(self, text): # Basic sandhi patterns sandhi_patterns = [ # # Visarga sandhi # r'ः\s*([कखगघङचछजझञटठडढणतथदधनपफबभम])', # # Vowel sandhi # r'([अआइईउऊऋॠऌॡएऐओऔ])्?\s*([अआइईउऊऋॠऌॡएऐओऔ])', # # Consonant sandhi # r'([क-ह])्\s*([क-ह])', # # Common contractions and combinations # r'([क-ह])्([यरलवहमनञणन])', # # Anusvara and chandrabindu combinations # r'[ंँ]([क-ह])', # # Handle special cases like ज्ञ, क्ष # r'(ज्ञ|क्ष)', # # Handle numbers and punctuation # r'([०-९])|([।॥,])', # # Handle specific compound formations # r'([क-ह])्य', # -ya formations # r'([क-ह])्र', # -ra formations # # Handle specific prefixes # r'(प्र|उप|अभि|नि|वि|आ|उद्|परि)', # # Handle specific suffixes # r'(तया|त्वम्|त्वात्)', ################## # Anusvara and visarga combinations r'ं|ः', # Common vowel sandhis r'ा|ि|ी|ु|ू|ृ|ॄ|ॢ|ॣ|े|ै|ो|ौ', # Virama (halant) combinations r'्', # Common consonant combinations r'त्त|त्र|त्व|न्त|न्द|न्ध|श्च|श्व|ष्ट|स्त|स्थ|ह्म|ह्य', # Basic word boundaries r'\s+', # Punctuation and numbers r'[।॥॰,!?०-९]+', ] # Combine all patterns pattern = '|'.join(sandhi_patterns) # Function to process each match def split_token(match): token = match.group(0) # Add spaces around the matched token return f' {token} ' # Apply the regex tokenized_text = re.sub(pattern, split_token, text) print('tokenized_text',tokenized_text) # Clean up extra spaces and split tokens = [token.strip() for token in tokenized_text.split() if token.strip()] return ' '.join(tokens) def load_tokens(self, token_path: str): """Load vocabulary and merges from file""" with open(token_path + "/saved.pkl", "rb") as f: self.tokens = pickle.load(f) print("tokens length:", len(self.tokens)) chars = sorted(list(set(self.tokens))) def load_vocab(self, vocab_path: str): """Load vocabulary and merges from file""" with open(vocab_path + "/merges_saved.pkl", "rb") as f: self.merges = pickle.load(f) #print(self.merges) # Create reverse vocab from merges self.vocab = {idx: bytes([idx]) for idx in range(256)} for (p0, p1), idx in self.merges.items(): self.vocab[idx] = self.vocab[p0] + self.vocab[p1] #print(self.vocab) def get_stats(self, tokens: List[int]) -> Dict[Tuple[int, int], int]: """Count frequency of token pairs""" stats = {} for pair in zip(tokens, tokens[1:]): # Pythonic way to iterate consecutive elements stats[pair] = stats.get(pair, 0) + 1 return stats def merge(self, tokens: List[int], pair: Tuple[int, int], idx: int) -> List[int]: """Merge all occurrences of a token pair""" new_tokens = [] i = 0 while i < len(tokens): if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]: new_tokens.append(idx) i += 2 else: new_tokens.append(tokens[i]) i += 1 return new_tokens def encode(self, text: str) -> List[int]: """Encode text to token IDs""" tokens = list(text.encode("utf-8")) while len(tokens) >= 2: stats = self.get_stats(tokens) pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) if pair not in self.merges: break # nothing else can be merged idx = self.merges[pair] tokens = self.merge(tokens, pair, idx) return tokens def decode(self, ids: List[int]) -> str: """Decode token IDs back to text""" tokens = b"".join(self.vocab[idx] for idx in ids) text = tokens.decode("utf-8", errors="replace") return text if __name__ == "__main__": # Create tokens from text vocab_path = 'rahular/itihasa' # loading sansakrit text from huggingface #SanskritBPETokenizer(vocab_path = vocab_path, merges_path='/Users/priye/Desktop/ERAV3/SanskritBPETokenizer' , token_path='/Users/priye/Desktop/ERAV3/SanskritBPETokenizer' ) # Example usage tokenizer = SanskritBPETokenizer(merges_path='/Users/priye/Desktop/ERAV3/SanskritBPETokenizer/data/vocab' , token_path='/Users/priye/Desktop/ERAV3/SanskritBPETokenizer/data/vocab' ) sample_text = "विश्वामित्रवचः श्रुत्वा राघवः सहलक्ष्मणः। विस्मयं परमं गत्वा विश्वामित्रमथाब्रवीत्॥" encoded = tokenizer.encode(sample_text) decoded = tokenizer.decode(encoded) print(f"Original text: {sample_text}") print(f"Encoded tokens: {encoded}") print(f"Decoded text: {decoded}") print(tokenizer.decode(tokenizer.encode(sample_text))) assert sample_text == tokenizer.decode(tokenizer.encode(sample_text))