import copy import importlib.metadata import json import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from packaging import version from transformers.configuration_utils import PretrainedConfig from transformers.utils import is_torchdynamo_compiling, logging logger = logging.get_logger(__name__) class Cache(torch.nn.Module): """ Base, abstract class for all caches. The actual data structure is specific to each subclass. """ def __init__(self): super().__init__() def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. These are specific to each subclass and allow new types of cache to be created. Return: A tuple containing the updated key and value states. """ raise NotImplementedError("Make sure to implement `update` in a subclass.") def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states, if there is any.""" raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: """Given the sequence length of the new inputs, returns the usable length of the cache.""" # Cache without size limit -> all cache is usable # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache # length, we will need to evict part of the cache (and thus not all cache is usable) max_length = self.get_max_length() previous_seq_length = self.get_seq_length(layer_idx) if max_length is not None and previous_seq_length + new_seq_length > max_length: return max_length - new_seq_length return previous_seq_length def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): device = self.key_cache[layer_idx].device self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) device = self.value_cache[layer_idx].device self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) @property def seen_tokens(self): logger.warning_once( "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " "model input instead." ) if hasattr(self, "_seen_tokens"): return self._seen_tokens else: return None @dataclass class CacheConfig: """ Base class for cache configs """ cache_implementation: None @classmethod def from_dict(cls, config_dict, **kwargs): """ Constructs a CacheConfig instance from a dictionary of parameters. Args: config_dict (Dict[str, Any]): Dictionary containing configuration parameters. **kwargs: Additional keyword arguments to override dictionary values. Returns: CacheConfig: Instance of CacheConfig constructed from the dictionary. """ config = cls(**config_dict) to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) return config # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ Save this instance to a JSON file. Args: json_file_path (`str` or `os.PathLike`): Path to the JSON file in which this configuration instance's parameters will be saved. use_diff (`bool`, *optional*, defaults to `True`): If set to `True`, only the difference between the config instance and the default `QuantizationConfig()` is serialized to JSON file. """ with open(json_file_path, "w", encoding="utf-8") as writer: config_dict = self.to_dict() json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" writer.write(json_string) # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ return copy.deepcopy(self.__dict__) # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ def __iter__(self): """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" for attr, value in copy.deepcopy(self.__dict__).items(): yield attr, value # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" def to_json_string(self): """ Serializes this instance to a JSON formatted string. Returns: str: JSON formatted string representing the configuration instance. """ return json.dumps(self.__dict__, indent=2) + "\n" # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update def update(self, **kwargs): """ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, returning all the unused kwargs. Args: kwargs (`Dict[str, Any]`): Dictionary of attributes to tentatively update this class. Returns: `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. """ to_remove = [] for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) to_remove.append(key) # Remove all the attributes that were updated, without modifying the input dict unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} return unused_kwargs class StaticCache(Cache): """ Static Cache class to be used with `torch.compile(model)` and `torch.export()`. Parameters: config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The maximum batch size with which the model will be used. max_cache_len (`int`): The maximum sequence length with which the model will be used. device (`torch.device`): The device on which the cache should be initialized. Should be the same as the layer. dtype (*optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. Example: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation ``` """ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: super().__init__() self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) self.dtype = dtype if dtype is not None else torch.float32 self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] # Note: There will be significant perf decrease if switching to use 5D tensors instead. cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) for idx in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) # Notes: # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case # it is not needed anyway) # 2. `torch.export()` requires mutations to be registered as buffers. if not is_torchdynamo_compiling(): self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) new_layer_key_cache = getattr(self, f"key_cache_{idx}") new_layer_value_cache = getattr(self, f"value_cache_{idx}") torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. It is VERY important to index using a tensor, otherwise you introduce a copy to the device. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input to know how where to write in the cache. Return: A tuple containing the updated key and value states. """ cache_position = cache_kwargs.get("cache_position") self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] if cache_position is None: k_out.copy_(key_states) v_out.copy_(value_states) else: # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place # operation, that avoids copies and uses less memory. try: k_out.index_copy_(2, cache_position, key_states) v_out.index_copy_(2, cache_position, value_states) except NotImplementedError: # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. k_out[:, :, cache_position] = key_states v_out[:, :, cache_position] = value_states return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: deprecate this function in favor of `cache_position` return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" return self.max_cache_len def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): # In-place ops prevent breaking the static address self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()