duzx16 commited on
Commit
eb3e683
1 Parent(s): cdb65fd

Add prefix prompt

Browse files
config.json CHANGED
@@ -37,5 +37,5 @@
37
  "transformers_version": "4.27.1",
38
  "tie_word_embeddings": false,
39
  "eos_token_id": 2,
40
- "pad_token_id": 2
41
  }
 
37
  "transformers_version": "4.27.1",
38
  "tie_word_embeddings": false,
39
  "eos_token_id": 2,
40
+ "pad_token_id": 0
41
  }
configuration_chatglm.py CHANGED
@@ -20,7 +20,6 @@ class ChatGLMConfig(PretrainedConfig):
20
  post_layer_norm=True,
21
  add_bias_linear=False,
22
  add_qkv_bias=False,
23
- interleaved_qkv=False,
24
  bias_dropout_fusion=True,
25
  multi_query_attention=False,
26
  multi_query_group_num=1,
@@ -28,9 +27,12 @@ class ChatGLMConfig(PretrainedConfig):
28
  attention_softmax_in_fp32=True,
29
  fp32_residual_connection=False,
30
  quantization_bit=0,
 
 
31
  **kwargs
32
  ):
33
  self.num_layers = num_layers
 
34
  self.padded_vocab_size = padded_vocab_size
35
  self.hidden_size = hidden_size
36
  self.ffn_hidden_size = ffn_hidden_size
@@ -52,4 +54,6 @@ class ChatGLMConfig(PretrainedConfig):
52
  self.attention_softmax_in_fp32 = attention_softmax_in_fp32
53
  self.fp32_residual_connection = fp32_residual_connection
54
  self.quantization_bit = quantization_bit
 
 
55
  super().__init__(**kwargs)
 
20
  post_layer_norm=True,
21
  add_bias_linear=False,
22
  add_qkv_bias=False,
 
23
  bias_dropout_fusion=True,
24
  multi_query_attention=False,
25
  multi_query_group_num=1,
 
27
  attention_softmax_in_fp32=True,
28
  fp32_residual_connection=False,
29
  quantization_bit=0,
30
+ pre_seq_len=None,
31
+ prefix_projection=False,
32
  **kwargs
33
  ):
34
  self.num_layers = num_layers
35
+ self.vocab_size = padded_vocab_size
36
  self.padded_vocab_size = padded_vocab_size
37
  self.hidden_size = hidden_size
38
  self.ffn_hidden_size = ffn_hidden_size
 
54
  self.attention_softmax_in_fp32 = attention_softmax_in_fp32
55
  self.fp32_residual_connection = fp32_residual_connection
56
  self.quantization_bit = quantization_bit
57
+ self.pre_seq_len = pre_seq_len
58
+ self.prefix_projection = prefix_projection
59
  super().__init__(**kwargs)
modeling_chatglm.py CHANGED
@@ -56,6 +56,37 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
56
  return scores
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def split_tensor_along_last_dim(
60
  tensor: torch.Tensor,
61
  num_partitions: int,
@@ -375,11 +406,11 @@ class SelfAttention(torch.nn.Module):
375
  key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
376
 
377
  # adjust key and value for inference
 
 
 
 
378
  if use_cache:
379
- if kv_cache is not None:
380
- cache_k, cache_v = kv_cache
381
- key_layer = torch.cat((cache_k, key_layer), dim=0)
382
- value_layer = torch.cat((cache_v, value_layer), dim=0)
383
  kv_cache = (key_layer, value_layer)
384
  else:
385
  kv_cache = None
@@ -566,6 +597,8 @@ class GLMTransformer(torch.nn.Module):
566
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
567
  dtype=config.torch_dtype)
568
 
 
 
569
  def _get_layer(self, layer_number):
570
  return self.layers[layer_number]
571
 
@@ -577,6 +610,13 @@ class GLMTransformer(torch.nn.Module):
577
  if not kv_caches:
578
  kv_caches = [None for _ in range(self.num_layers)]
579
  presents = () if use_cache else None
 
 
 
 
 
 
 
