Update modeling_kangaroo.py
Browse files- modeling_kangaroo.py +3 -3
modeling_kangaroo.py
CHANGED
@@ -1020,7 +1020,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
1020 |
min_dtype = torch.finfo(dtype).min
|
1021 |
sequence_length = input_tensor.shape[1]
|
1022 |
if using_static_cache:
|
1023 |
-
target_length = past_key_values.
|
1024 |
else:
|
1025 |
target_length = (
|
1026 |
attention_mask.shape[-1]
|
@@ -1308,8 +1308,8 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
|
|
1308 |
if isinstance(past_key_values, Cache):
|
1309 |
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
1310 |
max_cache_length = (
|
1311 |
-
torch.tensor(past_key_values.
|
1312 |
-
if past_key_values.
|
1313 |
else None
|
1314 |
)
|
1315 |
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
|
|
1020 |
min_dtype = torch.finfo(dtype).min
|
1021 |
sequence_length = input_tensor.shape[1]
|
1022 |
if using_static_cache:
|
1023 |
+
target_length = past_key_values.get_seq_length()
|
1024 |
else:
|
1025 |
target_length = (
|
1026 |
attention_mask.shape[-1]
|
|
|
1308 |
if isinstance(past_key_values, Cache):
|
1309 |
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
1310 |
max_cache_length = (
|
1311 |
+
torch.tensor(past_key_values.get_seq_length(), device=input_ids.device)
|
1312 |
+
if past_key_values.get_seq_length() is not None
|
1313 |
else None
|
1314 |
)
|
1315 |
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|