duzx16 commited on
Commit
8049563
·
1 Parent(s): 71189e7

Fix prefix projection

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +7 -6
modeling_chatglm.py CHANGED
@@ -68,11 +68,12 @@ class PrefixEncoder(torch.nn.Module):
68
  self.prefix_projection = config.prefix_projection
69
  if self.prefix_projection:
70
  # Use a two-layer MLP to encode the prefix
71
- self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
 
72
  self.trans = torch.nn.Sequential(
73
- torch.nn.Linear(config.hidden_size, config.hidden_size),
74
  torch.nn.Tanh(),
75
- torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
76
  )
77
  else:
78
  self.embedding = torch.nn.Embedding(config.pre_seq_len,
@@ -1013,7 +1014,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1013
  inputs = inputs.to(self.device)
1014
  return inputs
1015
 
1016
- @torch.no_grad()
1017
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
1018
  do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
1019
  if history is None:
@@ -1031,7 +1032,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1031
  history = history + [(query, response)]
1032
  return response, history
1033
 
1034
- @torch.no_grad()
1035
  def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
1036
  max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1037
  return_past_key_values=False, **kwargs):
@@ -1068,7 +1069,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1068
  else:
1069
  yield response, new_history
1070
 
1071
- @torch.no_grad()
1072
  def stream_generate(
1073
  self,
1074
  input_ids,
 
68
  self.prefix_projection = config.prefix_projection
69
  if self.prefix_projection:
70
  # Use a two-layer MLP to encode the prefix
71
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
72
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
73
  self.trans = torch.nn.Sequential(
74
+ torch.nn.Linear(kv_size, config.hidden_size),
75
  torch.nn.Tanh(),
76
+ torch.nn.Linear(config.hidden_size, kv_size)
77
  )
78
  else:
79
  self.embedding = torch.nn.Embedding(config.pre_seq_len,
 
1014
  inputs = inputs.to(self.device)
1015
  return inputs
1016
 
1017
+ @torch.inference_mode()
1018
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
1019
  do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
1020
  if history is None:
 
1032
  history = history + [(query, response)]
1033
  return response, history
1034
 
1035
+ @torch.inference_mode()
1036
  def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
1037
  max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1038
  return_past_key_values=False, **kwargs):
 
1069
  else:
1070
  yield response, new_history
1071
 
1072
+ @torch.inference_mode()
1073
  def stream_generate(
1074
  self,
1075
  input_ids,