580
  all_self_attentions = None
581
  all_hidden_states = () if output_hidden_states else None
582
  for index in range(self.num_layers):
@@ -584,14 +624,24 @@ class GLMTransformer(torch.nn.Module):
584
  all_hidden_states = all_hidden_states + (hidden_states,)
585
 
586
  layer = self._get_layer(index)
587
-
588
- hidden_states, kv_cache = layer(
589
- hidden_states,
590
- attention_mask,
591
- rotary_pos_emb,
592
- kv_cache=kv_caches[index],
593
- use_cache=use_cache
594
- )
 
 
 
 
 
 
 
 
 
 
595
  if use_cache:
596
  presents = presents + (kv_cache,)
597
 
@@ -645,7 +695,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
645
  return position_ids
646
 
647
  def _set_gradient_checkpointing(self, module, value=False):
648
- if isinstance(module, ChatGLMModel):
649
  module.gradient_checkpointing = value
650
 
651
 
@@ -688,6 +738,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
688
  if device is not None:
689
  init_kwargs["device"] = device
690
  self.embedding = init_method(Embedding, config, **init_kwargs)
 
 
 
691
 
692
  # Rotary positional embeddings
693
  self.seq_length = config.seq_length
@@ -700,11 +753,33 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
700
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
701
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
702
  dtype=config.torch_dtype, **init_kwargs)
703
- self.gradient_checkpointing = False
 
 
 
 
 
 
 
704
 
705
  def get_input_embeddings(self):
706
  return self.embedding.word_embeddings
707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  def forward(
709
  self,
710
  input_ids,
@@ -728,6 +803,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
728
  if inputs_embeds is None:
729
  inputs_embeds = self.embedding(input_ids)
730
 
 
 
 
 
 
 
 
 
731
  if full_attention_mask is None:
732
  if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
733
  full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
@@ -913,10 +996,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
913
  return response
914
 
915
  def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
916
- prompt = ""
917
- for i, (old_query, response) in enumerate(history):
918
- prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
919
- prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
920
  inputs = tokenizer([prompt], return_tensors="pt")
921
  inputs = inputs.to(self.device)
922
  return inputs
@@ -933,7 +1013,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
933
  inputs = inputs.to(self.device)
934
  return inputs
935
 
936
-
937
  @torch.no_grad()
938
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
939
  do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
@@ -969,6 +1048,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
969
  inputs = self.build_stream_inputs(tokenizer, query, history=history)
970
  if past_key_values is not None:
971
  past_length = past_key_values[0][0].shape[0]
 
 
972
  inputs.position_ids += past_length
973
  attention_mask = inputs.attention_mask
974
  attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
 
56
  return scores
57
 
58
 
59
+ class PrefixEncoder(torch.nn.Module):
60
+ """
61
+ The torch.nn model to encode the prefix
62
+ Input shape: (batch-size, prefix-length)
63
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
64
+ """
65
+
66
+ def __init__(self, config: ChatGLMConfig):
67
+ super().__init__()
68
+ self.prefix_projection = config.prefix_projection
69
+ if self.prefix_projection:
70
+ # Use a two-layer MLP to encode the prefix
71
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
72
+ self.trans = torch.nn.Sequential(
73
+ torch.nn.Linear(config.hidden_size, config.hidden_size),
74
+ torch.nn.Tanh(),
75
+ torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
76
+ )
77
+ else:
78
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
79
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
80
+
81
+ def forward(self, prefix: torch.Tensor):
82
+ if self.prefix_projection:
83
+ prefix_tokens = self.embedding(prefix)
84
+ past_key_values = self.trans(prefix_tokens)
85
+ else:
86
+ past_key_values = self.embedding(prefix)
87
+ return past_key_values
88
+
89
+
90
  def split_tensor_along_last_dim(
91
  tensor: torch.Tensor,
92
  num_partitions: int,
 
406
  key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
407
 
408
  # adjust key and value for inference
409
+ if kv_cache is not None:
410
+ cache_k, cache_v = kv_cache
411
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
412
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
413
  if use_cache:
 
 
 
 
414
  kv_cache = (key_layer, value_layer)
415
  else:
416
  kv_cache = None
 
597
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
598
  dtype=config.torch_dtype)
599
 
600
+ self.gradient_checkpointing = False
601
+
602
  def _get_layer(self, layer_number):
603
  return self.layers[layer_number]
604
 
 
610
  if not kv_caches:
611
  kv_caches = [None for _ in range(self.num_layers)]
612
  presents = () if use_cache else None
613
+ if self.gradient_checkpointing and self.training:
614
+ if use_cache:
615
+ logger.warning_once(
616
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
617
+ )
618
+ use_cache = False
619
+
620
  all_self_attentions = None
621
  all_hidden_states = () if output_hidden_states else None
622
  for index in range(self.num_layers):
 
624
  all_hidden_states = all_hidden_states + (hidden_states,)
625
 
626
  layer = self._get_layer(index)
627
+ if self.gradient_checkpointing and self.training:
628
+ layer_ret = torch.utils.checkpoint.checkpoint(
629
+ layer,
630
+ hidden_states,
631
+ attention_mask,
632
+ rotary_pos_emb,
633
+ kv_caches[index],
634
+ use_cache
635
+ )
636
+ else:
637
+ layer_ret = layer(
638
+ hidden_states,
639
+ attention_mask,
640
+ rotary_pos_emb,
641
+ kv_cache=kv_caches[index],
642
+ use_cache=use_cache
643
+ )
644
+ hidden_states, kv_cache = layer_ret
645
  if use_cache:
646
  presents = presents + (kv_cache,)
647
 
 
695
  return position_ids
696
 
697
  def _set_gradient_checkpointing(self, module, value=False):
698
+ if isinstance(module, GLMTransformer):
699
  module.gradient_checkpointing = value
700
 
701
 
 
738
  if device is not None:
739
  init_kwargs["device"] = device
740
  self.embedding = init_method(Embedding, config, **init_kwargs)
741
+ self.num_layers = config.num_layers
742
+ self.multi_query_group_num = config.multi_query_group_num
743
+ self.kv_channels = config.kv_channels
744
 
745
  # Rotary positional embeddings
746
  self.seq_length = config.seq_length
 
753
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
754
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
755
  dtype=config.torch_dtype, **init_kwargs)
