katuni4ka commited on
Commit
9e21dac
1 Parent(s): 91a0561

compatibility with new transformers

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +16 -3
modeling_chatglm.py CHANGED
@@ -45,6 +45,9 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
  # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
46
  ]
47
 
 
 
 
48
 
49
  def default_init(cls, *args, **kwargs):
50
  return cls(*args, **kwargs)
@@ -872,9 +875,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
872
  standardize_cache_format: bool = False,
873
  ) -> Dict[str, Any]:
874
  # update past_key_values
875
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
876
- outputs, standardize_cache_format=standardize_cache_format
877
- )
 
 
 
 
 
 
 
 
 
 
878
 
879
  # update attention mask
880
  if "attention_mask" in model_kwargs:
 
45
  # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
46
  ]
47
 
48
+ is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
49
+ is_transformers_4_44_or_higher = int(transformers.__version__.split(".")[1]) >= 44
50
+
51
 
52
  def default_init(cls, *args, **kwargs):
53
  return cls(*args, **kwargs)
 
875
  standardize_cache_format: bool = False,
876
  ) -> Dict[str, Any]:
877
  # update past_key_values
878
+ if is_transformers_4_44_or_higher:
879
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
880
+ outputs
881
+ )[1]
882
+ elif is_transformers_4_42_or_higher:
883
+ # update past_key_values
884
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
885
+ outputs, standardize_cache_format=standardize_cache_format
886
+ )[1]
887
+ else:
888
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
889
+ outputs, standardize_cache_format=standardize_cache_format
890
+ )
891
 
892
  # update attention mask
893
  if "attention_mask" in model_kwargs: