AndreHathora commited on
Commit
160a197
·
1 Parent(s): baf381a
Files changed (1) hide show
  1. app.py +739 -54
app.py CHANGED
@@ -5,6 +5,12 @@ import asyncio
5
  from typing import List
6
  import time
7
  from functools import lru_cache
 
 
 
 
 
 
8
 
9
  # Credits: This implementation is derived from and builds upon the excellent work by gaunernst
10
  # Original implementation: https://huggingface.co/spaces/gaunernst/kv-cache-calculator
@@ -28,7 +34,270 @@ POPULAR_MODELS = [
28
  "anthropic/claude-3-haiku-20240307",
29
  ]
30
 
31
- def search_models(query: str, max_results: int = 50) -> List[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if not query or len(query.strip()) < 1:
33
  return POPULAR_MODELS[:15]
34
 
@@ -43,15 +312,6 @@ def search_models(query: str, max_results: int = 50) -> List[str]:
43
 
44
  try:
45
  print(f"Searching HF Hub for: {query}")
46
- models = list_models(
47
- search=query,
48
- task="text-generation",
49
- library="transformers",
50
- sort="downloads",
51
- direction=-1,
52
- limit=max_results * 2,
53
- full=False
54
- )
55
 
56
  all_matches = []
57
  seen_models = set()
@@ -61,34 +321,28 @@ def search_models(query: str, max_results: int = 50) -> List[str]:
61
  all_matches.append(model)
62
  seen_models.add(model)
63
 
 
 
 
 
 
 
 
 
 
 
64
  for model in models:
65
  if model.id not in seen_models and len(all_matches) < max_results:
66
  all_matches.append(model.id)
67
  seen_models.add(model.id)
68
 
69
- if len(all_matches) < max_results // 2:
70
- try:
71
- broader_models = list_models(
72
- search=query,
73
- library="transformers",
74
- sort="downloads",
75
- direction=-1,
76
- limit=max_results * 2
77
- )
78
- for model in broader_models:
79
- if model.id not in seen_models and len(all_matches) < max_results:
80
- model_id_lower = model.id.lower()
81
- if any(keyword in model_id_lower for keyword in ['chat', 'instruct', 'base', 'model']):
82
- all_matches.append(model.id)
83
- seen_models.add(model.id)
84
- except Exception as e:
85
- print(f"Broader search failed: {e}")
86
-
87
  result = all_matches[:max_results]
88
  search_cache[cache_key] = (result, current_time)
89
- if len(search_cache) > 20:
 
90
  oldest_key = min(search_cache.keys(), key=lambda k: search_cache[k][1])
91
  del search_cache[oldest_key]
 
92
  return result
93
 
94
  except Exception as e:
@@ -98,6 +352,10 @@ def search_models(query: str, max_results: int = 50) -> List[str]:
98
 
99
 
100
  def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str):
 
 
 
 
101
  hf_token = hf_token.strip()
102
  try:
103
  cfg = AutoConfig.from_pretrained(
@@ -163,7 +421,17 @@ def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str
163
  nbytes_per_elem = 0.5 + 2 / 32 # 4-bit weights + scaling factor every 32 elements (MXFP4)
164
 
165
  kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9
166
- return kv_cache_size, model_config
 
 
 
 
 
 
 
 
 
 
167
 
168
 
169
  DESCRIPTION = (
@@ -171,14 +439,345 @@ DESCRIPTION = (
171
  "Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types."
172
  )
173
 
174
- def search_and_update_models(query):
175
- if not query or len(query.strip()) < 2:
176
- return gr.Dropdown(choices=POPULAR_MODELS)
 
 
 
 
 
 
 
 
177
 
178
- search_results = search_models(query.strip(), max_results=50)
179
- if query.strip() not in search_results:
180
- search_results.insert(0, query.strip())
181
- return gr.Dropdown(choices=search_results, value=query.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  with gr.Blocks(title="KV Cache Calculator", theme=gr.themes.Soft()) as demo:
184
  gr.Markdown("# KV Cache Calculator")
@@ -187,22 +786,20 @@ with gr.Blocks(title="KV Cache Calculator", theme=gr.themes.Soft()) as demo:
187
  with gr.Row():
188
  with gr.Column():
189
  model_search = gr.Textbox(
190
- label="🔍 Search Models",
191
- placeholder="Type model name (e.g., llama, qwen, mistral...)",
192
- value="Qwen/Qwen3-30B-A3B",
193
- info="Search the entire HuggingFace Hub database"
194
  )
195
 
196
  model_dropdown = gr.Dropdown(
197
- label="📋 Select Model",
198
- choices=POPULAR_MODELS,
199
- value="Qwen/Qwen3-30B-A3B",
200
- allow_custom_value=True,
201
- info="Models matching your search - or type a custom model ID"
202
  )
203
 
204
  with gr.Row():
205
- gr.Markdown("**💡 Tip:** Search updates the dropdown with real HF Hub results")
206
 
207
  ctx_len = gr.Number(label="Context Length", value=128_000, minimum=1)
208
  num_users = gr.Number(label="Number of Users", value=1, minimum=1)
@@ -228,27 +825,115 @@ with gr.Blocks(title="KV Cache Calculator", theme=gr.themes.Soft()) as demo:
228
  wrap=True
229
  )
230
 
231
- model_search.change(
232
- fn=search_and_update_models,
 
 
 
 
 
 
 
 
233
  inputs=[model_search],
234
- outputs=[model_dropdown],
 
 
 
 
 
 
 
235
  show_progress=False
236
  )
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  calculate_btn.click(
239
- fn=calculate,
240
- inputs=[model_dropdown, ctx_len, num_users, dtype, hf_token],
241
- outputs=[cache_size, model_config]
242
  )
243
 
244
  demo.css = """
245
  .gradio-container {
246
- max-width: 1000px !important;
247
  margin: 0 auto !important;
248
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  """
250
 
251
  if __name__ == "__main__":
 
 
 
252
  demo.launch(
253
  server_name="0.0.0.0",
254
  server_port=7860,
 
5
  from typing import List
6
  import time
7
  from functools import lru_cache
8
+ import requests
9
+ import json
10
+ import re
11
+ from datetime import datetime, timedelta
12
+ import threading
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
 
15
  # Credits: This implementation is derived from and builds upon the excellent work by gaunernst
16
  # Original implementation: https://huggingface.co/spaces/gaunernst/kv-cache-calculator
 
34
  "anthropic/claude-3-haiku-20240307",
35
  ]
36
 
37
+ # Static GPU specifications (performance specs don't change, only prices do)
38
+ # All GPUs with SM_80+ compute capability (Flash Attention support)
39
+ GPU_SPECS = {
40
+ # Consumer RTX 30 Series (Ampere - GA102/GA104/GA106) - SM_8.6
41
+ "RTX 3060": {"memory_gb": 12, "compute_capability": "8.6", "tflops_fp32": 13.0, "category": "Consumer"},
42
+ "RTX 3060 Ti": {"memory_gb": 8, "compute_capability": "8.6", "tflops_fp32": 16.2, "category": "Consumer"},
43
+ "RTX 3070": {"memory_gb": 8, "compute_capability": "8.6", "tflops_fp32": 20.3, "category": "Consumer"},
44
+ "RTX 3070 Ti": {"memory_gb": 8, "compute_capability": "8.6", "tflops_fp32": 21.7, "category": "Consumer"},
45
+ "RTX 3080": {"memory_gb": 10, "compute_capability": "8.6", "tflops_fp32": 29.8, "category": "Consumer"},
46
+ "RTX 3080 Ti": {"memory_gb": 12, "compute_capability": "8.6", "tflops_fp32": 34.1, "category": "Consumer"},
47
+ "RTX 3090": {"memory_gb": 24, "compute_capability": "8.6", "tflops_fp32": 35.6, "category": "Consumer"},
48
+ "RTX 3090 Ti": {"memory_gb": 24, "compute_capability": "8.6", "tflops_fp32": 40.0, "category": "Consumer"},
49
+
50
+ # Consumer RTX 40 Series (Ada Lovelace - AD102/AD103/AD104/AD106/AD107) - SM_8.9
51
+ "RTX 4060": {"memory_gb": 8, "compute_capability": "8.9", "tflops_fp32": 15.1, "category": "Consumer"},
52
+ "RTX 4060 Ti": {"memory_gb": 8, "compute_capability": "8.9", "tflops_fp32": 22.1, "category": "Consumer"},
53
+ "RTX 4060 Ti 16GB": {"memory_gb": 16, "compute_capability": "8.9", "tflops_fp32": 22.1, "category": "Consumer"},
54
+ "RTX 4070": {"memory_gb": 12, "compute_capability": "8.9", "tflops_fp32": 29.1, "category": "Consumer"},
55
+ "RTX 4070 Super": {"memory_gb": 12, "compute_capability": "8.9", "tflops_fp32": 35.5, "category": "Consumer"},
56
+ "RTX 4070 Ti": {"memory_gb": 12, "compute_capability": "8.9", "tflops_fp32": 40.1, "category": "Consumer"},
57
+ "RTX 4070 Ti Super": {"memory_gb": 16, "compute_capability": "8.9", "tflops_fp32": 44.1, "category": "Consumer"},
58
+ "RTX 4080": {"memory_gb": 16, "compute_capability": "8.9", "tflops_fp32": 48.7, "category": "Consumer"},
59
+ "RTX 4080 Super": {"memory_gb": 16, "compute_capability": "8.9", "tflops_fp32": 52.2, "category": "Consumer"},
60
+ "RTX 4090": {"memory_gb": 24, "compute_capability": "8.9", "tflops_fp32": 83.0, "category": "Consumer"},
61
+
62
+ # Professional/Workstation RTX A Series (Ampere) - SM_8.6
63
+ "RTX A2000": {"memory_gb": 12, "compute_capability": "8.6", "tflops_fp32": 8.0, "category": "Workstation"},
64
+ "RTX A4000": {"memory_gb": 16, "compute_capability": "8.6", "tflops_fp32": 19.2, "category": "Workstation"},
65
+ "RTX A4500": {"memory_gb": 20, "compute_capability": "8.6", "tflops_fp32": 23.7, "category": "Workstation"},
66
+ "RTX A5000": {"memory_gb": 24, "compute_capability": "8.6", "tflops_fp32": 27.8, "category": "Workstation"},
67
+ "RTX A6000": {"memory_gb": 48, "compute_capability": "8.6", "tflops_fp32": 38.7, "category": "Workstation"},
68
+
69
+ # Professional RTX 6000 Ada (Ada Lovelace) - SM_8.9
70
+ "RTX 6000 Ada": {"memory_gb": 48, "compute_capability": "8.9", "tflops_fp32": 91.1, "category": "Workstation"},
71
+
72
+ # Datacenter A100 Series (Ampere) - SM_8.0
73
+ "A100 40GB": {"memory_gb": 40, "compute_capability": "8.0", "tflops_fp32": 19.5, "category": "Datacenter"},
74
+ "A100 80GB": {"memory_gb": 80, "compute_capability": "8.0", "tflops_fp32": 19.5, "category": "Datacenter"},
75
+
76
+ # Datacenter H100 Series (Hopper) - SM_9.0
77
+ "H100 80GB": {"memory_gb": 80, "compute_capability": "9.0", "tflops_fp32": 67.0, "category": "Datacenter"},
78
+ "H100 94GB": {"memory_gb": 94, "compute_capability": "9.0", "tflops_fp32": 67.0, "category": "Datacenter"},
79
+
80
+ # Datacenter H200 (Hopper) - SM_9.0
81
+ "H200 141GB": {"memory_gb": 141, "compute_capability": "9.0", "tflops_fp32": 67.0, "category": "Datacenter"},
82
+
83
+ # Datacenter B200 (Blackwell) - SM_10.0
84
+ "B200 192GB": {"memory_gb": 180, "compute_capability": "10.0", "tflops_fp32": 80.0, "category": "Datacenter"},
85
+
86
+ # Datacenter L40/L40S (Ada Lovelace) - SM_8.9
87
+ "L40": {"memory_gb": 48, "compute_capability": "8.9", "tflops_fp32": 91.6, "category": "Datacenter"},
88
+ "L40S": {"memory_gb": 48, "compute_capability": "8.9", "tflops_fp32": 91.6, "category": "Datacenter"},
89
+ }
90
+
91
+ # Price cache with timestamp
92
+ price_cache = {}
93
+ PRICE_CACHE_DURATION = timedelta(hours=6) # Cache prices for 6 hours
94
+
95
+ def fetch_single_gpu_price(gpu_name):
96
+ """Fetch price for a single GPU (used in parallel)"""
97
+ try:
98
+ print(f"Fetching price for {gpu_name}...")
99
+ price = get_gpu_price_from_multiple_sources(gpu_name)
100
+ if price:
101
+ print(f"✓ Found price for {gpu_name}: ${price}")
102
+ return gpu_name, price
103
+ else:
104
+ print(f"✗ No price found for {gpu_name}, using fallback")
105
+ return gpu_name, get_fallback_price(gpu_name)
106
+ except Exception as e:
107
+ print(f"✗ Error fetching {gpu_name}: {e}")
108
+ return gpu_name, get_fallback_price(gpu_name)
109
+
110
+ def preload_gpu_prices():
111
+ """Pre-fetch all GPU prices in parallel on startup"""
112
+ print("🚀 Pre-loading GPU prices...")
113
+ start_time = time.time()
114
+
115
+ # Get list of GPUs to price
116
+ gpu_names = list(GPU_SPECS.keys())
117
+
118
+ # Use ThreadPoolExecutor for parallel requests
119
+ with ThreadPoolExecutor(max_workers=8) as executor:
120
+ # Submit all price fetch tasks
121
+ future_to_gpu = {executor.submit(fetch_single_gpu_price, gpu_name): gpu_name
122
+ for gpu_name in gpu_names}
123
+
124
+ # Collect results as they complete
125
+ for future in as_completed(future_to_gpu):
126
+ gpu_name, price = future.result()
127
+ # Store in cache with timestamp
128
+ cache_key = gpu_name.lower().replace(" ", "_")
129
+ price_cache[cache_key] = {
130
+ "price": price,
131
+ "timestamp": datetime.now()
132
+ }
133
+
134
+ end_time = time.time()
135
+ total_time = end_time - start_time
136
+ print(f"✅ Loaded prices for {len(gpu_names)} GPUs in {total_time:.1f} seconds")
137
+ print(f"💰 Cache contains {len(price_cache)} price entries")
138
+
139
+ def start_price_preloading():
140
+ """Start price preloading in background thread"""
141
+ def preload_worker():
142
+ preload_gpu_prices()
143
+
144
+ # Start preloading in background
145
+ preload_thread = threading.Thread(target=preload_worker, daemon=True)
146
+ preload_thread.start()
147
+ print("🔄 Price preloading started in background...")
148
+
149
+ def get_gpu_price_from_multiple_sources(gpu_name):
150
+ """Fetch GPU price from multiple sources with fallbacks"""
151
+ current_time = datetime.now()
152
+
153
+ # Check cache first
154
+ cache_key = gpu_name.lower().replace(" ", "_")
155
+ if cache_key in price_cache:
156
+ cached_data = price_cache[cache_key]
157
+ if current_time - cached_data["timestamp"] < PRICE_CACHE_DURATION:
158
+ return cached_data["price"]
159
+
160
+ price = None
161
+
162
+ try:
163
+ gpu_specs = GPU_SPECS.get(gpu_name, {})
164
+ gpu_category = gpu_specs.get("category", "Consumer")
165
+
166
+ if gpu_category == "Datacenter":
167
+ price = get_fallback_price(gpu_name)
168
+ else:
169
+ price = fetch_newegg_price(gpu_name)
170
+ if not price:
171
+ price = fetch_amazon_price(gpu_name)
172
+ if not price:
173
+ price = get_fallback_price(gpu_name)
174
+
175
+ except Exception as e:
176
+ print(f"Error fetching price for {gpu_name}: {e}")
177
+ price = get_fallback_price(gpu_name)
178
+
179
+ # Cache the result
180
+ if price:
181
+ price_cache[cache_key] = {
182
+ "price": price,
183
+ "timestamp": current_time
184
+ }
185
+
186
+ return price
187
+
188
+ def fetch_newegg_price(gpu_name):
189
+ """Fetch price from Newegg search (simplified approach)"""
190
+ try:
191
+ # Simple approach: search for GPU and extract price patterns
192
+ search_term = gpu_name.replace(" ", "+")
193
+ url = f"https://www.newegg.com/p/pl?d={search_term}"
194
+
195
+ headers = {
196
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
197
+ }
198
+
199
+ response = requests.get(url, headers=headers, timeout=2)
200
+ if response.status_code == 200:
201
+ # Look for price patterns in the HTML
202
+ price_patterns = [
203
+ r'\$([0-9,]+\.?\d*)',
204
+ r'price.*?(\d+[,.]?\d*)',
205
+ r'(\d{3,4})\.\d{2}'
206
+ ]
207
+
208
+ for pattern in price_patterns:
209
+ matches = re.findall(pattern, response.text)
210
+ if matches:
211
+ # Get the first reasonable price (between $200-$3000)
212
+ for match in matches:
213
+ try:
214
+ price = float(match.replace(',', ''))
215
+ if 200 <= price <= 3000:
216
+ return price
217
+ except:
218
+ continue
219
+ except:
220
+ pass
221
+ return None
222
+
223
+ def fetch_amazon_price(gpu_name):
224
+ """Fetch price from Amazon search (simplified approach)"""
225
+ try:
226
+ search_term = gpu_name.replace(" ", "+")
227
+ url = f"https://www.amazon.com/s?k={search_term}+graphics+card"
228
+
229
+ headers = {
230
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
231
+ }
232
+
233
+ response = requests.get(url, headers=headers, timeout=2)
234
+ if response.status_code == 200:
235
+ # Look for Amazon price patterns
236
+ price_patterns = [
237
+ r'\$([0-9,]+\.?\d*)',
238
+ r'a-price-whole.*?(\d+)',
239
+ ]
240
+
241
+ for pattern in price_patterns:
242
+ matches = re.findall(pattern, response.text)
243
+ if matches:
244
+ for match in matches:
245
+ try:
246
+ price = float(match.replace(',', ''))
247
+ if 200 <= price <= 3000:
248
+ return price
249
+ except:
250
+ continue
251
+ except:
252
+ pass
253
+ return None
254
+
255
+ def get_fallback_price(gpu_name):
256
+ """Fallback prices based on typical market values (updated periodically)"""
257
+ fallback_prices = {
258
+ # Consumer RTX 30 Series
259
+ "RTX 3060": 280,
260
+ "RTX 3060 Ti": 320,
261
+ "RTX 3070": 420,
262
+ "RTX 3070 Ti": 480,
263
+ "RTX 3080": 580,
264
+ "RTX 3080 Ti": 720,
265
+ "RTX 3090": 950,
266
+ "RTX 3090 Ti": 1100,
267
+
268
+ # Consumer RTX 40 Series
269
+ "RTX 4060": 300,
270
+ "RTX 4060 Ti": 380,
271
+ "RTX 4060 Ti 16GB": 480,
272
+ "RTX 4070": 580,
273
+ "RTX 4070 Super": 680,
274
+ "RTX 4070 Ti": 780,
275
+ "RTX 4070 Ti Super": 880,
276
+ "RTX 4080": 980,
277
+ "RTX 4080 Super": 880,
278
+ "RTX 4090": 1500,
279
+
280
+ # Professional/Workstation GPUs
281
+ "RTX A2000": 650,
282
+ "RTX A4000": 1200,
283
+ "RTX A4500": 2200,
284
+ "RTX A5000": 2800,
285
+ "RTX A6000": 4500,
286
+ "RTX 6000 Ada": 6800,
287
+
288
+ # Datacenter GPUs (estimated enterprise pricing)
289
+ "A100 40GB": 12000,
290
+ "A100 80GB": 15000,
291
+ "H100 80GB": 28000,
292
+ "H100 94GB": 32000,
293
+ "H200 141GB": 35000,
294
+ "B200 192GB": 45000,
295
+ "L40": 8500,
296
+ "L40S": 9500,
297
+ }
298
+ return fallback_prices.get(gpu_name, 1000)
299
+
300
+ def search_models_fast(query: str, max_results: int = 30) -> List[str]:
301
  if not query or len(query.strip()) < 1:
302
  return POPULAR_MODELS[:15]
303
 
 
312
 
313
  try:
314
  print(f"Searching HF Hub for: {query}")
 
 
 
 
 
 
 
 
 
315
 
316
  all_matches = []
317
  seen_models = set()
 
321
  all_matches.append(model)
322
  seen_models.add(model)
323
 
324
+ models = list_models(
325
+ search=query,
326
+ task="text-generation",
327
+ library="transformers",
328
+ sort="downloads",
329
+ direction=-1,
330
+ limit=max_results,
331
+ full=False
332
+ )
333
+
334
  for model in models:
335
  if model.id not in seen_models and len(all_matches) < max_results:
336
  all_matches.append(model.id)
337
  seen_models.add(model.id)
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  result = all_matches[:max_results]
340
  search_cache[cache_key] = (result, current_time)
341
+
342
+ if len(search_cache) > 15:
343
  oldest_key = min(search_cache.keys(), key=lambda k: search_cache[k][1])
344
  del search_cache[oldest_key]
345
+
346
  return result
347
 
348
  except Exception as e:
 
352
 
353
 
354
  def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str):
355
+ if not name or not name.strip():
356
+ raise gr.Error("Please search for and select a model first")
357
+
358
+ name = name.strip()
359
  hf_token = hf_token.strip()
360
  try:
361
  cfg = AutoConfig.from_pretrained(
 
421
  nbytes_per_elem = 0.5 + 2 / 32 # 4-bit weights + scaling factor every 32 elements (MXFP4)
422
 
423
  kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9
424
+
425
+ # Get GPU recommendations with complete memory analysis using actual config
426
+ gpu_recommendations = recommend_gpus(
427
+ kv_cache_size_gb=kv_cache_size,
428
+ config=cfg,
429
+ dtype=dtype,
430
+ ctx_len=ctx_len,
431
+ num_users=num_users
432
+ )
433
+
434
+ return kv_cache_size, model_config, gpu_recommendations
435
 
436
 
437
  DESCRIPTION = (
 
439
  "Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types."
440
  )
441
 
442
+ def search_models_on_submit(search_query):
443
+ if not search_query or len(search_query.strip()) < 2:
444
+ return [
445
+ gr.Textbox(interactive=True),
446
+ gr.Dropdown(choices=[], value="", visible=False),
447
+ gr.Button(interactive=True)
448
+ ]
449
+
450
+ query_stripped = search_query.strip()
451
+
452
+ search_results = search_models_fast(query_stripped, max_results=30)
453
 
454
+ if query_stripped not in search_results:
455
+ search_results.insert(0, query_stripped)
456
+
457
+ return [
458
+ gr.Textbox(interactive=True, value=query_stripped),
459
+ gr.Dropdown(
460
+ choices=search_results,
461
+ value=query_stripped,
462
+ visible=True,
463
+ info=f"Found {len(search_results)} models - select one"
464
+ ),
465
+ gr.Button(interactive=True)
466
+ ]
467
+
468
+ def update_selection_from_dropdown(dropdown_value):
469
+ return gr.Textbox(value=dropdown_value)
470
+
471
+ def estimate_model_memory(config, dtype):
472
+ """Estimate model weight memory requirements in GB using actual config object"""
473
+ try:
474
+ if not config:
475
+ return 5.0 # Default fallback
476
+
477
+ # Extract parameters for calculation
478
+ num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'num_layers', 32))
479
+ hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', 4096))
480
+ vocab_size = getattr(config, 'vocab_size', 50000)
481
+ intermediate_size = getattr(config, 'intermediate_size', hidden_size * 4)
482
+
483
+ # DeepSeek V3 specific parameter calculation following the exact formula
484
+ # Check if this is DeepSeek V3 architecture
485
+ is_deepseek_v3 = (getattr(config, 'model_type', '') == 'deepseek_v3' or
486
+ any('deepseek' in arch.lower() for arch in getattr(config, 'architectures', [])))
487
+
488
+ if is_deepseek_v3 and hasattr(config, 'q_lora_rank'):
489
+ # DeepSeek V3 specific calculation
490
+ # Config constants
491
+ L = num_layers # 61
492
+ H = hidden_size # 7168
493
+ I = intermediate_size # 18432
494
+ I_moe = getattr(config, 'moe_intermediate_size', 2048) # 2048
495
+ n_h = getattr(config, 'num_attention_heads', 128) # 128
496
+ r_q = getattr(config, 'q_lora_rank', 1536) # 1536
497
+ r_kv = getattr(config, 'kv_lora_rank', 512) # 512
498
+ V = vocab_size # 129,280
499
+
500
+ # Additional config values
501
+ qk_nope_head_dim = getattr(config, 'qk_nope_head_dim', 128)
502
+ qk_rope_head_dim = getattr(config, 'qk_rope_head_dim', 64)
503
+ v_head_dim = getattr(config, 'v_head_dim', 128)
504
+
505
+ # Attention per layer calculation
506
+ # W_q,a: H × r_q
507
+ w_q_a = H * r_q
508
+ # W_q,b: r_q × n_h × (qk_nope + qk_rope)
509
+ w_q_b = r_q * n_h * (qk_nope_head_dim + qk_rope_head_dim)
510
+ # W_kv,a: H × (r_kv + qk_rope)
511
+ w_kv_a = H * (r_kv + qk_rope_head_dim)
512
+ # W_kv,b: r_kv × n_h × (qk_nope + v)
513
+ w_kv_b = r_kv * n_h * (qk_nope_head_dim + v_head_dim)
514
+ # W_o: (n_h × v) × H
515
+ w_o = (n_h * v_head_dim) * H
516
+
517
+ attention_per_layer = w_q_a + w_q_b + w_kv_a + w_kv_b + w_o
518
+ total_attention = L * attention_per_layer
519
+
520
+ # Dense FFN layers (first 3 layers)
521
+ dense_ffn_per_layer = 3 * H * I # 3 projections: gate, up, down
522
+ total_dense_ffn = 3 * dense_ffn_per_layer # 3 dense layers
523
+
524
+ # MoE FFN layers (remaining 58 layers)
525
+ moe_ffn_per_expert = 3 * H * I_moe
526
+ n_routed_experts = getattr(config, 'n_routed_experts', 256) # 256
527
+ n_shared_experts = getattr(config, 'n_shared_experts', 1) # 1
528
+ experts_per_moe_layer = n_routed_experts + n_shared_experts # 257
529
+ moe_ffn_per_layer = experts_per_moe_layer * moe_ffn_per_expert
530
+ moe_layers = L - 3 # 58 MoE layers
531
+ total_moe_ffn = moe_layers * moe_ffn_per_layer
532
+
533
+ # Embeddings + LM head (untied)
534
+ embeddings_and_head = 2 * V * H
535
+
536
+ # Total parameters
537
+ total_params = total_attention + total_dense_ffn + total_moe_ffn + embeddings_and_head
538
+
539
+ print(f"DEBUG: DeepSeek V3 parameter breakdown:")
540
+ print(f" Attention ({L} layers): {total_attention/1e9:.2f}B")
541
+ print(f" Dense FFN (3 layers): {total_dense_ffn/1e9:.2f}B")
542
+ print(f" MoE FFN ({moe_layers} layers): {total_moe_ffn/1e9:.2f}B")
543
+ print(f" Embeddings + Head: {embeddings_and_head/1e9:.2f}B")
544
+ print(f" Total calculated: {total_params/1e9:.1f}B parameters")
545
+
546
+ else:
547
+ # Fallback to standard transformer calculation for other models
548
+ num_attention_heads = getattr(config, 'num_attention_heads', hidden_size // 64)
549
+ num_kv_heads = getattr(config, 'num_key_value_heads', num_attention_heads)
550
+ head_dim = getattr(config, 'head_dim', hidden_size // num_attention_heads)
551
+
552
+ # Standard attention calculation
553
+ q_params = hidden_size * (num_attention_heads * head_dim)
554
+ kv_params = hidden_size * (num_kv_heads * head_dim) * 2
555
+ o_params = (num_attention_heads * head_dim) * hidden_size
556
+ attention_params_per_layer = q_params + kv_params + o_params
557
+ attention_params = num_layers * attention_params_per_layer
558
+
559
+ # Standard FFN calculation
560
+ ffn_params = num_layers * (2 * hidden_size * intermediate_size + intermediate_size * hidden_size)
561
+
562
+ # Embeddings
563
+ embedding_params = vocab_size * hidden_size
564
+
565
+ # Other parameters
566
+ other_params = num_layers * 2 * hidden_size + hidden_size
567
+
568
+ total_params = embedding_params + attention_params + ffn_params + other_params
569
+
570
+ print(f"DEBUG: Standard transformer parameter breakdown:")
571
+ print(f" Embeddings: {embedding_params/1e9:.1f}B")
572
+ print(f" Attention: {attention_params/1e9:.1f}B")
573
+ print(f" FFN: {ffn_params/1e9:.1f}B")
574
+ print(f" Other: {other_params/1e9:.1f}B")
575
+ print(f" Total calculated: {total_params/1e9:.1f}B parameters")
576
+
577
+ # Convert to memory based on user-selected dtype
578
+ if dtype == "fp16/bf16":
579
+ bytes_per_param = 2
580
+ elif dtype == "fp8":
581
+ bytes_per_param = 1
582
+ elif dtype == "fp4":
583
+ bytes_per_param = 0.5
584
+ else:
585
+ bytes_per_param = 4 # fp32 fallback
586
+
587
+ model_memory_gb = (total_params * bytes_per_param) / (1024**3)
588
+
589
+ # Add minimal overhead (5% for loading)
590
+ model_memory_gb *= 1.05
591
+
592
+ return model_memory_gb
593
+
594
+ except Exception as e:
595
+ print(f"Error estimating model memory from config: {e}")
596
+ return 70.0 # Conservative fallback for large models
597
+
598
+
599
+ def estimate_activation_memory(ctx_len, num_users, config):
600
+ """Estimate activation memory requirements in GB using actual config object"""
601
+ try:
602
+ if not config:
603
+ return 1.0 # Default fallback
604
+
605
+ # Extract parameters directly from config object
606
+ hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', 4096))
607
+
608
+ batch_size = num_users
609
+
610
+ # For inference, activations are much smaller than training
611
+ # Only need to store activations for current forward pass, not gradients
612
+
613
+ # 1. Input/output activations: batch_size * ctx_len * hidden_size
614
+ io_activations = batch_size * ctx_len * hidden_size
615
+
616
+ # 2. Intermediate activations (only a few layers worth, not all)
617
+ # Most activations are computed and immediately used, not stored
618
+ intermediate_size = getattr(config, 'intermediate_size', hidden_size * 4)
619
+ stored_activations = batch_size * ctx_len * intermediate_size * 2 # Only ~2 layers worth
620
+
621
+ # 3. Attention scores for current layer (not all layers stored)
622
+ num_attention_heads = getattr(config, 'num_attention_heads', hidden_size // 64)
623
+ attention_scores = batch_size * num_attention_heads * ctx_len * ctx_len
624
+
625
+ # Total activation elements (much smaller for inference)
626
+ total_activation_elements = io_activations + stored_activations + attention_scores
627
+
628
+ # Convert to memory (fp16 = 2 bytes per element)
629
+ activation_memory_gb = (total_activation_elements * 2) / (1024**3)
630
+
631
+ # Cap at reasonable values for inference (activations shouldn't dominate)
632
+ max_reasonable_gb = max(5.0, ctx_len * batch_size / 10000) # Reasonable scaling
633
+ activation_memory_gb = min(activation_memory_gb, max_reasonable_gb)
634
+
635
+ return max(0.5, activation_memory_gb) # At least 500MB
636
+
637
+ except Exception as e:
638
+ print(f"Error estimating activation memory from config: {e}")
639
+ # Simple fallback based on context length
640
+ try:
641
+ # Much simpler formula for inference
642
+ fallback_gb = (num_users * ctx_len * 4096 * 4 * 2) / (1024**3) # Conservative
643
+ return min(10.0, max(0.5, fallback_gb)) # Cap at 10GB
644
+ except:
645
+ return 2.0 # Default 2GB
646
+
647
+ def calculate_multi_gpu_configs(total_memory_needed, suitable_gpus):
648
+ """Calculate multi-GPU configurations for large models (power-of-2 for tensor parallelism)"""
649
+ multi_gpu_configs = []
650
+
651
+ # Power-of-2 configurations for tensor parallelism (TP) - max 8 for practical use
652
+ gpu_counts = [1, 2, 4, 8] # Only powers of 2, max 8 GPUs
653
+
654
+ # For large models, check all high-memory GPUs, not just top 3 cost-effective ones
655
+ gpus_to_check = suitable_gpus if total_memory_needed > 500 else suitable_gpus[:3]
656
+
657
+ for gpu in gpus_to_check:
658
+ for count in gpu_counts:
659
+ total_gpu_memory = gpu["memory_gb"] * count
660
+
661
+ if total_gpu_memory >= total_memory_needed:
662
+ # Calculate per-GPU memory utilization
663
+ memory_per_gpu = total_memory_needed / count
664
+ utilization = (memory_per_gpu / gpu["memory_gb"]) * 100
665
+
666
+ # Skip very inefficient configurations (< 30% utilization for multi-GPU)
667
+ if count > 1 and utilization < 30:
668
+ continue
669
+
670
+ # Calculate total cost
671
+ total_cost = gpu["price_usd"] * count
672
+ cost_per_tflop_total = total_cost / (gpu["tflops_fp32"] * count)
673
+
674
+ # Format configuration name with TP indication
675
+ if count == 1:
676
+ config_name = gpu['name']
677
+ else:
678
+ config_name = f"{count}x {gpu['name']} (TP={count})"
679
+
680
+ category_emoji = {
681
+ "Consumer": "🎮",
682
+ "Workstation": "🏢",
683
+ "Datacenter": "🏭"
684
+ }.get(gpu.get("category", "Consumer"), "🎮")
685
+
686
+ multi_gpu_configs.append({
687
+ "config": config_name,
688
+ "gpu_count": count,
689
+ "total_memory_gb": total_gpu_memory,
690
+ "memory_per_gpu": memory_per_gpu,
691
+ "utilization": utilization,
692
+ "total_cost": total_cost,
693
+ "cost_per_tflop": cost_per_tflop_total,
694
+ "category_emoji": category_emoji,
695
+ "base_gpu": gpu
696
+ })
697
+
698
+ # For single GPU, only add once
699
+ if count == 1:
700
+ break
701
+
702
+ # Sort by cost-effectiveness (total cost per TFLOP)
703
+ multi_gpu_configs.sort(key=lambda x: x["cost_per_tflop"])
704
+
705
+ return multi_gpu_configs[:8] # Return top 8 configurations
706
+
707
+ def recommend_gpus(kv_cache_size_gb, config=None, dtype="fp16/bf16", ctx_len=128000, num_users=1):
708
+ """Recommend cost-effective GPU configurations (single and multi-GPU with tensor parallelism) for complete memory footprint"""
709
+ if not kv_cache_size_gb or kv_cache_size_gb <= 0:
710
+ print("DEBUG: KV cache size is 0 or invalid")
711
+ return []
712
+
713
+ # Calculate complete memory footprint using actual config object
714
+ model_memory_gb = estimate_model_memory(config, dtype)
715
+ activation_memory_gb = estimate_activation_memory(ctx_len, num_users, config)
716
+
717
+ # Total memory = Model weights + KV cache + Activations + Safety buffer
718
+ total_memory_needed = model_memory_gb + kv_cache_size_gb + activation_memory_gb + 1.0 # 1GB safety buffer
719
+
720
+ print(f"DEBUG: Memory breakdown - Model: {model_memory_gb:.1f}GB, KV: {kv_cache_size_gb:.1f}GB, Activations: {activation_memory_gb:.1f}GB, Total: {total_memory_needed:.1f}GB")
721
+
722
+ # Get all GPUs with real pricing (from cache or live fetch)
723
+ all_gpus = []
724
+
725
+ for gpu_name, specs in GPU_SPECS.items():
726
+ # Get real-time price (will use cache if available)
727
+ current_price = get_gpu_price_from_multiple_sources(gpu_name)
728
+ if current_price:
729
+ cost_per_tflop = current_price / specs["tflops_fp32"]
730
+ all_gpus.append({
731
+ "name": gpu_name,
732
+ "memory_gb": specs["memory_gb"],
733
+ "compute_capability": specs["compute_capability"],
734
+ "tflops_fp32": specs["tflops_fp32"],
735
+ "price_usd": current_price,
736
+ "cost_per_tflop": cost_per_tflop,
737
+ "category": specs.get("category", "Consumer")
738
+ })
739
+
740
+ print(f"DEBUG: Found {len(all_gpus)} GPUs with pricing")
741
+
742
+ if not all_gpus:
743
+ print("DEBUG: No GPUs found with pricing")
744
+ return []
745
+
746
+ # Sort by cost-effectiveness for single GPU evaluation
747
+ all_gpus.sort(key=lambda x: x["cost_per_tflop"])
748
+
749
+ # Calculate multi-GPU configurations
750
+ multi_gpu_configs = calculate_multi_gpu_configs(total_memory_needed, all_gpus)
751
+
752
+ print(f"DEBUG: Generated {len(multi_gpu_configs)} GPU configurations")
753
+
754
+ if not multi_gpu_configs:
755
+ print("DEBUG: No valid GPU configurations found")
756
+ return []
757
+
758
+ # Format recommendations
759
+ recommendations = []
760
+ for i, config in enumerate(multi_gpu_configs):
761
+ rank_icons = ["🥇", "🥈", "🥉", "🏅", "⭐", "💫", "🌟", "✨"]
762
+ rank = rank_icons[i] if i < len(rank_icons) else "💎"
763
+
764
+ price_source = "💲 Live" if config["base_gpu"]["name"].lower().replace(" ", "_") in price_cache else "📊 Est"
765
+
766
+ # Format configuration display
767
+ if config["gpu_count"] == 1:
768
+ config_display = f"{rank} {config['category_emoji']} {config['config']}"
769
+ memory_display = f"{config['total_memory_gb']:.0f} GB"
770
+ else:
771
+ config_display = f"{rank} {config['category_emoji']} {config['config']}"
772
+ memory_display = f"{config['total_memory_gb']:.0f} GB ({config['utilization']:.0f}% util)"
773
+
774
+ recommendations.append([
775
+ config_display,
776
+ f"{total_memory_needed:.1f}GB required",
777
+ f"{price_source} ${config['total_cost']:.0f}"
778
+ ])
779
+
780
+ return recommendations
781
 
782
  with gr.Blocks(title="KV Cache Calculator", theme=gr.themes.Soft()) as demo:
783
  gr.Markdown("# KV Cache Calculator")
 
786
  with gr.Row():
787
  with gr.Column():
788
  model_search = gr.Textbox(
789
+ label="🔍 Search Model",
790
+ placeholder="Type your model ID here.",
 
 
791
  )
792
 
793
  model_dropdown = gr.Dropdown(
794
+ label="📋 Select from Results",
795
+ choices=[],
796
+ value="",
797
+ visible=False,
798
+ info="Choose from search results"
799
  )
800
 
801
  with gr.Row():
802
+ gr.Markdown("**💡 Tip:** Type model names like 'llama', 'qwen', 'mistral', then press Enter to search")
803
 
804
  ctx_len = gr.Number(label="Context Length", value=128_000, minimum=1)
805
  num_users = gr.Number(label="Number of Users", value=1, minimum=1)
 
825
  wrap=True
826
  )
827
 
828
+ gpu_recommendations = gr.Dataframe(
829
+ label="💡 GPU Recommendations",
830
+ headers=["Configuration", "Memory Required", "Total Price"],
831
+ datatype=["str", "str", "str"],
832
+ wrap=False,
833
+ visible=False
834
+ )
835
+
836
+ model_search.submit(
837
+ fn=search_models_on_submit,
838
  inputs=[model_search],
839
+ outputs=[model_search, model_dropdown, calculate_btn],
840
+ show_progress="minimal"
841
+ )
842
+
843
+ model_dropdown.change(
844
+ fn=update_selection_from_dropdown,
845
+ inputs=[model_dropdown],
846
+ outputs=[model_search],
847
  show_progress=False
848
  )
849
 
850
+ def calculate_and_show_gpus(model_name, ctx_len, num_users, dtype, hf_token):
851
+ cache_size, model_config, gpu_recs = calculate(model_name, ctx_len, num_users, dtype, hf_token)
852
+
853
+ print(f"DEBUG: GPU recommendations count: {len(gpu_recs) if gpu_recs else 0}")
854
+ if gpu_recs:
855
+ print(f"DEBUG: First recommendation: {gpu_recs[0] if gpu_recs else 'None'}")
856
+
857
+ if gpu_recs:
858
+ return (
859
+ cache_size,
860
+ model_config,
861
+ gr.Dataframe(value=gpu_recs, visible=True)
862
+ )
863
+ else:
864
+ print("DEBUG: No GPU recommendations found, showing empty table")
865
+ return (
866
+ cache_size,
867
+ model_config,
868
+ gr.Dataframe(value=[], visible=False)
869
+ )
870
+
871
  calculate_btn.click(
872
+ fn=calculate_and_show_gpus,
873
+ inputs=[model_search, ctx_len, num_users, dtype, hf_token],
874
+ outputs=[cache_size, model_config, gpu_recommendations]
875
  )
876
 
877
  demo.css = """
878
  .gradio-container {
879
+ max-width: 1400px !important;
880
  margin: 0 auto !important;
881
  }
882
+
883
+ /* Make dataframes wider and prevent text wrapping */
884
+ .gradio-dataframe {
885
+ width: 100% !important;
886
+ min-width: 800px !important;
887
+ }
888
+
889
+ .gradio-dataframe table {
890
+ width: 100% !important;
891
+ table-layout: auto !important;
892
+ }
893
+
894
+ .gradio-dataframe td, .gradio-dataframe th {
895
+ white-space: nowrap !important;
896
+ padding: 8px 12px !important;
897
+ text-overflow: ellipsis !important;
898
+ min-width: 120px !important;
899
+ }
900
+
901
+ /* Style disabled textboxes to be clearly disabled */
902
+ .gradio-textbox:disabled,
903
+ .gradio-textbox[aria-disabled="true"] {
904
+ opacity: 0.6 !important;
905
+ background-color: #f5f5f5 !important;
906
+ color: #666 !important;
907
+ cursor: not-allowed !important;
908
+ border-color: #ccc !important;
909
+ }
910
+
911
+ /* Style placeholder text */
912
+ .gradio-textbox input::placeholder {
913
+ color: #999 !important;
914
+ font-style: italic;
915
+ }
916
+
917
+ /* Make disabled dropdowns more visually obvious */
918
+ .gradio-dropdown[data-testid="dropdown"]:disabled,
919
+ .gradio-dropdown[data-testid="dropdown"][aria-disabled="true"] {
920
+ opacity: 0.6 !important;
921
+ background-color: #f5f5f5 !important;
922
+ cursor: not-allowed !important;
923
+ }
924
+
925
+ /* Make disabled buttons more obvious too */
926
+ button:disabled {
927
+ opacity: 0.5 !important;
928
+ background-color: #e0e0e0 !important;
929
+ cursor: not-allowed !important;
930
+ }
931
  """
932
 
933
  if __name__ == "__main__":
934
+ # Start price preloading in background before launching the app
935
+ start_price_preloading()
936
+
937
  demo.launch(
938
  server_name="0.0.0.0",
939
  server_port=7860,