pranjalchitale commited on
Commit
67587b6
·
verified ·
1 Parent(s): 66390ee

Address as_target_tokenizer deprecation in Transformers v5.

Browse files
Files changed (1) hide show
  1. tokenization_indictrans.py +134 -144
tokenization_indictrans.py CHANGED
@@ -1,55 +1,54 @@
1
  import os
2
  import json
3
 
 
4
  from typing import Dict, List, Optional, Union, Tuple
5
 
6
- from transformers.utils import logging
7
  from sentencepiece import SentencePieceProcessor
8
  from transformers.tokenization_utils import PreTrainedTokenizer
9
 
10
 
11
  logger = logging.get_logger(__name__)
12
 
13
- SPIECE_UNDERLINE = "▁"
14
-
15
- SPECIAL_TAGS = {
16
- "_bt_",
17
- "_ft_",
18
- "asm_Beng",
19
- "awa_Deva",
20
- "ben_Beng",
21
- "bho_Deva",
22
- "brx_Deva",
23
- "doi_Deva",
24
- "eng_Latn",
25
- "gom_Deva",
26
- "gon_Deva",
27
- "guj_Gujr",
28
- "hin_Deva",
29
- "hne_Deva",
30
- "kan_Knda",
31
- "kas_Arab",
32
- "kas_Deva",
33
- "kha_Latn",
34
- "lus_Latn",
35
- "mag_Deva",
36
- "mai_Deva",
37
- "mal_Mlym",
38
- "mar_Deva",
39
- "mni_Beng",
40
- "mni_Mtei",
41
- "npi_Deva",
42
- "ory_Orya",
43
- "pan_Guru",
44
- "san_Deva",
45
- "sat_Olck",
46
- "snd_Arab",
47
- "snd_Deva",
48
- "tam_Taml",
49
- "tel_Telu",
50
- "urd_Arab",
51
- "unr_Deva",
52
- }
53
 
