|
|
|
|
|
""" |
|
|
Benchmarking, metrics, and proof generation for Enhanced SPG. |
|
|
Supports LongBench, NIAH, RULER, SCBench benchmarks. |
|
|
MEASURED VALUES ONLY - no estimations. FAIL FAST on errors. |
|
|
ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT. |
|
|
FIXED: Generation errors, proper fallback handling. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from transformers import ( |
|
|
AutoTokenizer, AutoModelForCausalLM, |
|
|
DynamicCache |
|
|
) |
|
|
from datasets import load_dataset |
|
|
from typing import Tuple, Optional, Dict, Any, List |
|
|
from dataclasses import dataclass, field |
|
|
from scipy import stats |
|
|
import time |
|
|
import json |
|
|
import hashlib |
|
|
import logging |
|
|
import gc |
|
|
import os |
|
|
import sys |
|
|
import platform |
|
|
import subprocess |
|
|
import zipfile |
|
|
import pathlib |
|
|
from datetime import datetime |
|
|
import random |
|
|
|
|
|
from config import ( |
|
|
CompressionConfig, CompressionType, ProvingConfig, |
|
|
ResearchConstants, SUPPORTED_MODELS, BENCHMARK_CONFIGS |
|
|
) |
|
|
from compression import QuantizedKVCache, detect_model_layers |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def set_seed(seed: int = 42) -> None: |
|
|
"""Set all seeds for reproducibility with explicit validation.""" |
|
|
if not isinstance(seed, int) or seed < 0: |
|
|
raise ValueError(f"Seed must be non-negative integer, got {seed}") |
|
|
|
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
logger.info(f"Set all random seeds to {seed}") |
|
|
|
|
|
def _peak_mem_bytes_all_gpus() -> int: |
|
|
"""Get peak memory across all GPUs. FAIL FAST if CUDA unavailable when expected.""" |
|
|
if not torch.cuda.is_available(): |
|
|
raise RuntimeError("CUDA memory tracking requested but CUDA is unavailable") |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
total_mem = sum(torch.cuda.max_memory_allocated(d) for d in range(torch.cuda.device_count())) |
|
|
logger.debug(f"Peak GPU memory: {total_mem / 1024 / 1024:.1f} MB") |
|
|
return total_mem |
|
|
|
|
|
@dataclass |
|
|
class BenchmarkMetrics: |
|
|
"""Comprehensive metrics with proper statistical handling - NO ESTIMATES.""" |
|
|
|
|
|
prefill_times: List[float] = field(default_factory=list) |
|
|
prefill_peak_memories: List[float] = field(default_factory=list) |
|
|
prefill_time_mean: float = 0.0 |
|
|
prefill_time_std: float = 0.0 |
|
|
prefill_time_ci: Tuple[float, float] = (0.0, 0.0) |
|
|
prefill_peak_memory_mean_mb: float = 0.0 |
|
|
prefill_peak_memory_std_mb: float = 0.0 |
|
|
prefill_peak_memory_ci_mb: Tuple[float, float] = (0.0, 0.0) |
|
|
prefill_tokens_per_sec: float = 0.0 |
|
|
|
|
|
|
|
|
decode_times: List[float] = field(default_factory=list) |
|
|
decode_peak_memories: List[float] = field(default_factory=list) |
|
|
decode_time_per_token_mean_ms: float = 0.0 |
|
|
decode_time_per_token_std_ms: float = 0.0 |
|
|
decode_time_per_token_ci_ms: Tuple[float, float] = (0.0, 0.0) |
|
|
decode_time_p50_ms: float = 0.0 |
|
|
decode_time_p95_ms: float = 0.0 |
|
|
decode_peak_memory_mean_mb: float = 0.0 |
|
|
decode_tokens_per_sec: float = 0.0 |
|
|
|
|
|
|
|
|
prefill_perplexities: List[float] = field(default_factory=list) |
|
|
generation_perplexities: List[float] = field(default_factory=list) |
|
|
prefill_perplexity_mean: float = 0.0 |
|
|
prefill_perplexity_std: float = 0.0 |
|
|
prefill_perplexity_ci: Tuple[float, float] = (0.0, 0.0) |
|
|
generation_perplexity_mean: float = 0.0 |
|
|
generation_perplexity_std: float = 0.0 |
|
|
generation_perplexity_ci: Tuple[float, float] = (0.0, 0.0) |
|
|
|
|
|
|
|
|
longbench_scores: List[Dict[str, float]] = field(default_factory=list) |
|
|
niah_retrieval_accuracy: List[float] = field(default_factory=list) |
|
|
ruler_exact_match: List[float] = field(default_factory=list) |
|
|
scbench_turn_accuracy: List[float] = field(default_factory=list) |
|
|
|
|
|
|
|
|
compression_ratios: List[float] = field(default_factory=list) |
|
|
compression_ratio_mean: float = 0.0 |
|
|
compression_ratio_std: float = 0.0 |
|
|
kv_cache_memory_mb: float = 0.0 |
|
|
kv_cache_memory_samples_mb: List[float] = field(default_factory=list) |
|
|
|
|
|
|
|
|
enhanced_spg_measured_compression: List[float] = field(default_factory=list) |
|
|
enhanced_spg_measured_auxiliary_overhead_mb: List[float] = field(default_factory=list) |
|
|
enhanced_spg_progressive_steps: List[int] = field(default_factory=list) |
|
|
|
|
|
|
|
|
spg_precision_distributions: List[Dict[str, float]] = field(default_factory=list) |
|
|
spg_effective_bits_per_token: List[float] = field(default_factory=list) |
|
|
spg_decay_rates_per_layer: List[List[float]] = field(default_factory=list) |
|
|
|
|
|
|
|
|
memory_reduction_ratio: float = 1.0 |
|
|
memory_reduction_pvalue: float = 1.0 |
|
|
speedup_ratio: float = 1.0 |
|
|
speedup_pvalue: float = 1.0 |
|
|
prefill_perplexity_delta: float = 0.0 |
|
|
generation_perplexity_delta: float = 0.0 |
|
|
perplexity_pvalue: float = 1.0 |
|
|
|
|
|
|
|
|
end_to_end_throughput: float = 0.0 |
|
|
end_to_end_latency_ms: float = 0.0 |
|
|
|
|
|
def calculate_statistics(self, config: CompressionConfig) -> None: |
|
|
"""Calculate all statistics with proper error handling.""" |
|
|
try: |
|
|
if self.prefill_times: |
|
|
self.prefill_time_mean = float(np.mean(self.prefill_times)) |
|
|
self.prefill_time_std = float(np.std(self.prefill_times)) |
|
|
self.prefill_time_ci = self._bootstrap_ci(self.prefill_times, config) |
|
|
self.prefill_tokens_per_sec = config.prefill_length / self.prefill_time_mean if self.prefill_time_mean > 0 else 0.0 |
|
|
|
|
|
if self.prefill_peak_memories: |
|
|
memories_mb = [m / (1024 * 1024) for m in self.prefill_peak_memories] |
|
|
self.prefill_peak_memory_mean_mb = float(np.mean(memories_mb)) |
|
|
self.prefill_peak_memory_std_mb = float(np.std(memories_mb)) |
|
|
self.prefill_peak_memory_ci_mb = self._bootstrap_ci(memories_mb, config) |
|
|
|
|
|
if self.decode_times: |
|
|
self.decode_time_per_token_mean_ms = float(np.mean(self.decode_times) * 1000) |
|
|
self.decode_time_per_token_std_ms = float(np.std(self.decode_times) * 1000) |
|
|
self.decode_time_per_token_ci_ms = tuple(x * 1000 for x in self._bootstrap_ci(self.decode_times, config)) |
|
|
self.decode_tokens_per_sec = 1.0 / np.mean(self.decode_times) if self.decode_times else 0.0 |
|
|
self.decode_time_p50_ms = float(np.percentile(self.decode_times, 50) * 1000) |
|
|
self.decode_time_p95_ms = float(np.percentile(self.decode_times, 95) * 1000) |
|
|
|
|
|
|
|
|
if self.prefill_time_mean > 0 and self.decode_time_per_token_mean_ms > 0: |
|
|
total_tokens = config.prefill_length + config.generation_length |
|
|
total_time_sec = self.prefill_time_mean + (self.decode_time_per_token_mean_ms * config.generation_length / 1000) |
|
|
self.end_to_end_throughput = total_tokens / total_time_sec if total_time_sec > 0 else 0.0 |
|
|
self.end_to_end_latency_ms = total_time_sec * 1000 |
|
|
|
|
|
if self.decode_peak_memories: |
|
|
self.decode_peak_memory_mean_mb = float(np.mean(self.decode_peak_memories) / (1024 * 1024)) |
|
|
|
|
|
if self.prefill_perplexities: |
|
|
self.prefill_perplexity_mean = float(np.mean(self.prefill_perplexities)) |
|
|
self.prefill_perplexity_std = float(np.std(self.prefill_perplexities)) |
|
|
self.prefill_perplexity_ci = self._bootstrap_ci(self.prefill_perplexities, config) |
|
|
|
|
|
if self.generation_perplexities: |
|
|
self.generation_perplexity_mean = float(np.mean(self.generation_perplexities)) |
|
|
self.generation_perplexity_std = float(np.std(self.generation_perplexities)) |
|
|
self.generation_perplexity_ci = self._bootstrap_ci(self.generation_perplexities, config) |
|
|
|
|
|
if self.compression_ratios: |
|
|
self.compression_ratio_mean = float(np.mean(self.compression_ratios)) |
|
|
self.compression_ratio_std = float(np.std(self.compression_ratios)) |
|
|
|
|
|
if self.kv_cache_memory_samples_mb: |
|
|
self.kv_cache_memory_mb = float(np.mean(self.kv_cache_memory_samples_mb)) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error calculating statistics: {e}") |
|
|
raise |
|
|
|
|
|
def _bootstrap_ci(self, data: List[float], config: CompressionConfig) -> Tuple[float, float]: |
|
|
"""Calculate bootstrap confidence interval with reproducible RNG.""" |
|
|
if not data or len(data) < 2: |
|
|
return (0.0, 0.0) |
|
|
|
|
|
try: |
|
|
rng = np.random.default_rng(config.seed) |
|
|
bootstrap_means = [] |
|
|
data_array = np.array(data) |
|
|
|
|
|
for _ in range(config.n_bootstrap): |
|
|
sample = rng.choice(data_array, size=len(data_array), replace=True) |
|
|
bootstrap_means.append(float(sample.mean())) |
|
|
|
|
|
if bootstrap_means: |
|
|
alpha = 1 - config.confidence_level |
|
|
lower = float(np.percentile(bootstrap_means, alpha/2 * 100)) |
|
|
upper = float(np.percentile(bootstrap_means, (1 - alpha/2) * 100)) |
|
|
return (lower, upper) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in bootstrap CI calculation: {e}") |
|
|
raise |
|
|
|
|
|
return (0.0, 0.0) |
|
|
|
|
|
|
|
|
def safe_tokenize(tokenizer, text, max_length=512): |
|
|
"""Safe tokenization with proper padding and truncation.""" |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
padding="max_length", |
|
|
return_attention_mask=True, |
|
|
add_special_tokens=True |
|
|
) |
|
|
|
|
|
if inputs.input_ids.shape[1] == 0: |
|
|
raise ValueError("Tokenization produced empty sequence") |
|
|
|
|
|
if inputs.input_ids.shape[1] > max_length: |
|
|
inputs.input_ids = inputs.input_ids[:, :max_length] |
|
|
inputs.attention_mask = inputs.attention_mask[:, :max_length] |
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
|
def validate_model_inputs(model, input_ids, attention_mask): |
|
|
"""Validate inputs are compatible with model.""" |
|
|
if hasattr(model.config, 'max_position_embeddings'): |
|
|
max_pos = model.config.max_position_embeddings |
|
|
if input_ids.shape[1] > max_pos: |
|
|
input_ids = input_ids[:, :max_pos] |
|
|
attention_mask = attention_mask[:, :max_pos] |
|
|
|
|
|
if hasattr(model.config, 'n_positions'): |
|
|
n_pos = model.config.n_positions |
|
|
if input_ids.shape[1] > n_pos: |
|
|
input_ids = input_ids[:, :n_pos] |
|
|
attention_mask = attention_mask[:, :n_pos] |
|
|
|
|
|
vocab_size = model.config.vocab_size |
|
|
if input_ids.max() >= vocab_size: |
|
|
input_ids = input_ids.clamp(0, vocab_size - 1) |
|
|
|
|
|
if input_ids.min() < 0: |
|
|
input_ids = input_ids.clamp(0, vocab_size - 1) |
|
|
|
|
|
return input_ids, attention_mask |
|
|
|
|
|
|
|
|
def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=None, max_new_tokens=20): |
|
|
"""Safe generation with proper error handling - returns generated text.""" |
|
|
try: |
|
|
input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask) |
|
|
|
|
|
gen_config = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"do_sample": False, |
|
|
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, |
|
|
"eos_token_id": tokenizer.eos_token_id, |
|
|
"attention_mask": attention_mask, |
|
|
"use_cache": True |
|
|
} |
|
|
|
|
|
if past_key_values is not None: |
|
|
gen_config["past_key_values"] = past_key_values |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate(input_ids, **gen_config) |
|
|
|
|
|
|
|
|
generated_ids = output[:, input_ids.shape[1]:] |
|
|
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
return generated_text |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Generation failed: {e}") |
|
|
|
|
|
return "" |
|
|
|
|
|
|
|
|
def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask, |
|
|
cache_manager: QuantizedKVCache, config: CompressionConfig, |
|
|
measure_memory: bool = True) -> Dict[str, Any]: |
|
|
""" |
|
|
Unified compression pipeline for ALL benchmarks with safety fixes. |
|
|
Returns compressed cache, metrics, and reconstructed KV pairs. |
|
|
""" |
|
|
device = input_ids.device |
|
|
|
|
|
input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask) |
|
|
|
|
|
if torch.cuda.is_available() and measure_memory: |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
outputs = model( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
use_cache=True, |
|
|
return_dict=True |
|
|
) |
|
|
past_key_values = outputs.past_key_values |
|
|
logits = outputs.logits |
|
|
except Exception as e: |
|
|
logger.error(f"Prefill failed: {e}") |
|
|
return { |
|
|
'past_key_values': None, |
|
|
'prefill_time': 0, |
|
|
'prefill_peak_mem': 0, |
|
|
'prefill_loss': None, |
|
|
'original_cache_size': 0, |
|
|
'compressed_cache_size': 0, |
|
|
'compression_ratio': 1.0, |
|
|
'logits': None |
|
|
} |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
prefill_time = time.perf_counter() - start_time |
|
|
|
|
|
prefill_peak_mem = 0 |
|
|
if torch.cuda.is_available() and measure_memory: |
|
|
prefill_peak_mem = _peak_mem_bytes_all_gpus() |
|
|
|
|
|
prefill_loss = None |
|
|
if logits is not None and input_ids.shape[1] > 1: |
|
|
try: |
|
|
seq_len = min(logits.shape[1], input_ids.shape[1] - 1) |
|
|
if seq_len > 0: |
|
|
shift_logits = logits[:, :seq_len, :].contiguous() |
|
|
shift_labels = input_ids[:, 1:seq_len+1].contiguous() |
|
|
|
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1), |
|
|
reduction='mean', |
|
|
ignore_index=tokenizer.pad_token_id or -100 |
|
|
) |
|
|
prefill_loss = loss.item() |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not calculate prefill loss: {e}") |
|
|
|
|
|
original_cache_size = 0 |
|
|
compressed_cache_size = 0 |
|
|
compression_ratio = 1.0 |
|
|
|
|
|
if past_key_values: |
|
|
try: |
|
|
if hasattr(past_key_values, 'to_legacy_cache'): |
|
|
kv_tuple = past_key_values.to_legacy_cache() |
|
|
else: |
|
|
kv_tuple = past_key_values |
|
|
|
|
|
for layer_idx, (keys, values) in enumerate(kv_tuple): |
|
|
if keys is not None and values is not None: |
|
|
original_cache_size += keys.nelement() * keys.element_size() |
|
|
original_cache_size += values.nelement() * values.element_size() |
|
|
|
|
|
if config.compression_type != CompressionType.NONE and cache_manager is not None: |
|
|
try: |
|
|
cache_manager.compress_and_store(layer_idx, keys, values) |
|
|
except Exception as e: |
|
|
logger.error(f"Compression failed for layer {layer_idx}: {e}") |
|
|
|
|
|
if config.compression_type != CompressionType.NONE and cache_manager is not None: |
|
|
reconstructed_kv = [] |
|
|
for layer_idx in range(len(kv_tuple)): |
|
|
try: |
|
|
dec_keys, dec_values = cache_manager.get_decompressed(layer_idx) |
|
|
if dec_keys is not None and dec_values is not None: |
|
|
reconstructed_kv.append((dec_keys, dec_values)) |
|
|
else: |
|
|
reconstructed_kv.append(kv_tuple[layer_idx]) |
|
|
except Exception as e: |
|
|
logger.error(f"Decompression failed for layer {layer_idx}: {e}") |
|
|
reconstructed_kv.append(kv_tuple[layer_idx]) |
|
|
|
|
|
if hasattr(DynamicCache, 'from_legacy_cache'): |
|
|
past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv)) |
|
|
else: |
|
|
past_key_values = tuple(reconstructed_kv) |
|
|
|
|
|
try: |
|
|
compressed_cache_size = cache_manager.get_memory_footprint() |
|
|
except: |
|
|
compressed_cache_size = original_cache_size |
|
|
else: |
|
|
compressed_cache_size = original_cache_size |
|
|
|
|
|
if compressed_cache_size > 0: |
|
|
compression_ratio = original_cache_size / compressed_cache_size |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Cache processing failed: {e}") |
|
|
compressed_cache_size = original_cache_size |
|
|
compression_ratio = 1.0 |
|
|
|
|
|
return { |
|
|
'past_key_values': past_key_values, |
|
|
'prefill_time': prefill_time, |
|
|
'prefill_peak_mem': prefill_peak_mem, |
|
|
'prefill_loss': prefill_loss, |
|
|
'original_cache_size': original_cache_size, |
|
|
'compressed_cache_size': compressed_cache_size, |
|
|
'compression_ratio': compression_ratio, |
|
|
'logits': logits |
|
|
} |
|
|
|
|
|
|
|
|
def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str: |
|
|
"""Create Needle-in-a-Haystack test context - NO HARDCODING.""" |
|
|
haystack_template = "The quick brown fox jumps over the lazy dog. " * 20 |
|
|
haystack_chunks = [] |
|
|
|
|
|
while len(" ".join(haystack_chunks)) < context_length: |
|
|
haystack_chunks.append(haystack_template) |
|
|
|
|
|
haystack = " ".join(haystack_chunks)[:context_length - len(needle) - 10] |
|
|
|
|
|
insertion_point = int(len(haystack) * depth_percent / 100) |
|
|
haystack_with_needle = ( |
|
|
haystack[:insertion_point] + |
|
|
" " + needle + " " + |
|
|
haystack[insertion_point:] |
|
|
) |
|
|
|
|
|
return haystack_with_needle |
|
|
|
|
|
|
|
|
def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]: |
|
|
"""Evaluate NIAH with SAME compression pipeline as WikiText.""" |
|
|
context = create_niah_haystack( |
|
|
config.prefill_length, |
|
|
config.niah_needle, |
|
|
config.niah_depth_percent |
|
|
) |
|
|
|
|
|
prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:" |
|
|
|
|
|
inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024)) |
|
|
input_ids = inputs.input_ids.to(model.device) |
|
|
attention_mask = inputs.attention_mask.to(model.device) |
|
|
|
|
|
compression_result = apply_compression_pipeline( |
|
|
model, tokenizer, input_ids, attention_mask, cache_manager, config |
|
|
) |
|
|
|
|
|
gen_start = time.perf_counter() |
|
|
generated_text = safe_generate(model, tokenizer, input_ids, attention_mask, |
|
|
compression_result['past_key_values'], max_new_tokens=20) |
|
|
gen_time = time.perf_counter() - gen_start |
|
|
|
|
|
accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0 |
|
|
|
|
|
logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}") |
|
|
logger.info(f"NIAH compression ratio: {compression_result['compression_ratio']:.1f}x") |
|
|
|
|
|
return { |
|
|
'accuracy': accuracy, |
|
|
'compression_ratio': compression_result['compression_ratio'], |
|
|
'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024), |
|
|
'prefill_time': compression_result['prefill_time'], |
|
|
'generation_time': gen_time, |
|
|
'prefill_peak_mem': compression_result['prefill_peak_mem'] |
|
|
} |
|
|
|
|
|
|
|
|
def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]: |
|
|
"""Evaluate RULER with SAME compression pipeline as WikiText.""" |
|
|
seq_len = min(config.ruler_max_seq_length, config.prefill_length, 1024) |
|
|
|
|
|
facts = [] |
|
|
for i in range(10): |
|
|
facts.append(f"Fact {i}: The capital of Country{i} is City{i}.") |
|
|
|
|
|
context = " ".join(facts) * (seq_len // (len(" ".join(facts)) + 1)) |
|
|
context = context[:seq_len - 100] |
|
|
|
|
|
query_idx = random.randint(0, 9) |
|
|
prompt = f"{context}\n\nWhat is the capital of Country{query_idx}?" |
|
|
|
|
|
inputs = safe_tokenize(tokenizer, prompt, max_length=seq_len) |
|
|
input_ids = inputs.input_ids.to(model.device) |
|
|
attention_mask = inputs.attention_mask.to(model.device) |
|
|
|
|
|
compression_result = apply_compression_pipeline( |
|
|
model, tokenizer, input_ids, attention_mask, cache_manager, config |
|
|
) |
|
|
|
|
|
gen_start = time.perf_counter() |
|
|
generated = safe_generate(model, tokenizer, input_ids, attention_mask, |
|
|
compression_result['past_key_values'], max_new_tokens=10) |
|
|
gen_time = time.perf_counter() - gen_start |
|
|
|
|
|
expected = f"City{query_idx}" |
|
|
exact_match = 1.0 if expected in generated else 0.0 |
|
|
|
|
|
logger.info(f"RULER exact match: {exact_match}, Generated: {generated[:50]}") |
|
|
logger.info(f"RULER compression ratio: {compression_result['compression_ratio']:.1f}x") |
|
|
|
|
|
return { |
|
|
'exact_match': exact_match, |
|
|
'compression_ratio': compression_result['compression_ratio'], |
|
|
'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024), |
|
|
'prefill_time': compression_result['prefill_time'], |
|
|
'generation_time': gen_time, |
|
|
'prefill_peak_mem': compression_result['prefill_peak_mem'] |
|
|
} |
|
|
|
|
|
|
|
|
def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]: |
|
|
"""Evaluate SCBench with SAME compression pipeline as WikiText.""" |
|
|
conversation = [] |
|
|
facts = {} |
|
|
|
|
|
for turn in range(config.scbench_num_turns): |
|
|
fact_key = f"item_{turn}" |
|
|
fact_value = f"value_{turn}_{random.randint(1000, 9999)}" |
|
|
facts[fact_key] = fact_value |
|
|
|
|
|
user_msg = f"Remember that {fact_key} is {fact_value}." |
|
|
assistant_msg = f"I'll remember that {fact_key} is {fact_value}." |
|
|
|
|
|
conversation.append(f"User: {user_msg}") |
|
|
conversation.append(f"Assistant: {assistant_msg}") |
|
|
|
|
|
query_key = random.choice(list(facts.keys())) |
|
|
conversation.append(f"User: What is {query_key}?") |
|
|
|
|
|
full_conversation = "\n".join(conversation) + "\nAssistant:" |
|
|
|
|
|
inputs = safe_tokenize(tokenizer, full_conversation, max_length=min(config.prefill_length, 1024)) |
|
|
input_ids = inputs.input_ids.to(model.device) |
|
|
attention_mask = inputs.attention_mask.to(model.device) |
|
|
|
|
|
compression_result = apply_compression_pipeline( |
|
|
model, tokenizer, input_ids, attention_mask, cache_manager, config |
|
|
) |
|
|
|
|
|
gen_start = time.perf_counter() |
|
|
generated = safe_generate(model, tokenizer, input_ids, attention_mask, |
|
|
compression_result['past_key_values'], max_new_tokens=20) |
|
|
gen_time = time.perf_counter() - gen_start |
|
|
|
|
|
expected_value = facts[query_key] |
|
|
accuracy = 1.0 if expected_value in generated else 0.0 |
|
|
|
|
|
logger.info(f"SCBench accuracy: {accuracy}, Generated: {generated[:50]}") |
|
|
logger.info(f"SCBench compression ratio: {compression_result['compression_ratio']:.1f}x") |
|
|
|
|
|
return { |
|
|
'accuracy': accuracy, |
|
|
'compression_ratio': compression_result['compression_ratio'], |
|
|
'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024), |
|
|
'prefill_time': compression_result['prefill_time'], |
|
|
'generation_time': gen_time, |
|
|
'prefill_peak_mem': compression_result['prefill_peak_mem'] |
|
|
} |
|
|
|
|
|
|
|
|
def evaluate_longbench_task(model, tokenizer, config: CompressionConfig, |
|
|
task: str, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]: |
|
|
"""Evaluate LongBench with SAME compression pipeline as WikiText.""" |
|
|
try: |
|
|
dataset = load_dataset("THUDM/LongBench", task, split="test") |
|
|
|
|
|
n_samples = min(config.eval_samples, len(dataset)) |
|
|
samples = dataset.select(range(n_samples)) |
|
|
|
|
|
scores = [] |
|
|
compression_ratios = [] |
|
|
kv_memories = [] |
|
|
prefill_times = [] |
|
|
gen_times = [] |
|
|
|
|
|
for sample in samples: |
|
|
context = sample.get("context", "") |
|
|
question = sample.get("input", sample.get("question", "")) |
|
|
answer = sample.get("answers", [sample.get("answer", "")]) |
|
|
|
|
|
if isinstance(answer, list) and answer: |
|
|
answer = answer[0] |
|
|
|
|
|
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" |
|
|
|
|
|
inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024)) |
|
|
input_ids = inputs.input_ids.to(model.device) |
|
|
attention_mask = inputs.attention_mask.to(model.device) |
|
|
|
|
|
compression_result = apply_compression_pipeline( |
|
|
model, tokenizer, input_ids, attention_mask, cache_manager, config, |
|
|
measure_memory=False |
|
|
) |
|
|
|
|
|
gen_start = time.perf_counter() |
|
|
generated = safe_generate(model, tokenizer, input_ids, attention_mask, |
|
|
compression_result['past_key_values'], max_new_tokens=50) |
|
|
gen_time = time.perf_counter() - gen_start |
|
|
|
|
|
score = 1.0 if str(answer).lower() in generated.lower() else 0.0 |
|
|
scores.append(score) |
|
|
compression_ratios.append(compression_result['compression_ratio']) |
|
|
kv_memories.append(compression_result['compressed_cache_size'] / (1024 * 1024)) |
|
|
prefill_times.append(compression_result['prefill_time']) |
|
|
gen_times.append(gen_time) |
|
|
|
|
|
avg_compression = float(np.mean(compression_ratios)) if compression_ratios else 1.0 |
|
|
|
|
|
return { |
|
|
'accuracy': float(np.mean(scores)), |
|
|
'n_samples': n_samples, |
|
|
'compression_ratio': avg_compression, |
|
|
'kv_cache_memory_mb': float(np.mean(kv_memories)) if kv_memories else 0.0, |
|
|
'prefill_time': float(np.mean(prefill_times)) if prefill_times else 0.0, |
|
|
'generation_time': float(np.mean(gen_times)) if gen_times else 0.0 |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error evaluating LongBench task {task}: {e}") |
|
|
return { |
|
|
'accuracy': 0.0, |
|
|
'n_samples': 0, |
|
|
'compression_ratio': 1.0, |
|
|
'kv_cache_memory_mb': 0.0, |
|
|
'prefill_time': 0.0, |
|
|
'generation_time': 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(model_name: str, config: CompressionConfig): |
|
|
"""Load model and tokenizer with proper configuration - NO HARDCODING.""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
if config.fail_on_cpu_fallback and device == "cpu": |
|
|
raise RuntimeError("CUDA required but unavailable (fail_on_cpu_fallback=True)") |
|
|
|
|
|
logger.info(f"Loading model: {model_name}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model_kwargs = { |
|
|
"torch_dtype": dtype, |
|
|
"device_map": "auto" if device == "cuda" else None, |
|
|
"low_cpu_mem_usage": True, |
|
|
"trust_remote_code": True |
|
|
} |
|
|
|
|
|
if config.use_flash_attention and device == "cuda": |
|
|
try: |
|
|
model_kwargs["attn_implementation"] = "flash_attention_2" |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) |
|
|
logger.info("Successfully loaded with Flash Attention 2") |
|
|
except Exception as e: |
|
|
logger.warning(f"Flash Attention not available: {e}") |
|
|
model_kwargs.pop("attn_implementation", None) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) |
|
|
else: |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]: |
|
|
"""Load dataset samples based on benchmark type - NO HARDCODING.""" |
|
|
logger.info(f"Loading samples for benchmark: {config.benchmark_type}") |
|
|
|
|
|
if config.benchmark_type == "wikitext": |
|
|
texts = [] |
|
|
min_tokens = config.prefill_length + config.generation_length |
|
|
|
|
|
try: |
|
|
for split in [config.dataset_split, "train", "validation"]: |
|
|
if len(texts) >= config.eval_samples: |
|
|
break |
|
|
|
|
|
try: |
|
|
dataset = load_dataset( |
|
|
config.dataset_name, |
|
|
config.dataset_config, |
|
|
split=split, |
|
|
streaming=False |
|
|
) |
|
|
|
|
|
logger.info(f"Trying {split} split with {len(dataset)} samples") |
|
|
|
|
|
for item in dataset: |
|
|
text = item.get('text', '').strip() |
|
|
|
|
|
if len(text) > 50: |
|
|
tokens = tokenizer.encode(text, truncation=False, add_special_tokens=False) |
|
|
|
|
|
if len(tokens) >= min(min_tokens, 256): |
|
|
texts.append(text) |
|
|
if len(texts) >= config.eval_samples * 3: |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load {split} split: {e}") |
|
|
continue |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load dataset: {e}") |
|
|
raise |
|
|
|
|
|
elif config.benchmark_type == "longbench": |
|
|
texts = [] |
|
|
if config.benchmark_subset: |
|
|
try: |
|
|
dataset = load_dataset("THUDM/LongBench", config.benchmark_subset, split="test") |
|
|
for item in dataset: |
|
|
if len(texts) >= config.eval_samples: |
|
|
break |
|
|
context = item.get("context", "") |
|
|
if len(context) > 100: |
|
|
texts.append(context) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load LongBench subset {config.benchmark_subset}: {e}") |
|
|
raise |
|
|
|
|
|
elif config.benchmark_type in ["niah", "ruler", "scbench"]: |
|
|
texts = ["Synthetic benchmark data"] * config.eval_samples |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported benchmark type: {config.benchmark_type}") |
|
|
|
|
|
if len(texts) < config.eval_samples: |
|
|
logger.warning(f"Only loaded {len(texts)} samples, requested {config.eval_samples}") |
|
|
|
|
|
logger.info(f"Loaded {len(texts)} text samples") |
|
|
return texts |
|
|
|
|
|
|
|
|
def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_texts: Optional[List[str]] = None) -> Tuple[BenchmarkMetrics, Dict, List[Dict], List[Dict]]: |
|
|
"""Research-grade benchmark with UNIFIED compression for ALL benchmarks.""" |
|
|
logger.info(f"Starting benchmark: {model_name} with {config.compression_type.value}") |
|
|
logger.info(f"Benchmark type: {config.benchmark_type}") |
|
|
logger.info(f"Config hash: {config.get_hash()}") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
|
|
|
constants = ResearchConstants() |
|
|
start_time = datetime.now().isoformat() |
|
|
per_sample_records = [] |
|
|
per_layer_fingerprints = [] |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer(model_name, config) |
|
|
|
|
|
try: |
|
|
n_layers = detect_model_layers(model) |
|
|
logger.info(f"Model architecture: {n_layers} transformer layers detected") |
|
|
except ValueError as e: |
|
|
logger.error(f"Failed to detect model layers: {e}") |
|
|
raise |
|
|
|
|
|
device = model.device |
|
|
with torch.inference_mode(): |
|
|
dummy = torch.randint(0, tokenizer.vocab_size, (1, min(config.prefill_length, 128)), device=device) |
|
|
am = torch.ones_like(dummy) |
|
|
for _ in range(config.warmup_steps): |
|
|
_ = model(dummy, attention_mask=am, use_cache=True, return_dict=True) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
|
|
if dataset_texts is None: |
|
|
dataset_texts = load_real_dataset_samples(config, tokenizer) |
|
|
|
|
|
all_metrics = [] |
|
|
|
|
|
for seed in range(config.n_seeds): |
|
|
set_seed(config.seed + seed) |
|
|
logger.info(f"Running evaluation with seed {config.seed + seed}") |
|
|
|
|
|
metrics = BenchmarkMetrics() |
|
|
|
|
|
if config.benchmark_type == "niah": |
|
|
for depth in BENCHMARK_CONFIGS["niah"]["depths"]: |
|
|
config.niah_depth_percent = depth |
|
|
for idx in range(min(config.eval_samples, 10)): |
|
|
if config.compression_type != CompressionType.NONE: |
|
|
cache_manager = QuantizedKVCache(config) |
|
|
cache_manager.n_layers = n_layers |
|
|
else: |
|
|
cache_manager = None |
|
|
|
|
|
result = evaluate_niah(model, tokenizer, config, cache_manager) |
|
|
|
|
|
metrics.niah_retrieval_accuracy.append(result['accuracy']) |
|
|
metrics.compression_ratios.append(result['compression_ratio']) |
|
|
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb']) |
|
|
metrics.prefill_times.append(result['prefill_time']) |
|
|
metrics.decode_times.append(result['generation_time'] / 20) |
|
|
|
|
|
if result['prefill_peak_mem'] > 0: |
|
|
metrics.prefill_peak_memories.append(result['prefill_peak_mem']) |
|
|
|
|
|
per_sample_records.append({ |
|
|
'benchmark': 'niah', |
|
|
'depth_percent': depth, |
|
|
'sample_idx': idx, |
|
|
'accuracy': result['accuracy'], |
|
|
'compression_ratio': result['compression_ratio'], |
|
|
'kv_cache_memory_mb': result['kv_cache_memory_mb'], |
|
|
'compression_type': config.compression_type.value |
|
|
}) |
|
|
|
|
|
elif config.benchmark_type == "ruler": |
|
|
for idx in range(config.eval_samples): |
|
|
if config.compression_type != CompressionType.NONE: |
|
|
cache_manager = QuantizedKVCache(config) |
|
|
cache_manager.n_layers = n_layers |
|
|
else: |
|
|
cache_manager = None |
|
|
|
|
|
result = evaluate_ruler(model, tokenizer, config, cache_manager) |
|
|
|
|
|
metrics.ruler_exact_match.append(result['exact_match']) |
|
|
metrics.compression_ratios.append(result['compression_ratio']) |
|
|
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb']) |
|
|
metrics.prefill_times.append(result['prefill_time']) |
|
|
metrics.decode_times.append(result['generation_time'] / 10) |
|
|
|
|
|
if result['prefill_peak_mem'] > 0: |
|
|
metrics.prefill_peak_memories.append(result['prefill_peak_mem']) |
|
|
|
|
|
per_sample_records.append({ |
|
|
'benchmark': 'ruler', |
|
|
'sample_idx': idx, |
|
|
'exact_match': result['exact_match'], |
|
|
'compression_ratio': result['compression_ratio'], |
|
|
'kv_cache_memory_mb': result['kv_cache_memory_mb'], |
|
|
'compression_type': config.compression_type.value |
|
|
}) |
|
|
|
|
|
elif config.benchmark_type == "scbench": |
|
|
for idx in range(config.eval_samples): |
|
|
if config.compression_type != CompressionType.NONE: |
|
|
cache_manager = QuantizedKVCache(config) |
|
|
cache_manager.n_layers = n_layers |
|
|
else: |
|
|
cache_manager = None |
|
|
|
|
|
result = evaluate_scbench(model, tokenizer, config, cache_manager) |
|
|
|
|
|
metrics.scbench_turn_accuracy.append(result['accuracy']) |
|
|
metrics.compression_ratios.append(result['compression_ratio']) |
|
|
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb']) |
|
|
metrics.prefill_times.append(result['prefill_time']) |
|
|
metrics.decode_times.append(result['generation_time'] / 20) |
|
|
|
|
|
if result['prefill_peak_mem'] > 0: |
|
|
metrics.prefill_peak_memories.append(result['prefill_peak_mem']) |
|
|
|
|
|
per_sample_records.append({ |
|
|
'benchmark': 'scbench', |
|
|
'sample_idx': idx, |
|
|
'accuracy': result['accuracy'], |
|
|
'compression_ratio': result['compression_ratio'], |
|
|
'kv_cache_memory_mb': result['kv_cache_memory_mb'], |
|
|
'compression_type': config.compression_type.value |
|
|
}) |
|
|
|
|
|
elif config.benchmark_type == "longbench": |
|
|
if config.benchmark_subset: |
|
|
if config.compression_type != CompressionType.NONE: |
|
|
cache_manager = QuantizedKVCache(config) |
|
|
cache_manager.n_layers = n_layers |
|
|
else: |
|
|
cache_manager = None |
|
|
|
|
|
result = evaluate_longbench_task(model, tokenizer, config, |
|
|
config.benchmark_subset, cache_manager) |
|
|
|
|
|
metrics.longbench_scores.append(result) |
|
|
metrics.compression_ratios.append(result['compression_ratio']) |
|
|
metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb']) |
|
|
metrics.prefill_times.append(result['prefill_time']) |
|
|
|
|
|
if result['generation_time'] > 0: |
|
|
metrics.decode_times.append(result['generation_time'] / 50) |
|
|
|
|
|
per_sample_records.append({ |
|
|
'benchmark': 'longbench', |
|
|
'subset': config.benchmark_subset, |
|
|
'accuracy': result['accuracy'], |
|
|
'compression_ratio': result['compression_ratio'], |
|
|
'kv_cache_memory_mb': result['kv_cache_memory_mb'], |
|
|
'compression_type': config.compression_type.value |
|
|
}) |
|
|
|
|
|
else: |
|
|
for idx in range(config.eval_samples): |
|
|
logger.info(f"Sample {idx+1}/{config.eval_samples}") |
|
|
|
|
|
text_idx = (idx + seed * config.eval_samples) % len(dataset_texts) |
|
|
text = dataset_texts[text_idx] |
|
|
|
|
|
if config.compression_type != CompressionType.NONE: |
|
|
cache_manager = QuantizedKVCache(config) |
|
|
cache_manager.n_layers = n_layers |
|
|
cache_manager.update_position(config.prefill_length + idx) |
|
|
else: |
|
|
cache_manager = None |
|
|
|
|
|
inputs = safe_tokenize(tokenizer, text, max_length=min(config.prefill_length, 1024)) |
|
|
input_ids = inputs.input_ids.to(device) |
|
|
attention_mask = inputs.attention_mask.to(device) |
|
|
|
|
|
compression_result = apply_compression_pipeline( |
|
|
model, tokenizer, input_ids, attention_mask, cache_manager, config |
|
|
) |
|
|
|
|
|
metrics.prefill_times.append(compression_result['prefill_time']) |
|
|
metrics.compression_ratios.append(compression_result['compression_ratio']) |
|
|
metrics.kv_cache_memory_samples_mb.append(compression_result['compressed_cache_size'] / (1024 * 1024)) |
|
|
|
|
|
if compression_result['prefill_peak_mem'] > 0: |
|
|
metrics.prefill_peak_memories.append(compression_result['prefill_peak_mem']) |
|
|
|
|
|
if compression_result['prefill_loss'] is not None: |
|
|
prefill_perplexity = np.exp(compression_result['prefill_loss']) |
|
|
metrics.prefill_perplexities.append(min(prefill_perplexity, 1000)) |
|
|
|
|
|
generated_ids = input_ids.clone() |
|
|
decode_times = [] |
|
|
generation_losses = [] |
|
|
past_key_values = compression_result['past_key_values'] |
|
|
|
|
|
for gen_step in range(config.generation_length): |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
step_start = time.perf_counter() |
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = model( |
|
|
generated_ids[:, -1:], |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True, |
|
|
return_dict=True |
|
|
) |
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
next_token = torch.argmax(next_token_logits, dim=-1) |
|
|
|
|
|
loss = F.cross_entropy(next_token_logits, next_token) |
|
|
generation_losses.append(loss.item()) |
|
|
|
|
|
generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) |
|
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
decode_time = time.perf_counter() - step_start |
|
|
decode_times.append(decode_time) |
|
|
|
|
|
if decode_times: |
|
|
metrics.decode_times.extend(decode_times) |
|
|
|
|
|
if generation_losses: |
|
|
generation_perplexity = np.exp(np.mean(generation_losses)) |
|
|
metrics.generation_perplexities.append(min(generation_perplexity, 1000)) |
|
|
|
|
|
per_sample_records.append({ |
|
|
'benchmark': 'wikitext', |
|
|
'sample_idx': idx, |
|
|
'prefill_perplexity': metrics.prefill_perplexities[-1] if metrics.prefill_perplexities else None, |
|
|
'generation_perplexity': metrics.generation_perplexities[-1] if metrics.generation_perplexities else None, |
|
|
'compression_ratio': compression_result['compression_ratio'], |
|
|
'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024), |
|
|
'compression_type': config.compression_type.value |
|
|
}) |
|
|
|
|
|
metrics.calculate_statistics(config) |
|
|
all_metrics.append(metrics) |
|
|
|
|
|
final_metrics = BenchmarkMetrics() |
|
|
for m in all_metrics: |
|
|
final_metrics.prefill_times.extend(m.prefill_times) |
|
|
final_metrics.prefill_peak_memories.extend(m.prefill_peak_memories) |
|
|
final_metrics.decode_times.extend(m.decode_times) |
|
|
final_metrics.decode_peak_memories.extend(m.decode_peak_memories) |
|
|
final_metrics.prefill_perplexities.extend(m.prefill_perplexities) |
|
|
final_metrics.generation_perplexities.extend(m.generation_perplexities) |
|
|
final_metrics.compression_ratios.extend(m.compression_ratios) |
|
|
final_metrics.kv_cache_memory_samples_mb.extend(m.kv_cache_memory_samples_mb) |
|
|
final_metrics.niah_retrieval_accuracy.extend(m.niah_retrieval_accuracy) |
|
|
final_metrics.ruler_exact_match.extend(m.ruler_exact_match) |
|
|
final_metrics.scbench_turn_accuracy.extend(m.scbench_turn_accuracy) |
|
|
final_metrics.longbench_scores.extend(m.longbench_scores) |
|
|
|
|
|
final_metrics.calculate_statistics(config) |
|
|
|
|
|
end_time = datetime.now().isoformat() |
|
|
summary = { |
|
|
'compression_type': config.compression_type.value, |
|
|
'model': model_name, |
|
|
'benchmark_type': config.benchmark_type, |
|
|
'n_seeds': config.n_seeds, |
|
|
'total_samples': config.eval_samples * config.n_seeds, |
|
|
'compression_ratio': final_metrics.compression_ratio_mean, |
|
|
'kv_cache_memory_mb': final_metrics.kv_cache_memory_mb, |
|
|
'start_time': start_time, |
|
|
'end_time': end_time |
|
|
} |
|
|
|
|
|
if config.benchmark_type == "niah" and final_metrics.niah_retrieval_accuracy: |
|
|
summary['niah_accuracy'] = float(np.mean(final_metrics.niah_retrieval_accuracy)) |
|
|
elif config.benchmark_type == "ruler" and final_metrics.ruler_exact_match: |
|
|
summary['ruler_exact_match'] = float(np.mean(final_metrics.ruler_exact_match)) |
|
|
elif config.benchmark_type == "scbench" and final_metrics.scbench_turn_accuracy: |
|
|
summary['scbench_accuracy'] = float(np.mean(final_metrics.scbench_turn_accuracy)) |
|
|
elif config.benchmark_type == "longbench" and final_metrics.longbench_scores: |
|
|
summary['longbench_accuracy'] = float(np.mean([s['accuracy'] for s in final_metrics.longbench_scores])) |
|
|
else: |
|
|
summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean |
|
|
summary['generation_perplexity'] = final_metrics.generation_perplexity_mean |
|
|
|
|
|
summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000 |
|
|
summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms |
|
|
summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec |
|
|
summary['end_to_end_throughput'] = final_metrics.end_to_end_throughput |
|
|
summary['end_to_end_latency_ms'] = final_metrics.end_to_end_latency_ms |
|
|
summary['peak_memory_mb'] = final_metrics.prefill_peak_memory_mean_mb |
|
|
|
|
|
return final_metrics, summary, per_sample_records, per_layer_fingerprints |
|
|
|
|
|
|
|
|
def export_proof_bundle(bundle_dir: str, config: CompressionConfig, |
|
|
metrics: BenchmarkMetrics, summary: Dict[str, Any], |
|
|
per_sample_records: List[Dict[str, Any]], |
|
|
per_layer_fingerprints: List[Dict[str, Any]]) -> str: |
|
|
"""Export attestable proof bundle with all metrics and fingerprints.""" |
|
|
p = pathlib.Path(bundle_dir) |
|
|
p.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
manifest = { |
|
|
"config": json.loads(config.to_json()), |
|
|
"config_hash": config.get_hash(), |
|
|
"model": config.model_name, |
|
|
"benchmark_type": config.benchmark_type, |
|
|
"python": sys.version, |
|
|
"torch": config.torch_version, |
|
|
"transformers": config.transformers_version, |
|
|
"cuda": config.cuda_version, |
|
|
"device_name": config.device_name, |
|
|
"start_time": summary.get("start_time"), |
|
|
"end_time": summary.get("end_time"), |
|
|
"hostname": platform.node() |
|
|
} |
|
|
|
|
|
(p / "manifest.json").write_text(json.dumps(manifest, indent=2)) |
|
|
(p / "summary.json").write_text(json.dumps(summary, indent=2, default=str)) |
|
|
|
|
|
records_dir = p / "records" |
|
|
records_dir.mkdir(exist_ok=True) |
|
|
|
|
|
with open(records_dir / "metrics.jsonl", "w") as f: |
|
|
for r in per_sample_records: |
|
|
f.write(json.dumps(r, default=str) + "\n") |
|
|
|
|
|
with open(records_dir / "kv_fingerprints.jsonl", "w") as f: |
|
|
for r in per_layer_fingerprints: |
|
|
f.write(json.dumps(r, default=str) + "\n") |
|
|
|
|
|
try: |
|
|
env_text = subprocess.check_output([sys.executable, "-m", "pip", "freeze"], text=True) |
|
|
(p / "env.lock").write_text(env_text) |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not capture environment: {e}") |
|
|
(p / "env.lock").write_text(f"# Environment capture failed: {e}\n") |
|
|
|
|
|
zip_path = str(p.with_suffix(".zip")) |
|
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: |
|
|
for root, _, files in os.walk(p): |
|
|
for name in files: |
|
|
full = pathlib.Path(root) / name |
|
|
z.write(full, arcname=str(full.relative_to(p))) |
|
|
|
|
|
logger.info(f"Proof bundle exported: {zip_path}") |
|
|
return zip_path |
|
|
|
|
|
|
|
|
def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: ProvingConfig) -> Dict[str, Any]: |
|
|
"""Verify proof bundle - recompute metrics and check tolerances.""" |
|
|
try: |
|
|
with open(os.path.join(bundle_root, "summary.json")) as f: |
|
|
summary = json.load(f) |
|
|
|
|
|
records = [] |
|
|
with open(os.path.join(bundle_root, "records", "metrics.jsonl")) as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
records.append(json.loads(line)) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load proof bundle: {e}") |
|
|
|
|
|
if not records: |
|
|
raise ValueError("No per-sample records found in proof bundle") |
|
|
|
|
|
primary_method = summary.get("compression_type", "enhanced_spg") |
|
|
primary_records = [r for r in records if r.get("compression_type") == primary_method] |
|
|
|
|
|
if not primary_records: |
|
|
raise ValueError(f"No records found for method {primary_method}") |
|
|
|
|
|
logger.info(f"Verifying {len(primary_records)} records for {primary_method}") |
|
|
|
|
|
def mean_of(key): |
|
|
vals = [float(r[key]) for r in primary_records if key in r and r[key] is not None] |
|
|
return float(np.mean(vals)) if vals else None |
|
|
|
|
|
recomputed = {} |
|
|
failures = [] |
|
|
|
|
|
if config.benchmark_type == "niah": |
|
|
if "niah_accuracy" in summary: |
|
|
recomputed["niah_accuracy"] = mean_of("accuracy") |
|
|
elif config.benchmark_type == "ruler": |
|
|
if "ruler_exact_match" in summary: |
|
|
recomputed["ruler_exact_match"] = mean_of("exact_match") |
|
|
elif config.benchmark_type == "scbench": |
|
|
if "scbench_accuracy" in summary: |
|
|
recomputed["scbench_accuracy"] = mean_of("accuracy") |
|
|
elif config.benchmark_type == "longbench": |
|
|
if "longbench_accuracy" in summary: |
|
|
recomputed["longbench_accuracy"] = mean_of("accuracy") |
|
|
elif config.benchmark_type == "wikitext": |
|
|
if "prefill_perplexity" in summary: |
|
|
recomputed["prefill_perplexity"] = mean_of("prefill_perplexity") |
|
|
if "generation_perplexity" in summary: |
|
|
recomputed["generation_perplexity"] = mean_of("generation_perplexity") |
|
|
|
|
|
recomputed["compression_ratio"] = mean_of("compression_ratio") |
|
|
recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb") |
|
|
|
|
|
for k, v in recomputed.items(): |
|
|
s = summary.get(k) |
|
|
if v is not None and s is not None: |
|
|
if abs(v - float(s)) > proving.numeric_tolerance: |
|
|
failures.append(f"{k}: recomputed {v:.6f} != summary {s:.6f}") |
|
|
|
|
|
ok = len(failures) == 0 |
|
|
|
|
|
result = { |
|
|
"ok": ok, |
|
|
"failures": failures, |
|
|
"recomputed": recomputed, |
|
|
"summary": summary, |
|
|
"n_samples": len(records) |
|
|
} |
|
|
|
|
|
if not ok: |
|
|
logger.error(f"Proof verification FAILED: {failures}") |
|
|
else: |
|
|
logger.info(f"Proof verification PASSED for {len(records)} samples") |
|
|
|
|
|
return result |