AndreHathora commited on
Commit
baf381a
Β·
1 Parent(s): fb095c2

Implement real-time HuggingFace Hub search functionality

Browse files

- Added live search of entire HF Hub database via API
- Implemented caching system for better performance
- Fixed textbox glitching by removing feedback loop
- Search now returns actual models from HF Hub, not just filtered static list
- Enhanced search with multi-tier approach (text-generation + broader search)
- Popular models prioritized in search results
- Added huggingface_hub dependency for API access

Files changed (2) hide show
  1. app.py +169 -35
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,9 +1,101 @@
1
  import gradio as gr
2
  from transformers import AutoConfig
 
 
 
 
 
3
 
4
  # Credits: This implementation is derived from and builds upon the excellent work by gaunernst
5
  # Original implementation: https://huggingface.co/spaces/gaunernst/kv-cache-calculator
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str):
9
  hf_token = hf_token.strip()
@@ -22,8 +114,6 @@ def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str
22
  cfg = cfg.text_config
23
 
24
  num_layers = cfg.num_hidden_layers
25
-
26
- # Determine attention mechanism type
27
  num_attention_heads = cfg.num_attention_heads
28
  num_kv_heads = getattr(cfg, "num_key_value_heads", num_attention_heads)
29
 
@@ -46,7 +136,6 @@ def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str
46
  "Requested context length is larger than the max value supported by the model"
47
  )
48
 
49
- # Calculate KV cache elements per token based on attention mechanism
50
  if use_mla:
51
  kv_lora_rank = cfg.kv_lora_rank
52
  qk_rope_head_dim = cfg.qk_rope_head_dim
@@ -58,7 +147,7 @@ def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str
58
 
59
  else:
