kfoughali commited on
Commit
e49d439
·
verified ·
1 Parent(s): 33d9292

Update benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +32 -138
benchmark.py CHANGED
@@ -4,7 +4,7 @@ Benchmarking, metrics, and proof generation for Enhanced SPG.
4
  Supports LongBench, NIAH, RULER, SCBench benchmarks.
5
  MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
6
  ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT.
7
- FIXED: CUDA assert errors, tokenization issues, safe generation.
8
  """
9
 
10
  import torch
@@ -144,16 +144,12 @@ class BenchmarkMetrics:
144
  self.prefill_time_std = float(np.std(self.prefill_times))
145
  self.prefill_time_ci = self._bootstrap_ci(self.prefill_times, config)
146
  self.prefill_tokens_per_sec = config.prefill_length / self.prefill_time_mean if self.prefill_time_mean > 0 else 0.0
147
- else:
148
- logger.debug("No prefill time data available")
149
 
150
  if self.prefill_peak_memories:
151
  memories_mb = [m / (1024 * 1024) for m in self.prefill_peak_memories]
152
  self.prefill_peak_memory_mean_mb = float(np.mean(memories_mb))
153
  self.prefill_peak_memory_std_mb = float(np.std(memories_mb))
154
  self.prefill_peak_memory_ci_mb = self._bootstrap_ci(memories_mb, config)
155
- else:
156
- logger.debug("No prefill memory data available")
157
 
158
  if self.decode_times:
159
  self.decode_time_per_token_mean_ms = float(np.mean(self.decode_times) * 1000)
@@ -162,8 +158,6 @@ class BenchmarkMetrics:
162
  self.decode_tokens_per_sec = 1.0 / np.mean(self.decode_times) if self.decode_times else 0.0
163
  self.decode_time_p50_ms = float(np.percentile(self.decode_times, 50) * 1000)
164
  self.decode_time_p95_ms = float(np.percentile(self.decode_times, 95) * 1000)
165
- else:
166
- logger.debug("No decode time data available")
167
 
168
  # Calculate end-to-end throughput
169
  if self.prefill_time_mean > 0 and self.decode_time_per_token_mean_ms > 0:
@@ -174,37 +168,23 @@ class BenchmarkMetrics:
174
 
175
  if self.decode_peak_memories:
176
  self.decode_peak_memory_mean_mb = float(np.mean(self.decode_peak_memories) / (1024 * 1024))
177
- else:
178
- logger.debug("No decode memory data available")
179
 
180
  if self.prefill_perplexities:
181
  self.prefill_perplexity_mean = float(np.mean(self.prefill_perplexities))
182
  self.prefill_perplexity_std = float(np.std(self.prefill_perplexities))
183
  self.prefill_perplexity_ci = self._bootstrap_ci(self.prefill_perplexities, config)
184
- logger.info(f"Calculated prefill perplexity: mean={self.prefill_perplexity_mean:.2f}, "
185
- f"std={self.prefill_perplexity_std:.2f}, samples={len(self.prefill_perplexities)}")
186
- else:
187
- logger.warning("No prefill perplexity data available")
188
 
189
  if self.generation_perplexities:
190
  self.generation_perplexity_mean = float(np.mean(self.generation_perplexities))
191
  self.generation_perplexity_std = float(np.std(self.generation_perplexities))
192
  self.generation_perplexity_ci = self._bootstrap_ci(self.generation_perplexities, config)
193
- logger.info(f"Calculated generation perplexity: mean={self.generation_perplexity_mean:.2f}, "
194
- f"std={self.generation_perplexity_std:.2f}, samples={len(self.generation_perplexities)}")
195
- else:
196
- logger.warning("No generation perplexity data available")
197
 
198
  if self.compression_ratios:
199
  self.compression_ratio_mean = float(np.mean(self.compression_ratios))
200
  self.compression_ratio_std = float(np.std(self.compression_ratios))
201
- else:
202
- logger.debug("No compression ratio data available")
203
 
204
  if self.kv_cache_memory_samples_mb:
205
  self.kv_cache_memory_mb = float(np.mean(self.kv_cache_memory_samples_mb))
206
- else:
207
- logger.debug("No KV cache memory data available")
208
 
209
  except Exception as e:
210
  logger.error(f"Error calculating statistics: {e}")
@@ -213,7 +193,6 @@ class BenchmarkMetrics:
213
  def _bootstrap_ci(self, data: List[float], config: CompressionConfig) -> Tuple[float, float]:
214
  """Calculate bootstrap confidence interval with reproducible RNG."""
215
  if not data or len(data) < 2:
216
- logger.warning("Insufficient data for confidence interval calculation")
217
  return (0.0, 0.0)
218
 
219
  try:
@@ -240,11 +219,9 @@ class BenchmarkMetrics:
240
 
241
  def safe_tokenize(tokenizer, text, max_length=512):
242
  """Safe tokenization with proper padding and truncation."""
243
- # Ensure pad_token is set
244
  if tokenizer.pad_token is None:
245
  tokenizer.pad_token = tokenizer.eos_token
246
 
247
- # Tokenize with explicit parameters
248
  inputs = tokenizer(
249
  text,
250
  return_tensors="pt",
@@ -255,12 +232,10 @@ def safe_tokenize(tokenizer, text, max_length=512):
255
  add_special_tokens=True
256
  )
257
 
258
- # Validate outputs
259
  if inputs.input_ids.shape[1] == 0:
260
  raise ValueError("Tokenization produced empty sequence")
261
 
262
  if inputs.input_ids.shape[1] > max_length:
263
- logger.warning(f"Sequence length {inputs.input_ids.shape[1]} exceeds max {max_length}")
264
  inputs.input_ids = inputs.input_ids[:, :max_length]
265
  inputs.attention_mask = inputs.attention_mask[:, :max_length]
266
 
@@ -269,41 +244,35 @@ def safe_tokenize(tokenizer, text, max_length=512):
269
 
270
  def validate_model_inputs(model, input_ids, attention_mask):
271
  """Validate inputs are compatible with model."""
272
- # Check sequence length against model's max position embeddings
273
  if hasattr(model.config, 'max_position_embeddings'):
274
  max_pos = model.config.max_position_embeddings
275
  if input_ids.shape[1] > max_pos:
276
- logger.warning(f"Input length {input_ids.shape[1]} exceeds model max {max_pos}")
277
  input_ids = input_ids[:, :max_pos]
278
  attention_mask = attention_mask[:, :max_pos]
279
 
280
- # For GPT-2, check n_positions
281
  if hasattr(model.config, 'n_positions'):
282
  n_pos = model.config.n_positions
283
  if input_ids.shape[1] > n_pos:
284
- logger.warning(f"Input length {input_ids.shape[1]} exceeds GPT-2 positions {n_pos}")
285
  input_ids = input_ids[:, :n_pos]
286
  attention_mask = attention_mask[:, :n_pos]
287
 
288
- # Ensure input_ids are within vocabulary range
289
  vocab_size = model.config.vocab_size
290
  if input_ids.max() >= vocab_size:
291
- logger.error(f"Token id {input_ids.max()} exceeds vocab size {vocab_size}")
 
 
292
  input_ids = input_ids.clamp(0, vocab_size - 1)
293
 
294
  return input_ids, attention_mask
295
 
296
 
297
  def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=None, max_new_tokens=20):
298
- """Safe generation with proper error handling."""
299
  try:
300
- # Validate inputs
301
  input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
302
 
303
- # Set generation config
304
  gen_config = {
305
  "max_new_tokens": max_new_tokens,
306
- "temperature": 0.7,
307
  "do_sample": False,
308
  "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
309
  "eos_token_id": tokenizer.eos_token_id,
@@ -311,20 +280,21 @@ def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=N
311
  "use_cache": True
312
  }
313
 
314
- # Add past_key_values if available
315
  if past_key_values is not None:
316
  gen_config["past_key_values"] = past_key_values
317
 
318
- # Generate with error handling
319
  with torch.no_grad():
320
  output = model.generate(input_ids, **gen_config)
321
 
322
- return output
 
 
 
323
 
324
  except Exception as e:
325
  logger.error(f"Generation failed: {e}")
326
- # Return input as fallback
327
- return input_ids
328
 
329
 
330
  def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
@@ -336,21 +306,17 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
336
  """
