Llama-3_1-Nemotron-51B-Instruct / variable_cache.py
itlevy's picture
transformers>=4.44.2, backward compat
b5dfaf4 verified
raw
history blame
No virus
4.5 kB
# coding=utf-8
# Copyright 2024 Nvidia Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Optional, Dict, Any, Tuple
import torch
from transformers.cache_utils import Cache # used to let GenerationMixin know that we use a Cache object
from .configuration_decilm import DeciLMConfig, AttentionConfig
from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2, StaticCache
class VariableCache(Cache_4_44_2, Cache):
"""
A Cache object that supports a different Cache implementation for every layer,
including layers without any kv-cache.
Implemented using a list of Cache objects, each represents a "model" with 1 layer.
The default implementation for the layer caches is StaticCache.
The cache of each layer is allocated to the same gpu as the layer itself.
"""
def __init__(self,
config: DeciLMConfig,
max_batch_size: int,
max_cache_len: int | None,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
):
Cache_4_44_2.__init__(self)
self.config = config
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.dtype = dtype
self.layer_caches: list[Cache | None] = [None] * config.num_hidden_layers
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
layer_cache = self.layer_caches[layer_idx]
if layer_cache is None:
block_config = self.config.block_configs[layer_idx]
layer_cache = self._init_layer_cache(attention_config=block_config.attention, device=key_states.device)
assert layer_cache is not None, "Trying to update the cache of a cache-less layer"
self.layer_caches[layer_idx] = layer_cache
k_out, v_out = layer_cache.update(key_states=key_states,
value_states=value_states,
layer_idx=0,
cache_kwargs=cache_kwargs)
seq_len = self.get_seq_length(layer_idx)
k_out = k_out[:, :, :seq_len, :]
v_out = v_out[:, :, :seq_len, :]
return k_out, v_out
def _init_layer_cache(self,
attention_config: AttentionConfig,
device: torch.device,
) -> Cache | None:
if attention_config.no_op or attention_config.replace_with_linear:
return None
config = deepcopy(self.config)
config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group
return StaticCache(config, self.max_batch_size, self.max_cache_len, device, self.dtype)
def _get_first_real_cache(self) -> Cache:
for layer_cache in self.layer_caches:
if layer_cache is not None:
return layer_cache
raise ValueError(f"No real cache found, all layer caches are None.")
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if layer_idx == 0 and self.layer_caches[0] is None:
try:
layer_cache = self._get_first_real_cache()
except ValueError:
return 0
else:
layer_cache = self.layer_caches[layer_idx]
return layer_cache.get_seq_length()
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len
def reset(self):
for layer_cache in self.layer_caches:
if hasattr(layer_cache, "reset"):
layer_cache.reset()