60
  head_dim = getattr(cfg, "head_dim", cfg.hidden_size // num_attention_heads)
61
- nelems_per_token = num_layers * num_kv_heads * head_dim * 2 # 2 for key and value
62
 
63
  model_config.append(["head_dim", head_dim])
64
  if attention_type == "GQA":
@@ -77,41 +166,87 @@ def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str
77
  return kv_cache_size, model_config
78
 
79
 
80
- # Minimal description for iframe embedding
81
  DESCRIPTION = (
82
  "Calculate KV cache memory requirements for transformer models. "
83
  "Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types."
84
  )
85
 
86
- demo = gr.Interface(
87
- title="KV Cache Calculator",
88
- description=DESCRIPTION,
89
- fn=calculate,
90
- inputs=[
91
- gr.Textbox(label="Model ID", value="Qwen/Qwen3-30B-A3B", placeholder="e.g., Qwen/Qwen3-30B-A3B"),
92
- gr.Number(label="Context Length", value=128_000, minimum=1),
93
- gr.Number(label="Number of Users", value=1, minimum=1),
94
- gr.Dropdown(label="KV Cache Data Type", choices=["fp16/bf16", "fp8", "fp4"], value="fp16/bf16"),
95
- gr.Textbox(label="HuggingFace Token (optional)", type="password", placeholder="For gated models"),
96
- ],
97
- outputs=[
98
- gr.Number(label="KV Cache Size (GB)", precision=2),
99
- gr.Dataframe(
100
- label="Model Configuration",
101
- headers=["Parameter", "Value"],
102
- datatype=["str", "str"],
103
- wrap=True
104
- ),
105
- ],
106
- theme=gr.themes.Soft(),
107
- css="""
108
- .gradio-container {
109
- max-width: 800px !important;
110
- margin: 0 auto !important;
111
- }
112
- """,
113
- analytics_enabled=False,
114
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
  demo.launch(
@@ -119,7 +254,6 @@ if __name__ == "__main__":
119
  server_port=7860,
120
  share=False,
121
  show_error=True,
122
- # Enable embedding in iframes
123
  allowed_paths=[],
124
  app_kwargs={"docs_url": None, "redoc_url": None}
125
  )
 
1
  import gradio as gr
2
  from transformers import AutoConfig
3
+ from huggingface_hub import list_models
4
+ import asyncio
5
+ from typing import List
6
+ import time
7
+ from functools import lru_cache
8
 
9
  # Credits: This implementation is derived from and builds upon the excellent work by gaunernst
10
  # Original implementation: https://huggingface.co/spaces/gaunernst/kv-cache-calculator
11
 
12
+ search_cache = {}
13
+
14
+ POPULAR_MODELS = [
15
+ "Qwen/Qwen3-30B-A3B",
16
+ "meta-llama/Llama-3.1-8B-Instruct",
17
+ "meta-llama/Llama-3.1-70B-Instruct",
18
+ "microsoft/DialoGPT-medium",
19
+ "microsoft/DialoGPT-large",
20
+ "mistralai/Mistral-7B-Instruct-v0.3",
21
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
22
+ "deepseek-ai/DeepSeek-V2-Chat",
23
+ "deepseek-ai/DeepSeek-V3-Base",
24
+ "google/gemma-2-9b",
25
+ "google/gemma-2-27b",
26
+ "Qwen/QwQ-32B-Preview",
27
+ "Qwen/Qwen2.5-72B-Instruct",
28
+ "anthropic/claude-3-haiku-20240307",
29
+ ]
30
+
31
+ def search_models(query: str, max_results: int = 50) -> List[str]:
32
+ if not query or len(query.strip()) < 1:
33
+ return POPULAR_MODELS[:15]
34
+
35
+ query = query.strip()
36
+ cache_key = f"{query.lower()}_{max_results}"
37
+
38
+ current_time = time.time()
39
+ if cache_key in search_cache:
40
+ cached_result, cache_time = search_cache[cache_key]
41
+ if current_time - cache_time < 300:
42
+ return cached_result
43
+
44
+ try:
45
+ print(f"Searching HF Hub for: {query}")
46
+ models = list_models(
47
+ search=query,
48
+ task="text-generation",
49
+ library="transformers",
50
+ sort="downloads",
51
+ direction=-1,
52
+ limit=max_results * 2,
53
+ full=False
54
+ )
55
+
56
+ all_matches = []
57
+ seen_models = set()
58
+
59
+ for model in POPULAR_MODELS:
60
+ if query.lower() in model.lower() and model not in seen_models:
61
+ all_matches.append(model)
62
+ seen_models.add(model)
63
+
64
+ for model in models:
65
+ if model.id not in seen_models and len(all_matches) < max_results:
66
+ all_matches.append(model.id)
67
+ seen_models.add(model.id)
68
+
69
+ if len(all_matches) < max_results // 2:
70
+ try:
71
+ broader_models = list_models(
72
+ search=query,
73
+ library="transformers",
74
+ sort="downloads",
75
+ direction=-1,
76
+ limit=max_results * 2
77
+ )
78
+ for model in broader_models:
79
+ if model.id not in seen_models and len(all_matches) < max_results:
80
+ model_id_lower = model.id.lower()
81
+ if any(keyword in model_id_lower for keyword in ['chat', 'instruct', 'base', 'model']):
82
+ all_matches.append(model.id)
83
+ seen_models.add(model.id)
84
+ except Exception as e:
85
+ print(f"Broader search failed: {e}")
86
+
87
+ result = all_matches[:max_results]
88
+ search_cache[cache_key] = (result, current_time)
89
+ if len(search_cache) > 20:
90
+ oldest_key = min(search_cache.keys(), key=lambda k: search_cache[k][1])
91
+ del search_cache[oldest_key]
92
+ return result
93
+
94
+ except Exception as e:
95
+ print(f"Search error: {e}")
96
+ popular_matches = [model for model in POPULAR_MODELS if query.lower() in model.lower()]
97
+ return popular_matches if popular_matches else POPULAR_MODELS[:15]
98
+
99
 
100
  def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str):
101
  hf_token = hf_token.strip()
 
114
  cfg = cfg.text_config
115
 
116
  num_layers = cfg.num_hidden_layers
 
 
117
  num_attention_heads = cfg.num_attention_heads
118
  num_kv_heads = getattr(cfg, "num_key_value_heads", num_attention_heads)
119
 
 
136
  "Requested context length is larger than the max value supported by the model"
137
  )
138
 
 
139
  if use_mla:
140
  kv_lora_rank = cfg.kv_lora_rank
141
  qk_rope_head_dim = cfg.qk_rope_head_dim
 
147
 
148
  else:
149
  head_dim = getattr(cfg, "head_dim", cfg.hidden_size // num_attention_heads)
150
+ nelems_per_token = num_layers * num_kv_heads * head_dim * 2
151
 
152
  model_config.append(["head_dim", head_dim])
153
  if attention_type == "GQA":
 
166
  return kv_cache_size, model_config
167
 
168
 
 
169
  DESCRIPTION = (
170
  "Calculate KV cache memory requirements for transformer models. "
171
  "Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types."
172
  )
173
 
174
+ def search_and_update_models(query):
175
+ if not query or len(query.strip()) < 2:
176
+ return gr.Dropdown(choices=POPULAR_MODELS)
177
+
178
+ search_results = search_models(query.strip(), max_results=50)
179
+ if query.strip() not in search_results:
180
+ search_results.insert(0, query.strip())
181
+ return gr.Dropdown(choices=search_results, value=query.strip())
182
+
183
+ with gr.Blocks(title="KV Cache Calculator", theme=gr.themes.Soft()) as demo:
184
+ gr.Markdown("# KV Cache Calculator")
185
+ gr.Markdown(DESCRIPTION)
186
+
187
+ with gr.Row():
188
+ with gr.Column():
189
+ model_search = gr.Textbox(
190
+ label="πŸ” Search Models",
191
+ placeholder="Type model name (e.g., llama, qwen, mistral...)",
192
+ value="Qwen/Qwen3-30B-A3B",
193
+ info="Search the entire HuggingFace Hub database"
194
+ )
195
+
196
+ model_dropdown = gr.Dropdown(
197
+ label="πŸ“‹ Select Model",
198
+ choices=POPULAR_MODELS,
199
+ value="Qwen/Qwen3-30B-A3B",
200
+ allow_custom_value=True,
201
+ info="Models matching your search - or type a custom model ID"
202
+ )
203
+
204
+ with gr.Row():
205
+ gr.Markdown("**πŸ’‘ Tip:** Search updates the dropdown with real HF Hub results")
206
+
207
+ ctx_len = gr.Number(label="Context Length", value=128_000, minimum=1)
208
+ num_users = gr.Number(label="Number of Users", value=1, minimum=1)
209
+ dtype = gr.Dropdown(
210
+ label="KV Cache Data Type",
211
+ choices=["fp16/bf16", "fp8", "fp4"],
212
+ value="fp16/bf16"
213
+ )
214
+ hf_token = gr.Textbox(
215
+ label="HuggingFace Token (optional)",
216
+ type="password",
217
+ placeholder="For gated models"
218
+ )
219
+
220
+ calculate_btn = gr.Button("Calculate KV Cache Size", variant="primary")
221
+
222
+ with gr.Column():
223
+ cache_size = gr.Number(label="KV Cache Size (GB)", precision=2)
224
+ model_config = gr.Dataframe(
225
+ label="Model Configuration",
226
+ headers=["Parameter", "Value"],
227
+ datatype=["str", "str"],
228
+ wrap=True
229
+ )
230
+
231
+ model_search.change(
232
+ fn=search_and_update_models,
233
+ inputs=[model_search],
234
+ outputs=[model_dropdown],
235
+ show_progress=False
236
+ )
237
+
238
+ calculate_btn.click(
239
+ fn=calculate,
240
+ inputs=[model_dropdown, ctx_len, num_users, dtype, hf_token],
241
+ outputs=[cache_size, model_config]
242
+ )
243
+
244
+ demo.css = """
245
+ .gradio-container {
246
+ max-width: 1000px !important;
247
+ margin: 0 auto !important;
248
+ }
249
+ """
250
 
251
  if __name__ == "__main__":
252
  demo.launch(
 
254
  server_port=7860,
255
  share=False,
256
  show_error=True,
 
257
  allowed_paths=[],
258
  app_kwargs={"docs_url": None, "redoc_url": None}
259
  )
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- transformers
 
 
1
+ transformers>=4.21.0
2
+ huggingface_hub>=0.16.0