kfoughali commited on
Commit
56bd642
Β·
verified Β·
1 Parent(s): 9196642

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +701 -0
app.py CHANGED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ Research-grade KV cache compression benchmark application.
4
+ RocketKV-enhanced SPG with 450x compression capability.
5
+ FIXED: CUDA assert errors, safer default parameters, GPT-2 sequence limits.
6
+ """
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+ from datetime import datetime
14
+ import json
15
+ import pandas as pd
16
+ import tempfile
17
+ import os
18
+ import logging
19
+ from typing import Dict, List, Any, Tuple
20
+
21
+ from config import (
22
+ CompressionConfig, CompressionType, EnhancedSPGConfig,
23
+ ProvingConfig, ResearchConstants, SUPPORTED_MODELS, BENCHMARK_CONFIGS
24
+ )
25
+ from benchmark import (
26
+ run_research_benchmark, export_proof_bundle, verify_proof_bundle,
27
+ BenchmarkMetrics
28
+ )
29
+ from compression import detect_model_layers
30
+
31
+ # Configure logging
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # Set style for plots
36
+ plt.style.use('seaborn-v0_8-darkgrid')
37
+ sns.set_palette("husl")
38
+
39
+ # Global state for results
40
+ current_results = {}
41
+
42
+
43
+ def run_benchmark(model_key, compression_type, benchmark_type, dataset_subset,
44
+ eval_samples, n_seeds, seq_length, generation_length,
45
+ base_decay_rate, sink_tokens, recent_window,
46
+ enable_adaptive, target_perplexity_delta,
47
+ enable_progressive, progressive_quality_threshold,
48
+ initial_compression_ratio, max_compression_ratio,
49
+ sequence_compression_ratio, head_compression_ratio,
50
+ head_retention_mode, magnitude_threshold_mode,
51
+ min_tokens_for_stability, recent_boost_factor,
52
+ fail_on_cpu):
53
+ """Run comprehensive benchmark with all compression methods."""
54
+
55
+ # Enable synchronous CUDA for debugging
56
+ if torch.cuda.is_available():
57
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
58
+
59
+ # Validate sequence length for GPT-2
60
+ if model_key == "gpt2" and seq_length > 1024:
61
+ logger.warning(f"Reducing sequence length from {seq_length} to 1024 for GPT-2")
62
+ seq_length = 1024
63
+
64
+ try:
65
+ # Create base configuration
66
+ base_config = CompressionConfig(
67
+ model_key=model_key,
68
+ compression_type=CompressionType[compression_type.upper()],
69
+ benchmark_type=benchmark_type,
70
+ benchmark_subset=dataset_subset if benchmark_type == "longbench" else None,
71
+ eval_samples=int(eval_samples),
72
+ n_seeds=int(n_seeds),
73
+ prefill_length=int(seq_length),
74
+ generation_length=int(generation_length),
75
+ fail_on_cpu_fallback=fail_on_cpu
76
+ )
77
+
78
+ # Configure Enhanced SPG with safer parameters
79
+ base_config.enhanced_spg_config = EnhancedSPGConfig(
80
+ base_decay_rate=float(base_decay_rate),
81
+ sink_tokens=int(sink_tokens),
82
+ recent_window=int(recent_window),
83
+ enable_adaptive=enable_adaptive,
84
+ target_perplexity_delta=float(target_perplexity_delta),
85
+ enable_progressive=enable_progressive,
86
+ quality_threshold=float(progressive_quality_threshold),
87
+ initial_compression_ratio=float(initial_compression_ratio),
88
+ max_compression_ratio=float(max_compression_ratio),
89
+ target_compression_ratio=float(max_compression_ratio),
90
+ sequence_compression_ratio=float(sequence_compression_ratio),
91
+ head_compression_ratio=float(head_compression_ratio),
92
+ head_retention_mode=head_retention_mode,
93
+ magnitude_threshold_mode=magnitude_threshold_mode,
94
+ min_tokens_for_stability=int(min_tokens_for_stability),
95
+ recent_boost_factor=float(recent_boost_factor),
96
+ enable_two_stage=True,
97
+ use_hybrid_sparse_attention=True,
98
+ use_snapkv_plus_plus=True,
99
+ stage1_compression_ratio=20.0, # Safer default
100
+ stage2_compression_ratio=20.0 # For 400x total
101
+ )
102
+
103
+ # Store results
104
+ results = {}
105
+ model_name = base_config.model_name
106
+
107
+ # Run benchmark for selected compression type
108
+ logger.info(f"Running {compression_type} benchmark...")
109
+ metrics, summary, records, fingerprints = run_research_benchmark(
110
+ model_name, base_config
111
+ )
112
+
113
+ results[compression_type] = {
114
+ 'metrics': metrics,
115
+ 'summary': summary,
116
+ 'records': records
117
+ }
118
+
119
+ # Also run NONE compression for baseline comparison
120
+ if compression_type != "none":
121
+ logger.info("Running baseline (no compression) benchmark...")
122
+ baseline_config = CompressionConfig(
123
+ model_key=model_key,
124
+ compression_type=CompressionType.NONE,
125
+ benchmark_type=benchmark_type,
126
+ benchmark_subset=dataset_subset if benchmark_type == "longbench" else None,
127
+ eval_samples=int(eval_samples),
128
+ n_seeds=int(n_seeds),
129
+ prefill_length=int(seq_length),
130
+ generation_length=int(generation_length),
131
+ fail_on_cpu_fallback=fail_on_cpu
132
+ )
133
+
134
+ try:
135
+ baseline_metrics, baseline_summary, baseline_records, _ = run_research_benchmark(
136
+ model_name, baseline_config
137
+ )
138
+
139
+ results['none'] = {
140
+ 'metrics': baseline_metrics,
141
+ 'summary': baseline_summary,
142
+ 'records': baseline_records
143
+ }
144
+ except Exception as e:
145
+ logger.error(f"Baseline benchmark failed: {e}")
146
+ # Continue without baseline
147
+
148
+ # Store globally for export
149
+ global current_results
150
+ current_results = results
151
+
152
+ # Create visualizations
153
+ plots = create_visualizations(results, benchmark_type)
154
+
155
+ # Create summary text
156
+ summary_text = create_summary_text(results, benchmark_type)
157
+
158
+ # Export proof bundle
159
+ with tempfile.TemporaryDirectory() as tmpdir:
160
+ bundle_path = export_proof_bundle(
161
+ tmpdir, base_config, metrics, summary, records, fingerprints
162
+ )
163
+
164
+ # Verify the bundle
165
+ verification = verify_proof_bundle(
166
+ tmpdir, base_config, base_config.proving
167
+ )
168
+
169
+ verification_text = f"Proof verification: {'PASSED βœ“' if verification['ok'] else 'FAILED βœ—'}"
170
+ if not verification['ok']:
171
+ verification_text += f"\nFailures: {verification['failures']}"
172
+
173
+ return plots, summary_text, verification_text
174
+
175
+ except Exception as e:
176
+ logger.error(f"Benchmark failed: {e}", exc_info=True)
177
+ return [], f"Error: {str(e)}", "Verification failed due to error"
178
+
179
+
180
+ def create_visualizations(results: Dict, benchmark_type: str) -> List:
181
+ """Create comprehensive visualizations from benchmark results."""
182
+ plots = []
183
+
184
+ # 1. Compression Ratio Comparison
185
+ fig, ax = plt.subplots(figsize=(10, 6))
186
+ methods = []
187
+ ratios = []
188
+ errors = []
189
+
190
+ for method, data in results.items():
191
+ if 'metrics' in data and hasattr(data['metrics'], 'compression_ratio_mean'):
192
+ methods.append(method.upper())
193
+ ratios.append(data['metrics'].compression_ratio_mean)
194
+ errors.append(data['metrics'].compression_ratio_std)
195
+
196
+ if methods:
197
+ bars = ax.bar(methods, ratios, yerr=errors, capsize=5)
198
+ ax.set_ylabel('Compression Ratio')
199
+ ax.set_title('KV Cache Compression Ratios')
200
+ ax.grid(True, alpha=0.3)
201
+
202
+ # Add value labels on bars
203
+ for bar, ratio in zip(bars, ratios):
204
+ height = bar.get_height()
205
+ ax.text(bar.get_x() + bar.get_width()/2., height,
206
+ f'{ratio:.1f}x', ha='center', va='bottom')
207
+
208
+ plt.tight_layout()
209
+ plots.append(fig)
210
+
211
+ # 2. Memory Usage Comparison
212
+ fig, ax = plt.subplots(figsize=(10, 6))
213
+ memories = []
214
+ memory_errors = []
215
+
216
+ for method, data in results.items():
217
+ if 'metrics' in data and hasattr(data['metrics'], 'kv_cache_memory_mb'):
218
+ memories.append(data['metrics'].kv_cache_memory_mb)
219
+ memory_errors.append(0) # No std for memory in current implementation
220
+
221
+ if methods and memories:
222
+ bars = ax.bar(methods, memories, yerr=memory_errors, capsize=5, color='coral')
223
+ ax.set_ylabel('Memory Usage (MB)')
224
+ ax.set_title('KV Cache Memory Footprint')
225
+ ax.grid(True, alpha=0.3)
226
+
227
+ for bar, mem in zip(bars, memories):
228
+ height = bar.get_height()
229
+ ax.text(bar.get_x() + bar.get_width()/2., height,
230
+ f'{mem:.1f}', ha='center', va='bottom')
231
+
232
+ plt.tight_layout()
233
+ plots.append(fig)
234
+
235
+ # 3. Benchmark-specific metrics
236
+ if benchmark_type == "wikitext":
237
+ # Perplexity comparison
238
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
239
+
240
+ # Prefill perplexity
241
+ prefill_ppls = []
242
+ prefill_errors = []
243
+ gen_ppls = []
244
+ gen_errors = []
245
+
246
+ for method, data in results.items():
247
+ if 'metrics' in data:
248
+ metrics = data['metrics']
249
+ if hasattr(metrics, 'prefill_perplexity_mean'):
250
+ prefill_ppls.append(metrics.prefill_perplexity_mean)
251
+ prefill_errors.append(metrics.prefill_perplexity_std)
252
+ if hasattr(metrics, 'generation_perplexity_mean'):
253
+ gen_ppls.append(metrics.generation_perplexity_mean)
254
+ gen_errors.append(metrics.generation_perplexity_std)
255
+
256
+ if prefill_ppls:
257
+ ax1.bar(methods[:len(prefill_ppls)], prefill_ppls, yerr=prefill_errors, capsize=5, color='skyblue')
258
+ ax1.set_ylabel('Perplexity')
259
+ ax1.set_title('Prefill Perplexity')
260
+ ax1.grid(True, alpha=0.3)
261
+
262
+ if gen_ppls:
263
+ ax2.bar(methods[:len(gen_ppls)], gen_ppls, yerr=gen_errors, capsize=5, color='lightgreen')
264
+ ax2.set_ylabel('Perplexity')
265
+ ax2.set_title('Generation Perplexity')
266
+ ax2.grid(True, alpha=0.3)
267
+
268
+ plt.suptitle('Quality Metrics: Perplexity Comparison')
269
+ plt.tight_layout()
270
+ plots.append(fig)
271
+
272
+ elif benchmark_type in ["niah", "ruler", "scbench"]:
273
+ # Accuracy metrics
274
+ fig, ax = plt.subplots(figsize=(10, 6))
275
+ accuracies = []
276
+
277
+ for method, data in results.items():
278
+ if 'summary' in data:
279
+ if benchmark_type == "niah" and 'niah_accuracy' in data['summary']:
280
+ accuracies.append(data['summary']['niah_accuracy'])
281
+ elif benchmark_type == "ruler" and 'ruler_exact_match' in data['summary']:
282
+ accuracies.append(data['summary']['ruler_exact_match'])
283
+ elif benchmark_type == "scbench" and 'scbench_accuracy' in data['summary']:
284
+ accuracies.append(data['summary']['scbench_accuracy'])
285
+
286
+ if accuracies:
287
+ bars = ax.bar(methods[:len(accuracies)], accuracies, color='gold')
288
+ ax.set_ylabel('Accuracy')
289
+ ax.set_ylim(0, 1.1)
290
+ ax.set_title(f'{benchmark_type.upper()} Accuracy')
291
+ ax.grid(True, alpha=0.3)
292
+
293
+ for bar, acc in zip(bars, accuracies):
294
+ height = bar.get_height()
295
+ ax.text(bar.get_x() + bar.get_width()/2., height,
296
+ f'{acc:.2%}', ha='center', va='bottom')
297
+
298
+ plt.tight_layout()
299
+ plots.append(fig)
300
+
301
+ # 4. Speed comparison
302
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
303
+
304
+ prefill_times = []
305
+ decode_times = []
306
+
307
+ for method, data in results.items():
308
+ if 'metrics' in data:
309
+ metrics = data['metrics']
310
+ if hasattr(metrics, 'prefill_time_mean'):
311
+ prefill_times.append(metrics.prefill_time_mean * 1000) # Convert to ms
312
+ if hasattr(metrics, 'decode_time_per_token_mean_ms'):
313
+ decode_times.append(metrics.decode_time_per_token_mean_ms)
314
+
315
+ if prefill_times:
316
+ ax1.bar(methods[:len(prefill_times)], prefill_times, color='purple', alpha=0.7)
317
+ ax1.set_ylabel('Time (ms)')
318
+ ax1.set_title('Prefill Time')
319
+ ax1.grid(True, alpha=0.3)
320
+
321
+ if decode_times:
322
+ ax2.bar(methods[:len(decode_times)], decode_times, color='orange', alpha=0.7)
323
+ ax2.set_ylabel('Time per Token (ms)')
324
+ ax2.set_title('Decode Time')
325
+ ax2.grid(True, alpha=0.3)
326
+
327
+ plt.suptitle('Performance Metrics: Speed Comparison')
328
+ plt.tight_layout()
329
+ plots.append(fig)
330
+
331
+ return plots
332
+
333
+
334
+ def create_summary_text(results: Dict, benchmark_type: str) -> str:
335
+ """Create detailed summary text from results."""
336
+ summary_lines = []
337
+ summary_lines.append("=" * 60)
338
+ summary_lines.append("BENCHMARK RESULTS SUMMARY")
339
+ summary_lines.append("=" * 60)
340
+ summary_lines.append(f"Benchmark Type: {benchmark_type.upper()}")
341
+ summary_lines.append(f"Timestamp: {datetime.now().isoformat()}")
342
+ summary_lines.append("")
343
+
344
+ for method, data in results.items():
345
+ if 'summary' not in data:
346
+ continue
347
+
348
+ summary = data['summary']
349
+ metrics = data['metrics'] if 'metrics' in data else None
350
+
351
+ summary_lines.append(f"Method: {method.upper()}")
352
+ summary_lines.append("-" * 40)
353
+
354
+ # Compression metrics
355
+ if 'compression_ratio' in summary:
356
+ summary_lines.append(f"Compression Ratio: {summary['compression_ratio']:.1f}x")
357
+ if 'kv_cache_memory_mb' in summary:
358
+ summary_lines.append(f"KV Cache Memory: {summary['kv_cache_memory_mb']:.2f} MB")
359
+
360
+ # Quality metrics
361
+ if benchmark_type == "wikitext":
362
+ if 'prefill_perplexity' in summary:
363
+ summary_lines.append(f"Prefill Perplexity: {summary['prefill_perplexity']:.2f}")
364
+ if 'generation_perplexity' in summary:
365
+ summary_lines.append(f"Generation Perplexity: {summary['generation_perplexity']:.2f}")
366
+ elif benchmark_type == "niah" and 'niah_accuracy' in summary:
367
+ summary_lines.append(f"NIAH Accuracy: {summary['niah_accuracy']:.2%}")
368
+ elif benchmark_type == "ruler" and 'ruler_exact_match' in summary:
369
+ summary_lines.append(f"RULER Exact Match: {summary['ruler_exact_match']:.2%}")
370
+ elif benchmark_type == "scbench" and 'scbench_accuracy' in summary:
371
+ summary_lines.append(f"SCBench Accuracy: {summary['scbench_accuracy']:.2%}")
372
+ elif benchmark_type == "longbench" and 'longbench_accuracy' in summary:
373
+ summary_lines.append(f"LongBench Accuracy: {summary['longbench_accuracy']:.2%}")
374
+
375
+ # Performance metrics
376
+ if 'prefill_time_ms' in summary:
377
+ summary_lines.append(f"Prefill Time: {summary['prefill_time_ms']:.2f} ms")
378
+ if 'decode_time_ms' in summary:
379
+ summary_lines.append(f"Decode Time per Token: {summary['decode_time_ms']:.2f} ms")
380
+ if 'throughput_tokens_sec' in summary:
381
+ summary_lines.append(f"Throughput: {summary['throughput_tokens_sec']:.1f} tokens/sec")
382
+ if 'end_to_end_throughput' in summary:
383
+ summary_lines.append(f"End-to-End Throughput: {summary['end_to_end_throughput']:.1f} tokens/sec")
384
+ if 'peak_memory_mb' in summary:
385
+ summary_lines.append(f"Peak Memory: {summary['peak_memory_mb']:.2f} MB")
386
+
387
+ summary_lines.append("")
388
+
389
+ # Add statistical comparison if baseline is available
390
+ if 'none' in results and len(results) > 1:
391
+ summary_lines.append("COMPARISON WITH BASELINE")
392
+ summary_lines.append("-" * 40)
393
+
394
+ baseline_summary = results['none']['summary']
395
+
396
+ for method, data in results.items():
397
+ if method == 'none' or 'summary' not in data:
398
+ continue
399
+
400
+ summary = data['summary']
401
+
402
+ # Calculate improvements
403
+ if 'compression_ratio' in summary:
404
+ summary_lines.append(f"{method.upper()} vs Baseline:")
405
+ summary_lines.append(f" Compression: {summary['compression_ratio']:.1f}x")
406
+
407
+ if 'kv_cache_memory_mb' in summary and 'kv_cache_memory_mb' in baseline_summary:
408
+ baseline_mem = baseline_summary['kv_cache_memory_mb']
409
+ method_mem = summary['kv_cache_memory_mb']
410
+ if baseline_mem > 0:
411
+ reduction = (1 - method_mem / baseline_mem) * 100
412
+ summary_lines.append(f" Memory Reduction: {reduction:.1f}%")
413
+
414
+ # Quality degradation for WikiText
415
+ if benchmark_type == "wikitext":
416
+ if 'generation_perplexity' in summary and 'generation_perplexity' in baseline_summary:
417
+ baseline_ppl = baseline_summary['generation_perplexity']
418
+ method_ppl = summary['generation_perplexity']
419
+ if baseline_ppl > 0:
420
+ degradation = ((method_ppl - baseline_ppl) / baseline_ppl) * 100
421
+ summary_lines.append(f" Perplexity Change: {degradation:+.1f}%")
422
+
423
+ # Accuracy comparison for other benchmarks
424
+ elif benchmark_type == "niah":
425
+ if 'niah_accuracy' in summary and 'niah_accuracy' in baseline_summary:
426
+ acc_diff = summary['niah_accuracy'] - baseline_summary['niah_accuracy']
427
+ summary_lines.append(f" Accuracy Difference: {acc_diff:+.2%}")
428
+
429
+ summary_lines.append("")
430
+
431
+ return "\n".join(summary_lines)
432
+
433
+
434
+ def export_results(format_type):
435
+ """Export current results in specified format."""
436
+ if not current_results:
437
+ return "No results to export. Please run a benchmark first."
438
+
439
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
440
+
441
+ if format_type == "JSON":
442
+ filename = f"results_{timestamp}.json"
443
+
444
+ # Convert numpy types to Python types for JSON serialization
445
+ def convert_numpy(obj):
446
+ if isinstance(obj, np.ndarray):
447
+ return obj.tolist()
448
+ elif isinstance(obj, (np.integer, np.int64, np.int32)):
449
+ return int(obj)
450
+ elif isinstance(obj, (np.floating, np.float64, np.float32)):
451
+ return float(obj)
452
+ elif isinstance(obj, BenchmarkMetrics):
453
+ return obj.__dict__
454
+ return obj
455
+
456
+ serializable_results = json.loads(
457
+ json.dumps(current_results, default=convert_numpy)
458
+ )
459
+
460
+ with open(filename, 'w') as f:
461
+ json.dump(serializable_results, f, indent=2)
462
+
463
+ return f"Results exported to {filename}"
464
+
465
+ elif format_type == "CSV":
466
+ filename = f"results_{timestamp}.csv"
467
+
468
+ # Flatten results for CSV
469
+ rows = []
470
+ for method, data in current_results.items():
471
+ if 'summary' in data:
472
+ row = {'method': method}
473
+ row.update(data['summary'])
474
+ rows.append(row)
475
+
476
+ if rows:
477
+ df = pd.DataFrame(rows)
478
+ df.to_csv(filename, index=False)
479
+ return f"Results exported to {filename}"
480
+ else:
481
+ return "No summary data to export"
482
+
483
+ elif format_type == "LaTeX":
484
+ filename = f"results_{timestamp}.tex"
485
+
486
+ # Create LaTeX table
487
+ latex_lines = [
488
+ "\\begin{table}[h]",
489
+ "\\centering",
490
+ "\\caption{KV Cache Compression Results}",
491
+ "\\begin{tabular}{lccc}",
492
+ "\\hline",
493
+ "Method & Compression & Memory (MB) & Throughput (tok/s) \\\\",
494
+ "\\hline"
495
+ ]
496
+
497
+ for method, data in current_results.items():
498
+ if 'summary' in data:
499
+ s = data['summary']
500
+ comp = f"{s.get('compression_ratio', 1.0):.1f}x"
501
+ mem = f"{s.get('kv_cache_memory_mb', 0):.1f}"
502
+ thr = f"{s.get('throughput_tokens_sec', 0):.1f}"
503
+ latex_lines.append(f"{method.upper()} & {comp} & {mem} & {thr} \\\\")
504
+
505
+ latex_lines.extend([
506
+ "\\hline",
507
+ "\\end{tabular}",
508
+ "\\end{table}"
509
+ ])
510
+
511
+ with open(filename, 'w') as f:
512
+ f.write('\n'.join(latex_lines))
513
+
514
+ return f"LaTeX table exported to {filename}"
515
+
516
+ return "Invalid export format"
517
+
518
+
519
+ # Create Gradio interface
520
+ def create_interface():
521
+ with gr.Blocks(title="RocketKV-Enhanced SPG Benchmark") as demo:
522
+ gr.Markdown("""
523
+ # πŸš€ RocketKV-Enhanced SPG Compression Benchmark
524
+
525
+ Research-grade KV cache compression with **450x compression capability**.
526
+ Implements Enhanced Sliding Precision Gradient with RocketKV-style optimizations.
527
+
528
+ **Features:**
529
+ - Multiple compression methods (SPG, Adaptive, Enhanced, Progressive)
530
+ - Comprehensive benchmarks (WikiText, NIAH, RULER, SCBench, LongBench)
531
+ - Attestable proof generation and verification
532
+ - Real-time visualization and analysis
533
+ """)
534
+
535
+ with gr.Tab("Configuration"):
536
+ with gr.Row():
537
+ with gr.Column():
538
+ gr.Markdown("### Model & Benchmark Settings")
539
+ model_dropdown = gr.Dropdown(
540
+ choices=list(SUPPORTED_MODELS.keys()),
541
+ value="gpt2",
542
+ label="Model"
543
+ )
544
+
545
+ compression_dropdown = gr.Dropdown(
546
+ choices=["none", "spg", "adaptive_spg", "enhanced_spg", "progressive_spg"],
547
+ value="enhanced_spg",
548
+ label="Compression Method"
549
+ )
550
+
551
+ benchmark_dropdown = gr.Dropdown(
552
+ choices=["wikitext", "niah", "ruler", "scbench", "longbench"],
553
+ value="wikitext",
554
+ label="Benchmark Type"
555
+ )
556
+
557
+ dataset_subset = gr.Dropdown(
558
+ choices=BENCHMARK_CONFIGS["longbench"]["subsets"],
559
+ value="narrativeqa",
560
+ label="LongBench Subset (if applicable)",
561
+ visible=False
562
+ )
563
+
564
+ # Show/hide subset based on benchmark type
565
+ def update_subset_visibility(benchmark_type):
566
+ return gr.update(visible=(benchmark_type == "longbench"))
567
+
568
+ benchmark_dropdown.change(
569
+ update_subset_visibility,
570
+ inputs=[benchmark_dropdown],
571
+ outputs=[dataset_subset]
572
+ )
573
+
574
+ with gr.Column():
575
+ gr.Markdown("### Evaluation Parameters")
576
+ eval_samples = gr.Slider(1, 100, value=20, step=1, label="Evaluation Samples")
577
+ n_seeds = gr.Slider(1, 5, value=3, step=1, label="Random Seeds")
578
+ seq_length = gr.Slider(128, 1024, value=512, step=128,
579
+ label="Sequence Length (max 1024 for GPT-2)")
580
+ generation_length = gr.Slider(16, 128, value=64, step=16, label="Generation Length")
581
+
582
+ with gr.Row():
583
+ with gr.Column():
584
+ gr.Markdown("### SPG Core Parameters")
585
+ base_decay = gr.Slider(0.8, 0.99, value=0.95, step=0.01, label="Base Decay Rate")
586
+ sink_tokens = gr.Slider(0, 8, value=2, step=1, label="Sink Tokens")
587
+ recent_window = gr.Slider(8, 64, value=32, step=8, label="Recent Window")
588
+
589
+ with gr.Column():
590
+ gr.Markdown("### Adaptive SPG")
591
+ enable_adaptive = gr.Checkbox(value=False, label="Enable Adaptive")
592
+ target_ppl_delta = gr.Slider(0.5, 5.0, value=1.8, step=0.1,
593
+ label="Target Perplexity Delta")
594
+
595
+ with gr.Row():
596
+ with gr.Column():
597
+ gr.Markdown("### Progressive Compression")
598
+ enable_progressive = gr.Checkbox(value=False, label="Enable Progressive")
599
+ quality_threshold = gr.Slider(0.005, 0.05, value=0.01, step=0.005,
600
+ label="Quality Threshold")
601
+ initial_compression = gr.Slider(10.0, 200.0, value=50.0, step=5.0,
602
+ label="Initial Compression Ratio")
603
+ max_compression = gr.Slider(100.0, 500.0, value=400.0, step=25.0,
604
+ label="Max Compression Ratio")
605
+
606
+ with gr.Column():
607
+ gr.Markdown("### Enhanced SPG (RocketKV-style)")
608
+ sequence_comp_ratio = gr.Slider(0.0001, 0.001, value=0.0001, step=0.00005,
609
+ label="Sequence Compression Ratio")
610
+ head_comp_ratio = gr.Slider(0.0001, 0.001, value=0.0001, step=0.00005,
611
+ label="Head Compression Ratio")
612
+ head_retention = gr.Dropdown(
613
+ choices=["conservative", "aggressive"],
614
+ value="aggressive",
615
+ label="Head Retention Mode"
616
+ )
617
+ magnitude_mode = gr.Dropdown(
618
+ choices=["conservative", "aggressive", "extreme"],
619
+ value="aggressive", # Changed from "extreme" for stability
620
+ label="Magnitude Threshold Mode"
621
+ )
622
+
623
+ with gr.Row():
624
+ with gr.Column():
625
+ gr.Markdown("### Stability Parameters")
626
+ min_tokens_stability = gr.Slider(4, 16, value=8, step=1,
627
+ label="Min Tokens for Stability")
628
+ recent_boost = gr.Slider(0.0, 0.5, value=0.1, step=0.05,
629
+ label="Recent Boost Factor")
630
+
631
+ with gr.Column():
632
+ gr.Markdown("### System Settings")
633
+ fail_on_cpu = gr.Checkbox(value=False, label="Fail on CPU Fallback")
634
+
635
+ with gr.Tab("Run Benchmark"):
636
+ run_button = gr.Button("πŸš€ Run Benchmark", variant="primary")
637
+
638
+ with gr.Row():
639
+ progress_text = gr.Textbox(label="Progress", lines=10)
640
+
641
+ with gr.Row():
642
+ plot_gallery = gr.Gallery(label="Results Visualization", columns=2, height="auto")
643
+
644
+ with gr.Row():
645
+ summary_output = gr.Textbox(label="Summary", lines=20)
646
+ verification_output = gr.Textbox(label="Proof Verification", lines=5)
647
+
648
+ with gr.Tab("Export Results"):
649
+ gr.Markdown("### Export Options")
650
+
651
+ export_format = gr.Radio(
652
+ choices=["JSON", "CSV", "LaTeX"],
653
+ value="JSON",
654
+ label="Export Format"
655
+ )
656
+
657
+ export_button = gr.Button("πŸ“₯ Export Results")
658
+ export_status = gr.Textbox(label="Export Status")
659
+
660
+ export_button.click(
661
+ export_results,
662
+ inputs=[export_format],
663
+ outputs=[export_status]
664
+ )
665
+
666
+ # Connect the run button
667
+ run_button.click(
668
+ run_benchmark,
669
+ inputs=[
670
+ model_dropdown, compression_dropdown, benchmark_dropdown, dataset_subset,
671
+ eval_samples, n_seeds, seq_length, generation_length,
672
+ base_decay, sink_tokens, recent_window,
673
+ enable_adaptive, target_ppl_delta,
674
+ enable_progressive, quality_threshold,
675
+ initial_compression, max_compression,
676
+ sequence_comp_ratio, head_comp_ratio,
677
+ head_retention, magnitude_mode,
678
+ min_tokens_stability, recent_boost,
679
+ fail_on_cpu
680
+ ],
681
+ outputs=[plot_gallery, summary_output, verification_output]
682
+ )
683
+
684
+ return demo
685
+
686
+
687
+ if __name__ == "__main__":
688
+ # Set up logging
689
+ logging.basicConfig(
690
+ level=logging.INFO,
691
+ format='%(asctime)s - %(levelname)s - %(message)s'
692
+ )
693
+
694
+ # Create and launch the interface
695
+ demo = create_interface()
696
+ demo.launch(
697
+ server_name="0.0.0.0",
698
+ server_port=7860,
699
+ share=False,
700
+ show_error=True
701
+ )