File size: 4,502 Bytes
b5dfaf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# coding=utf-8
# Copyright 2024 Nvidia Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Optional, Dict, Any, Tuple

import torch
from transformers.cache_utils import Cache  # used to let GenerationMixin know that we use a Cache object

from .configuration_decilm import DeciLMConfig, AttentionConfig
from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2, StaticCache


class VariableCache(Cache_4_44_2, Cache):
    """
    A Cache object that supports a different Cache implementation for every layer,
    including layers without any kv-cache.
    Implemented using a list of Cache objects, each represents a "model" with 1 layer.
    The default implementation for the layer caches is StaticCache.
    The cache of each layer is allocated to the same gpu as the layer itself.
    """

    def __init__(self,
                 config: DeciLMConfig,
                 max_batch_size: int,
                 max_cache_len: int | None,
                 device: torch.device | str | None = None,
                 dtype: torch.dtype | None = None,
                 ):
        Cache_4_44_2.__init__(self)

        self.config = config
        self.max_batch_size = max_batch_size
        self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
        self.dtype = dtype

        self.layer_caches: list[Cache | None] = [None] * config.num_hidden_layers

    def update(
            self,
            key_states: torch.Tensor,
            value_states: torch.Tensor,
            layer_idx: int,
            cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        layer_cache = self.layer_caches[layer_idx]

        if layer_cache is None:
            block_config = self.config.block_configs[layer_idx]
            layer_cache = self._init_layer_cache(attention_config=block_config.attention, device=key_states.device)
            assert layer_cache is not None, "Trying to update the cache of a cache-less layer"
            self.layer_caches[layer_idx] = layer_cache

        k_out, v_out = layer_cache.update(key_states=key_states,
                                          value_states=value_states,
                                          layer_idx=0,
                                          cache_kwargs=cache_kwargs)
        seq_len = self.get_seq_length(layer_idx)
        k_out = k_out[:, :, :seq_len, :]
        v_out = v_out[:, :, :seq_len, :]
        return k_out, v_out

    def _init_layer_cache(self,
                          attention_config: AttentionConfig,
                          device: torch.device,
                          ) -> Cache | None:
        if attention_config.no_op or attention_config.replace_with_linear:
            return None
        config = deepcopy(self.config)
        config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group
        return StaticCache(config, self.max_batch_size, self.max_cache_len, device, self.dtype)

    def _get_first_real_cache(self) -> Cache:
        for layer_cache in self.layer_caches:
            if layer_cache is not None:
                return layer_cache
        raise ValueError(f"No real cache found, all layer caches are None.")

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        if layer_idx == 0 and self.layer_caches[0] is None:
            try:
                layer_cache = self._get_first_real_cache()
            except ValueError:
                return 0
        else:
            layer_cache = self.layer_caches[layer_idx]
        return layer_cache.get_seq_length()

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states."""
        return self.max_cache_len

    def reset(self):
        for layer_cache in self.layer_caches:
            if hasattr(layer_cache, "reset"):
                layer_cache.reset()