337
  device = input_ids.device
338
 
339
- # Validate inputs first
340
  input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
341
 
342
- # Clear GPU cache if requested
343
  if torch.cuda.is_available() and measure_memory:
344
  torch.cuda.empty_cache()
345
  torch.cuda.reset_peak_memory_stats()
346
  torch.cuda.synchronize()
347
 
348
- # Measure prefill time
349
  if torch.cuda.is_available():
350
  torch.cuda.synchronize()
351
  start_time = time.perf_counter()
352
 
353
- # Prefill phase with error handling
354
  try:
355
  with torch.inference_mode():
356
  outputs = model(
@@ -363,7 +329,6 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
363
  logits = outputs.logits
364
  except Exception as e:
365
  logger.error(f"Prefill failed: {e}")
366
- # Return minimal valid result
367
  return {
368
  'past_key_values': None,
369
  'prefill_time': 0,
@@ -380,22 +345,18 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
380
 
381
  prefill_time = time.perf_counter() - start_time
382
 
383
- # Measure peak memory
384
  prefill_peak_mem = 0
385
  if torch.cuda.is_available() and measure_memory:
386
  prefill_peak_mem = _peak_mem_bytes_all_gpus()
387
 
388
- # Calculate prefill perplexity safely
389
  prefill_loss = None
390
  if logits is not None and input_ids.shape[1] > 1:
391
  try:
392
- # Ensure we have valid shapes
393
  seq_len = min(logits.shape[1], input_ids.shape[1] - 1)
394
  if seq_len > 0:
395
  shift_logits = logits[:, :seq_len, :].contiguous()
396
  shift_labels = input_ids[:, 1:seq_len+1].contiguous()
397
 
398
- # Calculate loss with ignore_index for padding
399
  loss = F.cross_entropy(
400
  shift_logits.view(-1, shift_logits.size(-1)),
401
  shift_labels.view(-1),
@@ -406,30 +367,28 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
406
  except Exception as e:
407
  logger.warning(f"Could not calculate prefill loss: {e}")
408
 
409
- # Compression phase - same as WikiText
410
  original_cache_size = 0
411
  compressed_cache_size = 0
412
  compression_ratio = 1.0
413
 
414
  if past_key_values:
415
  try:
416
- # Convert to legacy format for processing
417
- kv_tuple = past_key_values.to_legacy_cache() if hasattr(past_key_values, 'to_legacy_cache') else past_key_values
 
 
418
 
419
- # Calculate original size
420
  for layer_idx, (keys, values) in enumerate(kv_tuple):
421
  if keys is not None and values is not None:
422
  original_cache_size += keys.nelement() * keys.element_size()
423
  original_cache_size += values.nelement() * values.element_size()
424
 
425
- # Apply compression if enabled
426
  if config.compression_type != CompressionType.NONE and cache_manager is not None:
427
  try:
428
  cache_manager.compress_and_store(layer_idx, keys, values)
429
  except Exception as e:
430
  logger.error(f"Compression failed for layer {layer_idx}: {e}")
431
 
432
- # Reconstruct compressed cache
433
  if config.compression_type != CompressionType.NONE and cache_manager is not None:
434
  reconstructed_kv = []
435
  for layer_idx in range(len(kv_tuple)):
@@ -438,20 +397,16 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
438
  if dec_keys is not None and dec_values is not None:
439
  reconstructed_kv.append((dec_keys, dec_values))
440
  else:
441
- # Use original if decompression fails
442
- logger.warning(f"Decompression returned None for layer {layer_idx}, using original")
443
  reconstructed_kv.append(kv_tuple[layer_idx])
444
  except Exception as e:
445
  logger.error(f"Decompression failed for layer {layer_idx}: {e}")
446
  reconstructed_kv.append(kv_tuple[layer_idx])
447
 
448
- # Convert back to DynamicCache format
449
  if hasattr(DynamicCache, 'from_legacy_cache'):
450
  past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
451
  else:
452
  past_key_values = tuple(reconstructed_kv)
453
 
454
- # Measure compressed size
455
  try:
456
  compressed_cache_size = cache_manager.get_memory_footprint()
457
  except:
@@ -459,8 +414,8 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
459
  else:
460
  compressed_cache_size = original_cache_size
461
 
462
- # Calculate compression ratio
463
- compression_ratio = original_cache_size / compressed_cache_size if compressed_cache_size > 0 else 1.0
464
 
465
  except Exception as e:
466
  logger.error(f"Cache processing failed: {e}")
@@ -481,7 +436,6 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
481
 
482
  def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str:
483
  """Create Needle-in-a-Haystack test context - NO HARDCODING."""
484
- # Generate haystack text
485
  haystack_template = "The quick brown fox jumps over the lazy dog. " * 20
486
  haystack_chunks = []
487
 
@@ -490,7 +444,6 @@ def create_niah_haystack(context_length: int, needle: str, depth_percent: float)
490
 
491
  haystack = " ".join(haystack_chunks)[:context_length - len(needle) - 10]
492
 
493
- # Insert needle at specified depth
494
  insertion_point = int(len(haystack) * depth_percent / 100)
495
  haystack_with_needle = (
496
  haystack[:insertion_point] +
@@ -511,25 +464,19 @@ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Op
511
 
512
  prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:"
513
 
514
- # Use safe tokenization
515
  inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
516
  input_ids = inputs.input_ids.to(model.device)
517
  attention_mask = inputs.attention_mask.to(model.device)
518
 
519
- # Apply SAME compression pipeline as WikiText
520
  compression_result = apply_compression_pipeline(
521
  model, tokenizer, input_ids, attention_mask, cache_manager, config
522
  )
523
 
524
- # Generate with compressed cache using safe generation
525
  gen_start = time.perf_counter()
526
- output = safe_generate(model, tokenizer, input_ids, attention_mask,
527
- compression_result['past_key_values'], max_new_tokens=20)
528
  gen_time = time.perf_counter() - gen_start
529
 
530
- generated_text = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
531
-
532
- # Check if needle was retrieved
533
  accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0
534
 
535
  logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}")
@@ -547,10 +494,8 @@ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Op
547
 
548
  def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
549
  """Evaluate RULER with SAME compression pipeline as WikiText."""
550
- # Create synthetic RULER-like task
551
- seq_len = min(config.ruler_max_seq_length, config.prefill_length, 1024) # Cap at GPT-2 limit
552
 
553
- # Create a retrieval task with multiple facts
554
  facts = []
555
  for i in range(10):
556
  facts.append(f"Fact {i}: The capital of Country{i} is City{i}.")
@@ -565,20 +510,15 @@ def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: O
565
  input_ids = inputs.input_ids.to(model.device)
566
  attention_mask = inputs.attention_mask.to(model.device)
567
 
568
- # Apply SAME compression pipeline as WikiText
569
  compression_result = apply_compression_pipeline(
570
  model, tokenizer, input_ids, attention_mask, cache_manager, config
571
  )
572
 
573
- # Generate with compressed cache
574
  gen_start = time.perf_counter()
575
- output = safe_generate(model, tokenizer, input_ids, attention_mask,
576
- compression_result['past_key_values'], max_new_tokens=10)
577
  gen_time = time.perf_counter() - gen_start
578
 
579
- generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
580
-
581
- # Check exact match
582
  expected = f"City{query_idx}"
583
  exact_match = 1.0 if expected in generated else 0.0
584
 
@@ -597,7 +537,6 @@ def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: O
597
 
598
  def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
599
  """Evaluate SCBench with SAME compression pipeline as WikiText."""
600
- # Create multi-turn conversation
601
  conversation = []
602
  facts = {}
603
 
@@ -612,7 +551,6 @@ def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager:
612
  conversation.append(f"User: {user_msg}")
613
  conversation.append(f"Assistant: {assistant_msg}")
614
 
615
- # Query a random fact
616
  query_key = random.choice(list(facts.keys()))
617
  conversation.append(f"User: What is {query_key}?")
618
 
@@ -622,20 +560,15 @@ def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager:
622
  input_ids = inputs.input_ids.to(model.device)
623
  attention_mask = inputs.attention_mask.to(model.device)
624
 
625
- # Apply SAME compression pipeline as WikiText
626
  compression_result = apply_compression_pipeline(
627
  model, tokenizer, input_ids, attention_mask, cache_manager, config
628
  )
629
 
630
- # Generate with compressed cache
631
  gen_start = time.perf_counter()
632
- output = safe_generate(model, tokenizer, input_ids, attention_mask,
633
- compression_result['past_key_values'], max_new_tokens=20)
634
  gen_time = time.perf_counter() - gen_start
635
 
636
- generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
637
-
638
- # Check if correct value is recalled
639
  expected_value = facts[query_key]
640
  accuracy = 1.0 if expected_value in generated else 0.0
641
 
@@ -658,7 +591,6 @@ def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
658
  try:
659
  dataset = load_dataset("THUDM/LongBench", task, split="test")
660
 
661
- # Sample evaluation examples
662
  n_samples = min(config.eval_samples, len(dataset))
663
  samples = dataset.select(range(n_samples))
664
 
@@ -682,21 +614,16 @@ def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
682
  input_ids = inputs.input_ids.to(model.device)
683
  attention_mask = inputs.attention_mask.to(model.device)
684
 
685
- # Apply SAME compression pipeline as WikiText
686
  compression_result = apply_compression_pipeline(
687
  model, tokenizer, input_ids, attention_mask, cache_manager, config,
688
- measure_memory=False # Don't measure memory for each sample
689
  )
690
 
691
- # Generate with compressed cache
692
  gen_start = time.perf_counter()
693
- output = safe_generate(model, tokenizer, input_ids, attention_mask,
694
- compression_result['past_key_values'], max_new_tokens=50)
695
  gen_time = time.perf_counter() - gen_start
696
 
697
- generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
698
-
699
- # Simple accuracy metric
700
  score = 1.0 if str(answer).lower() in generated.lower() else 0.0
701
  scores.append(score)
702
  compression_ratios.append(compression_result['compression_ratio'])
@@ -705,7 +632,6 @@ def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
705
  gen_times.append(gen_time)
706
 
707
  avg_compression = float(np.mean(compression_ratios)) if compression_ratios else 1.0
708
- logger.info(f"LongBench {task} avg compression: {avg_compression:.1f}x")
709
 
710
  return {
711
  'accuracy': float(np.mean(scores)),
@@ -733,15 +659,11 @@ def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
733
  device = "cuda" if torch.cuda.is_available() else "cpu"
734
  dtype = torch.float16 if device == "cuda" else torch.float32
735
 
736
- # FAIL FAST if CUDA required but unavailable
737
  if config.fail_on_cpu_fallback and device == "cpu":
738
  raise RuntimeError("CUDA required but unavailable (fail_on_cpu_fallback=True)")
739
 
740
  logger.info(f"Loading model: {model_name}")
741
 
742
- # Check if model requires authentication
743
- model_info = SUPPORTED_MODELS.get(config.model_key, {})
744
-
745
  tokenizer = AutoTokenizer.from_pretrained(
746
  model_name,
747
  trust_remote_code=True
@@ -750,7 +672,6 @@ def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
750
  if tokenizer.pad_token is None:
751
  tokenizer.pad_token = tokenizer.eos_token
752
 
753
- # Model loading with Flash Attention support
754
  model_kwargs = {
755
  "torch_dtype": dtype,
756
  "device_map": "auto" if device == "cuda" else None,
@@ -758,20 +679,16 @@ def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
758
  "trust_remote_code": True
759
  }
760
 
761
- # Try Flash Attention if requested and available
762
  if config.use_flash_attention and device == "cuda":
763
  try:
764
- # First try to load with Flash Attention
765
  model_kwargs["attn_implementation"] = "flash_attention_2"
766
  model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
767
  logger.info("Successfully loaded with Flash Attention 2")
768
  except Exception as e:
769
- # Fall back to standard attention
770
- logger.warning(f"Flash Attention not available, using standard attention: {e}")
771
  model_kwargs.pop("attn_implementation", None)
772
  model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
773
  else:
774
- # Load without Flash Attention
775
  model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
776
 
777
  model.eval()
@@ -784,7 +701,6 @@ def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]
784
  logger.info(f"Loading samples for benchmark: {config.benchmark_type}")
785
 
786
  if config.benchmark_type == "wikitext":
787
- # Original WikiText loading
788
  texts = []
789
  min_tokens = config.prefill_length + config.generation_length
790
 
@@ -823,7 +739,6 @@ def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]
823
  raise
824
 
825
  elif config.benchmark_type == "longbench":
826
- # Load LongBench dataset
827
  texts = []
828
  if config.benchmark_subset:
829
  try:
@@ -839,7 +754,6 @@ def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]
839
  raise
840
 
841
  elif config.benchmark_type in ["niah", "ruler", "scbench"]:
842
- # These benchmarks generate synthetic data
843
  texts = ["Synthetic benchmark data"] * config.eval_samples
844
 
845
  else:
@@ -858,7 +772,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
858
  logger.info(f"Benchmark type: {config.benchmark_type}")
859
  logger.info(f"Config hash: {config.get_hash()}")
860
 
861
- # Enable synchronous CUDA for debugging
862
  if torch.cuda.is_available():
863
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
864
 
@@ -876,7 +789,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
876
  logger.error(f"Failed to detect model layers: {e}")
877
  raise
878
 
879
- # Warmup
880
  device = model.device
881
  with torch.inference_mode():
882
  dummy = torch.randint(0, tokenizer.vocab_size, (1, min(config.prefill_length, 128)), device=device)
@@ -899,13 +811,10 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
899
 
900
  metrics = BenchmarkMetrics()
901
 
902
- # Run benchmark-specific evaluation with UNIFIED compression
903
  if config.benchmark_type == "niah":
904
- # NIAH evaluation with unified compression
905
  for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
906
  config.niah_depth_percent = depth
907
  for idx in range(min(config.eval_samples, 10)):
908
- # Create cache manager for compression types
909
  if config.compression_type != CompressionType.NONE:
910
  cache_manager = QuantizedKVCache(config)
911
  cache_manager.n_layers = n_layers
@@ -918,12 +827,11 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
918
  metrics.compression_ratios.append(result['compression_ratio'])
919
  metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
920
  metrics.prefill_times.append(result['prefill_time'])
921
- metrics.decode_times.append(result['generation_time'] / 20) # Per token
922
 
923
  if result['prefill_peak_mem'] > 0:
924
  metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
925
 
926
- # Record per-sample data
927
  per_sample_records.append({
928
  'benchmark': 'niah',
929
  'depth_percent': depth,
@@ -935,7 +843,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
935
  })
936
 
937
  elif config.benchmark_type == "ruler":
938
- # RULER evaluation with unified compression
939
  for idx in range(config.eval_samples):
940
  if config.compression_type != CompressionType.NONE:
941
  cache_manager = QuantizedKVCache(config)
@@ -949,7 +856,7 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
949
  metrics.compression_ratios.append(result['compression_ratio'])
950
  metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
951
  metrics.prefill_times.append(result['prefill_time'])
952
- metrics.decode_times.append(result['generation_time'] / 10) # Per token
953
 
954
  if result['prefill_peak_mem'] > 0:
955
  metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
@@ -964,7 +871,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
964
  })
