SVECTOR-OFFICIAL commited on
Commit
0bf33af
·
verified ·
1 Parent(s): f158e89

Create akshara_tokenizer.py

Browse files
Files changed (1) hide show
  1. akshara_tokenizer.py +142 -0
akshara_tokenizer.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import regex as re
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ from transformers import PreTrainedTokenizer
7
+ from transformers.utils import logging
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+ class AksharaTokenizer(PreTrainedTokenizer):
12
+ """
13
+ Akshara tokenizer for processing Indic language text.
14
+ This tokenizer handles characters at the akshara (syllable) level.
15
+ """
16
+
17
+ vocab_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
18
+ model_input_names = ["input_ids", "attention_mask"]
19
+
20
+ def __init__(
21
+ self,
22
+ vocab_file,
23
+ merges_file=None,
24
+ unk_token="<unk>",
25
+ bos_token="<s>",
26
+ eos_token="</s>",
27
+ pad_token="<pad>",
28
+ mask_token="<mask>",
29
+ add_prefix_space=False,
30
+ **kwargs
31
+ ):
32
+ super().__init__(
33
+ unk_token=unk_token,
34
+ bos_token=bos_token,
35
+ eos_token=eos_token,
36
+ pad_token=pad_token,
37
+ mask_token=mask_token,
38
+ add_prefix_space=add_prefix_space,
39
+ **kwargs,
40
+ )
41
+
42
+ # Load vocabulary
43
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
44
+ self.encoder = json.load(vocab_handle)
45
+ self.decoder = {v: k for k, v in self.encoder.items()}
46
+
47
+ # Load merges if available
48
+ self.merges = {}
49
+ if merges_file is not None and os.path.isfile(merges_file):
50
+ with open(merges_file, encoding="utf-8") as merges_handle:
51
+ merges = merges_handle.read().split("\n")
52
+ self.merges = {tuple(merge.split()): i for i, merge in enumerate(merges) if merge}
53
+
54
+ # Special token handling
55
+ self.add_prefix_space = add_prefix_space
56
+
57
+ # Pre-compile regex patterns for tokenization
58
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
59
+
60
+ @property
61
+ def vocab_size(self):
62
+ return len(self.encoder)
63
+
64
+ def get_vocab(self):
65
+ return dict(self.encoder, **self.added_tokens_encoder)
66
+
67
+ def _tokenize(self, text):
68
+ """Tokenize text into akshara units."""
69
+ if self.add_prefix_space and not text.startswith(" "):
70
+ text = " " + text
71
+ tokens = re.findall(self.pat, text)
72
+ return tokens
73
+
74
+ def _convert_token_to_id(self, token):
75
+ """Convert a token to its ID in the vocabulary."""
76
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
77
+
78
+ def _convert_id_to_token(self, index):
79
+ """Convert an ID to its token in the vocabulary."""
80
+ return self.decoder.get(index, self.unk_token)
81
+
82
+ def convert_tokens_to_string(self, tokens):
83
+ """Convert a sequence of tokens to a single string."""
84
+ text = "".join(tokens)
85
+ text = text.replace(" ", "").replace("▁", " ").strip()
86
+ return text
87
+
88
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
89
+ """Build model inputs from a sequence by appending eos_token_id."""
90
+ if token_ids_1 is None:
91
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
92
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
93
+
94
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
95
+ """Get list where entries are [1] if a token is special and [0] otherwise."""
96
+ if already_has_special_tokens:
97
+ return super().get_special_tokens_mask(
98
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
99
+ )
100
+
101
+ if token_ids_1 is None:
102
+ return [1] + ([0] * len(token_ids_0)) + [1]
103
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
104
+
105
+ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
106
+ """Create a mask from the two sequences for sequence classification tasks."""
107
+ eos = [self.eos_token_id]
108
+ bos = [self.bos_token_id]
109
+
110
+ if token_ids_1 is None:
111
+ return len(bos + token_ids_0 + eos) * [0]
112
+ return len(bos + token_ids_0 + eos) * [0] + len(token_ids_1 + eos) * [1]
113
+
114
+ def save_vocabulary(self, save_directory, filename_prefix=None):
115
+ """Save the vocabulary and merges files to a directory."""
116
+ if not os.path.isdir(save_directory):
117
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
118
+ return
119
+
120
+ vocab_file = os.path.join(
121
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
122
+ )
123
+
124
+ with open(vocab_file, "w", encoding="utf-8") as f:
125
+ f.write(json.dumps(self.encoder, ensure_ascii=False))
126
+
127
+ return (vocab_file,)
128
+
129
+
130
+ # Register the tokenizer with the AutoTokenizer class
131
+ from transformers import AutoTokenizer
132
+ AutoTokenizer.register("akshara", AksharaTokenizer)
133
+
134
+ # Register the model configuration if needed
135
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
136
+ if "akshara" not in CONFIG_MAPPING:
137
+ from transformers import PretrainedConfig
138
+
139
+ class AksharaConfig(PretrainedConfig):
140
+ model_type = "akshara"
141
+
142
+ CONFIG_MAPPING.register("akshara", AksharaConfig)