File size: 6,206 Bytes
2af6ba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# coding=utf-8
# Copyright 2024 Nvidia Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Optional, Dict, Any, Tuple

import torch
from transformers.cache_utils import Cache  # used to let GenerationMixin know that we use a Cache object

from .configuration_decilm import DeciLMConfig
from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2, SinkCache, StaticCache, SlidingWindowCache


class VariableCache(Cache_4_44_2, Cache):
    """
    A Cache object that supports a different Cache implementation for every layer,
    including layers without any kv-cache.
    Implemented using a list of Cache objects, each represents a "model" with 1 layer.
    The default implementation for the layer caches is StaticCache.
    The cache of each layer is allocated to the same gpu as the layer itself.
    """

    def __init__(
            self,
            *,  # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions
            config: DeciLMConfig,
            batch_size: int = None,
            max_cache_len: int = None,
            dtype: torch.dtype = torch.float32,
            max_batch_size: Optional[int] = None,
            **kwargs,
    ) -> None:
        Cache_4_44_2.__init__(self)

        self.config = deepcopy(config)
        self.max_batch_size = batch_size or max_batch_size
        self.batch_size = self.max_batch_size
        self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
        self.dtype = dtype

        self.layer_caches: list[Cache_4_44_2 | None] = [None] * config.num_hidden_layers
        self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers

    def update(
            self,
            key_states: torch.Tensor,
            value_states: torch.Tensor,
            layer_idx: int,
            cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.layer_caches[layer_idx] is None:
            self.layer_devices[layer_idx] = key_states.device
            self._init_layer_cache(layer_idx)

        layer_cache = self.layer_caches[layer_idx]
        assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}"

        k_out, v_out = layer_cache.update(key_states=key_states,
                                          value_states=value_states,
                                          layer_idx=0,
                                          cache_kwargs=cache_kwargs)
        seq_len = self.get_seq_length(layer_idx)
        k_out = k_out[:, :, :seq_len, :]
        v_out = v_out[:, :, :seq_len, :]
        return k_out, v_out

    def _init_layer_cache(self, layer_idx: int) -> None:
        block_config = self.config.block_configs[layer_idx]
        attention_config = block_config.attention

        if attention_config.no_op or attention_config.replace_with_linear:
            return None

        device = self.layer_devices[layer_idx]
        assert device is not None, f"Trying to init layer cache for {layer_idx=} without device"

        config = deepcopy(self.config)
        config.num_hidden_layers = 1
        config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group

        if attention_config.window_length is not None:
            if not attention_config.is_sink:
                config.sliding_window = attention_config.window_length
                self.layer_caches[layer_idx] = SlidingWindowCache(config=config,
                                                                  max_batch_size=self.max_batch_size,
                                                                  max_cache_len=self.max_cache_len,
                                                                  device=device,
                                                                  dtype=self.dtype)
                return
            elif not attention_config.unshifted_sink:
                self.layer_caches[layer_idx] = SinkCache(window_length=attention_config.window_length,
                                                         num_sink_tokens=attention_config.num_sink_tokens)
                return

        self.layer_caches[layer_idx] = StaticCache(config=config,
                                                   max_batch_size=self.max_batch_size,
                                                   max_cache_len=self.max_cache_len,
                                                   device=device,
                                                   dtype=self.dtype)

    def _get_first_real_cache(self) -> Cache:
        for layer_cache in self.layer_caches:
            if layer_cache is not None:
                return layer_cache
        raise ValueError(f"No real cache found, all layer caches are None.")

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        if layer_idx == 0 and self.layer_caches[0] is None:
            try:
                layer_cache = self._get_first_real_cache()
            except ValueError:
                return 0
        else:
            layer_cache = self.layer_caches[layer_idx]
        return layer_cache.get_seq_length()

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states."""
        return self.max_cache_len

    def reset(self):
        for layer_idx in range(len(self.layer_caches)):
            layer_cache = self.layer_caches[layer_idx]
            if hasattr(layer_cache, "reset"):
                layer_cache.reset()
            else:
                self._init_layer_cache(layer_idx)