965
 
966
  elif config.benchmark_type == "scbench":
967
- # SCBench evaluation with unified compression
968
  for idx in range(config.eval_samples):
969
  if config.compression_type != CompressionType.NONE:
970
  cache_manager = QuantizedKVCache(config)
@@ -978,7 +884,7 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
978
  metrics.compression_ratios.append(result['compression_ratio'])
979
  metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
980
  metrics.prefill_times.append(result['prefill_time'])
981
- metrics.decode_times.append(result['generation_time'] / 20) # Per token
982
 
983
  if result['prefill_peak_mem'] > 0:
984
  metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
@@ -993,7 +899,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
993
  })
994
 
995
  elif config.benchmark_type == "longbench":
996
- # LongBench evaluation with unified compression
997
  if config.benchmark_subset:
998
  if config.compression_type != CompressionType.NONE:
999
  cache_manager = QuantizedKVCache(config)
@@ -1010,7 +915,7 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
1010
  metrics.prefill_times.append(result['prefill_time'])
1011
 
1012
  if result['generation_time'] > 0:
1013
- metrics.decode_times.append(result['generation_time'] / 50) # Per token
1014
 
1015
  per_sample_records.append({
1016
  'benchmark': 'longbench',
@@ -1022,7 +927,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
1022
  })
1023
 
1024
  else:
1025
- # Standard WikiText perplexity evaluation with existing compression
1026
  for idx in range(config.eval_samples):
1027
  logger.info(f"Sample {idx+1}/{config.eval_samples}")
1028
 
@@ -1036,12 +940,10 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
1036
  else:
1037
  cache_manager = None
1038
 
1039
- # Use safe tokenization
1040
  inputs = safe_tokenize(tokenizer, text, max_length=min(config.prefill_length, 1024))
1041
  input_ids = inputs.input_ids.to(device)
1042
  attention_mask = inputs.attention_mask.to(device)
1043
 
1044
- # Apply unified compression pipeline
1045
  compression_result = apply_compression_pipeline(
1046
  model, tokenizer, input_ids, attention_mask, cache_manager, config
1047
  )
@@ -1057,7 +959,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
1057
  prefill_perplexity = np.exp(compression_result['prefill_loss'])
1058
  metrics.prefill_perplexities.append(min(prefill_perplexity, 1000))
1059
 
1060
- # Generation phase with timing
1061
  generated_ids = input_ids.clone()
1062
  decode_times = []
1063
  generation_losses = []
@@ -1110,7 +1011,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
1110
  metrics.calculate_statistics(config)
1111
  all_metrics.append(metrics)
1112
 
1113
- # Aggregate results across seeds
1114
  final_metrics = BenchmarkMetrics()
1115
  for m in all_metrics:
1116
  final_metrics.prefill_times.extend(m.prefill_times)
@@ -1128,7 +1028,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
1128
 
1129
  final_metrics.calculate_statistics(config)
1130
 
1131
- # Summary
1132
  end_time = datetime.now().isoformat()
1133
  summary = {
1134
  'compression_type': config.compression_type.value,
@@ -1142,7 +1041,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
1142
  'end_time': end_time
1143
  }
1144
 
1145
- # Add benchmark-specific metrics
1146
  if config.benchmark_type == "niah" and final_metrics.niah_retrieval_accuracy:
1147
  summary['niah_accuracy'] = float(np.mean(final_metrics.niah_retrieval_accuracy))
1148
  elif config.benchmark_type == "ruler" and final_metrics.ruler_exact_match:
@@ -1155,7 +1053,6 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
1155
  summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean
1156
  summary['generation_perplexity'] = final_metrics.generation_perplexity_mean
1157
 
1158
- # Always add timing and memory metrics
1159
  summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000
1160
  summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms
1161
  summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec
@@ -1253,7 +1150,6 @@ def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: Pr
1253
  recomputed = {}
1254
  failures = []
1255
 
1256
- # Verify based on benchmark type
1257
  if config.benchmark_type == "niah":
1258
  if "niah_accuracy" in summary:
1259
  recomputed["niah_accuracy"] = mean_of("accuracy")
@@ -1267,13 +1163,11 @@ def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: Pr
1267
  if "longbench_accuracy" in summary:
1268
  recomputed["longbench_accuracy"] = mean_of("accuracy")
1269
  elif config.benchmark_type == "wikitext":
1270
- # WikiText benchmark metrics
1271
  if "prefill_perplexity" in summary:
1272
  recomputed["prefill_perplexity"] = mean_of("prefill_perplexity")
1273
  if "generation_perplexity" in summary:
1274
  recomputed["generation_perplexity"] = mean_of("generation_perplexity")
1275
 
1276
- # Always verify compression metrics
1277
  recomputed["compression_ratio"] = mean_of("compression_ratio")
1278
  recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb")
1279
 
 
4
  Supports LongBench, NIAH, RULER, SCBench benchmarks.
5
  MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
6
  ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT.
7
+ FIXED: Generation errors, proper fallback handling.
8
  """
9
 
10
  import torch
 
144
  self.prefill_time_std = float(np.std(self.prefill_times))
145
  self.prefill_time_ci = self._bootstrap_ci(self.prefill_times, config)
146
  self.prefill_tokens_per_sec = config.prefill_length / self.prefill_time_mean if self.prefill_time_mean > 0 else 0.0
 
 
147
 
148
  if self.prefill_peak_memories:
149
  memories_mb = [m / (1024 * 1024) for m in self.prefill_peak_memories]
150
  self.prefill_peak_memory_mean_mb = float(np.mean(memories_mb))
151
  self.prefill_peak_memory_std_mb = float(np.std(memories_mb))
152
  self.prefill_peak_memory_ci_mb = self._bootstrap_ci(memories_mb, config)
 
 
153
 
154
  if self.decode_times:
155
  self.decode_time_per_token_mean_ms = float(np.mean(self.decode_times) * 1000)
 
158
  self.decode_tokens_per_sec = 1.0 / np.mean(self.decode_times) if self.decode_times else 0.0
159
  self.decode_time_p50_ms = float(np.percentile(self.decode_times, 50) * 1000)
160
  self.decode_time_p95_ms = float(np.percentile(self.decode_times, 95) * 1000)
 
 
161
 
162
  # Calculate end-to-end throughput
163
  if self.prefill_time_mean > 0 and self.decode_time_per_token_mean_ms > 0:
 
168
 
169
  if self.decode_peak_memories:
170
  self.decode_peak_memory_mean_mb = float(np.mean(self.decode_peak_memories) / (1024 * 1024))
 
 
171
 
172
  if self.prefill_perplexities:
173
  self.prefill_perplexity_mean = float(np.mean(self.prefill_perplexities))
174
  self.prefill_perplexity_std = float(np.std(self.prefill_perplexities))
175
  self.prefill_perplexity_ci = self._bootstrap_ci(self.prefill_perplexities, config)
 
 
 
 
176
 
177
  if self.generation_perplexities:
178
  self.generation_perplexity_mean = float(np.mean(self.generation_perplexities))
179
  self.generation_perplexity_std = float(np.std(self.generation_perplexities))
180
  self.generation_perplexity_ci = self._bootstrap_ci(self.generation_perplexities, config)
 
 
 
 
181
 
182
  if self.compression_ratios:
183
  self.compression_ratio_mean = float(np.mean(self.compression_ratios))
184
  self.compression_ratio_std = float(np.std(self.compression_ratios))
 
 
185
 
186
  if self.kv_cache_memory_samples_mb:
187
  self.kv_cache_memory_mb = float(np.mean(self.kv_cache_memory_samples_mb))
 
 
188
 
189
  except Exception as e:
190
  logger.error(f"Error calculating statistics: {e}")
 
193
  def _bootstrap_ci(self, data: List[float], config: CompressionConfig) -> Tuple[float, float]:
194
  """Calculate bootstrap confidence interval with reproducible RNG."""
195
  if not data or len(data) < 2:
 
196
  return (0.0, 0.0)
197
 
198
  try:
 
219
 
220
  def safe_tokenize(tokenizer, text, max_length=512):
221
  """Safe tokenization with proper padding and truncation."""
 
222
  if tokenizer.pad_token is None:
223
  tokenizer.pad_token = tokenizer.eos_token
224
 
 
225
  inputs = tokenizer(
226
  text,
227
  return_tensors="pt",
 
232
  add_special_tokens=True
233
  )
234
 
 
235
  if inputs.input_ids.shape[1] == 0:
236
  raise ValueError("Tokenization produced empty sequence")
237
 
238
  if inputs.input_ids.shape[1] > max_length:
 
239
  inputs.input_ids = inputs.input_ids[:, :max_length]
240
  inputs.attention_mask = inputs.attention_mask[:, :max_length]
241
 
 
244
 
245
  def validate_model_inputs(model, input_ids, attention_mask):
246
  """Validate inputs are compatible with model."""
 
247
  if hasattr(model.config, 'max_position_embeddings'):
248
  max_pos = model.config.max_position_embeddings
249
  if input_ids.shape[1] > max_pos:
 
250
  input_ids = input_ids[:, :max_pos]
251
  attention_mask = attention_mask[:, :max_pos]
252
 
 
253
  if hasattr(model.config, 'n_positions'):
254
  n_pos = model.config.n_positions
255
  if input_ids.shape[1] > n_pos:
 
256
  input_ids = input_ids[:, :n_pos]
257
  attention_mask = attention_mask[:, :n_pos]
258
 
 
259
  vocab_size = model.config.vocab_size
260
  if input_ids.max() >= vocab_size:
261
+ input_ids = input_ids.clamp(0, vocab_size - 1)
262
+
263
+ if input_ids.min() < 0:
264
  input_ids = input_ids.clamp(0, vocab_size - 1)
265
 
266
  return input_ids, attention_mask
267
 
268
 
269
  def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=None, max_new_tokens=20):
270
+ """Safe generation with proper error handling - returns generated text."""
271
  try:
 
272
  input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
273
 
 
274
  gen_config = {
275
  "max_new_tokens": max_new_tokens,
 
276
  "do_sample": False,
277
  "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
278
  "eos_token_id": tokenizer.eos_token_id,
 
280
  "use_cache": True
281
  }
282
 
 
283
  if past_key_values is not None:
284
  gen_config["past_key_values"] = past_key_values
285
 
 
286
  with torch.no_grad():
287
  output = model.generate(input_ids, **gen_config)
288
 
289
+ # Decode only the generated part
290
+ generated_ids = output[:, input_ids.shape[1]:]
291
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
292
+ return generated_text
293
 
294
  except Exception as e:
295
  logger.error(f"Generation failed: {e}")
296
+ # Return empty string on failure
297
+ return ""
298
 
299
 
300
  def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
 
306
  """
307
  device = input_ids.device
308
 
 
309
  input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
310
 
 
311
  if torch.cuda.is_available() and measure_memory:
312
  torch.cuda.empty_cache()
313
  torch.cuda.reset_peak_memory_stats()
314
  torch.cuda.synchronize()
315
 
 
316
  if torch.cuda.is_available():
317
  torch.cuda.synchronize()
318
  start_time = time.perf_counter()
319
 
 
320
  try:
321
  with torch.inference_mode():
322
  outputs = model(
 
329
  logits = outputs.logits
330
  except Exception as e:
331
  logger.error(f"Prefill failed: {e}")
 
332
  return {
333
  'past_key_values': None,
334
  'prefill_time': 0,
 
345
 
346
  prefill_time = time.perf_counter() - start_time
347
 
 
348
  prefill_peak_mem = 0
349
  if torch.cuda.is_available() and measure_memory:
350
  prefill_peak_mem = _peak_mem_bytes_all_gpus()
351
 
 
352
  prefill_loss = None
353
  if logits is not None and input_ids.shape[1] > 1:
354
  try:
 
355
  seq_len = min(logits.shape[1], input_ids.shape[1] - 1)
356
  if seq_len > 0:
357
  shift_logits = logits[:, :seq_len, :].contiguous()
358
  shift_labels = input_ids[:, 1:seq_len+1].contiguous()
359
 
 
360
  loss = F.cross_entropy(
361
  shift_logits.view(-1, shift_logits.size(-1)),
362
  shift_labels.view(-1),
 
367
  except Exception as e:
368
  logger.warning(f"Could not calculate prefill loss: {e}")
369
 
 
370
  original_cache_size = 0
371
  compressed_cache_size = 0
372
  compression_ratio = 1.0
373
 
374
  if past_key_values:
375
  try:
376
+ if hasattr(past_key_values, 'to_legacy_cache'):
377
+ kv_tuple = past_key_values.to_legacy_cache()
378
+ else:
379
+ kv_tuple = past_key_values
380
 
 
381
  for layer_idx, (keys, values) in enumerate(kv_tuple):
382
  if keys is not None and values is not None:
383
  original_cache_size += keys.nelement() * keys.element_size()
384
  original_cache_size += values.nelement() * values.element_size()
385
 
 
386
  if config.compression_type != CompressionType.NONE and cache_manager is not None:
387
  try:
388
  cache_manager.compress_and_store(layer_idx, keys, values)
389
  except Exception as e:
390
  logger.error(f"Compression failed for layer {layer_idx}: {e}")
391
 
 
392
  if config.compression_type != CompressionType.NONE and cache_manager is not None:
393
  reconstructed_kv = []
394
  for layer_idx in range(len(kv_tuple)):
 
397
  if dec_keys is not None and dec_values is not None:
398
  reconstructed_kv.append((dec_keys, dec_values))
399
  else:
 
 
400
  reconstructed_kv.append(kv_tuple[layer_idx])
401
  except Exception as e:
402
  logger.error(f"Decompression failed for layer {layer_idx}: {e}")
403
  reconstructed_kv.append(kv_tuple[layer_idx])
404
 
 
405
  if hasattr(DynamicCache, 'from_legacy_cache'):
406
  past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
407
  else:
408
  past_key_values = tuple(reconstructed_kv)
409
 
 
410
  try:
411
  compressed_cache_size = cache_manager.get_memory_footprint()
412
  except:
 
414
  else:
415
  compressed_cache_size = original_cache_size
416
 
417
+ if compressed_cache_size > 0:
418
+ compression_ratio = original_cache_size / compressed_cache_size
419
 
420
  except Exception as e:
421
  logger.error(f"Cache processing failed: {e}")
 
436
 
437
  def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str:
438
  """Create Needle-in-a-Haystack test context - NO HARDCODING."""
 
439
  haystack_template = "The quick brown fox jumps over the lazy dog. " * 20
440
  haystack_chunks = []
441
 
 
444
 
445
  haystack = " ".join(haystack_chunks)[:context_length - len(needle) - 10]
446
 
 
447
  insertion_point = int(len(haystack) * depth_percent / 100)
448
  haystack_with_needle = (
449
  haystack[:insertion_point] +
 
464
 
465
  prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:"
466
 
 
467
  inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
468
  input_ids = inputs.input_ids.to(model.device)
469
  attention_mask = inputs.attention_mask.to(model.device)
470
 
 
471
  compression_result = apply_compression_pipeline(
472
  model, tokenizer, input_ids, attention_mask, cache_manager, config
473
  )
474
 
 
475
  gen_start = time.perf_counter()
476
+ generated_text = safe_generate(model, tokenizer, input_ids, attention_mask,
477
+ compression_result['past_key_values'], max_new_tokens=20)
478
  gen_time = time.perf_counter() - gen_start
479
 
 
 
 
480
  accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0
481
 
482
  logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}")
 
494
 
495
  def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
496
  """Evaluate RULER with SAME compression pipeline as WikiText."""
497
+ seq_len = min(config.ruler_max_seq_length, config.prefill_length, 1024)
 
498
 
 
499
  facts = []
500
  for i in range(10):
501
  facts.append(f"Fact {i}: The capital of Country{i} is City{i}.")
 
510
  input_ids = inputs.input_ids.to(model.device)
511
  attention_mask = inputs.attention_mask.to(model.device)
512
 
 
513
  compression_result = apply_compression_pipeline(
514
  model, tokenizer, input_ids, attention_mask, cache_manager, config
515
  )
516
 
 
517
  gen_start = time.perf_counter()
518
+ generated = safe_generate(model, tokenizer, input_ids, attention_mask,
519
+ compression_result['past_key_values'], max_new_tokens=10)
520
  gen_time = time.perf_counter() - gen_start
521
 
 
 
 
522
  expected = f"City{query_idx}"
523
  exact_match = 1.0 if expected in generated else 0.0
524
 
 
537
 
538
  def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
539
  """Evaluate SCBench with SAME compression pipeline as WikiText."""
 
540
  conversation = []
541
  facts = {}
542
 
 
551
  conversation.append(f"User: {user_msg}")
552
  conversation.append(f"Assistant: {assistant_msg}")
553
 
 
554
  query_key = random.choice(list(facts.keys()))
555
  conversation.append(f"User: What is {query_key}?")
556
 
 
560
  input_ids = inputs.input_ids.to(model.device)
561
  attention_mask = inputs.attention_mask.to(model.device)
562
 
 
563
  compression_result = apply_compression_pipeline(
564
  model, tokenizer, input_ids, attention_mask, cache_manager, config
565
  )
566
 
 
567
  gen_start = time.perf_counter()
568
+ generated = safe_generate(model, tokenizer, input_ids, attention_mask,
569
+ compression_result['past_key_values'], max_new_tokens=20)
570
  gen_time = time.perf_counter() - gen_start
571
 
 
 
 
572
  expected_value = facts[query_key]
573
  accuracy = 1.0 if expected_value in generated else 0.0
574
 
 
591
  try:
592
  dataset = load_dataset("THUDM/LongBench", task, split="test")
593
 
 
594
  n_samples = min(config.eval_samples, len(dataset))
595
  samples = dataset.select(range(n_samples))
596
 
 
614
  input_ids = inputs.input_ids.to(model.device)
615
  attention_mask = inputs.attention_mask.to(model.device)
616
 
 
617
  compression_result = apply_compression_pipeline(
618
  model, tokenizer, input_ids, attention_mask, cache_manager, config,
619
+ measure_memory=False
620
  )
621
 
 
622
  gen_start = time.perf_counter()
623
+ generated = safe_generate(model, tokenizer, input_ids, attention_mask,
624
+ compression_result['past_key_values'], max_new_tokens=50)
625
  gen_time = time.perf_counter() - gen_start
626
 
 
 
 
627
  score = 1.0 if str(answer).lower() in generated.lower() else 0.0
628
  scores.append(score)
629
  compression_ratios.append(compression_result['compression_ratio'])
 
632
  gen_times.append(gen_time)
633
 
634
  avg_compression = float(np.mean(compression_ratios)) if compression_ratios else 1.0
 
635
 
636
  return {
637
  'accuracy': float(np.mean(scores)),
 
659
  device = "cuda" if torch.cuda.is_available() else "cpu"
660
  dtype = torch.float16 if device == "cuda" else torch.float32
661
 
 
662
  if config.fail_on_cpu_fallback and device == "cpu":
663
  raise RuntimeError("CUDA required but unavailable (fail_on_cpu_fallback=True)")
664
 
665
  logger.info(f"Loading model: {model_name}")
666
 
 
 
 
667
  tokenizer = AutoTokenizer.from_pretrained(
668
  model_name,
669
  trust_remote_code=True
 
672
  if tokenizer.pad_token is None:
673
  tokenizer.pad_token = tokenizer.eos_token
674
 
 
675
  model_kwargs = {
676
  "torch_dtype": dtype,
677
  "device_map": "auto" if device == "cuda" else None,
 
679
  "trust_remote_code": True
680
  }
681
 
 
682
  if config.use_flash_attention and device == "cuda":
683
  try:
 
684
  model_kwargs["attn_implementation"] = "flash_attention_2"
685
  model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
686
  logger.info("Successfully loaded with Flash Attention 2")
687
  except Exception as e:
688
+ logger.warning(f"Flash Attention not available: {e}")
 
689
  model_kwargs.pop("attn_implementation", None)
690
  model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
691
  else:
 
692
  model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
693
 
694
  model.eval()
 
701
  logger.info(f"Loading samples for benchmark: {config.benchmark_type}")
702
 
703
  if config.benchmark_type == "wikitext":
 
704
  texts = []
705
  min_tokens = config.prefill_length + config.generation_length
706
 
 
739
  raise
740
 
741
  elif config.benchmark_type == "longbench":
 
742
  texts = []
743
  if config.benchmark_subset:
744
  try:
 
754
  raise
755
 
756
  elif config.benchmark_type in ["niah", "ruler", "scbench"]:
 
757
  texts = ["Synthetic benchmark data"] * config.eval_samples
758
 
759
  else:
 
772
  logger.info(f"Benchmark type: {config.benchmark_type}")
773
  logger.info(f"Config hash: {config.get_hash()}")
774
 
 
775
  if torch.cuda.is_available():
776
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
777
 
 
789
  logger.error(f"Failed to detect model layers: {e}")
790
  raise
791
 
 
792
  device = model.device
793
  with torch.inference_mode():
794
  dummy = torch.randint(0, tokenizer.vocab_size, (1, min(config.prefill_length, 128)), device=device)
 
811
 
812
  metrics = BenchmarkMetrics()
813
 
 
814
  if config.benchmark_type == "niah":
 
815
  for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
816
  config.niah_depth_percent = depth
817
  for idx in range(min(config.eval_samples, 10)):
 
818
  if config.compression_type != CompressionType.NONE:
819
  cache_manager = QuantizedKVCache(config)
820
  cache_manager.n_layers = n_layers
 
827
  metrics.compression_ratios.append(result['compression_ratio'])
828
  metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
829
  metrics.prefill_times.append(result['prefill_time'])
830
+ metrics.decode_times.append(result['generation_time'] / 20)
831
 
832
  if result['prefill_peak_mem'] > 0:
833
  metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
834
 
 
835
  per_sample_records.append({
836
  'benchmark': 'niah',
837
  'depth_percent': depth,
 
843
  })
844
 
845
  elif config.benchmark_type == "ruler":
 
846
  for idx in range(config.eval_samples):
847
  if config.compression_type != CompressionType.NONE:
848
  cache_manager = QuantizedKVCache(config)
 
856
  metrics.compression_ratios.append(result['compression_ratio'])
857
  metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
858
  metrics.prefill_times.append(result['prefill_time'])
859
+ metrics.decode_times.append(result['generation_time'] / 10)
860
 
861
  if result['prefill_peak_mem'] > 0:
862
  metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
 
871
  })
872
 
873
  elif config.benchmark_type == "scbench":
 
874
  for idx in range(config.eval_samples):
875
  if config.compression_type != CompressionType.NONE:
876
  cache_manager = QuantizedKVCache(config)
 
884
  metrics.compression_ratios.append(result['compression_ratio'])
885
  metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
886
  metrics.prefill_times.append(result['prefill_time'])
887
+ metrics.decode_times.append(result['generation_time'] / 20)
888
 
889
  if result['prefill_peak_mem'] > 0:
890
  metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
 
899
  })
900
 
901
  elif config.benchmark_type == "longbench":
 
902
  if config.benchmark_subset:
903
  if config.compression_type != CompressionType.NONE:
904
  cache_manager = QuantizedKVCache(config)
 
915
  metrics.prefill_times.append(result['prefill_time'])
916
 
917
  if result['generation_time'] > 0:
918
+ metrics.decode_times.append(result['generation_time'] / 50)
919
 
920
  per_sample_records.append({
921
  'benchmark': 'longbench',
 
927
  })
928
 
929
  else:
 
930
  for idx in range(config.eval_samples):
931
  logger.info(f"Sample {idx+1}/{config.eval_samples}")
932
 
 
940
  else:
941
  cache_manager = None
942
 
 
943
  inputs = safe_tokenize(tokenizer, text, max_length=min(config.prefill_length, 1024))
944
  input_ids = inputs.input_ids.to(device)
945
  attention_mask = inputs.attention_mask.to(device)
946
 
 
947
  compression_result = apply_compression_pipeline(
948
  model, tokenizer, input_ids, attention_mask, cache_manager, config
949
  )
 
959
  prefill_perplexity = np.exp(compression_result['prefill_loss'])
960
  metrics.prefill_perplexities.append(min(prefill_perplexity, 1000))
961
 
 
962
  generated_ids = input_ids.clone()
963
  decode_times = []
964
  generation_losses = []
 
1011
  metrics.calculate_statistics(config)
1012
  all_metrics.append(metrics)
1013
 
 
1014
  final_metrics = BenchmarkMetrics()
1015
  for m in all_metrics:
1016
  final_metrics.prefill_times.extend(m.prefill_times)
 
1028
 
1029
  final_metrics.calculate_statistics(config)
1030
 
 
1031
  end_time = datetime.now().isoformat()
1032
  summary = {
1033
  'compression_type': config.compression_type.value,
 
1041
  'end_time': end_time
1042
  }
1043
 
 
1044
  if config.benchmark_type == "niah" and final_metrics.niah_retrieval_accuracy:
1045
  summary['niah_accuracy'] = float(np.mean(final_metrics.niah_retrieval_accuracy))
1046
  elif config.benchmark_type == "ruler" and final_metrics.ruler_exact_match:
 
1053
  summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean
1054
  summary['generation_perplexity'] = final_metrics.generation_perplexity_mean
1055
 
 
1056
  summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000
1057
  summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms
1058
  summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec
 
1150
  recomputed = {}
1151
  failures = []
1152
 
 
1153
  if config.benchmark_type == "niah":
1154
  if "niah_accuracy" in summary:
1155
  recomputed["niah_accuracy"] = mean_of("accuracy")
 
1163
  if "longbench_accuracy" in summary:
1164
  recomputed["longbench_accuracy"] = mean_of("accuracy")
1165
  elif config.benchmark_type == "wikitext":
 
1166
  if "prefill_perplexity" in summary:
1167
  recomputed["prefill_perplexity"] = mean_of("prefill_perplexity")
1168
  if "generation_perplexity" in summary:
1169
  recomputed["generation_perplexity"] = mean_of("generation_perplexity")
1170
 
 
1171
  recomputed["compression_ratio"] = mean_of("compression_ratio")
1172
  recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb")
1173