duzx16
commited on
Commit
•
eb3e683
1
Parent(s):
cdb65fd
Add prefix prompt
Browse files- config.json +1 -1
- configuration_chatglm.py +5 -1
- modeling_chatglm.py +100 -19
- tokenization_chatglm.py +21 -3
config.json
CHANGED
@@ -37,5 +37,5 @@
|
|
37 |
"transformers_version": "4.27.1",
|
38 |
"tie_word_embeddings": false,
|
39 |
"eos_token_id": 2,
|
40 |
-
"pad_token_id":
|
41 |
}
|
|
|
37 |
"transformers_version": "4.27.1",
|
38 |
"tie_word_embeddings": false,
|
39 |
"eos_token_id": 2,
|
40 |
+
"pad_token_id": 0
|
41 |
}
|
configuration_chatglm.py
CHANGED
@@ -20,7 +20,6 @@ class ChatGLMConfig(PretrainedConfig):
|
|
20 |
post_layer_norm=True,
|
21 |
add_bias_linear=False,
|
22 |
add_qkv_bias=False,
|
23 |
-
interleaved_qkv=False,
|
24 |
bias_dropout_fusion=True,
|
25 |
multi_query_attention=False,
|
26 |
multi_query_group_num=1,
|
@@ -28,9 +27,12 @@ class ChatGLMConfig(PretrainedConfig):
|
|
28 |
attention_softmax_in_fp32=True,
|
29 |
fp32_residual_connection=False,
|
30 |
quantization_bit=0,
|
|
|
|
|
31 |
**kwargs
|
32 |
):
|
33 |
self.num_layers = num_layers
|
|
|
34 |
self.padded_vocab_size = padded_vocab_size
|
35 |
self.hidden_size = hidden_size
|
36 |
self.ffn_hidden_size = ffn_hidden_size
|
@@ -52,4 +54,6 @@ class ChatGLMConfig(PretrainedConfig):
|
|
52 |
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
53 |
self.fp32_residual_connection = fp32_residual_connection
|
54 |
self.quantization_bit = quantization_bit
|
|
|
|
|
55 |
super().__init__(**kwargs)
|
|
|
20 |
post_layer_norm=True,
|
21 |
add_bias_linear=False,
|
22 |
add_qkv_bias=False,
|
|
|
23 |
bias_dropout_fusion=True,
|
24 |
multi_query_attention=False,
|
25 |
multi_query_group_num=1,
|
|
|
27 |
attention_softmax_in_fp32=True,
|
28 |
fp32_residual_connection=False,
|
29 |
quantization_bit=0,
|
30 |
+
pre_seq_len=None,
|
31 |
+
prefix_projection=False,
|
32 |
**kwargs
|
33 |
):
|
34 |
self.num_layers = num_layers
|
35 |
+
self.vocab_size = padded_vocab_size
|
36 |
self.padded_vocab_size = padded_vocab_size
|
37 |
self.hidden_size = hidden_size
|
38 |
self.ffn_hidden_size = ffn_hidden_size
|
|
|
54 |
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
55 |
self.fp32_residual_connection = fp32_residual_connection
|
56 |
self.quantization_bit = quantization_bit
|
57 |
+
self.pre_seq_len = pre_seq_len
|
58 |
+
self.prefix_projection = prefix_projection
|
59 |
super().__init__(**kwargs)
|
modeling_chatglm.py
CHANGED
@@ -56,6 +56,37 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
56 |
return scores
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def split_tensor_along_last_dim(
|
60 |
tensor: torch.Tensor,
|
61 |
num_partitions: int,
|
@@ -375,11 +406,11 @@ class SelfAttention(torch.nn.Module):
|
|
375 |
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
376 |
|
377 |
# adjust key and value for inference
|
|
|
|
|
|
|
|
|
378 |
if use_cache:
|
379 |
-
if kv_cache is not None:
|
380 |
-
cache_k, cache_v = kv_cache
|
381 |
-
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
382 |
-
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
383 |
kv_cache = (key_layer, value_layer)
|
384 |
else:
|
385 |
kv_cache = None
|
@@ -566,6 +597,8 @@ class GLMTransformer(torch.nn.Module):
|
|
566 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
567 |
dtype=config.torch_dtype)
|
568 |
|
|
|
|
|
569 |
def _get_layer(self, layer_number):
|
570 |
return self.layers[layer_number]
|
571 |
|
@@ -577,6 +610,13 @@ class GLMTransformer(torch.nn.Module):
|
|
577 |
if not kv_caches:
|
578 |
kv_caches = [None for _ in range(self.num_layers)]
|
579 |
presents = () if use_cache else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
all_self_attentions = None
|
581 |
all_hidden_states = () if output_hidden_states else None
|
582 |
for index in range(self.num_layers):
|
@@ -584,14 +624,24 @@ class GLMTransformer(torch.nn.Module):
|
|
584 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
585 |
|
586 |
layer = self._get_layer(index)
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
if use_cache:
|
596 |
presents = presents + (kv_cache,)
|
597 |
|
@@ -645,7 +695,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
645 |
return position_ids
|
646 |
|
647 |
def _set_gradient_checkpointing(self, module, value=False):
|
648 |
-
if isinstance(module,
|
649 |
module.gradient_checkpointing = value
|
650 |
|
651 |
|
@@ -688,6 +738,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
688 |
if device is not None:
|
689 |
init_kwargs["device"] = device
|
690 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
|
|
|
|
|
|
691 |
|
692 |
# Rotary positional embeddings
|
693 |
self.seq_length = config.seq_length
|
@@ -700,11 +753,33 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
700 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
701 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
702 |
dtype=config.torch_dtype, **init_kwargs)
|
703 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
704 |
|
705 |
def get_input_embeddings(self):
|
706 |
return self.embedding.word_embeddings
|
707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
def forward(
|
709 |
self,
|
710 |
input_ids,
|
@@ -728,6 +803,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
728 |
if inputs_embeds is None:
|
729 |
inputs_embeds = self.embedding(input_ids)
|
730 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
731 |
if full_attention_mask is None:
|
732 |
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
733 |
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
@@ -913,10 +996,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
913 |
return response
|
914 |
|
915 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
916 |
-
prompt =
|
917 |
-
for i, (old_query, response) in enumerate(history):
|
918 |
-
prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
|
919 |
-
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
920 |
inputs = tokenizer([prompt], return_tensors="pt")
|
921 |
inputs = inputs.to(self.device)
|
922 |
return inputs
|
@@ -933,7 +1013,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
933 |
inputs = inputs.to(self.device)
|
934 |
return inputs
|
935 |
|
936 |
-
|
937 |
@torch.no_grad()
|
938 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
939 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
@@ -969,6 +1048,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
969 |
inputs = self.build_stream_inputs(tokenizer, query, history=history)
|
970 |
if past_key_values is not None:
|
971 |
past_length = past_key_values[0][0].shape[0]
|
|
|
|
|
972 |
inputs.position_ids += past_length
|
973 |
attention_mask = inputs.attention_mask
|
974 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
|
|
56 |
return scores
|
57 |
|
58 |
|
59 |
+
class PrefixEncoder(torch.nn.Module):
|
60 |
+
"""
|
61 |
+
The torch.nn model to encode the prefix
|
62 |
+
Input shape: (batch-size, prefix-length)
|
63 |
+
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, config: ChatGLMConfig):
|
67 |
+
super().__init__()
|
68 |
+
self.prefix_projection = config.prefix_projection
|
69 |
+
if self.prefix_projection:
|
70 |
+
# Use a two-layer MLP to encode the prefix
|
71 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
|
72 |
+
self.trans = torch.nn.Sequential(
|
73 |
+
torch.nn.Linear(config.hidden_size, config.hidden_size),
|
74 |
+
torch.nn.Tanh(),
|
75 |
+
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
79 |
+
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
80 |
+
|
81 |
+
def forward(self, prefix: torch.Tensor):
|
82 |
+
if self.prefix_projection:
|
83 |
+
prefix_tokens = self.embedding(prefix)
|
84 |
+
past_key_values = self.trans(prefix_tokens)
|
85 |
+
else:
|
86 |
+
past_key_values = self.embedding(prefix)
|
87 |
+
return past_key_values
|
88 |
+
|
89 |
+
|
90 |
def split_tensor_along_last_dim(
|
91 |
tensor: torch.Tensor,
|
92 |
num_partitions: int,
|
|
|
406 |
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
407 |
|
408 |
# adjust key and value for inference
|
409 |
+
if kv_cache is not None:
|
410 |
+
cache_k, cache_v = kv_cache
|
411 |
+
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
412 |
+
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
413 |
if use_cache:
|
|
|
|
|
|
|
|
|
414 |
kv_cache = (key_layer, value_layer)
|
415 |
else:
|
416 |
kv_cache = None
|
|
|
597 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
598 |
dtype=config.torch_dtype)
|
599 |
|
600 |
+
self.gradient_checkpointing = False
|
601 |
+
|
602 |
def _get_layer(self, layer_number):
|
603 |
return self.layers[layer_number]
|
604 |
|
|
|
610 |
if not kv_caches:
|
611 |
kv_caches = [None for _ in range(self.num_layers)]
|
612 |
presents = () if use_cache else None
|
613 |
+
if self.gradient_checkpointing and self.training:
|
614 |
+
if use_cache:
|
615 |
+
logger.warning_once(
|
616 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
617 |
+
)
|
618 |
+
use_cache = False
|
619 |
+
|
620 |
all_self_attentions = None
|
621 |
all_hidden_states = () if output_hidden_states else None
|
622 |
for index in range(self.num_layers):
|
|
|
624 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
625 |
|
626 |
layer = self._get_layer(index)
|
627 |
+
if self.gradient_checkpointing and self.training:
|
628 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
629 |
+
layer,
|
630 |
+
hidden_states,
|
631 |
+
attention_mask,
|
632 |
+
rotary_pos_emb,
|
633 |
+
kv_caches[index],
|
634 |
+
use_cache
|
635 |
+
)
|
636 |
+
else:
|
637 |
+
layer_ret = layer(
|
638 |
+
hidden_states,
|
639 |
+
attention_mask,
|
640 |
+
rotary_pos_emb,
|
641 |
+
kv_cache=kv_caches[index],
|
642 |
+
use_cache=use_cache
|
643 |
+
)
|
644 |
+
hidden_states, kv_cache = layer_ret
|
645 |
if use_cache:
|
646 |
presents = presents + (kv_cache,)
|
647 |
|
|
|
695 |
return position_ids
|
696 |
|
697 |
def _set_gradient_checkpointing(self, module, value=False):
|
698 |
+
if isinstance(module, GLMTransformer):
|
699 |
module.gradient_checkpointing = value
|
700 |
|
701 |
|
|
|
738 |
if device is not None:
|
739 |
init_kwargs["device"] = device
|
740 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
741 |
+
self.num_layers = config.num_layers
|
742 |
+
self.multi_query_group_num = config.multi_query_group_num
|
743 |
+
self.kv_channels = config.kv_channels
|
744 |
|
745 |
# Rotary positional embeddings
|
746 |
self.seq_length = config.seq_length
|
|
|
753 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
754 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
755 |
dtype=config.torch_dtype, **init_kwargs)
|
756 |
+
self.pre_seq_len = config.pre_seq_len
|
757 |
+
self.prefix_projection = config.prefix_projection
|
758 |
+
if self.pre_seq_len is not None:
|
759 |
+
for param in self.parameters():
|
760 |
+
param.requires_grad = False
|
761 |
+
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
762 |
+
self.prefix_encoder = PrefixEncoder(config)
|
763 |
+
self.dropout = torch.nn.Dropout(0.1)
|
764 |
|
765 |
def get_input_embeddings(self):
|
766 |
return self.embedding.word_embeddings
|
767 |
|
768 |
+
def get_prompt(self, batch_size, device, dtype=torch.half):
|
769 |
+
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
770 |
+
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
771 |
+
past_key_values = past_key_values.view(
|
772 |
+
batch_size,
|
773 |
+
self.pre_seq_len,
|
774 |
+
self.num_layers * 2,
|
775 |
+
self.multi_query_group_num,
|
776 |
+
self.kv_channels
|
777 |
+
)
|
778 |
+
# seq_len, b, nh, hidden_size
|
779 |
+
past_key_values = self.dropout(past_key_values)
|
780 |
+
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
781 |
+
return past_key_values
|
782 |
+
|
783 |
def forward(
|
784 |
self,
|
785 |
input_ids,
|
|
|
803 |
if inputs_embeds is None:
|
804 |
inputs_embeds = self.embedding(input_ids)
|
805 |
|
806 |
+
if self.pre_seq_len is not None:
|
807 |
+
if past_key_values is None:
|
808 |
+
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
809 |
+
dtype=inputs_embeds.dtype)
|
810 |
+
if attention_mask is not None:
|
811 |
+
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
812 |
+
attention_mask], dim=-1)
|
813 |
+
|
814 |
if full_attention_mask is None:
|
815 |
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
816 |
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
|
|
996 |
return response
|
997 |
|
998 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
999 |
+
prompt = tokenizer.build_prompt(query, history=history)
|
|
|
|
|
|
|
1000 |
inputs = tokenizer([prompt], return_tensors="pt")
|
1001 |
inputs = inputs.to(self.device)
|
1002 |
return inputs
|
|
|
1013 |
inputs = inputs.to(self.device)
|
1014 |
return inputs
|
1015 |
|
|
|
1016 |
@torch.no_grad()
|
1017 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
1018 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
|
|
1048 |
inputs = self.build_stream_inputs(tokenizer, query, history=history)
|
1049 |
if past_key_values is not None:
|
1050 |
past_length = past_key_values[0][0].shape[0]
|
1051 |
+
if self.transformer.pre_seq_len is not None:
|
1052 |
+
past_length -= self.transformer.pre_seq_len
|
1053 |
inputs.position_ids += past_length
|
1054 |
attention_mask = inputs.attention_mask
|
1055 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
tokenization_chatglm.py
CHANGED
@@ -17,7 +17,7 @@ class SPTokenizer:
|
|
17 |
self.n_words: int = self.sp_model.vocab_size()
|
18 |
self.bos_id: int = self.sp_model.bos_id()
|
19 |
self.eos_id: int = self.sp_model.eos_id()
|
20 |
-
self.pad_id: int = self.sp_model.
|
21 |
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
22 |
|
23 |
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
|
@@ -55,7 +55,7 @@ class SPTokenizer:
|
|
55 |
|
56 |
def convert_id_to_token(self, index):
|
57 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
58 |
-
if index in self.index_special_tokens:
|
59 |
return ""
|
60 |
return self.sp_model.IdToPiece(index)
|
61 |
|
@@ -69,6 +69,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
69 |
super().__init__(padding_side=padding_side, **kwargs)
|
70 |
self.name = "GLMTokenizer"
|
71 |
|
|
|
72 |
self.tokenizer = SPTokenizer(vocab_file)
|
73 |
self.special_tokens = {
|
74 |
"<bos>": self.tokenizer.bos_id,
|
@@ -84,12 +85,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
84 |
|
85 |
@property
|
86 |
def pad_token(self) -> str:
|
87 |
-
return "
|
88 |
|
89 |
@property
|
90 |
def pad_token_id(self):
|
91 |
return self.get_command("<pad>")
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
@property
|
94 |
def vocab_size(self):
|
95 |
return self.tokenizer.n_words
|
@@ -146,6 +155,15 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
146 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
147 |
return prefix_tokens
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
def build_inputs_with_special_tokens(
|
150 |
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
151 |
) -> List[int]:
|
|
|
17 |
self.n_words: int = self.sp_model.vocab_size()
|
18 |
self.bos_id: int = self.sp_model.bos_id()
|
19 |
self.eos_id: int = self.sp_model.eos_id()
|
20 |
+
self.pad_id: int = self.sp_model.unk_id()
|
21 |
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
22 |
|
23 |
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
|
|
|
55 |
|
56 |
def convert_id_to_token(self, index):
|
57 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
58 |
+
if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
59 |
return ""
|
60 |
return self.sp_model.IdToPiece(index)
|
61 |
|
|
|
69 |
super().__init__(padding_side=padding_side, **kwargs)
|
70 |
self.name = "GLMTokenizer"
|
71 |
|
72 |
+
self.vocab_file = vocab_file
|
73 |
self.tokenizer = SPTokenizer(vocab_file)
|
74 |
self.special_tokens = {
|
75 |
"<bos>": self.tokenizer.bos_id,
|
|
|
85 |
|
86 |
@property
|
87 |
def pad_token(self) -> str:
|
88 |
+
return "<unk>"
|
89 |
|
90 |
@property
|
91 |
def pad_token_id(self):
|
92 |
return self.get_command("<pad>")
|
93 |
|
94 |
+
@property
|
95 |
+
def eos_token(self) -> str:
|
96 |
+
return "</s>"
|
97 |
+
|
98 |
+
@property
|
99 |
+
def eos_token_id(self):
|
100 |
+
return self.get_command("<eos>")
|
101 |
+
|
102 |
@property
|
103 |
def vocab_size(self):
|
104 |
return self.tokenizer.n_words
|
|
|
155 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
156 |
return prefix_tokens
|
157 |
|
158 |
+
def build_prompt(self, query, history=None):
|
159 |
+
if history is None:
|
160 |
+
history = []
|
161 |
+
prompt = ""
|
162 |
+
for i, (old_query, response) in enumerate(history):
|
163 |
+
prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
|
164 |
+
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
165 |
+
return prompt
|
166 |
+
|
167 |
def build_inputs_with_special_tokens(
|
168 |
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
169 |
) -> List[int]:
|