# benchmark.py """ Benchmarking, metrics, and proof generation for Enhanced SPG. Supports LongBench, NIAH, RULER, SCBench benchmarks. MEASURED VALUES ONLY - no estimations. FAIL FAST on errors. ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT. FIXED: Generation errors, proper fallback handling. """ import torch import torch.nn.functional as F import numpy as np from transformers import ( AutoTokenizer, AutoModelForCausalLM, DynamicCache ) from datasets import load_dataset from typing import Tuple, Optional, Dict, Any, List from dataclasses import dataclass, field from scipy import stats import time import json import hashlib import logging import gc import os import sys import platform import subprocess import zipfile import pathlib from datetime import datetime import random from config import ( CompressionConfig, CompressionType, ProvingConfig, ResearchConstants, SUPPORTED_MODELS, BENCHMARK_CONFIGS ) from compression import QuantizedKVCache, detect_model_layers logger = logging.getLogger(__name__) def set_seed(seed: int = 42) -> None: """Set all seeds for reproducibility with explicit validation.""" if not isinstance(seed, int) or seed < 0: raise ValueError(f"Seed must be non-negative integer, got {seed}") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False logger.info(f"Set all random seeds to {seed}") def _peak_mem_bytes_all_gpus() -> int: """Get peak memory across all GPUs. FAIL FAST if CUDA unavailable when expected.""" if not torch.cuda.is_available(): raise RuntimeError("CUDA memory tracking requested but CUDA is unavailable") torch.cuda.synchronize() total_mem = sum(torch.cuda.max_memory_allocated(d) for d in range(torch.cuda.device_count())) logger.debug(f"Peak GPU memory: {total_mem / 1024 / 1024:.1f} MB") return total_mem @dataclass class BenchmarkMetrics: """Comprehensive metrics with proper statistical handling - NO ESTIMATES.""" # Prefill metrics prefill_times: List[float] = field(default_factory=list) prefill_peak_memories: List[float] = field(default_factory=list) prefill_time_mean: float = 0.0 prefill_time_std: float = 0.0 prefill_time_ci: Tuple[float, float] = (0.0, 0.0) prefill_peak_memory_mean_mb: float = 0.0 prefill_peak_memory_std_mb: float = 0.0 prefill_peak_memory_ci_mb: Tuple[float, float] = (0.0, 0.0) prefill_tokens_per_sec: float = 0.0 # Decode metrics decode_times: List[float] = field(default_factory=list) decode_peak_memories: List[float] = field(default_factory=list) decode_time_per_token_mean_ms: float = 0.0 decode_time_per_token_std_ms: float = 0.0 decode_time_per_token_ci_ms: Tuple[float, float] = (0.0, 0.0) decode_time_p50_ms: float = 0.0 decode_time_p95_ms: float = 0.0 decode_peak_memory_mean_mb: float = 0.0 decode_tokens_per_sec: float = 0.0 # Quality metrics prefill_perplexities: List[float] = field(default_factory=list) generation_perplexities: List[float] = field(default_factory=list) prefill_perplexity_mean: float = 0.0 prefill_perplexity_std: float = 0.0 prefill_perplexity_ci: Tuple[float, float] = (0.0, 0.0) generation_perplexity_mean: float = 0.0 generation_perplexity_std: float = 0.0 generation_perplexity_ci: Tuple[float, float] = (0.0, 0.0) # Benchmark-specific metrics longbench_scores: List[Dict[str, float]] = field(default_factory=list) niah_retrieval_accuracy: List[float] = field(default_factory=list) ruler_exact_match: List[float] = field(default_factory=list) scbench_turn_accuracy: List[float] = field(default_factory=list) # Compression metrics (MEASURED ONLY - no estimates) compression_ratios: List[float] = field(default_factory=list) compression_ratio_mean: float = 0.0 compression_ratio_std: float = 0.0 kv_cache_memory_mb: float = 0.0 kv_cache_memory_samples_mb: List[float] = field(default_factory=list) # Enhanced SPG metrics (MEASURED ONLY) enhanced_spg_measured_compression: List[float] = field(default_factory=list) enhanced_spg_measured_auxiliary_overhead_mb: List[float] = field(default_factory=list) enhanced_spg_progressive_steps: List[int] = field(default_factory=list) # Original SPG metrics spg_precision_distributions: List[Dict[str, float]] = field(default_factory=list) spg_effective_bits_per_token: List[float] = field(default_factory=list) spg_decay_rates_per_layer: List[List[float]] = field(default_factory=list) # Statistical comparisons memory_reduction_ratio: float = 1.0 memory_reduction_pvalue: float = 1.0 speedup_ratio: float = 1.0 speedup_pvalue: float = 1.0 prefill_perplexity_delta: float = 0.0 generation_perplexity_delta: float = 0.0 perplexity_pvalue: float = 1.0 # End-to-end metrics end_to_end_throughput: float = 0.0 end_to_end_latency_ms: float = 0.0 def calculate_statistics(self, config: CompressionConfig) -> None: """Calculate all statistics with proper error handling.""" try: if self.prefill_times: self.prefill_time_mean = float(np.mean(self.prefill_times)) self.prefill_time_std = float(np.std(self.prefill_times)) self.prefill_time_ci = self._bootstrap_ci(self.prefill_times, config) self.prefill_tokens_per_sec = config.prefill_length / self.prefill_time_mean if self.prefill_time_mean > 0 else 0.0 if self.prefill_peak_memories: memories_mb = [m / (1024 * 1024) for m in self.prefill_peak_memories] self.prefill_peak_memory_mean_mb = float(np.mean(memories_mb)) self.prefill_peak_memory_std_mb = float(np.std(memories_mb)) self.prefill_peak_memory_ci_mb = self._bootstrap_ci(memories_mb, config) if self.decode_times: self.decode_time_per_token_mean_ms = float(np.mean(self.decode_times) * 1000) self.decode_time_per_token_std_ms = float(np.std(self.decode_times) * 1000) self.decode_time_per_token_ci_ms = tuple(x * 1000 for x in self._bootstrap_ci(self.decode_times, config)) self.decode_tokens_per_sec = 1.0 / np.mean(self.decode_times) if self.decode_times else 0.0 self.decode_time_p50_ms = float(np.percentile(self.decode_times, 50) * 1000) self.decode_time_p95_ms = float(np.percentile(self.decode_times, 95) * 1000) # Calculate end-to-end throughput if self.prefill_time_mean > 0 and self.decode_time_per_token_mean_ms > 0: total_tokens = config.prefill_length + config.generation_length total_time_sec = self.prefill_time_mean + (self.decode_time_per_token_mean_ms * config.generation_length / 1000) self.end_to_end_throughput = total_tokens / total_time_sec if total_time_sec > 0 else 0.0 self.end_to_end_latency_ms = total_time_sec * 1000 if self.decode_peak_memories: self.decode_peak_memory_mean_mb = float(np.mean(self.decode_peak_memories) / (1024 * 1024)) if self.prefill_perplexities: self.prefill_perplexity_mean = float(np.mean(self.prefill_perplexities)) self.prefill_perplexity_std = float(np.std(self.prefill_perplexities)) self.prefill_perplexity_ci = self._bootstrap_ci(self.prefill_perplexities, config) if self.generation_perplexities: self.generation_perplexity_mean = float(np.mean(self.generation_perplexities)) self.generation_perplexity_std = float(np.std(self.generation_perplexities)) self.generation_perplexity_ci = self._bootstrap_ci(self.generation_perplexities, config) if self.compression_ratios: self.compression_ratio_mean = float(np.mean(self.compression_ratios)) self.compression_ratio_std = float(np.std(self.compression_ratios)) if self.kv_cache_memory_samples_mb: self.kv_cache_memory_mb = float(np.mean(self.kv_cache_memory_samples_mb)) except Exception as e: logger.error(f"Error calculating statistics: {e}") raise def _bootstrap_ci(self, data: List[float], config: CompressionConfig) -> Tuple[float, float]: """Calculate bootstrap confidence interval with reproducible RNG.""" if not data or len(data) < 2: return (0.0, 0.0) try: rng = np.random.default_rng(config.seed) bootstrap_means = [] data_array = np.array(data) for _ in range(config.n_bootstrap): sample = rng.choice(data_array, size=len(data_array), replace=True) bootstrap_means.append(float(sample.mean())) if bootstrap_means: alpha = 1 - config.confidence_level lower = float(np.percentile(bootstrap_means, alpha/2 * 100)) upper = float(np.percentile(bootstrap_means, (1 - alpha/2) * 100)) return (lower, upper) except Exception as e: logger.error(f"Error in bootstrap CI calculation: {e}") raise return (0.0, 0.0) def safe_tokenize(tokenizer, text, max_length=512): """Safe tokenization with proper padding and truncation.""" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=max_length, padding="max_length", return_attention_mask=True, add_special_tokens=True ) if inputs.input_ids.shape[1] == 0: raise ValueError("Tokenization produced empty sequence") if inputs.input_ids.shape[1] > max_length: inputs.input_ids = inputs.input_ids[:, :max_length] inputs.attention_mask = inputs.attention_mask[:, :max_length] return inputs def validate_model_inputs(model, input_ids, attention_mask): """Validate inputs are compatible with model.""" if hasattr(model.config, 'max_position_embeddings'): max_pos = model.config.max_position_embeddings if input_ids.shape[1] > max_pos: input_ids = input_ids[:, :max_pos] attention_mask = attention_mask[:, :max_pos] if hasattr(model.config, 'n_positions'): n_pos = model.config.n_positions if input_ids.shape[1] > n_pos: input_ids = input_ids[:, :n_pos] attention_mask = attention_mask[:, :n_pos] vocab_size = model.config.vocab_size if input_ids.max() >= vocab_size: input_ids = input_ids.clamp(0, vocab_size - 1) if input_ids.min() < 0: input_ids = input_ids.clamp(0, vocab_size - 1) return input_ids, attention_mask def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=None, max_new_tokens=20): """Safe generation with proper error handling - returns generated text.""" try: input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask) gen_config = { "max_new_tokens": max_new_tokens, "do_sample": False, "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, "eos_token_id": tokenizer.eos_token_id, "attention_mask": attention_mask, "use_cache": True } if past_key_values is not None: gen_config["past_key_values"] = past_key_values with torch.no_grad(): output = model.generate(input_ids, **gen_config) # Decode only the generated part generated_ids = output[:, input_ids.shape[1]:] generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) return generated_text except Exception as e: logger.error(f"Generation failed: {e}") # Return empty string on failure return "" def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask, cache_manager: QuantizedKVCache, config: CompressionConfig, measure_memory: bool = True) -> Dict[str, Any]: """ Unified compression pipeline for ALL benchmarks with safety fixes. Returns compressed cache, metrics, and reconstructed KV pairs. """ device = input_ids.device input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask) if torch.cuda.is_available() and measure_memory: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() if torch.cuda.is_available(): torch.cuda.synchronize() start_time = time.perf_counter() try: with torch.inference_mode(): outputs = model( input_ids, attention_mask=attention_mask, use_cache=True, return_dict=True ) past_key_values = outputs.past_key_values logits = outputs.logits except Exception as e: logger.error(f"Prefill failed: {e}") return { 'past_key_values': None, 'prefill_time': 0, 'prefill_peak_mem': 0, 'prefill_loss': None, 'original_cache_size': 0, 'compressed_cache_size': 0, 'compression_ratio': 1.0, 'logits': None } if torch.cuda.is_available(): torch.cuda.synchronize() prefill_time = time.perf_counter() - start_time prefill_peak_mem = 0 if torch.cuda.is_available() and measure_memory: prefill_peak_mem = _peak_mem_bytes_all_gpus() prefill_loss = None if logits is not None and input_ids.shape[1] > 1: try: seq_len = min(logits.shape[1], input_ids.shape[1] - 1) if seq_len > 0: shift_logits = logits[:, :seq_len, :].contiguous() shift_labels = input_ids[:, 1:seq_len+1].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction='mean', ignore_index=tokenizer.pad_token_id or -100 ) prefill_loss = loss.item() except Exception as e: logger.warning(f"Could not calculate prefill loss: {e}") original_cache_size = 0 compressed_cache_size = 0 compression_ratio = 1.0 if past_key_values: try: if hasattr(past_key_values, 'to_legacy_cache'): kv_tuple = past_key_values.to_legacy_cache() else: kv_tuple = past_key_values for layer_idx, (keys, values) in enumerate(kv_tuple): if keys is not None and values is not None: original_cache_size += keys.nelement() * keys.element_size() original_cache_size += values.nelement() * values.element_size() if config.compression_type != CompressionType.NONE and cache_manager is not None: try: cache_manager.compress_and_store(layer_idx, keys, values) except Exception as e: logger.error(f"Compression failed for layer {layer_idx}: {e}") if config.compression_type != CompressionType.NONE and cache_manager is not None: reconstructed_kv = [] for layer_idx in range(len(kv_tuple)): try: dec_keys, dec_values = cache_manager.get_decompressed(layer_idx) if dec_keys is not None and dec_values is not None: reconstructed_kv.append((dec_keys, dec_values)) else: reconstructed_kv.append(kv_tuple[layer_idx]) except Exception as e: logger.error(f"Decompression failed for layer {layer_idx}: {e}") reconstructed_kv.append(kv_tuple[layer_idx]) if hasattr(DynamicCache, 'from_legacy_cache'): past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv)) else: past_key_values = tuple(reconstructed_kv) try: compressed_cache_size = cache_manager.get_memory_footprint() except: compressed_cache_size = original_cache_size else: compressed_cache_size = original_cache_size if compressed_cache_size > 0: compression_ratio = original_cache_size / compressed_cache_size except Exception as e: logger.error(f"Cache processing failed: {e}") compressed_cache_size = original_cache_size compression_ratio = 1.0 return { 'past_key_values': past_key_values, 'prefill_time': prefill_time, 'prefill_peak_mem': prefill_peak_mem, 'prefill_loss': prefill_loss, 'original_cache_size': original_cache_size, 'compressed_cache_size': compressed_cache_size, 'compression_ratio': compression_ratio, 'logits': logits } def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str: """Create Needle-in-a-Haystack test context - NO HARDCODING.""" haystack_template = "The quick brown fox jumps over the lazy dog. " * 20 haystack_chunks = [] while len(" ".join(haystack_chunks)) < context_length: haystack_chunks.append(haystack_template) haystack = " ".join(haystack_chunks)[:context_length - len(needle) - 10] insertion_point = int(len(haystack) * depth_percent / 100) haystack_with_needle = ( haystack[:insertion_point] + " " + needle + " " + haystack[insertion_point:] ) return haystack_with_needle def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]: """Evaluate NIAH with SAME compression pipeline as WikiText.""" context = create_niah_haystack( config.prefill_length, config.niah_needle, config.niah_depth_percent ) prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:" inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024)) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) compression_result = apply_compression_pipeline( model, tokenizer, input_ids, attention_mask, cache_manager, config ) gen_start = time.perf_counter() generated_text = safe_generate(model, tokenizer, input_ids, attention_mask, compression_result['past_key_values'], max_new_tokens=20) gen_time = time.perf_counter() - gen_start accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0 logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}") logger.info(f"NIAH compression ratio: {compression_result['compression_ratio']:.1f}x") return { 'accuracy': accuracy, 'compression_ratio': compression_result['compression_ratio'], 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024), 'prefill_time': compression_result['prefill_time'], 'generation_time': gen_time, 'prefill_peak_mem': compression_result['prefill_peak_mem'] } def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]: """Evaluate RULER with SAME compression pipeline as WikiText.""" seq_len = min(config.ruler_max_seq_length, config.prefill_length, 1024) facts = [] for i in range(10): facts.append(f"Fact {i}: The capital of Country{i} is City{i}.") context = " ".join(facts) * (seq_len // (len(" ".join(facts)) + 1)) context = context[:seq_len - 100] query_idx = random.randint(0, 9) prompt = f"{context}\n\nWhat is the capital of Country{query_idx}?" inputs = safe_tokenize(tokenizer, prompt, max_length=seq_len) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) compression_result = apply_compression_pipeline( model, tokenizer, input_ids, attention_mask, cache_manager, config ) gen_start = time.perf_counter() generated = safe_generate(model, tokenizer, input_ids, attention_mask, compression_result['past_key_values'], max_new_tokens=10) gen_time = time.perf_counter() - gen_start expected = f"City{query_idx}" exact_match = 1.0 if expected in generated else 0.0 logger.info(f"RULER exact match: {exact_match}, Generated: {generated[:50]}") logger.info(f"RULER compression ratio: {compression_result['compression_ratio']:.1f}x") return { 'exact_match': exact_match, 'compression_ratio': compression_result['compression_ratio'], 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024), 'prefill_time': compression_result['prefill_time'], 'generation_time': gen_time, 'prefill_peak_mem': compression_result['prefill_peak_mem'] } def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]: """Evaluate SCBench with SAME compression pipeline as WikiText.""" conversation = [] facts = {} for turn in range(config.scbench_num_turns): fact_key = f"item_{turn}" fact_value = f"value_{turn}_{random.randint(1000, 9999)}" facts[fact_key] = fact_value user_msg = f"Remember that {fact_key} is {fact_value}." assistant_msg = f"I'll remember that {fact_key} is {fact_value}." conversation.append(f"User: {user_msg}") conversation.append(f"Assistant: {assistant_msg}") query_key = random.choice(list(facts.keys())) conversation.append(f"User: What is {query_key}?") full_conversation = "\n".join(conversation) + "\nAssistant:" inputs = safe_tokenize(tokenizer, full_conversation, max_length=min(config.prefill_length, 1024)) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) compression_result = apply_compression_pipeline( model, tokenizer, input_ids, attention_mask, cache_manager, config ) gen_start = time.perf_counter() generated = safe_generate(model, tokenizer, input_ids, attention_mask, compression_result['past_key_values'], max_new_tokens=20) gen_time = time.perf_counter() - gen_start expected_value = facts[query_key] accuracy = 1.0 if expected_value in generated else 0.0 logger.info(f"SCBench accuracy: {accuracy}, Generated: {generated[:50]}") logger.info(f"SCBench compression ratio: {compression_result['compression_ratio']:.1f}x") return { 'accuracy': accuracy, 'compression_ratio': compression_result['compression_ratio'], 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024), 'prefill_time': compression_result['prefill_time'], 'generation_time': gen_time, 'prefill_peak_mem': compression_result['prefill_peak_mem'] } def evaluate_longbench_task(model, tokenizer, config: CompressionConfig, task: str, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]: """Evaluate LongBench with SAME compression pipeline as WikiText.""" try: dataset = load_dataset("THUDM/LongBench", task, split="test") n_samples = min(config.eval_samples, len(dataset)) samples = dataset.select(range(n_samples)) scores = [] compression_ratios = [] kv_memories = [] prefill_times = [] gen_times = [] for sample in samples: context = sample.get("context", "") question = sample.get("input", sample.get("question", "")) answer = sample.get("answers", [sample.get("answer", "")]) if isinstance(answer, list) and answer: answer = answer[0] prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024)) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) compression_result = apply_compression_pipeline( model, tokenizer, input_ids, attention_mask, cache_manager, config, measure_memory=False ) gen_start = time.perf_counter() generated = safe_generate(model, tokenizer, input_ids, attention_mask, compression_result['past_key_values'], max_new_tokens=50) gen_time = time.perf_counter() - gen_start score = 1.0 if str(answer).lower() in generated.lower() else 0.0 scores.append(score) compression_ratios.append(compression_result['compression_ratio']) kv_memories.append(compression_result['compressed_cache_size'] / (1024 * 1024)) prefill_times.append(compression_result['prefill_time']) gen_times.append(gen_time) avg_compression = float(np.mean(compression_ratios)) if compression_ratios else 1.0 return { 'accuracy': float(np.mean(scores)), 'n_samples': n_samples, 'compression_ratio': avg_compression, 'kv_cache_memory_mb': float(np.mean(kv_memories)) if kv_memories else 0.0, 'prefill_time': float(np.mean(prefill_times)) if prefill_times else 0.0, 'generation_time': float(np.mean(gen_times)) if gen_times else 0.0 } except Exception as e: logger.error(f"Error evaluating LongBench task {task}: {e}") return { 'accuracy': 0.0, 'n_samples': 0, 'compression_ratio': 1.0, 'kv_cache_memory_mb': 0.0, 'prefill_time': 0.0, 'generation_time': 0.0 } def load_model_and_tokenizer(model_name: str, config: CompressionConfig): """Load model and tokenizer with proper configuration - NO HARDCODING.""" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 if config.fail_on_cpu_fallback and device == "cpu": raise RuntimeError("CUDA required but unavailable (fail_on_cpu_fallback=True)") logger.info(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model_kwargs = { "torch_dtype": dtype, "device_map": "auto" if device == "cuda" else None, "low_cpu_mem_usage": True, "trust_remote_code": True } if config.use_flash_attention and device == "cuda": try: model_kwargs["attn_implementation"] = "flash_attention_2" model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) logger.info("Successfully loaded with Flash Attention 2") except Exception as e: logger.warning(f"Flash Attention not available: {e}") model_kwargs.pop("attn_implementation", None) model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) else: model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) model.eval() return model, tokenizer def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]: """Load dataset samples based on benchmark type - NO HARDCODING.""" logger.info(f"Loading samples for benchmark: {config.benchmark_type}") if config.benchmark_type == "wikitext": texts = [] min_tokens = config.prefill_length + config.generation_length try: for split in [config.dataset_split, "train", "validation"]: if len(texts) >= config.eval_samples: break try: dataset = load_dataset( config.dataset_name, config.dataset_config, split=split, streaming=False ) logger.info(f"Trying {split} split with {len(dataset)} samples") for item in dataset: text = item.get('text', '').strip() if len(text) > 50: tokens = tokenizer.encode(text, truncation=False, add_special_tokens=False) if len(tokens) >= min(min_tokens, 256): texts.append(text) if len(texts) >= config.eval_samples * 3: break except Exception as e: logger.warning(f"Failed to load {split} split: {e}") continue except Exception as e: logger.error(f"Failed to load dataset: {e}") raise elif config.benchmark_type == "longbench": texts = [] if config.benchmark_subset: try: dataset = load_dataset("THUDM/LongBench", config.benchmark_subset, split="test") for item in dataset: if len(texts) >= config.eval_samples: break context = item.get("context", "") if len(context) > 100: texts.append(context) except Exception as e: logger.error(f"Failed to load LongBench subset {config.benchmark_subset}: {e}") raise elif config.benchmark_type in ["niah", "ruler", "scbench"]: texts = ["Synthetic benchmark data"] * config.eval_samples else: raise ValueError(f"Unsupported benchmark type: {config.benchmark_type}") if len(texts) < config.eval_samples: logger.warning(f"Only loaded {len(texts)} samples, requested {config.eval_samples}") logger.info(f"Loaded {len(texts)} text samples") return texts def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_texts: Optional[List[str]] = None) -> Tuple[BenchmarkMetrics, Dict, List[Dict], List[Dict]]: """Research-grade benchmark with UNIFIED compression for ALL benchmarks.""" logger.info(f"Starting benchmark: {model_name} with {config.compression_type.value}") logger.info(f"Benchmark type: {config.benchmark_type}") logger.info(f"Config hash: {config.get_hash()}") if torch.cuda.is_available(): os.environ['CUDA_LAUNCH_BLOCKING'] = '1' constants = ResearchConstants() start_time = datetime.now().isoformat() per_sample_records = [] per_layer_fingerprints = [] model, tokenizer = load_model_and_tokenizer(model_name, config) try: n_layers = detect_model_layers(model) logger.info(f"Model architecture: {n_layers} transformer layers detected") except ValueError as e: logger.error(f"Failed to detect model layers: {e}") raise device = model.device with torch.inference_mode(): dummy = torch.randint(0, tokenizer.vocab_size, (1, min(config.prefill_length, 128)), device=device) am = torch.ones_like(dummy) for _ in range(config.warmup_steps): _ = model(dummy, attention_mask=am, use_cache=True, return_dict=True) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() if dataset_texts is None: dataset_texts = load_real_dataset_samples(config, tokenizer) all_metrics = [] for seed in range(config.n_seeds): set_seed(config.seed + seed) logger.info(f"Running evaluation with seed {config.seed + seed}") metrics = BenchmarkMetrics() if config.benchmark_type == "niah": for depth in BENCHMARK_CONFIGS["niah"]["depths"]: config.niah_depth_percent = depth for idx in range(min(config.eval_samples, 10)): if config.compression_type != CompressionType.NONE: cache_manager = QuantizedKVCache(config) cache_manager.n_layers = n_layers else: cache_manager = None result = evaluate_niah(model, tokenizer, config, cache_manager) metrics.niah_retrieval_accuracy.append(result['accuracy']) metrics.compression_ratios.append(result['compression_ratio']) metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb']) metrics.prefill_times.append(result['prefill_time']) metrics.decode_times.append(result['generation_time'] / 20) if result['prefill_peak_mem'] > 0: metrics.prefill_peak_memories.append(result['prefill_peak_mem']) per_sample_records.append({ 'benchmark': 'niah', 'depth_percent': depth, 'sample_idx': idx, 'accuracy': result['accuracy'], 'compression_ratio': result['compression_ratio'], 'kv_cache_memory_mb': result['kv_cache_memory_mb'], 'compression_type': config.compression_type.value }) elif config.benchmark_type == "ruler": for idx in range(config.eval_samples): if config.compression_type != CompressionType.NONE: cache_manager = QuantizedKVCache(config) cache_manager.n_layers = n_layers else: cache_manager = None result = evaluate_ruler(model, tokenizer, config, cache_manager) metrics.ruler_exact_match.append(result['exact_match']) metrics.compression_ratios.append(result['compression_ratio']) metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb']) metrics.prefill_times.append(result['prefill_time']) metrics.decode_times.append(result['generation_time'] / 10) if result['prefill_peak_mem'] > 0: metrics.prefill_peak_memories.append(result['prefill_peak_mem']) per_sample_records.append({ 'benchmark': 'ruler', 'sample_idx': idx, 'exact_match': result['exact_match'], 'compression_ratio': result['compression_ratio'], 'kv_cache_memory_mb': result['kv_cache_memory_mb'], 'compression_type': config.compression_type.value }) elif config.benchmark_type == "scbench": for idx in range(config.eval_samples): if config.compression_type != CompressionType.NONE: cache_manager = QuantizedKVCache(config) cache_manager.n_layers = n_layers else: cache_manager = None result = evaluate_scbench(model, tokenizer, config, cache_manager) metrics.scbench_turn_accuracy.append(result['accuracy']) metrics.compression_ratios.append(result['compression_ratio']) metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb']) metrics.prefill_times.append(result['prefill_time']) metrics.decode_times.append(result['generation_time'] / 20) if result['prefill_peak_mem'] > 0: metrics.prefill_peak_memories.append(result['prefill_peak_mem']) per_sample_records.append({ 'benchmark': 'scbench', 'sample_idx': idx, 'accuracy': result['accuracy'], 'compression_ratio': result['compression_ratio'], 'kv_cache_memory_mb': result['kv_cache_memory_mb'], 'compression_type': config.compression_type.value }) elif config.benchmark_type == "longbench": if config.benchmark_subset: if config.compression_type != CompressionType.NONE: cache_manager = QuantizedKVCache(config) cache_manager.n_layers = n_layers else: cache_manager = None result = evaluate_longbench_task(model, tokenizer, config, config.benchmark_subset, cache_manager) metrics.longbench_scores.append(result) metrics.compression_ratios.append(result['compression_ratio']) metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb']) metrics.prefill_times.append(result['prefill_time']) if result['generation_time'] > 0: metrics.decode_times.append(result['generation_time'] / 50) per_sample_records.append({ 'benchmark': 'longbench', 'subset': config.benchmark_subset, 'accuracy': result['accuracy'], 'compression_ratio': result['compression_ratio'], 'kv_cache_memory_mb': result['kv_cache_memory_mb'], 'compression_type': config.compression_type.value }) else: for idx in range(config.eval_samples): logger.info(f"Sample {idx+1}/{config.eval_samples}") text_idx = (idx + seed * config.eval_samples) % len(dataset_texts) text = dataset_texts[text_idx] if config.compression_type != CompressionType.NONE: cache_manager = QuantizedKVCache(config) cache_manager.n_layers = n_layers cache_manager.update_position(config.prefill_length + idx) else: cache_manager = None inputs = safe_tokenize(tokenizer, text, max_length=min(config.prefill_length, 1024)) input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device) compression_result = apply_compression_pipeline( model, tokenizer, input_ids, attention_mask, cache_manager, config ) metrics.prefill_times.append(compression_result['prefill_time']) metrics.compression_ratios.append(compression_result['compression_ratio']) metrics.kv_cache_memory_samples_mb.append(compression_result['compressed_cache_size'] / (1024 * 1024)) if compression_result['prefill_peak_mem'] > 0: metrics.prefill_peak_memories.append(compression_result['prefill_peak_mem']) if compression_result['prefill_loss'] is not None: prefill_perplexity = np.exp(compression_result['prefill_loss']) metrics.prefill_perplexities.append(min(prefill_perplexity, 1000)) generated_ids = input_ids.clone() decode_times = [] generation_losses = [] past_key_values = compression_result['past_key_values'] for gen_step in range(config.generation_length): if torch.cuda.is_available(): torch.cuda.synchronize() step_start = time.perf_counter() with torch.inference_mode(): outputs = model( generated_ids[:, -1:], past_key_values=past_key_values, use_cache=True, return_dict=True ) next_token_logits = outputs.logits[:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1) loss = F.cross_entropy(next_token_logits, next_token) generation_losses.append(loss.item()) generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) past_key_values = outputs.past_key_values if torch.cuda.is_available(): torch.cuda.synchronize() decode_time = time.perf_counter() - step_start decode_times.append(decode_time) if decode_times: metrics.decode_times.extend(decode_times) if generation_losses: generation_perplexity = np.exp(np.mean(generation_losses)) metrics.generation_perplexities.append(min(generation_perplexity, 1000)) per_sample_records.append({ 'benchmark': 'wikitext', 'sample_idx': idx, 'prefill_perplexity': metrics.prefill_perplexities[-1] if metrics.prefill_perplexities else None, 'generation_perplexity': metrics.generation_perplexities[-1] if metrics.generation_perplexities else None, 'compression_ratio': compression_result['compression_ratio'], 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024), 'compression_type': config.compression_type.value }) metrics.calculate_statistics(config) all_metrics.append(metrics) final_metrics = BenchmarkMetrics() for m in all_metrics: final_metrics.prefill_times.extend(m.prefill_times) final_metrics.prefill_peak_memories.extend(m.prefill_peak_memories) final_metrics.decode_times.extend(m.decode_times) final_metrics.decode_peak_memories.extend(m.decode_peak_memories) final_metrics.prefill_perplexities.extend(m.prefill_perplexities) final_metrics.generation_perplexities.extend(m.generation_perplexities) final_metrics.compression_ratios.extend(m.compression_ratios) final_metrics.kv_cache_memory_samples_mb.extend(m.kv_cache_memory_samples_mb) final_metrics.niah_retrieval_accuracy.extend(m.niah_retrieval_accuracy) final_metrics.ruler_exact_match.extend(m.ruler_exact_match) final_metrics.scbench_turn_accuracy.extend(m.scbench_turn_accuracy) final_metrics.longbench_scores.extend(m.longbench_scores) final_metrics.calculate_statistics(config) end_time = datetime.now().isoformat() summary = { 'compression_type': config.compression_type.value, 'model': model_name, 'benchmark_type': config.benchmark_type, 'n_seeds': config.n_seeds, 'total_samples': config.eval_samples * config.n_seeds, 'compression_ratio': final_metrics.compression_ratio_mean, 'kv_cache_memory_mb': final_metrics.kv_cache_memory_mb, 'start_time': start_time, 'end_time': end_time } if config.benchmark_type == "niah" and final_metrics.niah_retrieval_accuracy: summary['niah_accuracy'] = float(np.mean(final_metrics.niah_retrieval_accuracy)) elif config.benchmark_type == "ruler" and final_metrics.ruler_exact_match: summary['ruler_exact_match'] = float(np.mean(final_metrics.ruler_exact_match)) elif config.benchmark_type == "scbench" and final_metrics.scbench_turn_accuracy: summary['scbench_accuracy'] = float(np.mean(final_metrics.scbench_turn_accuracy)) elif config.benchmark_type == "longbench" and final_metrics.longbench_scores: summary['longbench_accuracy'] = float(np.mean([s['accuracy'] for s in final_metrics.longbench_scores])) else: summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean summary['generation_perplexity'] = final_metrics.generation_perplexity_mean summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000 summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec summary['end_to_end_throughput'] = final_metrics.end_to_end_throughput summary['end_to_end_latency_ms'] = final_metrics.end_to_end_latency_ms summary['peak_memory_mb'] = final_metrics.prefill_peak_memory_mean_mb return final_metrics, summary, per_sample_records, per_layer_fingerprints def export_proof_bundle(bundle_dir: str, config: CompressionConfig, metrics: BenchmarkMetrics, summary: Dict[str, Any], per_sample_records: List[Dict[str, Any]], per_layer_fingerprints: List[Dict[str, Any]]) -> str: """Export attestable proof bundle with all metrics and fingerprints.""" p = pathlib.Path(bundle_dir) p.mkdir(parents=True, exist_ok=True) manifest = { "config": json.loads(config.to_json()), "config_hash": config.get_hash(), "model": config.model_name, "benchmark_type": config.benchmark_type, "python": sys.version, "torch": config.torch_version, "transformers": config.transformers_version, "cuda": config.cuda_version, "device_name": config.device_name, "start_time": summary.get("start_time"), "end_time": summary.get("end_time"), "hostname": platform.node() } (p / "manifest.json").write_text(json.dumps(manifest, indent=2)) (p / "summary.json").write_text(json.dumps(summary, indent=2, default=str)) records_dir = p / "records" records_dir.mkdir(exist_ok=True) with open(records_dir / "metrics.jsonl", "w") as f: for r in per_sample_records: f.write(json.dumps(r, default=str) + "\n") with open(records_dir / "kv_fingerprints.jsonl", "w") as f: for r in per_layer_fingerprints: f.write(json.dumps(r, default=str) + "\n") try: env_text = subprocess.check_output([sys.executable, "-m", "pip", "freeze"], text=True) (p / "env.lock").write_text(env_text) except Exception as e: logger.warning(f"Could not capture environment: {e}") (p / "env.lock").write_text(f"# Environment capture failed: {e}\n") zip_path = str(p.with_suffix(".zip")) with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: for root, _, files in os.walk(p): for name in files: full = pathlib.Path(root) / name z.write(full, arcname=str(full.relative_to(p))) logger.info(f"Proof bundle exported: {zip_path}") return zip_path def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: ProvingConfig) -> Dict[str, Any]: """Verify proof bundle - recompute metrics and check tolerances.""" try: with open(os.path.join(bundle_root, "summary.json")) as f: summary = json.load(f) records = [] with open(os.path.join(bundle_root, "records", "metrics.jsonl")) as f: for line in f: if line.strip(): records.append(json.loads(line)) except Exception as e: raise RuntimeError(f"Failed to load proof bundle: {e}") if not records: raise ValueError("No per-sample records found in proof bundle") primary_method = summary.get("compression_type", "enhanced_spg") primary_records = [r for r in records if r.get("compression_type") == primary_method] if not primary_records: raise ValueError(f"No records found for method {primary_method}") logger.info(f"Verifying {len(primary_records)} records for {primary_method}") def mean_of(key): vals = [float(r[key]) for r in primary_records if key in r and r[key] is not None] return float(np.mean(vals)) if vals else None recomputed = {} failures = [] if config.benchmark_type == "niah": if "niah_accuracy" in summary: recomputed["niah_accuracy"] = mean_of("accuracy") elif config.benchmark_type == "ruler": if "ruler_exact_match" in summary: recomputed["ruler_exact_match"] = mean_of("exact_match") elif config.benchmark_type == "scbench": if "scbench_accuracy" in summary: recomputed["scbench_accuracy"] = mean_of("accuracy") elif config.benchmark_type == "longbench": if "longbench_accuracy" in summary: recomputed["longbench_accuracy"] = mean_of("accuracy") elif config.benchmark_type == "wikitext": if "prefill_perplexity" in summary: recomputed["prefill_perplexity"] = mean_of("prefill_perplexity") if "generation_perplexity" in summary: recomputed["generation_perplexity"] = mean_of("generation_perplexity") recomputed["compression_ratio"] = mean_of("compression_ratio") recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb") for k, v in recomputed.items(): s = summary.get(k) if v is not None and s is not None: if abs(v - float(s)) > proving.numeric_tolerance: failures.append(f"{k}: recomputed {v:.6f} != summary {s:.6f}") ok = len(failures) == 0 result = { "ok": ok, "failures": failures, "recomputed": recomputed, "summary": summary, "n_samples": len(records) } if not ok: logger.error(f"Proof verification FAILED: {failures}") else: logger.info(f"Proof verification PASSED for {len(records)} samples") return result