serpent / benchmark.py
kfoughali's picture
Update benchmark.py
e49d439 verified
raw
history blame
52.3 kB
# benchmark.py
"""
Benchmarking, metrics, and proof generation for Enhanced SPG.
Supports LongBench, NIAH, RULER, SCBench benchmarks.
MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT.
FIXED: Generation errors, proper fallback handling.
"""
import torch
import torch.nn.functional as F
import numpy as np
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
DynamicCache
)
from datasets import load_dataset
from typing import Tuple, Optional, Dict, Any, List
from dataclasses import dataclass, field
from scipy import stats
import time
import json
import hashlib
import logging
import gc
import os
import sys
import platform
import subprocess
import zipfile
import pathlib
from datetime import datetime
import random
from config import (
CompressionConfig, CompressionType, ProvingConfig,
ResearchConstants, SUPPORTED_MODELS, BENCHMARK_CONFIGS
)
from compression import QuantizedKVCache, detect_model_layers
logger = logging.getLogger(__name__)
def set_seed(seed: int = 42) -> None:
"""Set all seeds for reproducibility with explicit validation."""
if not isinstance(seed, int) or seed < 0:
raise ValueError(f"Seed must be non-negative integer, got {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logger.info(f"Set all random seeds to {seed}")
def _peak_mem_bytes_all_gpus() -> int:
"""Get peak memory across all GPUs. FAIL FAST if CUDA unavailable when expected."""
if not torch.cuda.is_available():
raise RuntimeError("CUDA memory tracking requested but CUDA is unavailable")
torch.cuda.synchronize()
total_mem = sum(torch.cuda.max_memory_allocated(d) for d in range(torch.cuda.device_count()))
logger.debug(f"Peak GPU memory: {total_mem / 1024 / 1024:.1f} MB")
return total_mem
@dataclass
class BenchmarkMetrics:
"""Comprehensive metrics with proper statistical handling - NO ESTIMATES."""
# Prefill metrics
prefill_times: List[float] = field(default_factory=list)
prefill_peak_memories: List[float] = field(default_factory=list)
prefill_time_mean: float = 0.0
prefill_time_std: float = 0.0
prefill_time_ci: Tuple[float, float] = (0.0, 0.0)
prefill_peak_memory_mean_mb: float = 0.0
prefill_peak_memory_std_mb: float = 0.0
prefill_peak_memory_ci_mb: Tuple[float, float] = (0.0, 0.0)
prefill_tokens_per_sec: float = 0.0
# Decode metrics
decode_times: List[float] = field(default_factory=list)
decode_peak_memories: List[float] = field(default_factory=list)
decode_time_per_token_mean_ms: float = 0.0
decode_time_per_token_std_ms: float = 0.0
decode_time_per_token_ci_ms: Tuple[float, float] = (0.0, 0.0)
decode_time_p50_ms: float = 0.0
decode_time_p95_ms: float = 0.0
decode_peak_memory_mean_mb: float = 0.0
decode_tokens_per_sec: float = 0.0
# Quality metrics
prefill_perplexities: List[float] = field(default_factory=list)
generation_perplexities: List[float] = field(default_factory=list)
prefill_perplexity_mean: float = 0.0
prefill_perplexity_std: float = 0.0
prefill_perplexity_ci: Tuple[float, float] = (0.0, 0.0)
generation_perplexity_mean: float = 0.0
generation_perplexity_std: float = 0.0
generation_perplexity_ci: Tuple[float, float] = (0.0, 0.0)
# Benchmark-specific metrics
longbench_scores: List[Dict[str, float]] = field(default_factory=list)
niah_retrieval_accuracy: List[float] = field(default_factory=list)
ruler_exact_match: List[float] = field(default_factory=list)
scbench_turn_accuracy: List[float] = field(default_factory=list)
# Compression metrics (MEASURED ONLY - no estimates)
compression_ratios: List[float] = field(default_factory=list)
compression_ratio_mean: float = 0.0
compression_ratio_std: float = 0.0
kv_cache_memory_mb: float = 0.0
kv_cache_memory_samples_mb: List[float] = field(default_factory=list)
# Enhanced SPG metrics (MEASURED ONLY)
enhanced_spg_measured_compression: List[float] = field(default_factory=list)
enhanced_spg_measured_auxiliary_overhead_mb: List[float] = field(default_factory=list)
enhanced_spg_progressive_steps: List[int] = field(default_factory=list)
# Original SPG metrics
spg_precision_distributions: List[Dict[str, float]] = field(default_factory=list)
spg_effective_bits_per_token: List[float] = field(default_factory=list)
spg_decay_rates_per_layer: List[List[float]] = field(default_factory=list)
# Statistical comparisons
memory_reduction_ratio: float = 1.0
memory_reduction_pvalue: float = 1.0
speedup_ratio: float = 1.0
speedup_pvalue: float = 1.0
prefill_perplexity_delta: float = 0.0
generation_perplexity_delta: float = 0.0
perplexity_pvalue: float = 1.0
# End-to-end metrics
end_to_end_throughput: float = 0.0
end_to_end_latency_ms: float = 0.0
def calculate_statistics(self, config: CompressionConfig) -> None:
"""Calculate all statistics with proper error handling."""
try:
if self.prefill_times:
self.prefill_time_mean = float(np.mean(self.prefill_times))
self.prefill_time_std = float(np.std(self.prefill_times))
self.prefill_time_ci = self._bootstrap_ci(self.prefill_times, config)
self.prefill_tokens_per_sec = config.prefill_length / self.prefill_time_mean if self.prefill_time_mean > 0 else 0.0
if self.prefill_peak_memories:
memories_mb = [m / (1024 * 1024) for m in self.prefill_peak_memories]
self.prefill_peak_memory_mean_mb = float(np.mean(memories_mb))
self.prefill_peak_memory_std_mb = float(np.std(memories_mb))
self.prefill_peak_memory_ci_mb = self._bootstrap_ci(memories_mb, config)
if self.decode_times:
self.decode_time_per_token_mean_ms = float(np.mean(self.decode_times) * 1000)
self.decode_time_per_token_std_ms = float(np.std(self.decode_times) * 1000)
self.decode_time_per_token_ci_ms = tuple(x * 1000 for x in self._bootstrap_ci(self.decode_times, config))
self.decode_tokens_per_sec = 1.0 / np.mean(self.decode_times) if self.decode_times else 0.0
self.decode_time_p50_ms = float(np.percentile(self.decode_times, 50) * 1000)
self.decode_time_p95_ms = float(np.percentile(self.decode_times, 95) * 1000)
# Calculate end-to-end throughput
if self.prefill_time_mean > 0 and self.decode_time_per_token_mean_ms > 0:
total_tokens = config.prefill_length + config.generation_length
total_time_sec = self.prefill_time_mean + (self.decode_time_per_token_mean_ms * config.generation_length / 1000)
self.end_to_end_throughput = total_tokens / total_time_sec if total_time_sec > 0 else 0.0
self.end_to_end_latency_ms = total_time_sec * 1000
if self.decode_peak_memories:
self.decode_peak_memory_mean_mb = float(np.mean(self.decode_peak_memories) / (1024 * 1024))
if self.prefill_perplexities:
self.prefill_perplexity_mean = float(np.mean(self.prefill_perplexities))
self.prefill_perplexity_std = float(np.std(self.prefill_perplexities))
self.prefill_perplexity_ci = self._bootstrap_ci(self.prefill_perplexities, config)
if self.generation_perplexities:
self.generation_perplexity_mean = float(np.mean(self.generation_perplexities))
self.generation_perplexity_std = float(np.std(self.generation_perplexities))
self.generation_perplexity_ci = self._bootstrap_ci(self.generation_perplexities, config)
if self.compression_ratios:
self.compression_ratio_mean = float(np.mean(self.compression_ratios))
self.compression_ratio_std = float(np.std(self.compression_ratios))
if self.kv_cache_memory_samples_mb:
self.kv_cache_memory_mb = float(np.mean(self.kv_cache_memory_samples_mb))
except Exception as e:
logger.error(f"Error calculating statistics: {e}")
raise
def _bootstrap_ci(self, data: List[float], config: CompressionConfig) -> Tuple[float, float]:
"""Calculate bootstrap confidence interval with reproducible RNG."""
if not data or len(data) < 2:
return (0.0, 0.0)
try:
rng = np.random.default_rng(config.seed)
bootstrap_means = []
data_array = np.array(data)
for _ in range(config.n_bootstrap):
sample = rng.choice(data_array, size=len(data_array), replace=True)
bootstrap_means.append(float(sample.mean()))
if bootstrap_means:
alpha = 1 - config.confidence_level
lower = float(np.percentile(bootstrap_means, alpha/2 * 100))
upper = float(np.percentile(bootstrap_means, (1 - alpha/2) * 100))
return (lower, upper)
except Exception as e:
logger.error(f"Error in bootstrap CI calculation: {e}")
raise
return (0.0, 0.0)
def safe_tokenize(tokenizer, text, max_length=512):
"""Safe tokenization with proper padding and truncation."""
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_length,
padding="max_length",
return_attention_mask=True,
add_special_tokens=True
)
if inputs.input_ids.shape[1] == 0:
raise ValueError("Tokenization produced empty sequence")
if inputs.input_ids.shape[1] > max_length:
inputs.input_ids = inputs.input_ids[:, :max_length]
inputs.attention_mask = inputs.attention_mask[:, :max_length]
return inputs
def validate_model_inputs(model, input_ids, attention_mask):
"""Validate inputs are compatible with model."""
if hasattr(model.config, 'max_position_embeddings'):
max_pos = model.config.max_position_embeddings
if input_ids.shape[1] > max_pos:
input_ids = input_ids[:, :max_pos]
attention_mask = attention_mask[:, :max_pos]
if hasattr(model.config, 'n_positions'):
n_pos = model.config.n_positions
if input_ids.shape[1] > n_pos:
input_ids = input_ids[:, :n_pos]
attention_mask = attention_mask[:, :n_pos]
vocab_size = model.config.vocab_size
if input_ids.max() >= vocab_size:
input_ids = input_ids.clamp(0, vocab_size - 1)
if input_ids.min() < 0:
input_ids = input_ids.clamp(0, vocab_size - 1)
return input_ids, attention_mask
def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=None, max_new_tokens=20):
"""Safe generation with proper error handling - returns generated text."""
try:
input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
gen_config = {
"max_new_tokens": max_new_tokens,
"do_sample": False,
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"attention_mask": attention_mask,
"use_cache": True
}
if past_key_values is not None:
gen_config["past_key_values"] = past_key_values
with torch.no_grad():
output = model.generate(input_ids, **gen_config)
# Decode only the generated part
generated_ids = output[:, input_ids.shape[1]:]
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return generated_text
except Exception as e:
logger.error(f"Generation failed: {e}")
# Return empty string on failure
return ""
def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
cache_manager: QuantizedKVCache, config: CompressionConfig,
measure_memory: bool = True) -> Dict[str, Any]:
"""
Unified compression pipeline for ALL benchmarks with safety fixes.
Returns compressed cache, metrics, and reconstructed KV pairs.
"""
device = input_ids.device
input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
if torch.cuda.is_available() and measure_memory:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
try:
with torch.inference_mode():
outputs = model(
input_ids,
attention_mask=attention_mask,
use_cache=True,
return_dict=True
)
past_key_values = outputs.past_key_values
logits = outputs.logits
except Exception as e:
logger.error(f"Prefill failed: {e}")
return {
'past_key_values': None,
'prefill_time': 0,
'prefill_peak_mem': 0,
'prefill_loss': None,
'original_cache_size': 0,
'compressed_cache_size': 0,
'compression_ratio': 1.0,
'logits': None
}
if torch.cuda.is_available():
torch.cuda.synchronize()
prefill_time = time.perf_counter() - start_time
prefill_peak_mem = 0
if torch.cuda.is_available() and measure_memory:
prefill_peak_mem = _peak_mem_bytes_all_gpus()
prefill_loss = None
if logits is not None and input_ids.shape[1] > 1:
try:
seq_len = min(logits.shape[1], input_ids.shape[1] - 1)
if seq_len > 0:
shift_logits = logits[:, :seq_len, :].contiguous()
shift_labels = input_ids[:, 1:seq_len+1].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction='mean',
ignore_index=tokenizer.pad_token_id or -100
)
prefill_loss = loss.item()
except Exception as e:
logger.warning(f"Could not calculate prefill loss: {e}")
original_cache_size = 0
compressed_cache_size = 0
compression_ratio = 1.0
if past_key_values:
try:
if hasattr(past_key_values, 'to_legacy_cache'):
kv_tuple = past_key_values.to_legacy_cache()
else:
kv_tuple = past_key_values
for layer_idx, (keys, values) in enumerate(kv_tuple):
if keys is not None and values is not None:
original_cache_size += keys.nelement() * keys.element_size()
original_cache_size += values.nelement() * values.element_size()
if config.compression_type != CompressionType.NONE and cache_manager is not None:
try:
cache_manager.compress_and_store(layer_idx, keys, values)
except Exception as e:
logger.error(f"Compression failed for layer {layer_idx}: {e}")
if config.compression_type != CompressionType.NONE and cache_manager is not None:
reconstructed_kv = []
for layer_idx in range(len(kv_tuple)):
try:
dec_keys, dec_values = cache_manager.get_decompressed(layer_idx)
if dec_keys is not None and dec_values is not None:
reconstructed_kv.append((dec_keys, dec_values))
else:
reconstructed_kv.append(kv_tuple[layer_idx])
except Exception as e:
logger.error(f"Decompression failed for layer {layer_idx}: {e}")
reconstructed_kv.append(kv_tuple[layer_idx])
if hasattr(DynamicCache, 'from_legacy_cache'):
past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
else:
past_key_values = tuple(reconstructed_kv)
try:
compressed_cache_size = cache_manager.get_memory_footprint()
except:
compressed_cache_size = original_cache_size
else:
compressed_cache_size = original_cache_size
if compressed_cache_size > 0:
compression_ratio = original_cache_size / compressed_cache_size
except Exception as e:
logger.error(f"Cache processing failed: {e}")
compressed_cache_size = original_cache_size
compression_ratio = 1.0
return {
'past_key_values': past_key_values,
'prefill_time': prefill_time,
'prefill_peak_mem': prefill_peak_mem,
'prefill_loss': prefill_loss,
'original_cache_size': original_cache_size,
'compressed_cache_size': compressed_cache_size,
'compression_ratio': compression_ratio,
'logits': logits
}
def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str:
"""Create Needle-in-a-Haystack test context - NO HARDCODING."""
haystack_template = "The quick brown fox jumps over the lazy dog. " * 20
haystack_chunks = []
while len(" ".join(haystack_chunks)) < context_length:
haystack_chunks.append(haystack_template)
haystack = " ".join(haystack_chunks)[:context_length - len(needle) - 10]
insertion_point = int(len(haystack) * depth_percent / 100)
haystack_with_needle = (
haystack[:insertion_point] +
" " + needle + " " +
haystack[insertion_point:]
)
return haystack_with_needle
def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
"""Evaluate NIAH with SAME compression pipeline as WikiText."""
context = create_niah_haystack(
config.prefill_length,
config.niah_needle,
config.niah_depth_percent
)
prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:"
inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
compression_result = apply_compression_pipeline(
model, tokenizer, input_ids, attention_mask, cache_manager, config
)
gen_start = time.perf_counter()
generated_text = safe_generate(model, tokenizer, input_ids, attention_mask,
compression_result['past_key_values'], max_new_tokens=20)
gen_time = time.perf_counter() - gen_start
accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0
logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}")
logger.info(f"NIAH compression ratio: {compression_result['compression_ratio']:.1f}x")
return {
'accuracy': accuracy,
'compression_ratio': compression_result['compression_ratio'],
'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
'prefill_time': compression_result['prefill_time'],
'generation_time': gen_time,
'prefill_peak_mem': compression_result['prefill_peak_mem']
}
def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
"""Evaluate RULER with SAME compression pipeline as WikiText."""
seq_len = min(config.ruler_max_seq_length, config.prefill_length, 1024)
facts = []
for i in range(10):
facts.append(f"Fact {i}: The capital of Country{i} is City{i}.")
context = " ".join(facts) * (seq_len // (len(" ".join(facts)) + 1))
context = context[:seq_len - 100]
query_idx = random.randint(0, 9)
prompt = f"{context}\n\nWhat is the capital of Country{query_idx}?"
inputs = safe_tokenize(tokenizer, prompt, max_length=seq_len)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
compression_result = apply_compression_pipeline(
model, tokenizer, input_ids, attention_mask, cache_manager, config
)
gen_start = time.perf_counter()
generated = safe_generate(model, tokenizer, input_ids, attention_mask,
compression_result['past_key_values'], max_new_tokens=10)
gen_time = time.perf_counter() - gen_start
expected = f"City{query_idx}"
exact_match = 1.0 if expected in generated else 0.0
logger.info(f"RULER exact match: {exact_match}, Generated: {generated[:50]}")
logger.info(f"RULER compression ratio: {compression_result['compression_ratio']:.1f}x")
return {
'exact_match': exact_match,
'compression_ratio': compression_result['compression_ratio'],
'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
'prefill_time': compression_result['prefill_time'],
'generation_time': gen_time,
'prefill_peak_mem': compression_result['prefill_peak_mem']
}
def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
"""Evaluate SCBench with SAME compression pipeline as WikiText."""
conversation = []
facts = {}
for turn in range(config.scbench_num_turns):
fact_key = f"item_{turn}"
fact_value = f"value_{turn}_{random.randint(1000, 9999)}"
facts[fact_key] = fact_value
user_msg = f"Remember that {fact_key} is {fact_value}."
assistant_msg = f"I'll remember that {fact_key} is {fact_value}."
conversation.append(f"User: {user_msg}")
conversation.append(f"Assistant: {assistant_msg}")
query_key = random.choice(list(facts.keys()))
conversation.append(f"User: What is {query_key}?")
full_conversation = "\n".join(conversation) + "\nAssistant:"
inputs = safe_tokenize(tokenizer, full_conversation, max_length=min(config.prefill_length, 1024))
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
compression_result = apply_compression_pipeline(
model, tokenizer, input_ids, attention_mask, cache_manager, config
)
gen_start = time.perf_counter()
generated = safe_generate(model, tokenizer, input_ids, attention_mask,
compression_result['past_key_values'], max_new_tokens=20)
gen_time = time.perf_counter() - gen_start
expected_value = facts[query_key]
accuracy = 1.0 if expected_value in generated else 0.0
logger.info(f"SCBench accuracy: {accuracy}, Generated: {generated[:50]}")
logger.info(f"SCBench compression ratio: {compression_result['compression_ratio']:.1f}x")
return {
'accuracy': accuracy,
'compression_ratio': compression_result['compression_ratio'],
'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
'prefill_time': compression_result['prefill_time'],
'generation_time': gen_time,
'prefill_peak_mem': compression_result['prefill_peak_mem']
}
def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
task: str, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
"""Evaluate LongBench with SAME compression pipeline as WikiText."""
try:
dataset = load_dataset("THUDM/LongBench", task, split="test")
n_samples = min(config.eval_samples, len(dataset))
samples = dataset.select(range(n_samples))
scores = []
compression_ratios = []
kv_memories = []
prefill_times = []
gen_times = []
for sample in samples:
context = sample.get("context", "")
question = sample.get("input", sample.get("question", ""))
answer = sample.get("answers", [sample.get("answer", "")])
if isinstance(answer, list) and answer:
answer = answer[0]
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
compression_result = apply_compression_pipeline(
model, tokenizer, input_ids, attention_mask, cache_manager, config,
measure_memory=False
)
gen_start = time.perf_counter()
generated = safe_generate(model, tokenizer, input_ids, attention_mask,
compression_result['past_key_values'], max_new_tokens=50)
gen_time = time.perf_counter() - gen_start
score = 1.0 if str(answer).lower() in generated.lower() else 0.0
scores.append(score)
compression_ratios.append(compression_result['compression_ratio'])
kv_memories.append(compression_result['compressed_cache_size'] / (1024 * 1024))
prefill_times.append(compression_result['prefill_time'])
gen_times.append(gen_time)
avg_compression = float(np.mean(compression_ratios)) if compression_ratios else 1.0
return {
'accuracy': float(np.mean(scores)),
'n_samples': n_samples,
'compression_ratio': avg_compression,
'kv_cache_memory_mb': float(np.mean(kv_memories)) if kv_memories else 0.0,
'prefill_time': float(np.mean(prefill_times)) if prefill_times else 0.0,
'generation_time': float(np.mean(gen_times)) if gen_times else 0.0
}
except Exception as e:
logger.error(f"Error evaluating LongBench task {task}: {e}")
return {
'accuracy': 0.0,
'n_samples': 0,
'compression_ratio': 1.0,
'kv_cache_memory_mb': 0.0,
'prefill_time': 0.0,
'generation_time': 0.0
}
def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
"""Load model and tokenizer with proper configuration - NO HARDCODING."""
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
if config.fail_on_cpu_fallback and device == "cpu":
raise RuntimeError("CUDA required but unavailable (fail_on_cpu_fallback=True)")
logger.info(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_kwargs = {
"torch_dtype": dtype,
"device_map": "auto" if device == "cuda" else None,
"low_cpu_mem_usage": True,
"trust_remote_code": True
}
if config.use_flash_attention and device == "cuda":
try:
model_kwargs["attn_implementation"] = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
logger.info("Successfully loaded with Flash Attention 2")
except Exception as e:
logger.warning(f"Flash Attention not available: {e}")
model_kwargs.pop("attn_implementation", None)
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
model.eval()
return model, tokenizer
def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]:
"""Load dataset samples based on benchmark type - NO HARDCODING."""
logger.info(f"Loading samples for benchmark: {config.benchmark_type}")
if config.benchmark_type == "wikitext":
texts = []
min_tokens = config.prefill_length + config.generation_length
try:
for split in [config.dataset_split, "train", "validation"]:
if len(texts) >= config.eval_samples:
break
try:
dataset = load_dataset(
config.dataset_name,
config.dataset_config,
split=split,
streaming=False
)
logger.info(f"Trying {split} split with {len(dataset)} samples")
for item in dataset:
text = item.get('text', '').strip()
if len(text) > 50:
tokens = tokenizer.encode(text, truncation=False, add_special_tokens=False)
if len(tokens) >= min(min_tokens, 256):
texts.append(text)
if len(texts) >= config.eval_samples * 3:
break
except Exception as e:
logger.warning(f"Failed to load {split} split: {e}")
continue
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
raise
elif config.benchmark_type == "longbench":
texts = []
if config.benchmark_subset:
try:
dataset = load_dataset("THUDM/LongBench", config.benchmark_subset, split="test")
for item in dataset:
if len(texts) >= config.eval_samples:
break
context = item.get("context", "")
if len(context) > 100:
texts.append(context)
except Exception as e:
logger.error(f"Failed to load LongBench subset {config.benchmark_subset}: {e}")
raise
elif config.benchmark_type in ["niah", "ruler", "scbench"]:
texts = ["Synthetic benchmark data"] * config.eval_samples
else:
raise ValueError(f"Unsupported benchmark type: {config.benchmark_type}")
if len(texts) < config.eval_samples:
logger.warning(f"Only loaded {len(texts)} samples, requested {config.eval_samples}")
logger.info(f"Loaded {len(texts)} text samples")
return texts
def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_texts: Optional[List[str]] = None) -> Tuple[BenchmarkMetrics, Dict, List[Dict], List[Dict]]:
"""Research-grade benchmark with UNIFIED compression for ALL benchmarks."""
logger.info(f"Starting benchmark: {model_name} with {config.compression_type.value}")
logger.info(f"Benchmark type: {config.benchmark_type}")
logger.info(f"Config hash: {config.get_hash()}")
if torch.cuda.is_available():
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
constants = ResearchConstants()
start_time = datetime.now().isoformat()
per_sample_records = []
per_layer_fingerprints = []
model, tokenizer = load_model_and_tokenizer(model_name, config)
try:
n_layers = detect_model_layers(model)
logger.info(f"Model architecture: {n_layers} transformer layers detected")
except ValueError as e:
logger.error(f"Failed to detect model layers: {e}")
raise
device = model.device
with torch.inference_mode():
dummy = torch.randint(0, tokenizer.vocab_size, (1, min(config.prefill_length, 128)), device=device)
am = torch.ones_like(dummy)
for _ in range(config.warmup_steps):
_ = model(dummy, attention_mask=am, use_cache=True, return_dict=True)
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
if dataset_texts is None:
dataset_texts = load_real_dataset_samples(config, tokenizer)
all_metrics = []
for seed in range(config.n_seeds):
set_seed(config.seed + seed)
logger.info(f"Running evaluation with seed {config.seed + seed}")
metrics = BenchmarkMetrics()
if config.benchmark_type == "niah":
for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
config.niah_depth_percent = depth
for idx in range(min(config.eval_samples, 10)):
if config.compression_type != CompressionType.NONE:
cache_manager = QuantizedKVCache(config)
cache_manager.n_layers = n_layers
else:
cache_manager = None
result = evaluate_niah(model, tokenizer, config, cache_manager)
metrics.niah_retrieval_accuracy.append(result['accuracy'])
metrics.compression_ratios.append(result['compression_ratio'])
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
metrics.prefill_times.append(result['prefill_time'])
metrics.decode_times.append(result['generation_time'] / 20)
if result['prefill_peak_mem'] > 0:
metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
per_sample_records.append({
'benchmark': 'niah',
'depth_percent': depth,
'sample_idx': idx,
'accuracy': result['accuracy'],
'compression_ratio': result['compression_ratio'],
'kv_cache_memory_mb': result['kv_cache_memory_mb'],
'compression_type': config.compression_type.value
})
elif config.benchmark_type == "ruler":
for idx in range(config.eval_samples):
if config.compression_type != CompressionType.NONE:
cache_manager = QuantizedKVCache(config)
cache_manager.n_layers = n_layers
else:
cache_manager = None
result = evaluate_ruler(model, tokenizer, config, cache_manager)
metrics.ruler_exact_match.append(result['exact_match'])
metrics.compression_ratios.append(result['compression_ratio'])
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
metrics.prefill_times.append(result['prefill_time'])
metrics.decode_times.append(result['generation_time'] / 10)
if result['prefill_peak_mem'] > 0:
metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
per_sample_records.append({
'benchmark': 'ruler',
'sample_idx': idx,
'exact_match': result['exact_match'],
'compression_ratio': result['compression_ratio'],
'kv_cache_memory_mb': result['kv_cache_memory_mb'],
'compression_type': config.compression_type.value
})
elif config.benchmark_type == "scbench":
for idx in range(config.eval_samples):
if config.compression_type != CompressionType.NONE:
cache_manager = QuantizedKVCache(config)
cache_manager.n_layers = n_layers
else:
cache_manager = None
result = evaluate_scbench(model, tokenizer, config, cache_manager)
metrics.scbench_turn_accuracy.append(result['accuracy'])
metrics.compression_ratios.append(result['compression_ratio'])
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
metrics.prefill_times.append(result['prefill_time'])
metrics.decode_times.append(result['generation_time'] / 20)
if result['prefill_peak_mem'] > 0:
metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
per_sample_records.append({
'benchmark': 'scbench',
'sample_idx': idx,
'accuracy': result['accuracy'],
'compression_ratio': result['compression_ratio'],
'kv_cache_memory_mb': result['kv_cache_memory_mb'],
'compression_type': config.compression_type.value
})
elif config.benchmark_type == "longbench":
if config.benchmark_subset:
if config.compression_type != CompressionType.NONE:
cache_manager = QuantizedKVCache(config)
cache_manager.n_layers = n_layers
else:
cache_manager = None
result = evaluate_longbench_task(model, tokenizer, config,
config.benchmark_subset, cache_manager)
metrics.longbench_scores.append(result)
metrics.compression_ratios.append(result['compression_ratio'])
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
metrics.prefill_times.append(result['prefill_time'])
if result['generation_time'] > 0:
metrics.decode_times.append(result['generation_time'] / 50)
per_sample_records.append({
'benchmark': 'longbench',
'subset': config.benchmark_subset,
'accuracy': result['accuracy'],
'compression_ratio': result['compression_ratio'],
'kv_cache_memory_mb': result['kv_cache_memory_mb'],
'compression_type': config.compression_type.value
})
else:
for idx in range(config.eval_samples):
logger.info(f"Sample {idx+1}/{config.eval_samples}")
text_idx = (idx + seed * config.eval_samples) % len(dataset_texts)
text = dataset_texts[text_idx]
if config.compression_type != CompressionType.NONE:
cache_manager = QuantizedKVCache(config)
cache_manager.n_layers = n_layers
cache_manager.update_position(config.prefill_length + idx)
else:
cache_manager = None
inputs = safe_tokenize(tokenizer, text, max_length=min(config.prefill_length, 1024))
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
compression_result = apply_compression_pipeline(
model, tokenizer, input_ids, attention_mask, cache_manager, config
)
metrics.prefill_times.append(compression_result['prefill_time'])
metrics.compression_ratios.append(compression_result['compression_ratio'])
metrics.kv_cache_memory_samples_mb.append(compression_result['compressed_cache_size'] / (1024 * 1024))
if compression_result['prefill_peak_mem'] > 0:
metrics.prefill_peak_memories.append(compression_result['prefill_peak_mem'])
if compression_result['prefill_loss'] is not None:
prefill_perplexity = np.exp(compression_result['prefill_loss'])
metrics.prefill_perplexities.append(min(prefill_perplexity, 1000))
generated_ids = input_ids.clone()
decode_times = []
generation_losses = []
past_key_values = compression_result['past_key_values']
for gen_step in range(config.generation_length):
if torch.cuda.is_available():
torch.cuda.synchronize()
step_start = time.perf_counter()
with torch.inference_mode():
outputs = model(
generated_ids[:, -1:],
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
loss = F.cross_entropy(next_token_logits, next_token)
generation_losses.append(loss.item())
generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1)
past_key_values = outputs.past_key_values
if torch.cuda.is_available():
torch.cuda.synchronize()
decode_time = time.perf_counter() - step_start
decode_times.append(decode_time)
if decode_times:
metrics.decode_times.extend(decode_times)
if generation_losses:
generation_perplexity = np.exp(np.mean(generation_losses))
metrics.generation_perplexities.append(min(generation_perplexity, 1000))
per_sample_records.append({
'benchmark': 'wikitext',
'sample_idx': idx,
'prefill_perplexity': metrics.prefill_perplexities[-1] if metrics.prefill_perplexities else None,
'generation_perplexity': metrics.generation_perplexities[-1] if metrics.generation_perplexities else None,
'compression_ratio': compression_result['compression_ratio'],
'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
'compression_type': config.compression_type.value
})
metrics.calculate_statistics(config)
all_metrics.append(metrics)
final_metrics = BenchmarkMetrics()
for m in all_metrics:
final_metrics.prefill_times.extend(m.prefill_times)
final_metrics.prefill_peak_memories.extend(m.prefill_peak_memories)
final_metrics.decode_times.extend(m.decode_times)
final_metrics.decode_peak_memories.extend(m.decode_peak_memories)
final_metrics.prefill_perplexities.extend(m.prefill_perplexities)
final_metrics.generation_perplexities.extend(m.generation_perplexities)
final_metrics.compression_ratios.extend(m.compression_ratios)
final_metrics.kv_cache_memory_samples_mb.extend(m.kv_cache_memory_samples_mb)
final_metrics.niah_retrieval_accuracy.extend(m.niah_retrieval_accuracy)
final_metrics.ruler_exact_match.extend(m.ruler_exact_match)
final_metrics.scbench_turn_accuracy.extend(m.scbench_turn_accuracy)
final_metrics.longbench_scores.extend(m.longbench_scores)
final_metrics.calculate_statistics(config)
end_time = datetime.now().isoformat()
summary = {
'compression_type': config.compression_type.value,
'model': model_name,
'benchmark_type': config.benchmark_type,
'n_seeds': config.n_seeds,
'total_samples': config.eval_samples * config.n_seeds,
'compression_ratio': final_metrics.compression_ratio_mean,
'kv_cache_memory_mb': final_metrics.kv_cache_memory_mb,
'start_time': start_time,
'end_time': end_time
}
if config.benchmark_type == "niah" and final_metrics.niah_retrieval_accuracy:
summary['niah_accuracy'] = float(np.mean(final_metrics.niah_retrieval_accuracy))
elif config.benchmark_type == "ruler" and final_metrics.ruler_exact_match:
summary['ruler_exact_match'] = float(np.mean(final_metrics.ruler_exact_match))
elif config.benchmark_type == "scbench" and final_metrics.scbench_turn_accuracy:
summary['scbench_accuracy'] = float(np.mean(final_metrics.scbench_turn_accuracy))
elif config.benchmark_type == "longbench" and final_metrics.longbench_scores:
summary['longbench_accuracy'] = float(np.mean([s['accuracy'] for s in final_metrics.longbench_scores]))
else:
summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean
summary['generation_perplexity'] = final_metrics.generation_perplexity_mean
summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000
summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms
summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec
summary['end_to_end_throughput'] = final_metrics.end_to_end_throughput
summary['end_to_end_latency_ms'] = final_metrics.end_to_end_latency_ms
summary['peak_memory_mb'] = final_metrics.prefill_peak_memory_mean_mb
return final_metrics, summary, per_sample_records, per_layer_fingerprints
def export_proof_bundle(bundle_dir: str, config: CompressionConfig,
metrics: BenchmarkMetrics, summary: Dict[str, Any],
per_sample_records: List[Dict[str, Any]],
per_layer_fingerprints: List[Dict[str, Any]]) -> str:
"""Export attestable proof bundle with all metrics and fingerprints."""
p = pathlib.Path(bundle_dir)
p.mkdir(parents=True, exist_ok=True)
manifest = {
"config": json.loads(config.to_json()),
"config_hash": config.get_hash(),
"model": config.model_name,
"benchmark_type": config.benchmark_type,
"python": sys.version,
"torch": config.torch_version,
"transformers": config.transformers_version,
"cuda": config.cuda_version,
"device_name": config.device_name,
"start_time": summary.get("start_time"),
"end_time": summary.get("end_time"),
"hostname": platform.node()
}
(p / "manifest.json").write_text(json.dumps(manifest, indent=2))
(p / "summary.json").write_text(json.dumps(summary, indent=2, default=str))
records_dir = p / "records"
records_dir.mkdir(exist_ok=True)
with open(records_dir / "metrics.jsonl", "w") as f:
for r in per_sample_records:
f.write(json.dumps(r, default=str) + "\n")
with open(records_dir / "kv_fingerprints.jsonl", "w") as f:
for r in per_layer_fingerprints:
f.write(json.dumps(r, default=str) + "\n")
try:
env_text = subprocess.check_output([sys.executable, "-m", "pip", "freeze"], text=True)
(p / "env.lock").write_text(env_text)
except Exception as e:
logger.warning(f"Could not capture environment: {e}")
(p / "env.lock").write_text(f"# Environment capture failed: {e}\n")
zip_path = str(p.with_suffix(".zip"))
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z:
for root, _, files in os.walk(p):
for name in files:
full = pathlib.Path(root) / name
z.write(full, arcname=str(full.relative_to(p)))
logger.info(f"Proof bundle exported: {zip_path}")
return zip_path
def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: ProvingConfig) -> Dict[str, Any]:
"""Verify proof bundle - recompute metrics and check tolerances."""
try:
with open(os.path.join(bundle_root, "summary.json")) as f:
summary = json.load(f)
records = []
with open(os.path.join(bundle_root, "records", "metrics.jsonl")) as f:
for line in f:
if line.strip():
records.append(json.loads(line))
except Exception as e:
raise RuntimeError(f"Failed to load proof bundle: {e}")
if not records:
raise ValueError("No per-sample records found in proof bundle")
primary_method = summary.get("compression_type", "enhanced_spg")
primary_records = [r for r in records if r.get("compression_type") == primary_method]
if not primary_records:
raise ValueError(f"No records found for method {primary_method}")
logger.info(f"Verifying {len(primary_records)} records for {primary_method}")
def mean_of(key):
vals = [float(r[key]) for r in primary_records if key in r and r[key] is not None]
return float(np.mean(vals)) if vals else None
recomputed = {}
failures = []
if config.benchmark_type == "niah":
if "niah_accuracy" in summary:
recomputed["niah_accuracy"] = mean_of("accuracy")
elif config.benchmark_type == "ruler":
if "ruler_exact_match" in summary:
recomputed["ruler_exact_match"] = mean_of("exact_match")
elif config.benchmark_type == "scbench":
if "scbench_accuracy" in summary:
recomputed["scbench_accuracy"] = mean_of("accuracy")
elif config.benchmark_type == "longbench":
if "longbench_accuracy" in summary:
recomputed["longbench_accuracy"] = mean_of("accuracy")
elif config.benchmark_type == "wikitext":
if "prefill_perplexity" in summary:
recomputed["prefill_perplexity"] = mean_of("prefill_perplexity")
if "generation_perplexity" in summary:
recomputed["generation_perplexity"] = mean_of("generation_perplexity")
recomputed["compression_ratio"] = mean_of("compression_ratio")
recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb")
for k, v in recomputed.items():
s = summary.get(k)
if v is not None and s is not None:
if abs(v - float(s)) > proving.numeric_tolerance:
failures.append(f"{k}: recomputed {v:.6f} != summary {s:.6f}")
ok = len(failures) == 0
result = {
"ok": ok,
"failures": failures,
"recomputed": recomputed,
"summary": summary,
"n_samples": len(records)
}
if not ok:
logger.error(f"Proof verification FAILED: {failures}")
else:
logger.info(f"Proof verification PASSED for {len(records)} samples")
return result