p1atdev commited on
Commit
8d4a5ed
·
verified ·
1 Parent(s): ab3a2b1

Upload tokenizer

Browse files
special_tokens_map.json CHANGED
@@ -1,6 +1,30 @@
1
  {
2
- "bos_token": "<|bos|>",
3
- "eos_token": "<|eos|>",
4
- "pad_token": "<|pad|>",
5
- "unk_token": "<|unknown|>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  }
 
1
  {
2
+ "bos_token": {
3
+ "content": "<|bos|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|eos|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|unknown|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
  }
tokenization_dart.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import json
4
+ from typing import Optional, Dict, List, Set, Tuple, Union, Literal, Type
5
+ from pydantic.dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from transformers import PreTrainedTokenizerFast
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ VOCAB_FILES_NAMES = {
15
+ "tag_category": "tag_category.json",
16
+ }
17
+
18
+
19
+ @dataclass
20
+ class Category:
21
+ name: str
22
+ max_count: Optional[int]
23
+ next_category: List[int]
24
+ can_end: bool
25
+ bos_token_id: int
26
+ eos_token_id: int
27
+ default_mask: int
28
+
29
+
30
+ @dataclass
31
+ class SpecialMapping:
32
+ allow: List[int]
33
+ disallow: List[int]
34
+
35
+
36
+ @dataclass
37
+ class TagCategoryConfig:
38
+ start_category: int
39
+ categories: Dict[str, Category]
40
+ special_mapping: Dict[
41
+ str, Dict[str, SpecialMapping]
42
+ ] # {token_id: { category_id: SpecialMapping }}
43
+ category_tags_pairs: Dict[str, List[int]]
44
+
45
+
46
+ class OverrideMask:
47
+ allow: np.ndarray
48
+ disallow: np.ndarray
49
+
50
+ def __init__(self, allow: np.ndarray, disallow: np.ndarray) -> None:
51
+ self.allow = allow
52
+ self.disallow = disallow
53
+
54
+
55
+ def load_tag_category(config_json: str):
56
+ with open(config_json, "rb") as file:
57
+ config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))
58
+
59
+ return config
60
+
61
+
62
+ class DartTokenizer(PreTrainedTokenizerFast):
63
+ """Dart tokenizer"""
64
+
65
+ vocab_files_names = VOCAB_FILES_NAMES
66
+
67
+ def __init__(self, tag_category, **kwargs):
68
+ super().__init__(**kwargs)
69
+
70
+ self.tag_category_config = load_tag_category(tag_category)
71
+
72
+ self.category_bos_map = {
73
+ category.bos_token_id: category_id
74
+ for category_id, category in self.tag_category_config.categories.items()
75
+ }
76
+ self.category_eos_map = {
77
+ category.eos_token_id: category_id
78
+ for category_id, category in self.tag_category_config.categories.items()
79
+ }
80
+
81
+ self._id_to_category_map = np.zeros(self.vocab_size).astype("uint8")
82
+ for category_id, tokens in self.tag_category_config.category_tags_pairs.items():
83
+ self._id_to_category_map[tokens] = int(category_id)
84
+
85
+ self.category_mask = self.create_category_vocab_mask()
86
+
87
+ def create_vocab_mask(self, value: int = 1):
88
+ """Create an array of vocab size filled with specified value"""
89
+ return np.full(self.vocab_size, value).astype("uint8")
90
+
91
+ def create_category_vocab_mask(self):
92
+ """Create vocab masks for each category"""
93
+ return {
94
+ category_id: self.create_vocab_mask(
95
+ value=category.default_mask,
96
+ )
97
+ for category_id, category in self.tag_category_config.categories.items()
98
+ }
99
+
100
+ def get_token_ids_in_category(self, category_id: Union[int, str]):
101
+ """Get token ids in the specified category"""
102
+ return self.tag_category_config.category_tags_pairs[str(category_id)]
103
+
104
+ def get_category(self, category_id: Union[int, str]):
105
+ """Get the specified category config"""
106
+ return self.tag_category_config.categories[str(category_id)]
107
+
108
+ def get_special_mapping(self, token_id: Union[int, str]):
109
+ """Get the special mapping of specified token id"""
110
+ return self.tag_category_config.special_mapping[str(token_id)]
111
+
112
+ def get_banned_tokens_mask(self, tokens: Union[str, List[str], int, List[int]]):
113
+ if isinstance(tokens, str):
114
+ tokens = [tokens]
115
+ elif isinstance(tokens, int):
116
+ tokens = [tokens]
117
+ elif isinstance(tokens, list):
118
+ tokens = [
119
+ self.convert_tokens_to_ids(token) if isinstance(token, str) else token
120
+ for token in tokens
121
+ ]
122
+
123
+ assert isinstance(tokens, list) and all(
124
+ [isinstance(token, int) for token in tokens]
125
+ )
126
+
127
+ mask = self.create_vocab_mask(value=1)
128
+ mask[tokens] = 0
129
+
130
+ return mask
131
+
132
+ def convert_ids_to_category_ids(self, token_ids: Union[int, List[int]]):
133
+ return self._id_to_category_map[token_ids]
134
+
135
+ def get_next_tokens_mask(
136
+ self,
137
+ input_ids: List[int],
138
+ category_mask: Optional[Dict[str, np.ndarray]] = None,
139
+ ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
140
+ """Get the next token's vocab mask to be used"""
141
+
142
+ if category_mask == None:
143
+ category_mask = self.category_mask
144
+
145
+ vocab_mask = self.create_vocab_mask(value=0)
146
+
147
+ if len(input_ids) == 0:
148
+ # only allow bos token
149
+ vocab_mask[self.bos_token_id] = 1
150
+
151
+ return vocab_mask, category_mask
152
+
153
+ # the last token's id in the input ids
154
+ last_token_id = input_ids[-1]
155
+
156
+ if last_token_id == self.unk_token_id:
157
+ # unknown token
158
+ logger.warning(
159
+ "The unk_token was provided! The vocab mask could not be created properly."
160
+ )
161
+ return self.create_vocab_mask(value=1), category_mask
162
+
163
+ # if the last token has a special mapping
164
+ if str(last_token_id) in self.tag_category_config.special_mapping.keys():
165
+ for category_id, mapping in self.get_special_mapping(last_token_id).items():
166
+ # update mask
167
+ category_mask[category_id][mapping.allow] = 1
168
+ category_mask[category_id][mapping.disallow] = 0
169
+
170
+ if last_token_id == self.bos_token_id:
171
+ # the first category
172
+ start_category_id = self.tag_category_config.start_category
173
+ start_category = self.get_category(start_category_id)
174
+
175
+ # only allow the next category's bos token
176
+ vocab_mask[start_category.bos_token_id] = 1
177
+
178
+ return vocab_mask, category_mask
179
+
180
+ elif last_token_id == self.eos_token_id:
181
+ # end of text. only allows pad token
182
+
183
+ vocab_mask[self.pad_token_id] = 1
184
+
185
+ return vocab_mask, category_mask
186
+
187
+ elif last_token_id in self.category_bos_map:
188
+ # begin of category
189
+
190
+ # only allow same category's tags
191
+ current_category_id = self.category_bos_map[last_token_id]
192
+ category = self.get_category(current_category_id)
193
+
194
+ tokens_in_category = self.get_token_ids_in_category(current_category_id)
195
+ vocab_mask[tokens_in_category] = 1
196
+
197
+ vocab_mask *= category_mask[str(current_category_id)]
198
+ vocab_mask[category.eos_token_id] = 1
199
+
200
+ return vocab_mask, category_mask # current category's mask
201
+
202
+ elif last_token_id in self.category_eos_map:
203
+ # boundary of categories
204
+
205
+ current_category_id = self.category_eos_map[last_token_id]
206
+ category = self.get_category(current_category_id)
207
+
208
+ if category.can_end:
209
+ # this category can finish generation
210
+ vocab_mask[self.eos_token_id] = 1
211
+
212
+ for next_category_id in category.next_category:
213
+ # allow the next category's bos token
214
+ vocab_mask[self.get_category(next_category_id).bos_token_id] = 1
215
+
216
+ return vocab_mask, category_mask
217
+
218
+ else:
219
+ # inside each category
220
+ current_category_id = self.convert_ids_to_category_ids(last_token_id).item()
221
+ tokens_in_category = self.get_token_ids_in_category(current_category_id)
222
+
223
+ vocab_mask[tokens_in_category] = 1
224
+ vocab_mask[self.get_category(current_category_id).eos_token_id] = 1
225
+ vocab_mask *= category_mask[str(current_category_id)]
226
+ vocab_mask[input_ids] = 0 # do not reuse used tokens
227
+
228
+ return vocab_mask, category_mask
tokenizer_config.json CHANGED
@@ -353,6 +353,12 @@
353
  "special": true
354
  }
355
  },
 
 
 
 
 
 
356
  "bos_token": "<|bos|>",
357
  "clean_up_tokenization_spaces": true,
358
  "eos_token": "<|eos|>",
@@ -362,6 +368,6 @@
362
  "pad_token": "<|pad|>",
363
  "pad_token_type_id": 0,
364
  "padding_side": "right",
365
- "tokenizer_class": "PreTrainedTokenizerFast",
366
  "unk_token": "<|unknown|>"
367
  }
 
353
  "special": true
354
  }
355
  },
356
+ "auto_map": {
357
+ "AutoTokenizer": [
358
+ "tokenization_dart.DartTokenizer",
359
+ null
360
+ ]
361
+ },
362
  "bos_token": "<|bos|>",
363
  "clean_up_tokenization_spaces": true,
364
  "eos_token": "<|eos|>",
 
368
  "pad_token": "<|pad|>",
369
  "pad_token_type_id": 0,
370
  "padding_side": "right",
371
+ "tokenizer_class": "DartTokenizer",
372
  "unk_token": "<|unknown|>"
373
  }