54
  VOCAB_FILES_NAMES = {
55
  "src_vocab_fp": "dict.SRC.json",
@@ -60,9 +59,8 @@ VOCAB_FILES_NAMES = {
60
 
61
 
62
  class IndicTransTokenizer(PreTrainedTokenizer):
63
- _added_tokens_encoder = {}
64
- _added_tokens_decoder = {}
65
-
66
  vocab_files_names = VOCAB_FILES_NAMES
67
  model_input_names = ["input_ids", "attention_mask"]
68
 
@@ -79,47 +77,55 @@ class IndicTransTokenizer(PreTrainedTokenizer):
79
  do_lower_case=False,
80
  **kwargs,
81
  ):
82
-
83
- self.src = True
84
-
85
  self.src_vocab_fp = src_vocab_fp
86
  self.tgt_vocab_fp = tgt_vocab_fp
87
  self.src_spm_fp = src_spm_fp
88
  self.tgt_spm_fp = tgt_spm_fp
89
 
90
- self.unk_token = unk_token.content
91
- self.pad_token = pad_token.content
92
- self.eos_token = eos_token.content
93
- self.bos_token = bos_token.content
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- self.encoder = self._load_json(self.src_vocab_fp)
96
- if self.unk_token not in self.encoder:
97
  raise KeyError("<unk> token must be in vocab")
98
- assert self.pad_token in self.encoder
99
- self.encoder_rev = {v: k for k, v in self.encoder.items()}
100
 
101
- self.decoder = self._load_json(self.tgt_vocab_fp)
102
- if self.unk_token not in self.encoder:
103
- raise KeyError("<unk> token must be in vocab")
104
- assert self.pad_token in self.encoder
105
- self.decoder_rev = {v: k for k, v in self.decoder.items()}
106
 
107
- # load SentencePiece model for pre-processing
108
  self.src_spm = self._load_spm(self.src_spm_fp)
109
  self.tgt_spm = self._load_spm(self.tgt_spm_fp)
110
 
111
- self.current_spm = self.src_spm
112
- self.current_encoder = self.encoder
113
- self.current_encoder_rev = self.encoder_rev
114
 
115
- self.unk_token_id = self.encoder[self.unk_token]
116
- self.pad_token_id = self.encoder[self.pad_token]
117
- self.eos_token_id = self.encoder[self.eos_token]
118
- self.bos_token_id = self.encoder[self.bos_token]
 
119
 
120
  super().__init__(
121
  src_vocab_file=self.src_vocab_fp,
122
- tgt_vocab_file=self.src_vocab_fp,
123
  do_lower_case=do_lower_case,
124
  unk_token=unk_token,
125
  bos_token=bos_token,
@@ -128,134 +134,118 @@ class IndicTransTokenizer(PreTrainedTokenizer):
128
  **kwargs,
129
  )
130
 
131
- def add_new_special_tags(self, new_tags: List[str]):
132
- SPECIAL_TAGS.update(new_tags)
 
133
 
134
- def _switch_to_input_mode(self):
135
- self.src = True
136
  self.padding_side = "left"
137
- self.current_spm = self.src_spm
138
- self.current_encoder = self.encoder
139
- self.current_encoder_rev = self.encoder_rev
140
 
141
- def _switch_to_target_mode(self):
142
- self.src = False
143
  self.padding_side = "right"
144
- self.current_spm = self.tgt_spm
145
- self.current_encoder = self.decoder
146
- self.current_encoder_rev = self.decoder_rev
147
 
148
- def _load_spm(self, path: str) -> SentencePieceProcessor:
 
149
  return SentencePieceProcessor(model_file=path)
150
 
151
- def _save_json(self, data, path: str) -> None:
 
152
  with open(path, "w", encoding="utf-8") as f:
153
  json.dump(data, f, indent=2)
154
 
155
- def _load_json(self, path: str) -> Union[Dict, List]:
 
156
  with open(path, "r", encoding="utf-8") as f:
157
  return json.load(f)
158
 
159
- def _split_tags(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
160
- tags = [token for token in tokens if token in SPECIAL_TAGS]
161
- tokens = [token for token in tokens if token not in SPECIAL_TAGS]
162
- return tags, tokens
163
-
164
- def _split_pads(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
165
- pads = [token for token in tokens if token == self.pad_token]
166
- tokens = [token for token in tokens if token != self.pad_token]
167
- return pads, tokens
168
-
169
  @property
170
  def src_vocab_size(self) -> int:
171
- return len(self.encoder)
172
 
173
  @property
174
  def tgt_vocab_size(self) -> int:
175
- return len(self.decoder)
176
 
177
  def get_src_vocab(self) -> Dict[str, int]:
178
- return dict(self.encoder, **self.added_tokens_encoder)
179
 
180
  def get_tgt_vocab(self) -> Dict[str, int]:
181
- return dict(self.decoder, **self.added_tokens_decoder)
182
 
183
- # hack override
184
  def get_vocab(self) -> Dict[str, int]:
185
  return self.get_src_vocab()
186
 
187
- # hack override
188
  @property
189
  def vocab_size(self) -> int:
190
  return self.src_vocab_size
191
 
192
  def _convert_token_to_id(self, token: str) -> int:
193
- """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
194
- return self.current_encoder.get(token, self.current_encoder[self.unk_token])
195
 
196
  def _convert_id_to_token(self, index: int) -> str:
197
- """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
198
- return self.current_encoder_rev.get(index, self.unk_token)
199
 
200
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
201
- """Uses sentencepiece model for detokenization"""
202
- pads, tokens = self._split_pads(tokens)
203
-
204
- if self.src:
205
 
206
- tags, non_tags = self._split_tags(tokens)
 
 
 
 
207
 
208
- return (
209
- " ".join(pads)
210
- + " "
211
- + " ".join(tags)
212
- + " "
213
- + "".join(non_tags).replace(SPIECE_UNDERLINE, " ").strip()
214
- )
215
 
216
- return (
217
- "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
218
- + " "
219
- + " ".join(pads)
 
 
 
 
 
 
 
 
 
 
 
220
  )
221
-
222
- def _tokenize(self, text) -> List[str]:
223
- if self.src:
224
- tokens = text.split(" ")
225
- tags, non_tags = self._split_tags(tokens)
226
- text = " ".join(non_tags)
227
- tokens = self.current_spm.EncodeAsPieces(text)
228
- return tags + tokens
229
- else:
230
- return self.current_spm.EncodeAsPieces(text)
231
 
232
  def build_inputs_with_special_tokens(
233
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
  ) -> List[int]:
235
- if token_ids_1 is None:
236
- return token_ids_0 + [self.eos_token_id]
237
- # We don't expect to process pairs, but leave the pair logic for API consistency
238
- return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
239
 
240
  def save_vocabulary(
241
  self, save_directory: str, filename_prefix: Optional[str] = None
242
- ) -> Tuple[str]:
243
  if not os.path.isdir(save_directory):
244
  logger.error(f"Vocabulary path ({save_directory}) should be a directory")
245
- return
246
 
247
  src_spm_fp = os.path.join(save_directory, "model.SRC")
248
  tgt_spm_fp = os.path.join(save_directory, "model.TGT")
249
  src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
250
  tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
251
 
252
- self._save_json(self.encoder, src_vocab_fp)
253
- self._save_json(self.decoder, tgt_vocab_fp)
254
-
255
- with open(src_spm_fp, "wb") as f:
256
- f.write(self.src_spm.serialized_model_proto())
257
 
258
- with open(tgt_spm_fp, "wb") as f:
259
- f.write(self.tgt_spm.serialized_model_proto())
 
260
 
261
- return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp
 
1
  import os
2
  import json
3
 
4
+ from transformers.utils import logging
5
  from typing import Dict, List, Optional, Union, Tuple
6
 
 
7
  from sentencepiece import SentencePieceProcessor
8
  from transformers.tokenization_utils import PreTrainedTokenizer
9
 
10
 
11
  logger = logging.get_logger(__name__)
12
 
13
+ # Convert LANGUAGE_TAGS to a frozen set for faster lookups
14
+ LANGUAGE_TAGS = frozenset(
15
+ {
16
+ "asm_Beng",
17
+ "awa_Deva",
18
+ "ben_Beng",
19
+ "bho_Deva",
20
+ "brx_Deva",
21
+ "doi_Deva",
22
+ "eng_Latn",
23
+ "gom_Deva",
24
+ "gon_Deva",
25
+ "guj_Gujr",
26
+ "hin_Deva",
27
+ "hne_Deva",
28
+ "kan_Knda",
29
+ "kas_Arab",
30
+ "kas_Deva",
31
+ "kha_Latn",
32
+ "lus_Latn",
33
+ "mag_Deva",
34
+ "mai_Deva",
35
+ "mal_Mlym",
36
+ "mar_Deva",
37
+ "mni_Beng",
38
+ "mni_Mtei",
39
+ "npi_Deva",
40
+ "ory_Orya",
41
+ "pan_Guru",
42
+ "san_Deva",
43
+ "sat_Olck",
44
+ "snd_Arab",
45
+ "snd_Deva",
46
+ "tam_Taml",
47
+ "tel_Telu",
48
+ "urd_Arab",
49
+ "unr_Deva",
50
+ }
51
+ )
 
52
 
53
  VOCAB_FILES_NAMES = {
54
  "src_vocab_fp": "dict.SRC.json",
 
59
 
60
 
61
  class IndicTransTokenizer(PreTrainedTokenizer):
62
+ _added_tokens_encoder: Dict[str, int] = {}
63
+ _added_tokens_decoder: Dict[str, int] = {}
 
64
  vocab_files_names = VOCAB_FILES_NAMES
65
  model_input_names = ["input_ids", "attention_mask"]
66
 
 
77
  do_lower_case=False,
78
  **kwargs,
79
  ):
 
 
 
80
  self.src_vocab_fp = src_vocab_fp
81
  self.tgt_vocab_fp = tgt_vocab_fp
82
  self.src_spm_fp = src_spm_fp
83
  self.tgt_spm_fp = tgt_spm_fp
84
 
85
+ # Store token content directly instead of accessing .content
86
+ self.unk_token = (
87
+ hasattr(unk_token, "content") and unk_token.content or unk_token
88
+ )
89
+ self.pad_token = (
90
+ hasattr(pad_token, "content") and pad_token.content or pad_token
91
+ )
92
+ self.eos_token = (
93
+ hasattr(eos_token, "content") and eos_token.content or eos_token
94
+ )
95
+ self.bos_token = (
96
+ hasattr(bos_token, "content") and bos_token.content or bos_token
97
+ )
98
+
99
+ # Load vocabularies
100
+ self.src_encoder = self._load_json(self.src_vocab_fp)
101
+ self.tgt_encoder = self._load_json(self.tgt_vocab_fp)
102
 
103
+ # Validate tokens
104
+ if self.unk_token not in self.src_encoder:
105
  raise KeyError("<unk> token must be in vocab")
106
+ if self.pad_token not in self.src_encoder:
107
+ raise KeyError("<pad> token must be in vocab")
108
 
109
+ # Pre-compute reverse mappings
110
+ self.src_decoder = {v: k for k, v in self.src_encoder.items()}
111
+ self.tgt_decoder = {v: k for k, v in self.tgt_encoder.items()}
 
 
112
 
113
+ # Load SPM models
114
  self.src_spm = self._load_spm(self.src_spm_fp)
115
  self.tgt_spm = self._load_spm(self.tgt_spm_fp)
116
 
117
+ # Initialize current settings
118
+ self._switch_to_input_mode()
 
119
 
120
+ # Cache token IDs
121
+ self.unk_token_id = self.src_encoder[self.unk_token]
122
+ self.pad_token_id = self.src_encoder[self.pad_token]
123
+ self.eos_token_id = self.src_encoder[self.eos_token]
124
+ self.bos_token_id = self.src_encoder[self.bos_token]
125
 
126
  super().__init__(
127
  src_vocab_file=self.src_vocab_fp,
128
+ tgt_vocab_file=self.tgt_vocab_fp,
129
  do_lower_case=do_lower_case,
130
  unk_token=unk_token,
131
  bos_token=bos_token,
 
134
  **kwargs,
135
  )
136
 
137
+ def add_new_language_tags(self, new_tags: List[str]) -> None:
138
+ global LANGUAGE_TAGS
139
+ LANGUAGE_TAGS = frozenset(LANGUAGE_TAGS | set(new_tags))
140
 
141
+ def _switch_to_input_mode(self) -> None:
142
+ self.spm = self.src_spm
143
  self.padding_side = "left"
144
+ self.encoder = self.src_encoder
145
+ self.decoder = self.src_decoder
146
+ self._tokenize = self._src_tokenize
147
 
148
+ def _switch_to_target_mode(self) -> None:
149
+ self.spm = self.tgt_spm
150
  self.padding_side = "right"
151
+ self.encoder = self.tgt_encoder
152
+ self.decoder = self.tgt_decoder
153
+ self._tokenize = self._tgt_tokenize
154
 
155
+ @staticmethod
156
+ def _load_spm(path: str) -> SentencePieceProcessor:
157
  return SentencePieceProcessor(model_file=path)
158
 
159
+ @staticmethod
160
+ def _save_json(data: Union[Dict, List], path: str) -> None:
161
  with open(path, "w", encoding="utf-8") as f:
162
  json.dump(data, f, indent=2)
163
 
164
+ @staticmethod
165
+ def _load_json(path: str) -> Union[Dict, List]:
166
  with open(path, "r", encoding="utf-8") as f:
167
  return json.load(f)
168
 
 
 
 
 
 
 
 
 
 
 
169
  @property
170
  def src_vocab_size(self) -> int:
171
+ return len(self.src_encoder)
172
 
173
  @property
174
  def tgt_vocab_size(self) -> int:
175
+ return len(self.tgt_encoder)
176
 
177
  def get_src_vocab(self) -> Dict[str, int]:
178
+ return dict(self.src_encoder, **self.added_tokens_encoder)
179
 
180
  def get_tgt_vocab(self) -> Dict[str, int]:
181
+ return dict(self.tgt_encoder, **self.added_tokens_decoder)
182
 
 
183
  def get_vocab(self) -> Dict[str, int]:
184
  return self.get_src_vocab()
185
 
 
186
  @property
187
  def vocab_size(self) -> int:
188
  return self.src_vocab_size
189
 
190
  def _convert_token_to_id(self, token: str) -> int:
191
+ return self.encoder.get(token, self.unk_token_id)
 
192
 
193
  def _convert_id_to_token(self, index: int) -> str:
194
+ return self.decoder.get(index, self.unk_token)
 
195
 
196
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
197
+ return "".join(tokens).replace("▁", " ").strip()
 
 
 
198
 
199
+ def _src_tokenize(self, text: str) -> List[str]:
200
+ src_lang, tgt_lang, text = text.split(" ", 2)
201
+ assert src_lang in LANGUAGE_TAGS, f"Invalid source language tag: {src_lang}"
202
+ assert tgt_lang in LANGUAGE_TAGS, f"Invalid target language tag: {tgt_lang}"
203
+ return [src_lang, tgt_lang] + self.spm.EncodeAsPieces(text)
204
 
205
+ def _tgt_tokenize(self, text: str) -> List[str]:
206
+ return self.spm.EncodeAsPieces(text)
 
 
 
 
 
207
 
208
+ def _decode(
209
+ self,
210
+ token_ids: Union[int, List[int]],
211
+ skip_special_tokens: bool = False,
212
+ clean_up_tokenization_spaces: bool = None,
213
+ spaces_between_special_tokens: bool = True,
214
+ **kwargs,
215
+ ) -> str:
216
+ self._switch_to_target_mode()
217
+ decoded_token_ids = super()._decode(
218
+ token_ids=token_ids,
219
+ skip_special_tokens=skip_special_tokens,
220
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
221
+ spaces_between_special_tokens=spaces_between_special_tokens,
222
+ **kwargs,
223
  )
224
+ self._switch_to_input_mode()
225
+ return decoded_token_ids
 
 
 
 
 
 
 
 
226
 
227
  def build_inputs_with_special_tokens(
228
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
229
  ) -> List[int]:
230
+ return token_ids_0 + [self.eos_token_id]
 
 
 
231
 
232
  def save_vocabulary(
233
  self, save_directory: str, filename_prefix: Optional[str] = None
234
+ ) -> Tuple[str, ...]:
235
  if not os.path.isdir(save_directory):
236
  logger.error(f"Vocabulary path ({save_directory}) should be a directory")
237
+ return ()
238
 
239
  src_spm_fp = os.path.join(save_directory, "model.SRC")
240
  tgt_spm_fp = os.path.join(save_directory, "model.TGT")
241
  src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
242
  tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
243
 
244
+ self._save_json(self.src_encoder, src_vocab_fp)
245
+ self._save_json(self.tgt_encoder, tgt_vocab_fp)
 
 
 
246
 
247
+ for fp, spm in [(src_spm_fp, self.src_spm), (tgt_spm_fp, self.tgt_spm)]:
248
+ with open(fp, "wb") as f:
249
+ f.write(spm.serialized_model_proto())
250
 
251
+ return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp