File size: 4,418 Bytes
6e06b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from transformers import AutoConfig

# Credits: This implementation is derived from and builds upon the excellent work by gaunernst
# Original implementation: https://huggingface.co/spaces/gaunernst/kv-cache-calculator


def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str):
    hf_token = hf_token.strip()
    try:
        cfg = AutoConfig.from_pretrained(
            name,
            trust_remote_code=True,
            token=hf_token or None,
        )
    except Exception as e:
        raise gr.Error(e)

    use_mla = cfg.architectures[0].startswith(("DeepseekV2", "DeepseekV3"))

    if hasattr(cfg, "text_config"):
        cfg = cfg.text_config

    num_layers = cfg.num_hidden_layers
    
    # Determine attention mechanism type
    num_attention_heads = cfg.num_attention_heads
    num_kv_heads = getattr(cfg, "num_key_value_heads", num_attention_heads)
    
    if use_mla:
        attention_type = "MLA"
    elif num_kv_heads == num_attention_heads:
        attention_type = "MHA"
    else:
        attention_type = "GQA"
    
    model_config = [
        ["num_layers", num_layers],
        ["max_ctx_len", cfg.max_position_embeddings],
        ["attention_type", attention_type],
        ["num_attention_heads", num_attention_heads],
        ["num_kv_heads", num_kv_heads],
    ]
    if ctx_len > cfg.max_position_embeddings:
        gr.Warning(
            "Requested context length is larger than the max value supported by the model"
        )

    # Calculate KV cache elements per token based on attention mechanism
    if use_mla:
        kv_lora_rank = cfg.kv_lora_rank
        qk_rope_head_dim = cfg.qk_rope_head_dim
        nelems_per_token = num_layers * (kv_lora_rank + qk_rope_head_dim)

        model_config.append(["kv_lora_rank", kv_lora_rank])
        model_config.append(["qk_rope_head_dim", qk_rope_head_dim])
        model_config.append(["calc_formula", f"{num_layers} * ({kv_lora_rank} + {qk_rope_head_dim})"])

    else:
        head_dim = getattr(cfg, "head_dim", cfg.hidden_size // num_attention_heads)
        nelems_per_token = num_layers * num_kv_heads * head_dim * 2  # 2 for key and value

        model_config.append(["head_dim", head_dim])
        if attention_type == "GQA":
            kv_ratio = num_attention_heads // num_kv_heads
            model_config.append(["gqa_ratio", f"{kv_ratio}:1"])
        model_config.append(["calc_formula", f"{num_layers} * {num_kv_heads} * {head_dim} * 2"])

    if dtype == "fp16/bf16":
        nbytes_per_elem = 2
    elif dtype == "fp8":
        nbytes_per_elem = 1 + 2 / cfg.hidden_size  # assume per-token scaling
    elif dtype == "fp4":
        nbytes_per_elem = 0.5 + 2 / 32  # 4-bit weights + scaling factor every 32 elements (MXFP4)

    kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9
    return kv_cache_size, model_config


# Minimal description for iframe embedding
DESCRIPTION = (
    "Calculate KV cache memory requirements for transformer models. "
    "Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types."
)

demo = gr.Interface(
    title="KV Cache Calculator",
    description=DESCRIPTION,
    fn=calculate,
    inputs=[
        gr.Textbox(label="Model ID", value="Qwen/Qwen3-30B-A3B", placeholder="e.g., Qwen/Qwen3-30B-A3B"),
        gr.Number(label="Context Length", value=128_000, minimum=1),
        gr.Number(label="Number of Users", value=1, minimum=1),
        gr.Dropdown(label="KV Cache Data Type", choices=["fp16/bf16", "fp8", "fp4"], value="fp16/bf16"),
        gr.Textbox(label="HuggingFace Token (optional)", type="password", placeholder="For gated models"),
    ],
    outputs=[
        gr.Number(label="KV Cache Size (GB)", precision=2),
        gr.Dataframe(
            label="Model Configuration", 
            headers=["Parameter", "Value"], 
            datatype=["str", "str"],
            wrap=True
        ),
    ],
    theme=gr.themes.Soft(),
    css="""
    .gradio-container {
        max-width: 800px !important;
        margin: 0 auto !important;
    }
    """,
    analytics_enabled=False,
)

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True,
        # Enable embedding in iframes
        allowed_paths=[],
        app_kwargs={"docs_url": None, "redoc_url": None}
    )