x54-729 commited on
Commit
e0d5d9e
·
1 Parent(s): b724d24

replace get_max_length

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +3 -3
modeling_internlm2.py CHANGED
@@ -1081,7 +1081,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
1081
  min_dtype = torch.finfo(dtype).min
1082
  sequence_length = input_tensor.shape[1]
1083
  if using_static_cache:
1084
- target_length = past_key_values.get_max_length()
1085
  else:
1086
  target_length = (
1087
  attention_mask.shape[-1]
@@ -1274,8 +1274,8 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1274
  if isinstance(past_key_values, Cache):
1275
  past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1276
  max_cache_length = (
1277
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1278
- if past_key_values.get_max_length() is not None
1279
  else None
1280
  )
1281
  cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
 
1081
  min_dtype = torch.finfo(dtype).min
1082
  sequence_length = input_tensor.shape[1]
1083
  if using_static_cache:
1084
+ target_length = past_key_values.get_max_cache_shape()
1085
  else:
1086
  target_length = (
1087
  attention_mask.shape[-1]
 
1274
  if isinstance(past_key_values, Cache):
1275
  past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1276
  max_cache_length = (
1277
+ torch.tensor(past_key_values.get_max_cache_shape(), device=input_ids.device)
1278
+ if past_key_values.get_max_cache_shape() is not None
1279
  else None
1280
  )
1281
  cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)