DRAFT: Add a fast tokenizer implementation and converter

#8
by chielo - opened
Files changed (2) hide show
  1. tokenization_chatglm.py +250 -10
  2. tokenizer_config.json +3 -3
tokenization_chatglm.py CHANGED
@@ -1,11 +1,37 @@
1
  import json
2
  import os
3
- import torch
4
- from typing import List, Optional, Union, Dict
 
5
  from sentencepiece import SentencePieceProcessor
6
- from transformers import PreTrainedTokenizer
7
- from transformers.utils import logging, PaddingStrategy
8
- from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class SPTokenizer:
@@ -21,11 +47,9 @@ class SPTokenizer:
21
  self.pad_id: int = self.sp_model.unk_id()
22
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
23
 
24
- special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop", "<|system|>", "<|user|>", "<|assistant|>",
25
- "<|observation|>"]
26
  self.special_tokens = {}
27
  self.index_special_tokens = {}
28
- for token in special_tokens:
29
  self.special_tokens[token] = self.n_words
30
  self.index_special_tokens[self.n_words] = token
31
  self.n_words += 1
@@ -171,8 +195,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
171
  return (vocab_file,)
172
 
173
  def get_prefix_tokens(self):
174
- prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
175
- return prefix_tokens
176
 
177
  def build_single_message(self, role, metadata, message):
178
  assert role in ["system", "user", "assistant", "observation"], role
@@ -281,3 +304,220 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
281
  encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
282
 
283
  return encoded_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import os
3
+ import warnings
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
  from sentencepiece import SentencePieceProcessor
7
+ from tokenizers import AddedToken, decoders, normalizers, processors
8
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
9
+ from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, SpmConverter
10
+ from transformers.tokenization_utils_base import (
11
+ BatchEncoding,
12
+ EncodedInput,
13
+ PreTokenizedInput,
14
+ PreTokenizedInputPair,
15
+ TextInput,
16
+ TextInputPair,
17
+ TruncationStrategy,
18
+ )
19
+ from transformers.utils import PaddingStrategy
20
+
21
+ ADDITIONAL_SPECIAL_TOKENS = [
22
+ "[MASK]",
23
+ "[gMASK]",
24
+ "[sMASK]",
25
+ "<!sop!>",
26
+ "<!eop!>",
27
+ "<|system|>",
28
+ "<|user|>",
29
+ "<|assistant|>",
30
+ "<|observation|>",
31
+ ]
32
+ PREFIX_TOKENS = ["[gMASK]", "<!sop!>"]
33
+
34
+ ENCODE_SEP_TOKEN_FOR_FAST = "<!encode-sep!>"
35
 
36
 
37
  class SPTokenizer:
 
47
  self.pad_id: int = self.sp_model.unk_id()
48
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
49
 
 
 
50
  self.special_tokens = {}
51
  self.index_special_tokens = {}
52
+ for token in ADDITIONAL_SPECIAL_TOKENS:
53
  self.special_tokens[token] = self.n_words
54
  self.index_special_tokens[self.n_words] = token
55
  self.n_words += 1
 
195
  return (vocab_file,)
196
 
197
  def get_prefix_tokens(self):
198
+ return list(map(self.get_command, PREFIX_TOKENS))
 
199
 
200
  def build_single_message(self, role, metadata, message):
201
  assert role in ["system", "user", "assistant", "observation"], role
 
304
  encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
305
 
306
  return encoded_inputs
307
+
308
+
309
+ class ChatGLMTokenizerFast(PreTrainedTokenizerFast):
310
+ # multiple breaking changes, no more backward-compatibility
311
+ slow_tokenizer_class = ChatGLMTokenizer
312
+ vocab_files_names = {
313
+ **ChatGLMTokenizer.vocab_files_names,
314
+ **PreTrainedTokenizerFast.vocab_files_names,
315
+ }
316
+
317
+ def __init__(self, **kwargs):
318
+ kwargs.setdefault("clean_up_tokenization_spaces", False)
319
+ kwargs.setdefault("bos_token", "<s>")
320
+ kwargs.setdefault("eos_token", "</s>")
321
+ kwargs.setdefault("unk_token", "<unk>")
322
+ kwargs.setdefault("pad_token", "<unk>")
323
+ super().__init__(**kwargs)
324
+
325
+ @property
326
+ def encode_sep_token(self):
327
+ return ENCODE_SEP_TOKEN_FOR_FAST
328
+
329
+ def _batch_encode_plus(
330
+ self,
331
+ batch_text_or_text_pairs: Union[
332
+ List[TextInput],
333
+ List[TextInputPair],
334
+ List[PreTokenizedInput],
335
+ List[PreTokenizedInputPair],
336
+ ],
337
+ add_special_tokens: bool = True,
338
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
339
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
340
+ max_length: Optional[int] = None,
341
+ stride: int = 0,
342
+ is_split_into_words: bool = False,
343
+ pad_to_multiple_of: Optional[int] = None,
344
+ return_tensors: Optional[str] = None,
345
+ return_token_type_ids: Optional[bool] = None,
346
+ return_attention_mask: Optional[bool] = None,
347
+ return_overflowing_tokens: bool = False,
348
+ return_special_tokens_mask: bool = False,
349
+ return_offsets_mapping: bool = False,
350
+ return_length: bool = False,
351
+ verbose: bool = True,
352
+ ) -> BatchEncoding:
353
+ def split_sep(t: Union[TextInput, PreTokenizedInput]) -> PreTokenizedInput:
354
+ if isinstance(t, str):
355
+ return t.split(self.encode_sep_token)
356
+
357
+ return [w for word in t for w in split_sep(word)]
358
+
359
+ def split_maybe_tupled(
360
+ t: Union[TextInput, TextInputPair, PreTokenizedInput, PreTokenizedInputPair]
361
+ ) -> Union[PreTokenizedInputPair, PreTokenizedInput]:
362
+ if isinstance(t, tuple):
363
+ return split_sep(t[0]), split_sep(t[1])
364
+
365
+ return split_sep(t)
366
+
367
+ return super()._batch_encode_plus(
368
+ list(map(split_maybe_tupled, batch_text_or_text_pairs)), # pyright: ignore
369
+ add_special_tokens,
370
+ padding_strategy,
371
+ truncation_strategy,
372
+ max_length,
373
+ stride,
374
+ True,
375
+ pad_to_multiple_of,
376
+ return_tensors,
377
+ return_token_type_ids,
378
+ return_attention_mask,
379
+ return_overflowing_tokens,
380
+ return_special_tokens_mask,
381
+ return_offsets_mapping,
382
+ return_length,
383
+ verbose,
384
+ )
385
+
386
+ @property
387
+ def can_save_slow_tokenizer(self) -> bool:
388
+ # multiple breaking changes
389
+ return False
390
+
391
+ def save_pretrained(
392
+ self,
393
+ save_directory: Union[str, os.PathLike],
394
+ legacy_format: Optional[bool] = None,
395
+ filename_prefix: Optional[str] = None,
396
+ push_to_hub: bool = False,
397
+ **kwargs,
398
+ ) -> Tuple[str]:
399
+ warnings.warn(
400
+ f"{type(self)} does not support saving slow tokenizer. "
401
+ "Saving it at the same directory may break the slow tokenizer. "
402
+ "Please keep a backup of the original tokenizer beforehand."
403
+ )
404
+ return super().save_pretrained(
405
+ save_directory, legacy_format, filename_prefix, push_to_hub, **kwargs
406
+ )
407
+
408
+ def build_single_message(self, role, metadata, message):
409
+ assert role in ["system", "user", "assistant", "observation"], role
410
+ return f"<|{role}|>{self.encode_sep_token}{metadata}\n{self.encode_sep_token}{message}"
411
+
412
+ def build_chat_text(self, query, history=None, role="user", metadata=""):
413
+ inputs = []
414
+
415
+ for item in history or []:
416
+ content = item["content"]
417
+
418
+ if item["role"] == "system" and "tools" in item:
419
+ content += "\n" + json.dumps(
420
+ item["tools"], indent=4, ensure_ascii=False
421
+ )
422
+
423
+ inputs.append(
424
+ self.build_single_message(
425
+ item["role"], item.get("metadata", ""), content
426
+ )
427
+ )
428
+
429
+ inputs.append(self.build_single_message(role, metadata, query))
430
+ inputs.append("<|assistant|>")
431
+
432
+ return "".join(inputs)
433
+
434
+ def build_chat_input(self, *args, **kwargs):
435
+ return self.batch_encode_plus(
436
+ [self.build_chat_text(*args, **kwargs)],
437
+ return_tensors="pt",
438
+ )
439
+
440
+
441
+ ChatGLMTokenizer.register_for_auto_class()
442
+ ChatGLMTokenizerFast.register_for_auto_class()
443
+
444
+
445
+ class ChatGLMTokenizerConverter(SpmConverter):
446
+ handle_byte_fallback = True
447
+
448
+ def normalizer(self, proto):
449
+ return normalizers.Sequence(
450
+ [
451
+ normalizers.Prepend(prepend="▁"),
452
+ normalizers.Replace(pattern=" ", content="▁"),
453
+ ]
454
+ )
455
+
456
+ def pre_tokenizer(self, replacement, add_prefix_space):
457
+ # don't use Metaspace, it will skip merging spaces into one token
458
+
459
+ # give up to split `encode_sep_token` here, buggy
460
+ # return pre_tokenizers.Split(ENCODE_SEP_TOKEN_FOR_FAST, "removed")
461
+
462
+ return None
463
+
464
+ def decoder(self, replacement, add_prefix_space):
465
+ return decoders.Sequence(
466
+ [
467
+ decoders.ByteFallback(),
468
+ super().decoder(replacement, add_prefix_space),
469
+ ]
470
+ )
471
+
472
+ def tokenizer(self, proto):
473
+ tokenizer = super().tokenizer(proto)
474
+
475
+ tokenizer.model.byte_fallback = True
476
+
477
+ special_tokens = [
478
+ "<unk>",
479
+ "<s>",
480
+ "</s>",
481
+ *ADDITIONAL_SPECIAL_TOKENS,
482
+ ]
483
+
484
+ tokenizer.add_special_tokens(
485
+ [
486
+ AddedToken(token, special=True, normalized=False)
487
+ for token in special_tokens
488
+ ]
489
+ )
490
+
491
+ return tokenizer
492
+
493
+ def converted(self):
494
+ tokenizer = super().converted()
495
+
496
+ # Post processors
497
+ prefix_token_ids = list(map(tokenizer.token_to_id, PREFIX_TOKENS))
498
+ assert all(i is not None for i in prefix_token_ids)
499
+ prefix_template = " ".join(PREFIX_TOKENS)
500
+
501
+ template_special_tokens = list(frozenset(zip(PREFIX_TOKENS, prefix_token_ids)))
502
+
503
+ if "</s>" not in PREFIX_TOKENS:
504
+ eos_token_id = tokenizer.token_to_id("</s>")
505
+ assert eos_token_id is not None
506
+ template_special_tokens.append(("</s>", eos_token_id))
507
+
508
+ post = processors.TemplateProcessing(
509
+ single=f"{prefix_template} $A",
510
+ pair=f"{prefix_template} $A $B:1 </s>:1",
511
+ special_tokens=template_special_tokens,
512
+ )
513
+ if tokenizer.post_processor is None:
514
+ tokenizer.post_processor = post
515
+ else:
516
+ tokenizer.post_processor = processors.Sequence(
517
+ [tokenizer.post_processor, post]
518
+ )
519
+
520
+ return tokenizer
521
+
522
+
523
+ SLOW_TO_FAST_CONVERTERS[ChatGLMTokenizer.__name__] = ChatGLMTokenizerConverter
tokenizer_config.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
- "name_or_path": "THUDM/chatglm2-6b",
3
  "remove_space": false,
4
  "do_lower_case": false,
5
  "tokenizer_class": "ChatGLMTokenizer",
6
  "auto_map": {
7
  "AutoTokenizer": [
8
  "tokenization_chatglm.ChatGLMTokenizer",
9
- null
10
- ]
11
  }
12
  }
 
1
  {
2
+ "name_or_path": "THUDM/chatglm3-6b",
3
  "remove_space": false,
4
  "do_lower_case": false,
5
  "tokenizer_class": "ChatGLMTokenizer",
6
  "auto_map": {
7
  "AutoTokenizer": [
8
  "tokenization_chatglm.ChatGLMTokenizer",
9
+ "tokenization_chatglm.ChatGLMTokenizerFast"
10
+ ]
11
  }
12
  }