756
+ self.pre_seq_len = config.pre_seq_len
757
+ self.prefix_projection = config.prefix_projection
758
+ if self.pre_seq_len is not None:
759
+ for param in self.parameters():
760
+ param.requires_grad = False
761
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
762
+ self.prefix_encoder = PrefixEncoder(config)
763
+ self.dropout = torch.nn.Dropout(0.1)
764
 
765
  def get_input_embeddings(self):
766
  return self.embedding.word_embeddings
767
 
768
+ def get_prompt(self, batch_size, device, dtype=torch.half):
769
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
770
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
771
+ past_key_values = past_key_values.view(
772
+ batch_size,
773
+ self.pre_seq_len,
774
+ self.num_layers * 2,
775
+ self.multi_query_group_num,
776
+ self.kv_channels
777
+ )
778
+ # seq_len, b, nh, hidden_size
779
+ past_key_values = self.dropout(past_key_values)
780
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
781
+ return past_key_values
782
+
783
  def forward(
784
  self,
785
  input_ids,
 
803
  if inputs_embeds is None:
804
  inputs_embeds = self.embedding(input_ids)
805
 
806
+ if self.pre_seq_len is not None:
807
+ if past_key_values is None:
808
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
809
+ dtype=inputs_embeds.dtype)
810
+ if attention_mask is not None:
811
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
812
+ attention_mask], dim=-1)
813
+
814
  if full_attention_mask is None:
815
  if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
816
  full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
 
996
  return response
997
 
998
  def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
999
+ prompt = tokenizer.build_prompt(query, history=history)
 
 
 
1000
  inputs = tokenizer([prompt], return_tensors="pt")
1001
  inputs = inputs.to(self.device)
1002
  return inputs
 
1013
  inputs = inputs.to(self.device)
1014
  return inputs
1015
 
 
1016
  @torch.no_grad()
1017
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
1018
  do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
 
1048
  inputs = self.build_stream_inputs(tokenizer, query, history=history)
1049
  if past_key_values is not None:
1050
  past_length = past_key_values[0][0].shape[0]
1051
+ if self.transformer.pre_seq_len is not None:
1052
+ past_length -= self.transformer.pre_seq_len
1053
  inputs.position_ids += past_length
1054
  attention_mask = inputs.attention_mask
1055
  attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
tokenization_chatglm.py CHANGED
@@ -17,7 +17,7 @@ class SPTokenizer:
17
  self.n_words: int = self.sp_model.vocab_size()
18
  self.bos_id: int = self.sp_model.bos_id()
19
  self.eos_id: int = self.sp_model.eos_id()
20
- self.pad_id: int = self.sp_model.eos_id()
21
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
22
 
23
  special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
@@ -55,7 +55,7 @@ class SPTokenizer:
55
 
56
  def convert_id_to_token(self, index):
57
  """Converts an index (integer) in a token (str) using the vocab."""
58
- if index in self.index_special_tokens:
59
  return ""
60
  return self.sp_model.IdToPiece(index)
61
 
@@ -69,6 +69,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
69
  super().__init__(padding_side=padding_side, **kwargs)
70
  self.name = "GLMTokenizer"
71
 
 
72
  self.tokenizer = SPTokenizer(vocab_file)
73
  self.special_tokens = {
74
  "<bos>": self.tokenizer.bos_id,
@@ -84,12 +85,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
84
 
85
  @property
86
  def pad_token(self) -> str:
87
- return "</s>"
88
 
89
  @property
90
  def pad_token_id(self):
91
  return self.get_command("<pad>")
92
 
 
 
 
 
 
 
 
 
93
  @property
94
  def vocab_size(self):
95
  return self.tokenizer.n_words
@@ -146,6 +155,15 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
146
  prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
147
  return prefix_tokens
148
 
 
 
 
 
 
 
 
 
 
149
  def build_inputs_with_special_tokens(
150
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
151
  ) -> List[int]:
 
17
  self.n_words: int = self.sp_model.vocab_size()
18
  self.bos_id: int = self.sp_model.bos_id()
19
  self.eos_id: int = self.sp_model.eos_id()
20
+ self.pad_id: int = self.sp_model.unk_id()
21
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
22
 
23
  special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
 
55
 
56
  def convert_id_to_token(self, index):
57
  """Converts an index (integer) in a token (str) using the vocab."""
58
+ if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
59
  return ""
60
  return self.sp_model.IdToPiece(index)
61
 
 
69
  super().__init__(padding_side=padding_side, **kwargs)
70
  self.name = "GLMTokenizer"
71
 
72
+ self.vocab_file = vocab_file
73
  self.tokenizer = SPTokenizer(vocab_file)
74
  self.special_tokens = {
75
  "<bos>": self.tokenizer.bos_id,
 
85
 
86
  @property
87
  def pad_token(self) -> str:
88
+ return "<unk>"
89
 
90
  @property
91
  def pad_token_id(self):
92
  return self.get_command("<pad>")
93
 
94
+ @property
95
+ def eos_token(self) -> str:
96
+ return "</s>"
97
+
98
+ @property
99
+ def eos_token_id(self):
100
+ return self.get_command("<eos>")
101
+
102
  @property
103
  def vocab_size(self):
104
  return self.tokenizer.n_words
 
155
  prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
156
  return prefix_tokens
157
 
158
+ def build_prompt(self, query, history=None):
159
+ if history is None:
160
+ history = []
161
+ prompt = ""
162
+ for i, (old_query, response) in enumerate(history):
163
+ prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
164
+ prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
165
+ return prompt
166
+
167
  def build_inputs_with_special_tokens(
168
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
169
  ) -> List[int]: