AndreHathora's picture
Implement real-time HuggingFace Hub search functionality
baf381a
raw
history blame
9.12 kB
import gradio as gr
from transformers import AutoConfig
from huggingface_hub import list_models
import asyncio
from typing import List
import time
from functools import lru_cache
# Credits: This implementation is derived from and builds upon the excellent work by gaunernst
# Original implementation: https://huggingface.co/spaces/gaunernst/kv-cache-calculator
search_cache = {}
POPULAR_MODELS = [
"Qwen/Qwen3-30B-A3B",
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct",
"microsoft/DialoGPT-medium",
"microsoft/DialoGPT-large",
"mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"deepseek-ai/DeepSeek-V2-Chat",
"deepseek-ai/DeepSeek-V3-Base",
"google/gemma-2-9b",
"google/gemma-2-27b",
"Qwen/QwQ-32B-Preview",
"Qwen/Qwen2.5-72B-Instruct",
"anthropic/claude-3-haiku-20240307",
]
def search_models(query: str, max_results: int = 50) -> List[str]:
if not query or len(query.strip()) < 1:
return POPULAR_MODELS[:15]
query = query.strip()
cache_key = f"{query.lower()}_{max_results}"
current_time = time.time()
if cache_key in search_cache:
cached_result, cache_time = search_cache[cache_key]
if current_time - cache_time < 300:
return cached_result
try:
print(f"Searching HF Hub for: {query}")
models = list_models(
search=query,
task="text-generation",
library="transformers",
sort="downloads",
direction=-1,
limit=max_results * 2,
full=False
)
all_matches = []
seen_models = set()
for model in POPULAR_MODELS:
if query.lower() in model.lower() and model not in seen_models:
all_matches.append(model)
seen_models.add(model)
for model in models:
if model.id not in seen_models and len(all_matches) < max_results:
all_matches.append(model.id)
seen_models.add(model.id)
if len(all_matches) < max_results // 2:
try:
broader_models = list_models(
search=query,
library="transformers",
sort="downloads",
direction=-1,
limit=max_results * 2
)
for model in broader_models:
if model.id not in seen_models and len(all_matches) < max_results:
model_id_lower = model.id.lower()
if any(keyword in model_id_lower for keyword in ['chat', 'instruct', 'base', 'model']):
all_matches.append(model.id)
seen_models.add(model.id)
except Exception as e:
print(f"Broader search failed: {e}")
result = all_matches[:max_results]
search_cache[cache_key] = (result, current_time)
if len(search_cache) > 20:
oldest_key = min(search_cache.keys(), key=lambda k: search_cache[k][1])
del search_cache[oldest_key]
return result
except Exception as e:
print(f"Search error: {e}")
popular_matches = [model for model in POPULAR_MODELS if query.lower() in model.lower()]
return popular_matches if popular_matches else POPULAR_MODELS[:15]
def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str):
hf_token = hf_token.strip()
try:
cfg = AutoConfig.from_pretrained(
name,
trust_remote_code=True,
token=hf_token or None,
)
except Exception as e:
raise gr.Error(e)
use_mla = cfg.architectures[0].startswith(("DeepseekV2", "DeepseekV3"))
if hasattr(cfg, "text_config"):
cfg = cfg.text_config
num_layers = cfg.num_hidden_layers
num_attention_heads = cfg.num_attention_heads
num_kv_heads = getattr(cfg, "num_key_value_heads", num_attention_heads)
if use_mla:
attention_type = "MLA"
elif num_kv_heads == num_attention_heads:
attention_type = "MHA"
else:
attention_type = "GQA"
model_config = [
["num_layers", num_layers],
["max_ctx_len", cfg.max_position_embeddings],
["attention_type", attention_type],
["num_attention_heads", num_attention_heads],
["num_kv_heads", num_kv_heads],
]
if ctx_len > cfg.max_position_embeddings:
gr.Warning(
"Requested context length is larger than the max value supported by the model"
)
if use_mla:
kv_lora_rank = cfg.kv_lora_rank
qk_rope_head_dim = cfg.qk_rope_head_dim
nelems_per_token = num_layers * (kv_lora_rank + qk_rope_head_dim)
model_config.append(["kv_lora_rank", kv_lora_rank])
model_config.append(["qk_rope_head_dim", qk_rope_head_dim])
model_config.append(["calc_formula", f"{num_layers} * ({kv_lora_rank} + {qk_rope_head_dim})"])
else:
head_dim = getattr(cfg, "head_dim", cfg.hidden_size // num_attention_heads)
nelems_per_token = num_layers * num_kv_heads * head_dim * 2
model_config.append(["head_dim", head_dim])
if attention_type == "GQA":
kv_ratio = num_attention_heads // num_kv_heads
model_config.append(["gqa_ratio", f"{kv_ratio}:1"])
model_config.append(["calc_formula", f"{num_layers} * {num_kv_heads} * {head_dim} * 2"])
if dtype == "fp16/bf16":
nbytes_per_elem = 2
elif dtype == "fp8":
nbytes_per_elem = 1 + 2 / cfg.hidden_size # assume per-token scaling
elif dtype == "fp4":
nbytes_per_elem = 0.5 + 2 / 32 # 4-bit weights + scaling factor every 32 elements (MXFP4)
kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9
return kv_cache_size, model_config
DESCRIPTION = (
"Calculate KV cache memory requirements for transformer models. "
"Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types."
)
def search_and_update_models(query):
if not query or len(query.strip()) < 2:
return gr.Dropdown(choices=POPULAR_MODELS)
search_results = search_models(query.strip(), max_results=50)
if query.strip() not in search_results:
search_results.insert(0, query.strip())
return gr.Dropdown(choices=search_results, value=query.strip())
with gr.Blocks(title="KV Cache Calculator", theme=gr.themes.Soft()) as demo:
gr.Markdown("# KV Cache Calculator")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
model_search = gr.Textbox(
label="πŸ” Search Models",
placeholder="Type model name (e.g., llama, qwen, mistral...)",
value="Qwen/Qwen3-30B-A3B",
info="Search the entire HuggingFace Hub database"
)
model_dropdown = gr.Dropdown(
label="πŸ“‹ Select Model",
choices=POPULAR_MODELS,
value="Qwen/Qwen3-30B-A3B",
allow_custom_value=True,
info="Models matching your search - or type a custom model ID"
)
with gr.Row():
gr.Markdown("**πŸ’‘ Tip:** Search updates the dropdown with real HF Hub results")
ctx_len = gr.Number(label="Context Length", value=128_000, minimum=1)
num_users = gr.Number(label="Number of Users", value=1, minimum=1)
dtype = gr.Dropdown(
label="KV Cache Data Type",
choices=["fp16/bf16", "fp8", "fp4"],
value="fp16/bf16"
)
hf_token = gr.Textbox(
label="HuggingFace Token (optional)",
type="password",
placeholder="For gated models"
)
calculate_btn = gr.Button("Calculate KV Cache Size", variant="primary")
with gr.Column():
cache_size = gr.Number(label="KV Cache Size (GB)", precision=2)
model_config = gr.Dataframe(
label="Model Configuration",
headers=["Parameter", "Value"],
datatype=["str", "str"],
wrap=True
)
model_search.change(
fn=search_and_update_models,
inputs=[model_search],
outputs=[model_dropdown],
show_progress=False
)
calculate_btn.click(
fn=calculate,
inputs=[model_dropdown, ctx_len, num_users, dtype, hf_token],
outputs=[cache_size, model_config]
)
demo.css = """
.gradio-container {
max-width: 1000px !important;
margin: 0 auto !important;
}
"""
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
allowed_paths=[],
app_kwargs={"docs_url": None, "redoc_url": None}
)