Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
"""
|
3 |
+
Research-grade KV cache compression benchmark application.
|
4 |
+
RocketKV-enhanced SPG with 450x compression capability.
|
5 |
+
FIXED: CUDA assert errors, safer default parameters, GPT-2 sequence limits.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import seaborn as sns
|
13 |
+
from datetime import datetime
|
14 |
+
import json
|
15 |
+
import pandas as pd
|
16 |
+
import tempfile
|
17 |
+
import os
|
18 |
+
import logging
|
19 |
+
from typing import Dict, List, Any, Tuple
|
20 |
+
|
21 |
+
from config import (
|
22 |
+
CompressionConfig, CompressionType, EnhancedSPGConfig,
|
23 |
+
ProvingConfig, ResearchConstants, SUPPORTED_MODELS, BENCHMARK_CONFIGS
|
24 |
+
)
|
25 |
+
from benchmark import (
|
26 |
+
run_research_benchmark, export_proof_bundle, verify_proof_bundle,
|
27 |
+
BenchmarkMetrics
|
28 |
+
)
|
29 |
+
from compression import detect_model_layers
|
30 |
+
|
31 |
+
# Configure logging
|
32 |
+
logging.basicConfig(level=logging.INFO)
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
# Set style for plots
|
36 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
37 |
+
sns.set_palette("husl")
|
38 |
+
|
39 |
+
# Global state for results
|
40 |
+
current_results = {}
|
41 |
+
|
42 |
+
|
43 |
+
def run_benchmark(model_key, compression_type, benchmark_type, dataset_subset,
|
44 |
+
eval_samples, n_seeds, seq_length, generation_length,
|
45 |
+
base_decay_rate, sink_tokens, recent_window,
|
46 |
+
enable_adaptive, target_perplexity_delta,
|
47 |
+
enable_progressive, progressive_quality_threshold,
|
48 |
+
initial_compression_ratio, max_compression_ratio,
|
49 |
+
sequence_compression_ratio, head_compression_ratio,
|
50 |
+
head_retention_mode, magnitude_threshold_mode,
|
51 |
+
min_tokens_for_stability, recent_boost_factor,
|
52 |
+
fail_on_cpu):
|
53 |
+
"""Run comprehensive benchmark with all compression methods."""
|
54 |
+
|
55 |
+
# Enable synchronous CUDA for debugging
|
56 |
+
if torch.cuda.is_available():
|
57 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
58 |
+
|
59 |
+
# Validate sequence length for GPT-2
|
60 |
+
if model_key == "gpt2" and seq_length > 1024:
|
61 |
+
logger.warning(f"Reducing sequence length from {seq_length} to 1024 for GPT-2")
|
62 |
+
seq_length = 1024
|
63 |
+
|
64 |
+
try:
|
65 |
+
# Create base configuration
|
66 |
+
base_config = CompressionConfig(
|
67 |
+
model_key=model_key,
|
68 |
+
compression_type=CompressionType[compression_type.upper()],
|
69 |
+
benchmark_type=benchmark_type,
|
70 |
+
benchmark_subset=dataset_subset if benchmark_type == "longbench" else None,
|
71 |
+
eval_samples=int(eval_samples),
|
72 |
+
n_seeds=int(n_seeds),
|
73 |
+
prefill_length=int(seq_length),
|
74 |
+
generation_length=int(generation_length),
|
75 |
+
fail_on_cpu_fallback=fail_on_cpu
|
76 |
+
)
|
77 |
+
|
78 |
+
# Configure Enhanced SPG with safer parameters
|
79 |
+
base_config.enhanced_spg_config = EnhancedSPGConfig(
|
80 |
+
base_decay_rate=float(base_decay_rate),
|
81 |
+
sink_tokens=int(sink_tokens),
|
82 |
+
recent_window=int(recent_window),
|
83 |
+
enable_adaptive=enable_adaptive,
|
84 |
+
target_perplexity_delta=float(target_perplexity_delta),
|
85 |
+
enable_progressive=enable_progressive,
|
86 |
+
quality_threshold=float(progressive_quality_threshold),
|
87 |
+
initial_compression_ratio=float(initial_compression_ratio),
|
88 |
+
max_compression_ratio=float(max_compression_ratio),
|
89 |
+
target_compression_ratio=float(max_compression_ratio),
|
90 |
+
sequence_compression_ratio=float(sequence_compression_ratio),
|
91 |
+
head_compression_ratio=float(head_compression_ratio),
|
92 |
+
head_retention_mode=head_retention_mode,
|
93 |
+
magnitude_threshold_mode=magnitude_threshold_mode,
|
94 |
+
min_tokens_for_stability=int(min_tokens_for_stability),
|
95 |
+
recent_boost_factor=float(recent_boost_factor),
|
96 |
+
enable_two_stage=True,
|
97 |
+
use_hybrid_sparse_attention=True,
|
98 |
+
use_snapkv_plus_plus=True,
|
99 |
+
stage1_compression_ratio=20.0, # Safer default
|
100 |
+
stage2_compression_ratio=20.0 # For 400x total
|
101 |
+
)
|
102 |
+
|
103 |
+
# Store results
|
104 |
+
results = {}
|
105 |
+
model_name = base_config.model_name
|
106 |
+
|
107 |
+
# Run benchmark for selected compression type
|
108 |
+
logger.info(f"Running {compression_type} benchmark...")
|
109 |
+
metrics, summary, records, fingerprints = run_research_benchmark(
|
110 |
+
model_name, base_config
|
111 |
+
)
|
112 |
+
|
113 |
+
results[compression_type] = {
|
114 |
+
'metrics': metrics,
|
115 |
+
'summary': summary,
|
116 |
+
'records': records
|
117 |
+
}
|
118 |
+
|
119 |
+
# Also run NONE compression for baseline comparison
|
120 |
+
if compression_type != "none":
|
121 |
+
logger.info("Running baseline (no compression) benchmark...")
|
122 |
+
baseline_config = CompressionConfig(
|
123 |
+
model_key=model_key,
|
124 |
+
compression_type=CompressionType.NONE,
|
125 |
+
benchmark_type=benchmark_type,
|
126 |
+
benchmark_subset=dataset_subset if benchmark_type == "longbench" else None,
|
127 |
+
eval_samples=int(eval_samples),
|
128 |
+
n_seeds=int(n_seeds),
|
129 |
+
prefill_length=int(seq_length),
|
130 |
+
generation_length=int(generation_length),
|
131 |
+
fail_on_cpu_fallback=fail_on_cpu
|
132 |
+
)
|
133 |
+
|
134 |
+
try:
|
135 |
+
baseline_metrics, baseline_summary, baseline_records, _ = run_research_benchmark(
|
136 |
+
model_name, baseline_config
|
137 |
+
)
|
138 |
+
|
139 |
+
results['none'] = {
|
140 |
+
'metrics': baseline_metrics,
|
141 |
+
'summary': baseline_summary,
|
142 |
+
'records': baseline_records
|
143 |
+
}
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Baseline benchmark failed: {e}")
|
146 |
+
# Continue without baseline
|
147 |
+
|
148 |
+
# Store globally for export
|
149 |
+
global current_results
|
150 |
+
current_results = results
|
151 |
+
|
152 |
+
# Create visualizations
|
153 |
+
plots = create_visualizations(results, benchmark_type)
|
154 |
+
|
155 |
+
# Create summary text
|
156 |
+
summary_text = create_summary_text(results, benchmark_type)
|
157 |
+
|
158 |
+
# Export proof bundle
|
159 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
160 |
+
bundle_path = export_proof_bundle(
|
161 |
+
tmpdir, base_config, metrics, summary, records, fingerprints
|
162 |
+
)
|
163 |
+
|
164 |
+
# Verify the bundle
|
165 |
+
verification = verify_proof_bundle(
|
166 |
+
tmpdir, base_config, base_config.proving
|
167 |
+
)
|
168 |
+
|
169 |
+
verification_text = f"Proof verification: {'PASSED β' if verification['ok'] else 'FAILED β'}"
|
170 |
+
if not verification['ok']:
|
171 |
+
verification_text += f"\nFailures: {verification['failures']}"
|
172 |
+
|
173 |
+
return plots, summary_text, verification_text
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
logger.error(f"Benchmark failed: {e}", exc_info=True)
|
177 |
+
return [], f"Error: {str(e)}", "Verification failed due to error"
|
178 |
+
|
179 |
+
|
180 |
+
def create_visualizations(results: Dict, benchmark_type: str) -> List:
|
181 |
+
"""Create comprehensive visualizations from benchmark results."""
|
182 |
+
plots = []
|
183 |
+
|
184 |
+
# 1. Compression Ratio Comparison
|
185 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
186 |
+
methods = []
|
187 |
+
ratios = []
|
188 |
+
errors = []
|
189 |
+
|
190 |
+
for method, data in results.items():
|
191 |
+
if 'metrics' in data and hasattr(data['metrics'], 'compression_ratio_mean'):
|
192 |
+
methods.append(method.upper())
|
193 |
+
ratios.append(data['metrics'].compression_ratio_mean)
|
194 |
+
errors.append(data['metrics'].compression_ratio_std)
|
195 |
+
|
196 |
+
if methods:
|
197 |
+
bars = ax.bar(methods, ratios, yerr=errors, capsize=5)
|
198 |
+
ax.set_ylabel('Compression Ratio')
|
199 |
+
ax.set_title('KV Cache Compression Ratios')
|
200 |
+
ax.grid(True, alpha=0.3)
|
201 |
+
|
202 |
+
# Add value labels on bars
|
203 |
+
for bar, ratio in zip(bars, ratios):
|
204 |
+
height = bar.get_height()
|
205 |
+
ax.text(bar.get_x() + bar.get_width()/2., height,
|
206 |
+
f'{ratio:.1f}x', ha='center', va='bottom')
|
207 |
+
|
208 |
+
plt.tight_layout()
|
209 |
+
plots.append(fig)
|
210 |
+
|
211 |
+
# 2. Memory Usage Comparison
|
212 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
213 |
+
memories = []
|
214 |
+
memory_errors = []
|
215 |
+
|
216 |
+
for method, data in results.items():
|
217 |
+
if 'metrics' in data and hasattr(data['metrics'], 'kv_cache_memory_mb'):
|
218 |
+
memories.append(data['metrics'].kv_cache_memory_mb)
|
219 |
+
memory_errors.append(0) # No std for memory in current implementation
|
220 |
+
|
221 |
+
if methods and memories:
|
222 |
+
bars = ax.bar(methods, memories, yerr=memory_errors, capsize=5, color='coral')
|
223 |
+
ax.set_ylabel('Memory Usage (MB)')
|
224 |
+
ax.set_title('KV Cache Memory Footprint')
|
225 |
+
ax.grid(True, alpha=0.3)
|
226 |
+
|
227 |
+
for bar, mem in zip(bars, memories):
|
228 |
+
height = bar.get_height()
|
229 |
+
ax.text(bar.get_x() + bar.get_width()/2., height,
|
230 |
+
f'{mem:.1f}', ha='center', va='bottom')
|
231 |
+
|
232 |
+
plt.tight_layout()
|
233 |
+
plots.append(fig)
|
234 |
+
|
235 |
+
# 3. Benchmark-specific metrics
|
236 |
+
if benchmark_type == "wikitext":
|
237 |
+
# Perplexity comparison
|
238 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
239 |
+
|
240 |
+
# Prefill perplexity
|
241 |
+
prefill_ppls = []
|
242 |
+
prefill_errors = []
|
243 |
+
gen_ppls = []
|
244 |
+
gen_errors = []
|
245 |
+
|
246 |
+
for method, data in results.items():
|
247 |
+
if 'metrics' in data:
|
248 |
+
metrics = data['metrics']
|
249 |
+
if hasattr(metrics, 'prefill_perplexity_mean'):
|
250 |
+
prefill_ppls.append(metrics.prefill_perplexity_mean)
|
251 |
+
prefill_errors.append(metrics.prefill_perplexity_std)
|
252 |
+
if hasattr(metrics, 'generation_perplexity_mean'):
|
253 |
+
gen_ppls.append(metrics.generation_perplexity_mean)
|
254 |
+
gen_errors.append(metrics.generation_perplexity_std)
|
255 |
+
|
256 |
+
if prefill_ppls:
|
257 |
+
ax1.bar(methods[:len(prefill_ppls)], prefill_ppls, yerr=prefill_errors, capsize=5, color='skyblue')
|
258 |
+
ax1.set_ylabel('Perplexity')
|
259 |
+
ax1.set_title('Prefill Perplexity')
|
260 |
+
ax1.grid(True, alpha=0.3)
|
261 |
+
|
262 |
+
if gen_ppls:
|
263 |
+
ax2.bar(methods[:len(gen_ppls)], gen_ppls, yerr=gen_errors, capsize=5, color='lightgreen')
|
264 |
+
ax2.set_ylabel('Perplexity')
|
265 |
+
ax2.set_title('Generation Perplexity')
|
266 |
+
ax2.grid(True, alpha=0.3)
|
267 |
+
|
268 |
+
plt.suptitle('Quality Metrics: Perplexity Comparison')
|
269 |
+
plt.tight_layout()
|
270 |
+
plots.append(fig)
|
271 |
+
|
272 |
+
elif benchmark_type in ["niah", "ruler", "scbench"]:
|
273 |
+
# Accuracy metrics
|
274 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
275 |
+
accuracies = []
|
276 |
+
|
277 |
+
for method, data in results.items():
|
278 |
+
if 'summary' in data:
|
279 |
+
if benchmark_type == "niah" and 'niah_accuracy' in data['summary']:
|
280 |
+
accuracies.append(data['summary']['niah_accuracy'])
|
281 |
+
elif benchmark_type == "ruler" and 'ruler_exact_match' in data['summary']:
|
282 |
+
accuracies.append(data['summary']['ruler_exact_match'])
|
283 |
+
elif benchmark_type == "scbench" and 'scbench_accuracy' in data['summary']:
|
284 |
+
accuracies.append(data['summary']['scbench_accuracy'])
|
285 |
+
|
286 |
+
if accuracies:
|
287 |
+
bars = ax.bar(methods[:len(accuracies)], accuracies, color='gold')
|
288 |
+
ax.set_ylabel('Accuracy')
|
289 |
+
ax.set_ylim(0, 1.1)
|
290 |
+
ax.set_title(f'{benchmark_type.upper()} Accuracy')
|
291 |
+
ax.grid(True, alpha=0.3)
|
292 |
+
|
293 |
+
for bar, acc in zip(bars, accuracies):
|
294 |
+
height = bar.get_height()
|
295 |
+
ax.text(bar.get_x() + bar.get_width()/2., height,
|
296 |
+
f'{acc:.2%}', ha='center', va='bottom')
|
297 |
+
|
298 |
+
plt.tight_layout()
|
299 |
+
plots.append(fig)
|
300 |
+
|
301 |
+
# 4. Speed comparison
|
302 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
303 |
+
|
304 |
+
prefill_times = []
|
305 |
+
decode_times = []
|
306 |
+
|
307 |
+
for method, data in results.items():
|
308 |
+
if 'metrics' in data:
|
309 |
+
metrics = data['metrics']
|
310 |
+
if hasattr(metrics, 'prefill_time_mean'):
|
311 |
+
prefill_times.append(metrics.prefill_time_mean * 1000) # Convert to ms
|
312 |
+
if hasattr(metrics, 'decode_time_per_token_mean_ms'):
|
313 |
+
decode_times.append(metrics.decode_time_per_token_mean_ms)
|
314 |
+
|
315 |
+
if prefill_times:
|
316 |
+
ax1.bar(methods[:len(prefill_times)], prefill_times, color='purple', alpha=0.7)
|
317 |
+
ax1.set_ylabel('Time (ms)')
|
318 |
+
ax1.set_title('Prefill Time')
|
319 |
+
ax1.grid(True, alpha=0.3)
|
320 |
+
|
321 |
+
if decode_times:
|
322 |
+
ax2.bar(methods[:len(decode_times)], decode_times, color='orange', alpha=0.7)
|
323 |
+
ax2.set_ylabel('Time per Token (ms)')
|
324 |
+
ax2.set_title('Decode Time')
|
325 |
+
ax2.grid(True, alpha=0.3)
|
326 |
+
|
327 |
+
plt.suptitle('Performance Metrics: Speed Comparison')
|
328 |
+
plt.tight_layout()
|
329 |
+
plots.append(fig)
|
330 |
+
|
331 |
+
return plots
|
332 |
+
|
333 |
+
|
334 |
+
def create_summary_text(results: Dict, benchmark_type: str) -> str:
|
335 |
+
"""Create detailed summary text from results."""
|
336 |
+
summary_lines = []
|
337 |
+
summary_lines.append("=" * 60)
|
338 |
+
summary_lines.append("BENCHMARK RESULTS SUMMARY")
|
339 |
+
summary_lines.append("=" * 60)
|
340 |
+
summary_lines.append(f"Benchmark Type: {benchmark_type.upper()}")
|
341 |
+
summary_lines.append(f"Timestamp: {datetime.now().isoformat()}")
|
342 |
+
summary_lines.append("")
|
343 |
+
|
344 |
+
for method, data in results.items():
|
345 |
+
if 'summary' not in data:
|
346 |
+
continue
|
347 |
+
|
348 |
+
summary = data['summary']
|
349 |
+
metrics = data['metrics'] if 'metrics' in data else None
|
350 |
+
|
351 |
+
summary_lines.append(f"Method: {method.upper()}")
|
352 |
+
summary_lines.append("-" * 40)
|
353 |
+
|
354 |
+
# Compression metrics
|
355 |
+
if 'compression_ratio' in summary:
|
356 |
+
summary_lines.append(f"Compression Ratio: {summary['compression_ratio']:.1f}x")
|
357 |
+
if 'kv_cache_memory_mb' in summary:
|
358 |
+
summary_lines.append(f"KV Cache Memory: {summary['kv_cache_memory_mb']:.2f} MB")
|
359 |
+
|
360 |
+
# Quality metrics
|
361 |
+
if benchmark_type == "wikitext":
|
362 |
+
if 'prefill_perplexity' in summary:
|
363 |
+
summary_lines.append(f"Prefill Perplexity: {summary['prefill_perplexity']:.2f}")
|
364 |
+
if 'generation_perplexity' in summary:
|
365 |
+
summary_lines.append(f"Generation Perplexity: {summary['generation_perplexity']:.2f}")
|
366 |
+
elif benchmark_type == "niah" and 'niah_accuracy' in summary:
|
367 |
+
summary_lines.append(f"NIAH Accuracy: {summary['niah_accuracy']:.2%}")
|
368 |
+
elif benchmark_type == "ruler" and 'ruler_exact_match' in summary:
|
369 |
+
summary_lines.append(f"RULER Exact Match: {summary['ruler_exact_match']:.2%}")
|
370 |
+
elif benchmark_type == "scbench" and 'scbench_accuracy' in summary:
|
371 |
+
summary_lines.append(f"SCBench Accuracy: {summary['scbench_accuracy']:.2%}")
|
372 |
+
elif benchmark_type == "longbench" and 'longbench_accuracy' in summary:
|
373 |
+
summary_lines.append(f"LongBench Accuracy: {summary['longbench_accuracy']:.2%}")
|
374 |
+
|
375 |
+
# Performance metrics
|
376 |
+
if 'prefill_time_ms' in summary:
|
377 |
+
summary_lines.append(f"Prefill Time: {summary['prefill_time_ms']:.2f} ms")
|
378 |
+
if 'decode_time_ms' in summary:
|
379 |
+
summary_lines.append(f"Decode Time per Token: {summary['decode_time_ms']:.2f} ms")
|
380 |
+
if 'throughput_tokens_sec' in summary:
|
381 |
+
summary_lines.append(f"Throughput: {summary['throughput_tokens_sec']:.1f} tokens/sec")
|
382 |
+
if 'end_to_end_throughput' in summary:
|
383 |
+
summary_lines.append(f"End-to-End Throughput: {summary['end_to_end_throughput']:.1f} tokens/sec")
|
384 |
+
if 'peak_memory_mb' in summary:
|
385 |
+
summary_lines.append(f"Peak Memory: {summary['peak_memory_mb']:.2f} MB")
|
386 |
+
|
387 |
+
summary_lines.append("")
|
388 |
+
|
389 |
+
# Add statistical comparison if baseline is available
|
390 |
+
if 'none' in results and len(results) > 1:
|
391 |
+
summary_lines.append("COMPARISON WITH BASELINE")
|
392 |
+
summary_lines.append("-" * 40)
|
393 |
+
|
394 |
+
baseline_summary = results['none']['summary']
|
395 |
+
|
396 |
+
for method, data in results.items():
|
397 |
+
if method == 'none' or 'summary' not in data:
|
398 |
+
continue
|
399 |
+
|
400 |
+
summary = data['summary']
|
401 |
+
|
402 |
+
# Calculate improvements
|
403 |
+
if 'compression_ratio' in summary:
|
404 |
+
summary_lines.append(f"{method.upper()} vs Baseline:")
|
405 |
+
summary_lines.append(f" Compression: {summary['compression_ratio']:.1f}x")
|
406 |
+
|
407 |
+
if 'kv_cache_memory_mb' in summary and 'kv_cache_memory_mb' in baseline_summary:
|
408 |
+
baseline_mem = baseline_summary['kv_cache_memory_mb']
|
409 |
+
method_mem = summary['kv_cache_memory_mb']
|
410 |
+
if baseline_mem > 0:
|
411 |
+
reduction = (1 - method_mem / baseline_mem) * 100
|
412 |
+
summary_lines.append(f" Memory Reduction: {reduction:.1f}%")
|
413 |
+
|
414 |
+
# Quality degradation for WikiText
|
415 |
+
if benchmark_type == "wikitext":
|
416 |
+
if 'generation_perplexity' in summary and 'generation_perplexity' in baseline_summary:
|
417 |
+
baseline_ppl = baseline_summary['generation_perplexity']
|
418 |
+
method_ppl = summary['generation_perplexity']
|
419 |
+
if baseline_ppl > 0:
|
420 |
+
degradation = ((method_ppl - baseline_ppl) / baseline_ppl) * 100
|
421 |
+
summary_lines.append(f" Perplexity Change: {degradation:+.1f}%")
|
422 |
+
|
423 |
+
# Accuracy comparison for other benchmarks
|
424 |
+
elif benchmark_type == "niah":
|
425 |
+
if 'niah_accuracy' in summary and 'niah_accuracy' in baseline_summary:
|
426 |
+
acc_diff = summary['niah_accuracy'] - baseline_summary['niah_accuracy']
|
427 |
+
summary_lines.append(f" Accuracy Difference: {acc_diff:+.2%}")
|
428 |
+
|
429 |
+
summary_lines.append("")
|
430 |
+
|
431 |
+
return "\n".join(summary_lines)
|
432 |
+
|
433 |
+
|
434 |
+
def export_results(format_type):
|
435 |
+
"""Export current results in specified format."""
|
436 |
+
if not current_results:
|
437 |
+
return "No results to export. Please run a benchmark first."
|
438 |
+
|
439 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
440 |
+
|
441 |
+
if format_type == "JSON":
|
442 |
+
filename = f"results_{timestamp}.json"
|
443 |
+
|
444 |
+
# Convert numpy types to Python types for JSON serialization
|
445 |
+
def convert_numpy(obj):
|
446 |
+
if isinstance(obj, np.ndarray):
|
447 |
+
return obj.tolist()
|
448 |
+
elif isinstance(obj, (np.integer, np.int64, np.int32)):
|
449 |
+
return int(obj)
|
450 |
+
elif isinstance(obj, (np.floating, np.float64, np.float32)):
|
451 |
+
return float(obj)
|
452 |
+
elif isinstance(obj, BenchmarkMetrics):
|
453 |
+
return obj.__dict__
|
454 |
+
return obj
|
455 |
+
|
456 |
+
serializable_results = json.loads(
|
457 |
+
json.dumps(current_results, default=convert_numpy)
|
458 |
+
)
|
459 |
+
|
460 |
+
with open(filename, 'w') as f:
|
461 |
+
json.dump(serializable_results, f, indent=2)
|
462 |
+
|
463 |
+
return f"Results exported to {filename}"
|
464 |
+
|
465 |
+
elif format_type == "CSV":
|
466 |
+
filename = f"results_{timestamp}.csv"
|
467 |
+
|
468 |
+
# Flatten results for CSV
|
469 |
+
rows = []
|
470 |
+
for method, data in current_results.items():
|
471 |
+
if 'summary' in data:
|
472 |
+
row = {'method': method}
|
473 |
+
row.update(data['summary'])
|
474 |
+
rows.append(row)
|
475 |
+
|
476 |
+
if rows:
|
477 |
+
df = pd.DataFrame(rows)
|
478 |
+
df.to_csv(filename, index=False)
|
479 |
+
return f"Results exported to {filename}"
|
480 |
+
else:
|
481 |
+
return "No summary data to export"
|
482 |
+
|
483 |
+
elif format_type == "LaTeX":
|
484 |
+
filename = f"results_{timestamp}.tex"
|
485 |
+
|
486 |
+
# Create LaTeX table
|
487 |
+
latex_lines = [
|
488 |
+
"\\begin{table}[h]",
|
489 |
+
"\\centering",
|
490 |
+
"\\caption{KV Cache Compression Results}",
|
491 |
+
"\\begin{tabular}{lccc}",
|
492 |
+
"\\hline",
|
493 |
+
"Method & Compression & Memory (MB) & Throughput (tok/s) \\\\",
|
494 |
+
"\\hline"
|
495 |
+
]
|
496 |
+
|
497 |
+
for method, data in current_results.items():
|
498 |
+
if 'summary' in data:
|
499 |
+
s = data['summary']
|
500 |
+
comp = f"{s.get('compression_ratio', 1.0):.1f}x"
|
501 |
+
mem = f"{s.get('kv_cache_memory_mb', 0):.1f}"
|
502 |
+
thr = f"{s.get('throughput_tokens_sec', 0):.1f}"
|
503 |
+
latex_lines.append(f"{method.upper()} & {comp} & {mem} & {thr} \\\\")
|
504 |
+
|
505 |
+
latex_lines.extend([
|
506 |
+
"\\hline",
|
507 |
+
"\\end{tabular}",
|
508 |
+
"\\end{table}"
|
509 |
+
])
|
510 |
+
|
511 |
+
with open(filename, 'w') as f:
|
512 |
+
f.write('\n'.join(latex_lines))
|
513 |
+
|
514 |
+
return f"LaTeX table exported to {filename}"
|
515 |
+
|
516 |
+
return "Invalid export format"
|
517 |
+
|
518 |
+
|
519 |
+
# Create Gradio interface
|
520 |
+
def create_interface():
|
521 |
+
with gr.Blocks(title="RocketKV-Enhanced SPG Benchmark") as demo:
|
522 |
+
gr.Markdown("""
|
523 |
+
# π RocketKV-Enhanced SPG Compression Benchmark
|
524 |
+
|
525 |
+
Research-grade KV cache compression with **450x compression capability**.
|
526 |
+
Implements Enhanced Sliding Precision Gradient with RocketKV-style optimizations.
|
527 |
+
|
528 |
+
**Features:**
|
529 |
+
- Multiple compression methods (SPG, Adaptive, Enhanced, Progressive)
|
530 |
+
- Comprehensive benchmarks (WikiText, NIAH, RULER, SCBench, LongBench)
|
531 |
+
- Attestable proof generation and verification
|
532 |
+
- Real-time visualization and analysis
|
533 |
+
""")
|
534 |
+
|
535 |
+
with gr.Tab("Configuration"):
|
536 |
+
with gr.Row():
|
537 |
+
with gr.Column():
|
538 |
+
gr.Markdown("### Model & Benchmark Settings")
|
539 |
+
model_dropdown = gr.Dropdown(
|
540 |
+
choices=list(SUPPORTED_MODELS.keys()),
|
541 |
+
value="gpt2",
|
542 |
+
label="Model"
|
543 |
+
)
|
544 |
+
|
545 |
+
compression_dropdown = gr.Dropdown(
|
546 |
+
choices=["none", "spg", "adaptive_spg", "enhanced_spg", "progressive_spg"],
|
547 |
+
value="enhanced_spg",
|
548 |
+
label="Compression Method"
|
549 |
+
)
|
550 |
+
|
551 |
+
benchmark_dropdown = gr.Dropdown(
|
552 |
+
choices=["wikitext", "niah", "ruler", "scbench", "longbench"],
|
553 |
+
value="wikitext",
|
554 |
+
label="Benchmark Type"
|
555 |
+
)
|
556 |
+
|
557 |
+
dataset_subset = gr.Dropdown(
|
558 |
+
choices=BENCHMARK_CONFIGS["longbench"]["subsets"],
|
559 |
+
value="narrativeqa",
|
560 |
+
label="LongBench Subset (if applicable)",
|
561 |
+
visible=False
|
562 |
+
)
|
563 |
+
|
564 |
+
# Show/hide subset based on benchmark type
|
565 |
+
def update_subset_visibility(benchmark_type):
|
566 |
+
return gr.update(visible=(benchmark_type == "longbench"))
|
567 |
+
|
568 |
+
benchmark_dropdown.change(
|
569 |
+
update_subset_visibility,
|
570 |
+
inputs=[benchmark_dropdown],
|
571 |
+
outputs=[dataset_subset]
|
572 |
+
)
|
573 |
+
|
574 |
+
with gr.Column():
|
575 |
+
gr.Markdown("### Evaluation Parameters")
|
576 |
+
eval_samples = gr.Slider(1, 100, value=20, step=1, label="Evaluation Samples")
|
577 |
+
n_seeds = gr.Slider(1, 5, value=3, step=1, label="Random Seeds")
|
578 |
+
seq_length = gr.Slider(128, 1024, value=512, step=128,
|
579 |
+
label="Sequence Length (max 1024 for GPT-2)")
|
580 |
+
generation_length = gr.Slider(16, 128, value=64, step=16, label="Generation Length")
|
581 |
+
|
582 |
+
with gr.Row():
|
583 |
+
with gr.Column():
|
584 |
+
gr.Markdown("### SPG Core Parameters")
|
585 |
+
base_decay = gr.Slider(0.8, 0.99, value=0.95, step=0.01, label="Base Decay Rate")
|
586 |
+
sink_tokens = gr.Slider(0, 8, value=2, step=1, label="Sink Tokens")
|
587 |
+
recent_window = gr.Slider(8, 64, value=32, step=8, label="Recent Window")
|
588 |
+
|
589 |
+
with gr.Column():
|
590 |
+
gr.Markdown("### Adaptive SPG")
|
591 |
+
enable_adaptive = gr.Checkbox(value=False, label="Enable Adaptive")
|
592 |
+
target_ppl_delta = gr.Slider(0.5, 5.0, value=1.8, step=0.1,
|
593 |
+
label="Target Perplexity Delta")
|
594 |
+
|
595 |
+
with gr.Row():
|
596 |
+
with gr.Column():
|
597 |
+
gr.Markdown("### Progressive Compression")
|
598 |
+
enable_progressive = gr.Checkbox(value=False, label="Enable Progressive")
|
599 |
+
quality_threshold = gr.Slider(0.005, 0.05, value=0.01, step=0.005,
|
600 |
+
label="Quality Threshold")
|
601 |
+
initial_compression = gr.Slider(10.0, 200.0, value=50.0, step=5.0,
|
602 |
+
label="Initial Compression Ratio")
|
603 |
+
max_compression = gr.Slider(100.0, 500.0, value=400.0, step=25.0,
|
604 |
+
label="Max Compression Ratio")
|
605 |
+
|
606 |
+
with gr.Column():
|
607 |
+
gr.Markdown("### Enhanced SPG (RocketKV-style)")
|
608 |
+
sequence_comp_ratio = gr.Slider(0.0001, 0.001, value=0.0001, step=0.00005,
|
609 |
+
label="Sequence Compression Ratio")
|
610 |
+
head_comp_ratio = gr.Slider(0.0001, 0.001, value=0.0001, step=0.00005,
|
611 |
+
label="Head Compression Ratio")
|
612 |
+
head_retention = gr.Dropdown(
|
613 |
+
choices=["conservative", "aggressive"],
|
614 |
+
value="aggressive",
|
615 |
+
label="Head Retention Mode"
|
616 |
+
)
|
617 |
+
magnitude_mode = gr.Dropdown(
|
618 |
+
choices=["conservative", "aggressive", "extreme"],
|
619 |
+
value="aggressive", # Changed from "extreme" for stability
|
620 |
+
label="Magnitude Threshold Mode"
|
621 |
+
)
|
622 |
+
|
623 |
+
with gr.Row():
|
624 |
+
with gr.Column():
|
625 |
+
gr.Markdown("### Stability Parameters")
|
626 |
+
min_tokens_stability = gr.Slider(4, 16, value=8, step=1,
|
627 |
+
label="Min Tokens for Stability")
|
628 |
+
recent_boost = gr.Slider(0.0, 0.5, value=0.1, step=0.05,
|
629 |
+
label="Recent Boost Factor")
|
630 |
+
|
631 |
+
with gr.Column():
|
632 |
+
gr.Markdown("### System Settings")
|
633 |
+
fail_on_cpu = gr.Checkbox(value=False, label="Fail on CPU Fallback")
|
634 |
+
|
635 |
+
with gr.Tab("Run Benchmark"):
|
636 |
+
run_button = gr.Button("π Run Benchmark", variant="primary")
|
637 |
+
|
638 |
+
with gr.Row():
|
639 |
+
progress_text = gr.Textbox(label="Progress", lines=10)
|
640 |
+
|
641 |
+
with gr.Row():
|
642 |
+
plot_gallery = gr.Gallery(label="Results Visualization", columns=2, height="auto")
|
643 |
+
|
644 |
+
with gr.Row():
|
645 |
+
summary_output = gr.Textbox(label="Summary", lines=20)
|
646 |
+
verification_output = gr.Textbox(label="Proof Verification", lines=5)
|
647 |
+
|
648 |
+
with gr.Tab("Export Results"):
|
649 |
+
gr.Markdown("### Export Options")
|
650 |
+
|
651 |
+
export_format = gr.Radio(
|
652 |
+
choices=["JSON", "CSV", "LaTeX"],
|
653 |
+
value="JSON",
|
654 |
+
label="Export Format"
|
655 |
+
)
|
656 |
+
|
657 |
+
export_button = gr.Button("π₯ Export Results")
|
658 |
+
export_status = gr.Textbox(label="Export Status")
|
659 |
+
|
660 |
+
export_button.click(
|
661 |
+
export_results,
|
662 |
+
inputs=[export_format],
|
663 |
+
outputs=[export_status]
|
664 |
+
)
|
665 |
+
|
666 |
+
# Connect the run button
|
667 |
+
run_button.click(
|
668 |
+
run_benchmark,
|
669 |
+
inputs=[
|
670 |
+
model_dropdown, compression_dropdown, benchmark_dropdown, dataset_subset,
|
671 |
+
eval_samples, n_seeds, seq_length, generation_length,
|
672 |
+
base_decay, sink_tokens, recent_window,
|
673 |
+
enable_adaptive, target_ppl_delta,
|
674 |
+
enable_progressive, quality_threshold,
|
675 |
+
initial_compression, max_compression,
|
676 |
+
sequence_comp_ratio, head_comp_ratio,
|
677 |
+
head_retention, magnitude_mode,
|
678 |
+
min_tokens_stability, recent_boost,
|
679 |
+
fail_on_cpu
|
680 |
+
],
|
681 |
+
outputs=[plot_gallery, summary_output, verification_output]
|
682 |
+
)
|
683 |
+
|
684 |
+
return demo
|
685 |
+
|
686 |
+
|
687 |
+
if __name__ == "__main__":
|
688 |
+
# Set up logging
|
689 |
+
logging.basicConfig(
|
690 |
+
level=logging.INFO,
|
691 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
692 |
+
)
|
693 |
+
|
694 |
+
# Create and launch the interface
|
695 |
+
demo = create_interface()
|
696 |
+
demo.launch(
|
697 |
+
server_name="0.0.0.0",
|
698 |
+
server_port=7860,
|
699 |
+
share=False,
|
700 |
+
show_error=True
|
701 |
+
)
|