AndreHathora's picture
Add enhanced KV cache calculator with GQA/MHA detection and fp4 support
6e06b7a
raw
history blame
4.42 kB
import gradio as gr
from transformers import AutoConfig
# Credits: This implementation is derived from and builds upon the excellent work by gaunernst
# Original implementation: https://huggingface.co/spaces/gaunernst/kv-cache-calculator
def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str):
hf_token = hf_token.strip()
try:
cfg = AutoConfig.from_pretrained(
name,
trust_remote_code=True,
token=hf_token or None,
)
except Exception as e:
raise gr.Error(e)
use_mla = cfg.architectures[0].startswith(("DeepseekV2", "DeepseekV3"))
if hasattr(cfg, "text_config"):
cfg = cfg.text_config
num_layers = cfg.num_hidden_layers
# Determine attention mechanism type
num_attention_heads = cfg.num_attention_heads
num_kv_heads = getattr(cfg, "num_key_value_heads", num_attention_heads)
if use_mla:
attention_type = "MLA"
elif num_kv_heads == num_attention_heads:
attention_type = "MHA"
else:
attention_type = "GQA"
model_config = [
["num_layers", num_layers],
["max_ctx_len", cfg.max_position_embeddings],
["attention_type", attention_type],
["num_attention_heads", num_attention_heads],
["num_kv_heads", num_kv_heads],
]
if ctx_len > cfg.max_position_embeddings:
gr.Warning(
"Requested context length is larger than the max value supported by the model"
)
# Calculate KV cache elements per token based on attention mechanism
if use_mla:
kv_lora_rank = cfg.kv_lora_rank
qk_rope_head_dim = cfg.qk_rope_head_dim
nelems_per_token = num_layers * (kv_lora_rank + qk_rope_head_dim)
model_config.append(["kv_lora_rank", kv_lora_rank])
model_config.append(["qk_rope_head_dim", qk_rope_head_dim])
model_config.append(["calc_formula", f"{num_layers} * ({kv_lora_rank} + {qk_rope_head_dim})"])
else:
head_dim = getattr(cfg, "head_dim", cfg.hidden_size // num_attention_heads)
nelems_per_token = num_layers * num_kv_heads * head_dim * 2 # 2 for key and value
model_config.append(["head_dim", head_dim])
if attention_type == "GQA":
kv_ratio = num_attention_heads // num_kv_heads
model_config.append(["gqa_ratio", f"{kv_ratio}:1"])
model_config.append(["calc_formula", f"{num_layers} * {num_kv_heads} * {head_dim} * 2"])
if dtype == "fp16/bf16":
nbytes_per_elem = 2
elif dtype == "fp8":
nbytes_per_elem = 1 + 2 / cfg.hidden_size # assume per-token scaling
elif dtype == "fp4":
nbytes_per_elem = 0.5 + 2 / 32 # 4-bit weights + scaling factor every 32 elements (MXFP4)
kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9
return kv_cache_size, model_config
# Minimal description for iframe embedding
DESCRIPTION = (
"Calculate KV cache memory requirements for transformer models. "
"Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types."
)
demo = gr.Interface(
title="KV Cache Calculator",
description=DESCRIPTION,
fn=calculate,
inputs=[
gr.Textbox(label="Model ID", value="Qwen/Qwen3-30B-A3B", placeholder="e.g., Qwen/Qwen3-30B-A3B"),
gr.Number(label="Context Length", value=128_000, minimum=1),
gr.Number(label="Number of Users", value=1, minimum=1),
gr.Dropdown(label="KV Cache Data Type", choices=["fp16/bf16", "fp8", "fp4"], value="fp16/bf16"),
gr.Textbox(label="HuggingFace Token (optional)", type="password", placeholder="For gated models"),
],
outputs=[
gr.Number(label="KV Cache Size (GB)", precision=2),
gr.Dataframe(
label="Model Configuration",
headers=["Parameter", "Value"],
datatype=["str", "str"],
wrap=True
),
],
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 800px !important;
margin: 0 auto !important;
}
""",
analytics_enabled=False,
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
# Enable embedding in iframes
allowed_paths=[],
app_kwargs={"docs_url": None, "redoc_url": None}
)