|
import gradio as gr |
|
from transformers import AutoConfig |
|
from huggingface_hub import list_models |
|
import asyncio |
|
from typing import List |
|
import time |
|
from functools import lru_cache |
|
|
|
|
|
|
|
|
|
search_cache = {} |
|
|
|
POPULAR_MODELS = [ |
|
"Qwen/Qwen3-30B-A3B", |
|
"meta-llama/Llama-3.1-8B-Instruct", |
|
"meta-llama/Llama-3.1-70B-Instruct", |
|
"microsoft/DialoGPT-medium", |
|
"microsoft/DialoGPT-large", |
|
"mistralai/Mistral-7B-Instruct-v0.3", |
|
"mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
"deepseek-ai/DeepSeek-V2-Chat", |
|
"deepseek-ai/DeepSeek-V3-Base", |
|
"google/gemma-2-9b", |
|
"google/gemma-2-27b", |
|
"Qwen/QwQ-32B-Preview", |
|
"Qwen/Qwen2.5-72B-Instruct", |
|
"anthropic/claude-3-haiku-20240307", |
|
] |
|
|
|
def search_models(query: str, max_results: int = 50) -> List[str]: |
|
if not query or len(query.strip()) < 1: |
|
return POPULAR_MODELS[:15] |
|
|
|
query = query.strip() |
|
cache_key = f"{query.lower()}_{max_results}" |
|
|
|
current_time = time.time() |
|
if cache_key in search_cache: |
|
cached_result, cache_time = search_cache[cache_key] |
|
if current_time - cache_time < 300: |
|
return cached_result |
|
|
|
try: |
|
print(f"Searching HF Hub for: {query}") |
|
models = list_models( |
|
search=query, |
|
task="text-generation", |
|
library="transformers", |
|
sort="downloads", |
|
direction=-1, |
|
limit=max_results * 2, |
|
full=False |
|
) |
|
|
|
all_matches = [] |
|
seen_models = set() |
|
|
|
for model in POPULAR_MODELS: |
|
if query.lower() in model.lower() and model not in seen_models: |
|
all_matches.append(model) |
|
seen_models.add(model) |
|
|
|
for model in models: |
|
if model.id not in seen_models and len(all_matches) < max_results: |
|
all_matches.append(model.id) |
|
seen_models.add(model.id) |
|
|
|
if len(all_matches) < max_results // 2: |
|
try: |
|
broader_models = list_models( |
|
search=query, |
|
library="transformers", |
|
sort="downloads", |
|
direction=-1, |
|
limit=max_results * 2 |
|
) |
|
for model in broader_models: |
|
if model.id not in seen_models and len(all_matches) < max_results: |
|
model_id_lower = model.id.lower() |
|
if any(keyword in model_id_lower for keyword in ['chat', 'instruct', 'base', 'model']): |
|
all_matches.append(model.id) |
|
seen_models.add(model.id) |
|
except Exception as e: |
|
print(f"Broader search failed: {e}") |
|
|
|
result = all_matches[:max_results] |
|
search_cache[cache_key] = (result, current_time) |
|
if len(search_cache) > 20: |
|
oldest_key = min(search_cache.keys(), key=lambda k: search_cache[k][1]) |
|
del search_cache[oldest_key] |
|
return result |
|
|
|
except Exception as e: |
|
print(f"Search error: {e}") |
|
popular_matches = [model for model in POPULAR_MODELS if query.lower() in model.lower()] |
|
return popular_matches if popular_matches else POPULAR_MODELS[:15] |
|
|
|
|
|
def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str): |
|
hf_token = hf_token.strip() |
|
try: |
|
cfg = AutoConfig.from_pretrained( |
|
name, |
|
trust_remote_code=True, |
|
token=hf_token or None, |
|
) |
|
except Exception as e: |
|
raise gr.Error(e) |
|
|
|
use_mla = cfg.architectures[0].startswith(("DeepseekV2", "DeepseekV3")) |
|
|
|
if hasattr(cfg, "text_config"): |
|
cfg = cfg.text_config |
|
|
|
num_layers = cfg.num_hidden_layers |
|
num_attention_heads = cfg.num_attention_heads |
|
num_kv_heads = getattr(cfg, "num_key_value_heads", num_attention_heads) |
|
|
|
if use_mla: |
|
attention_type = "MLA" |
|
elif num_kv_heads == num_attention_heads: |
|
attention_type = "MHA" |
|
else: |
|
attention_type = "GQA" |
|
|
|
model_config = [ |
|
["num_layers", num_layers], |
|
["max_ctx_len", cfg.max_position_embeddings], |
|
["attention_type", attention_type], |
|
["num_attention_heads", num_attention_heads], |
|
["num_kv_heads", num_kv_heads], |
|
] |
|
if ctx_len > cfg.max_position_embeddings: |
|
gr.Warning( |
|
"Requested context length is larger than the max value supported by the model" |
|
) |
|
|
|
if use_mla: |
|
kv_lora_rank = cfg.kv_lora_rank |
|
qk_rope_head_dim = cfg.qk_rope_head_dim |
|
nelems_per_token = num_layers * (kv_lora_rank + qk_rope_head_dim) |
|
|
|
model_config.append(["kv_lora_rank", kv_lora_rank]) |
|
model_config.append(["qk_rope_head_dim", qk_rope_head_dim]) |
|
model_config.append(["calc_formula", f"{num_layers} * ({kv_lora_rank} + {qk_rope_head_dim})"]) |
|
|
|
else: |
|
head_dim = getattr(cfg, "head_dim", cfg.hidden_size // num_attention_heads) |
|
nelems_per_token = num_layers * num_kv_heads * head_dim * 2 |
|
|
|
model_config.append(["head_dim", head_dim]) |
|
if attention_type == "GQA": |
|
kv_ratio = num_attention_heads // num_kv_heads |
|
model_config.append(["gqa_ratio", f"{kv_ratio}:1"]) |
|
model_config.append(["calc_formula", f"{num_layers} * {num_kv_heads} * {head_dim} * 2"]) |
|
|
|
if dtype == "fp16/bf16": |
|
nbytes_per_elem = 2 |
|
elif dtype == "fp8": |
|
nbytes_per_elem = 1 + 2 / cfg.hidden_size |
|
elif dtype == "fp4": |
|
nbytes_per_elem = 0.5 + 2 / 32 |
|
|
|
kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9 |
|
return kv_cache_size, model_config |
|
|
|
|
|
DESCRIPTION = ( |
|
"Calculate KV cache memory requirements for transformer models. " |
|
"Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types." |
|
) |
|
|
|
def search_and_update_models(query): |
|
if not query or len(query.strip()) < 2: |
|
return gr.Dropdown(choices=POPULAR_MODELS) |
|
|
|
search_results = search_models(query.strip(), max_results=50) |
|
if query.strip() not in search_results: |
|
search_results.insert(0, query.strip()) |
|
return gr.Dropdown(choices=search_results, value=query.strip()) |
|
|
|
with gr.Blocks(title="KV Cache Calculator", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# KV Cache Calculator") |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
model_search = gr.Textbox( |
|
label="π Search Models", |
|
placeholder="Type model name (e.g., llama, qwen, mistral...)", |
|
value="Qwen/Qwen3-30B-A3B", |
|
info="Search the entire HuggingFace Hub database" |
|
) |
|
|
|
model_dropdown = gr.Dropdown( |
|
label="π Select Model", |
|
choices=POPULAR_MODELS, |
|
value="Qwen/Qwen3-30B-A3B", |
|
allow_custom_value=True, |
|
info="Models matching your search - or type a custom model ID" |
|
) |
|
|
|
with gr.Row(): |
|
gr.Markdown("**π‘ Tip:** Search updates the dropdown with real HF Hub results") |
|
|
|
ctx_len = gr.Number(label="Context Length", value=128_000, minimum=1) |
|
num_users = gr.Number(label="Number of Users", value=1, minimum=1) |
|
dtype = gr.Dropdown( |
|
label="KV Cache Data Type", |
|
choices=["fp16/bf16", "fp8", "fp4"], |
|
value="fp16/bf16" |
|
) |
|
hf_token = gr.Textbox( |
|
label="HuggingFace Token (optional)", |
|
type="password", |
|
placeholder="For gated models" |
|
) |
|
|
|
calculate_btn = gr.Button("Calculate KV Cache Size", variant="primary") |
|
|
|
with gr.Column(): |
|
cache_size = gr.Number(label="KV Cache Size (GB)", precision=2) |
|
model_config = gr.Dataframe( |
|
label="Model Configuration", |
|
headers=["Parameter", "Value"], |
|
datatype=["str", "str"], |
|
wrap=True |
|
) |
|
|
|
model_search.change( |
|
fn=search_and_update_models, |
|
inputs=[model_search], |
|
outputs=[model_dropdown], |
|
show_progress=False |
|
) |
|
|
|
calculate_btn.click( |
|
fn=calculate, |
|
inputs=[model_dropdown, ctx_len, num_users, dtype, hf_token], |
|
outputs=[cache_size, model_config] |
|
) |
|
|
|
demo.css = """ |
|
.gradio-container { |
|
max-width: 1000px !important; |
|
margin: 0 auto !important; |
|
} |
|
""" |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
show_error=True, |
|
allowed_paths=[], |
|
app_kwargs={"docs_url": None, "redoc_url": None} |
|
) |
|
|