|
import gradio as gr |
|
from transformers import AutoConfig |
|
|
|
|
|
|
|
|
|
|
|
def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str): |
|
hf_token = hf_token.strip() |
|
try: |
|
cfg = AutoConfig.from_pretrained( |
|
name, |
|
trust_remote_code=True, |
|
token=hf_token or None, |
|
) |
|
except Exception as e: |
|
raise gr.Error(e) |
|
|
|
use_mla = cfg.architectures[0].startswith(("DeepseekV2", "DeepseekV3")) |
|
|
|
if hasattr(cfg, "text_config"): |
|
cfg = cfg.text_config |
|
|
|
num_layers = cfg.num_hidden_layers |
|
|
|
|
|
num_attention_heads = cfg.num_attention_heads |
|
num_kv_heads = getattr(cfg, "num_key_value_heads", num_attention_heads) |
|
|
|
if use_mla: |
|
attention_type = "MLA" |
|
elif num_kv_heads == num_attention_heads: |
|
attention_type = "MHA" |
|
else: |
|
attention_type = "GQA" |
|
|
|
model_config = [ |
|
["num_layers", num_layers], |
|
["max_ctx_len", cfg.max_position_embeddings], |
|
["attention_type", attention_type], |
|
["num_attention_heads", num_attention_heads], |
|
["num_kv_heads", num_kv_heads], |
|
] |
|
if ctx_len > cfg.max_position_embeddings: |
|
gr.Warning( |
|
"Requested context length is larger than the max value supported by the model" |
|
) |
|
|
|
|
|
if use_mla: |
|
kv_lora_rank = cfg.kv_lora_rank |
|
qk_rope_head_dim = cfg.qk_rope_head_dim |
|
nelems_per_token = num_layers * (kv_lora_rank + qk_rope_head_dim) |
|
|
|
model_config.append(["kv_lora_rank", kv_lora_rank]) |
|
model_config.append(["qk_rope_head_dim", qk_rope_head_dim]) |
|
model_config.append(["calc_formula", f"{num_layers} * ({kv_lora_rank} + {qk_rope_head_dim})"]) |
|
|
|
else: |
|
head_dim = getattr(cfg, "head_dim", cfg.hidden_size // num_attention_heads) |
|
nelems_per_token = num_layers * num_kv_heads * head_dim * 2 |
|
|
|
model_config.append(["head_dim", head_dim]) |
|
if attention_type == "GQA": |
|
kv_ratio = num_attention_heads // num_kv_heads |
|
model_config.append(["gqa_ratio", f"{kv_ratio}:1"]) |
|
model_config.append(["calc_formula", f"{num_layers} * {num_kv_heads} * {head_dim} * 2"]) |
|
|
|
if dtype == "fp16/bf16": |
|
nbytes_per_elem = 2 |
|
elif dtype == "fp8": |
|
nbytes_per_elem = 1 + 2 / cfg.hidden_size |
|
elif dtype == "fp4": |
|
nbytes_per_elem = 0.5 + 2 / 32 |
|
|
|
kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9 |
|
return kv_cache_size, model_config |
|
|
|
|
|
|
|
DESCRIPTION = ( |
|
"Calculate KV cache memory requirements for transformer models. " |
|
"Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types." |
|
) |
|
|
|
demo = gr.Interface( |
|
title="KV Cache Calculator", |
|
description=DESCRIPTION, |
|
fn=calculate, |
|
inputs=[ |
|
gr.Textbox(label="Model ID", value="Qwen/Qwen3-30B-A3B", placeholder="e.g., Qwen/Qwen3-30B-A3B"), |
|
gr.Number(label="Context Length", value=128_000, minimum=1), |
|
gr.Number(label="Number of Users", value=1, minimum=1), |
|
gr.Dropdown(label="KV Cache Data Type", choices=["fp16/bf16", "fp8", "fp4"], value="fp16/bf16"), |
|
gr.Textbox(label="HuggingFace Token (optional)", type="password", placeholder="For gated models"), |
|
], |
|
outputs=[ |
|
gr.Number(label="KV Cache Size (GB)", precision=2), |
|
gr.Dataframe( |
|
label="Model Configuration", |
|
headers=["Parameter", "Value"], |
|
datatype=["str", "str"], |
|
wrap=True |
|
), |
|
], |
|
theme=gr.themes.Soft(), |
|
css=""" |
|
.gradio-container { |
|
max-width: 800px !important; |
|
margin: 0 auto !important; |
|
} |
|
""", |
|
analytics_enabled=False, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
show_error=True, |
|
|
|
allowed_paths=[], |
|
app_kwargs={"docs_url": None, "redoc_url": None} |
|
) |
|
|