File size: 5,306 Bytes
0bf33af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import json
import regex as re
from typing import Dict, List, Optional, Tuple, Union

from transformers import PreTrainedTokenizer
from transformers.utils import logging

logger = logging.get_logger(__name__)

class AksharaTokenizer(PreTrainedTokenizer):
   """
   Akshara tokenizer for processing Indic language text.
   This tokenizer handles characters at the akshara (syllable) level.
   """
   
   vocab_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
   model_input_names = ["input_ids", "attention_mask"]
   
   def __init__(
       self,
       vocab_file,
       merges_file=None,
       unk_token="<unk>",
       bos_token="<s>",
       eos_token="</s>",
       pad_token="<pad>",
       mask_token="<mask>",
       add_prefix_space=False,
       **kwargs
   ):
       super().__init__(
           unk_token=unk_token,
           bos_token=bos_token,
           eos_token=eos_token,
           pad_token=pad_token,
           mask_token=mask_token,
           add_prefix_space=add_prefix_space,
           **kwargs,
       )
       
       # Load vocabulary
       with open(vocab_file, encoding="utf-8") as vocab_handle:
           self.encoder = json.load(vocab_handle)
       self.decoder = {v: k for k, v in self.encoder.items()}
       
       # Load merges if available
       self.merges = {}
       if merges_file is not None and os.path.isfile(merges_file):
           with open(merges_file, encoding="utf-8") as merges_handle:
               merges = merges_handle.read().split("\n")
               self.merges = {tuple(merge.split()): i for i, merge in enumerate(merges) if merge}
       
       # Special token handling
       self.add_prefix_space = add_prefix_space
       
       # Pre-compile regex patterns for tokenization
       self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
   
   @property
   def vocab_size(self):
       return len(self.encoder)
   
   def get_vocab(self):
       return dict(self.encoder, **self.added_tokens_encoder)
   
   def _tokenize(self, text):
       """Tokenize text into akshara units."""
       if self.add_prefix_space and not text.startswith(" "):
           text = " " + text
       tokens = re.findall(self.pat, text)
       return tokens
   
   def _convert_token_to_id(self, token):
       """Convert a token to its ID in the vocabulary."""
       return self.encoder.get(token, self.encoder.get(self.unk_token))
   
   def _convert_id_to_token(self, index):
       """Convert an ID to its token in the vocabulary."""
       return self.decoder.get(index, self.unk_token)
   
   def convert_tokens_to_string(self, tokens):
       """Convert a sequence of tokens to a single string."""
       text = "".join(tokens)
       text = text.replace(" ", "").replace("▁", " ").strip()
       return text
   
   def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
       """Build model inputs from a sequence by appending eos_token_id."""
       if token_ids_1 is None:
           return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
       return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
   
   def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
       """Get list where entries are [1] if a token is special and [0] otherwise."""
       if already_has_special_tokens:
           return super().get_special_tokens_mask(
               token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
           )
       
       if token_ids_1 is None:
           return [1] + ([0] * len(token_ids_0)) + [1]
       return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
   
   def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
       """Create a mask from the two sequences for sequence classification tasks."""
       eos = [self.eos_token_id]
       bos = [self.bos_token_id]
       
       if token_ids_1 is None:
           return len(bos + token_ids_0 + eos) * [0]
       return len(bos + token_ids_0 + eos) * [0] + len(token_ids_1 + eos) * [1]
   
   def save_vocabulary(self, save_directory, filename_prefix=None):
       """Save the vocabulary and merges files to a directory."""
       if not os.path.isdir(save_directory):
           logger.error(f"Vocabulary path ({save_directory}) should be a directory")
           return
       
       vocab_file = os.path.join(
           save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
       )
       
       with open(vocab_file, "w", encoding="utf-8") as f:
           f.write(json.dumps(self.encoder, ensure_ascii=False))
           
       return (vocab_file,)


# Register the tokenizer with the AutoTokenizer class
from transformers import AutoTokenizer
AutoTokenizer.register("akshara", AksharaTokenizer)

# Register the model configuration if needed
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
if "akshara" not in CONFIG_MAPPING:
   from transformers import PretrainedConfig
   
   class AksharaConfig(PretrainedConfig):
       model_type = "akshara"
   
   CONFIG_MAPPING.register("akshara", AksharaConfig)