MLKV: Multi-Layer Key-Value Heads for Memory Efficient Transformer Decoding
Abstract
Auto-regressive inference of transformers benefit greatly from Key-Value (KV) caching, but can lead to major memory bottlenecks as model size, batch size, and sequence length grow at scale. We introduce Multi-Layer Key-Value (MLKV) sharing, a novel approach extending KV sharing across transformer layers to reduce memory usage beyond what was possible with Multi-Query Attention (MQA) and Grouped-Query Attention (GQA). Evaluations on various NLP benchmarks and inference metrics using uptrained Pythia-160M variants demonstrate that MLKV significantly reduces memory usage with minimal performance loss, reducing KV cache size down to a factor of 6x compared to MQA. These results highlight MLKV's potential for efficient deployment of transformer models at scale. We provide code at https://github.com/zaydzuhri/pythia-mlkv
Community
Our proposed KV sharing method Multi-Layer Key-Value (MLKV) provides the option to further reduce KV cache size in transformers beyond what was possible with GQA and MQA. By sharing KV heads not only inside a layer but also between layers, we can reduce the total KV head count to lower than the number of layers in the transformer. We show through experiments that reductions of a factor up to 6x in cache size compared to MQA are possible and provide a fair accuracy/memory trade-off. We recommend sharing to every second layer (KV head count equal to half the number of layers) for 2x reduction from MQA with very minimal reduction in accuracy, but ultimately give the option to architecture designers to decide if even lower number of KV heads is needed for more memory constrained use cases.
It is worth noting that layer-wise KV sharing is not new: e.g. https://arxiv.org/abs/2002.09402 (I'm not sure that this was even the first)
Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper