kfoughali commited on
Commit
33d9292
·
verified ·
1 Parent(s): 52797d8

Update benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +242 -144
benchmark.py CHANGED
@@ -1,8 +1,10 @@
 
1
  """
2
  Benchmarking, metrics, and proof generation for Enhanced SPG.
3
  Supports LongBench, NIAH, RULER, SCBench benchmarks.
4
  MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
5
  ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT.
 
6
  """
7
 
8
  import torch
@@ -236,15 +238,107 @@ class BenchmarkMetrics:
236
  return (0.0, 0.0)
237
 
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
240
  cache_manager: QuantizedKVCache, config: CompressionConfig,
241
  measure_memory: bool = True) -> Dict[str, Any]:
242
  """
243
- Unified compression pipeline for ALL benchmarks.
244
  Returns compressed cache, metrics, and reconstructed KV pairs.
245
  """
246
  device = input_ids.device
247
 
 
 
 
248
  # Clear GPU cache if requested
249
  if torch.cuda.is_available() and measure_memory:
250
  torch.cuda.empty_cache()
@@ -256,16 +350,30 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
256
  torch.cuda.synchronize()
257
  start_time = time.perf_counter()
258
 
259
- # Prefill phase
260
- with torch.inference_mode():
261
- outputs = model(
262
- input_ids,
263
- attention_mask=attention_mask,
264
- use_cache=True,
265
- return_dict=True
266
- )
267
- past_key_values = outputs.past_key_values
268
- logits = outputs.logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  if torch.cuda.is_available():
271
  torch.cuda.synchronize()
@@ -277,18 +385,26 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
277
  if torch.cuda.is_available() and measure_memory:
278
  prefill_peak_mem = _peak_mem_bytes_all_gpus()
279
 
280
- # Calculate prefill perplexity if we have logits
281
  prefill_loss = None
282
  if logits is not None and input_ids.shape[1] > 1:
283
- shift_logits = logits[..., :-1, :].contiguous()
284
- shift_labels = input_ids[..., 1:].contiguous()
285
- loss = F.cross_entropy(
286
- shift_logits.view(-1, shift_logits.size(-1)),
287
- shift_labels.view(-1),
288
- reduction='mean',
289
- ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100
290
- )
291
- prefill_loss = loss.item()
 
 
 
 
 
 
 
 
292
 
293
  # Compression phase - same as WikiText
294
  original_cache_size = 0
@@ -296,39 +412,60 @@ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
296
  compression_ratio = 1.0
297
 
298
  if past_key_values:
299
- # Convert to legacy format for processing
300
- kv_tuple = past_key_values.to_legacy_cache() if hasattr(past_key_values, 'to_legacy_cache') else past_key_values
301
-
302
- # Calculate original size
303
- for layer_idx, (keys, values) in enumerate(kv_tuple):
304
- original_cache_size += keys.nelement() * keys.element_size()
305
- original_cache_size += values.nelement() * values.element_size()
306
 
307
- # Apply compression if enabled
308
- if config.compression_type != CompressionType.NONE:
309
- cache_manager.compress_and_store(layer_idx, keys, values)
310
-
311
- # Reconstruct compressed cache
312
- if config.compression_type != CompressionType.NONE:
313
- reconstructed_kv = []
314
- for layer_idx in range(len(kv_tuple)):
315
- dec_keys, dec_values = cache_manager.get_decompressed(layer_idx)
316
- if dec_keys is not None and dec_values is not None:
317
- reconstructed_kv.append((dec_keys, dec_values))
 
318
 
319
- # Convert back to DynamicCache format
320
- if hasattr(DynamicCache, 'from_legacy_cache'):
321
- past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  else:
323
- past_key_values = tuple(reconstructed_kv)
324
 
325
- # Measure compressed size
326
- compressed_cache_size = cache_manager.get_memory_footprint()
327
- else:
 
 
328
  compressed_cache_size = original_cache_size
329
-
330
- # Calculate compression ratio
331
- compression_ratio = original_cache_size / compressed_cache_size if compressed_cache_size > 0 else 1.0
332
 
333
  return {
334
  'past_key_values': past_key_values,
@@ -374,7 +511,8 @@ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Op
374
 
375
  prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:"
376
 
377
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config.prefill_length)
 
378
  input_ids = inputs.input_ids.to(model.device)
379
  attention_mask = inputs.attention_mask.to(model.device)
380
 
@@ -383,25 +521,11 @@ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Op
383
  model, tokenizer, input_ids, attention_mask, cache_manager, config
384
  )
385
 
386
- # Generate with compressed cache
387
- with torch.inference_mode():
388
- # Measure generation time
389
- if torch.cuda.is_available():
390
- torch.cuda.synchronize()
391
- gen_start = time.perf_counter()
392
-
393
- output = model.generate(
394
- input_ids,
395
- past_key_values=compression_result['past_key_values'],
396
- max_new_tokens=20,
397
- temperature=0.0,
398
- do_sample=False,
399
- attention_mask=attention_mask
400
- )
401
-
402
- if torch.cuda.is_available():
403
- torch.cuda.synchronize()
404
- gen_time = time.perf_counter() - gen_start
405
 
406
  generated_text = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
407
 
@@ -424,7 +548,7 @@ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Op
424
  def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
425
  """Evaluate RULER with SAME compression pipeline as WikiText."""
426
  # Create synthetic RULER-like task
427
- seq_len = min(config.ruler_max_seq_length, config.prefill_length)
428
 
429
  # Create a retrieval task with multiple facts
430
  facts = []
@@ -437,7 +561,7 @@ def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: O
437
  query_idx = random.randint(0, 9)
438
  prompt = f"{context}\n\nWhat is the capital of Country{query_idx}?"
439
 
440
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=seq_len)
441
  input_ids = inputs.input_ids.to(model.device)
442
  attention_mask = inputs.attention_mask.to(model.device)
443
 
@@ -447,23 +571,10 @@ def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: O
447
  )
448
 
449
  # Generate with compressed cache
450
- with torch.inference_mode():
451
- if torch.cuda.is_available():
452
- torch.cuda.synchronize()
453
- gen_start = time.perf_counter()
454
-
455
- output = model.generate(
456
- input_ids,
457
- past_key_values=compression_result['past_key_values'],
458
- max_new_tokens=10,
459
- temperature=0.0,
460
- do_sample=False,
461
- attention_mask=attention_mask
462
- )
463
-
464
- if torch.cuda.is_available():
465
- torch.cuda.synchronize()
466
- gen_time = time.perf_counter() - gen_start
467
 
468
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
469
 
@@ -507,8 +618,7 @@ def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager:
507
 
508
  full_conversation = "\n".join(conversation) + "\nAssistant:"
509
 
510
- inputs = tokenizer(full_conversation, return_tensors="pt", truncation=True,
511
- max_length=config.prefill_length)
512
  input_ids = inputs.input_ids.to(model.device)
513
  attention_mask = inputs.attention_mask.to(model.device)
514
 
@@ -518,23 +628,10 @@ def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager:
518
  )
519
 
520
  # Generate with compressed cache
521
- with torch.inference_mode():
522
- if torch.cuda.is_available():
523
- torch.cuda.synchronize()
524
- gen_start = time.perf_counter()
525
-
526
- output = model.generate(
527
- input_ids,
528
- past_key_values=compression_result['past_key_values'],
529
- max_new_tokens=20,
530
- temperature=0.0,
531
- do_sample=False,
532
- attention_mask=attention_mask
533
- )
534
-
535
- if torch.cuda.is_available():
536
- torch.cuda.synchronize()
537
- gen_time = time.perf_counter() - gen_start
538
 
539
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
540
 
@@ -581,8 +678,7 @@ def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
581
 
582
  prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
583
 
584
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
585
- max_length=config.prefill_length)
586
  input_ids = inputs.input_ids.to(model.device)
587
  attention_mask = inputs.attention_mask.to(model.device)
588
 
@@ -593,23 +689,10 @@ def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
593
  )
594
 
595
  # Generate with compressed cache
596
- with torch.inference_mode():
597
- if torch.cuda.is_available():
598
- torch.cuda.synchronize()
599
- gen_start = time.perf_counter()
600
-
601
- output = model.generate(
602
- input_ids,
603
- past_key_values=compression_result['past_key_values'],
604
- max_new_tokens=50,
605
- temperature=0.0,
606
- do_sample=False,
607
- attention_mask=attention_mask
608
- )
609
-
610
- if torch.cuda.is_available():
611
- torch.cuda.synchronize()
612
- gen_time = time.perf_counter() - gen_start
613
 
614
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
615
 
@@ -775,6 +858,10 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
775
  logger.info(f"Benchmark type: {config.benchmark_type}")
776
  logger.info(f"Config hash: {config.get_hash()}")
777
 
 
 
 
 
778
  constants = ResearchConstants()
779
  start_time = datetime.now().isoformat()
780
  per_sample_records = []
@@ -818,8 +905,12 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
818
  for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
819
  config.niah_depth_percent = depth
820
  for idx in range(min(config.eval_samples, 10)):
821
- cache_manager = QuantizedKVCache(config)
822
- cache_manager.n_layers = n_layers
 
 
 
 
823
 
824
  result = evaluate_niah(model, tokenizer, config, cache_manager)
825
 
@@ -846,8 +937,11 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
846
  elif config.benchmark_type == "ruler":
847
  # RULER evaluation with unified compression
848
  for idx in range(config.eval_samples):
849
- cache_manager = QuantizedKVCache(config)
850
- cache_manager.n_layers = n_layers
 
 
 
851
 
852
  result = evaluate_ruler(model, tokenizer, config, cache_manager)
853
 
@@ -872,8 +966,11 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
872
  elif config.benchmark_type == "scbench":
873
  # SCBench evaluation with unified compression
874
  for idx in range(config.eval_samples):
875
- cache_manager = QuantizedKVCache(config)
876
- cache_manager.n_layers = n_layers
 
 
 
877
 
878
  result = evaluate_scbench(model, tokenizer, config, cache_manager)
879
 
@@ -898,8 +995,11 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
898
  elif config.benchmark_type == "longbench":
899
  # LongBench evaluation with unified compression
900
  if config.benchmark_subset:
901
- cache_manager = QuantizedKVCache(config)
902
- cache_manager.n_layers = n_layers
 
 
 
903
 
904
  result = evaluate_longbench_task(model, tokenizer, config,
905
  config.benchmark_subset, cache_manager)
@@ -929,17 +1029,15 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
929
  text_idx = (idx + seed * config.eval_samples) % len(dataset_texts)
930
  text = dataset_texts[text_idx]
931
 
932
- cache_manager = QuantizedKVCache(config)
933
- cache_manager.n_layers = n_layers
934
- cache_manager.update_position(config.prefill_length + idx)
 
 
 
935
 
936
- inputs = tokenizer(
937
- text,
938
- return_tensors="pt",
939
- truncation=True,
940
- max_length=config.prefill_length,
941
- padding="max_length"
942
- )
943
  input_ids = inputs.input_ids.to(device)
944
  attention_mask = inputs.attention_mask.to(device)
945
 
 
1
+ # benchmark.py
2
  """
3
  Benchmarking, metrics, and proof generation for Enhanced SPG.
4
  Supports LongBench, NIAH, RULER, SCBench benchmarks.
5
  MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
6
  ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT.
7
+ FIXED: CUDA assert errors, tokenization issues, safe generation.
8
  """
9
 
10
  import torch
 
238
  return (0.0, 0.0)
239
 
240
 
241
+ def safe_tokenize(tokenizer, text, max_length=512):
242
+ """Safe tokenization with proper padding and truncation."""
243
+ # Ensure pad_token is set
244
+ if tokenizer.pad_token is None:
245
+ tokenizer.pad_token = tokenizer.eos_token
246
+
247
+ # Tokenize with explicit parameters
248
+ inputs = tokenizer(
249
+ text,
250
+ return_tensors="pt",
251
+ truncation=True,
252
+ max_length=max_length,
253
+ padding="max_length",
254
+ return_attention_mask=True,
255
+ add_special_tokens=True
256
+ )
257
+
258
+ # Validate outputs
259
+ if inputs.input_ids.shape[1] == 0:
260
+ raise ValueError("Tokenization produced empty sequence")
261
+
262
+ if inputs.input_ids.shape[1] > max_length:
263
+ logger.warning(f"Sequence length {inputs.input_ids.shape[1]} exceeds max {max_length}")
264
+ inputs.input_ids = inputs.input_ids[:, :max_length]
265
+ inputs.attention_mask = inputs.attention_mask[:, :max_length]
266
+
267
+ return inputs
268
+
269
+
270
+ def validate_model_inputs(model, input_ids, attention_mask):
271
+ """Validate inputs are compatible with model."""
272
+ # Check sequence length against model's max position embeddings
273
+ if hasattr(model.config, 'max_position_embeddings'):
274
+ max_pos = model.config.max_position_embeddings
275
+ if input_ids.shape[1] > max_pos:
276
+ logger.warning(f"Input length {input_ids.shape[1]} exceeds model max {max_pos}")
277
+ input_ids = input_ids[:, :max_pos]
278
+ attention_mask = attention_mask[:, :max_pos]
279
+
280
+ # For GPT-2, check n_positions
281
+ if hasattr(model.config, 'n_positions'):
282
+ n_pos = model.config.n_positions
283
+ if input_ids.shape[1] > n_pos:
284
+ logger.warning(f"Input length {input_ids.shape[1]} exceeds GPT-2 positions {n_pos}")
285
+ input_ids = input_ids[:, :n_pos]
286
+ attention_mask = attention_mask[:, :n_pos]
287
+
288
+ # Ensure input_ids are within vocabulary range
289
+ vocab_size = model.config.vocab_size
290
+ if input_ids.max() >= vocab_size:
291
+ logger.error(f"Token id {input_ids.max()} exceeds vocab size {vocab_size}")
292
+ input_ids = input_ids.clamp(0, vocab_size - 1)
293
+
294
+ return input_ids, attention_mask
295
+
296
+
297
+ def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=None, max_new_tokens=20):
298
+ """Safe generation with proper error handling."""
299
+ try:
300
+ # Validate inputs
301
+ input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
302
+
303
+ # Set generation config
304
+ gen_config = {
305
+ "max_new_tokens": max_new_tokens,
306
+ "temperature": 0.7,
307
+ "do_sample": False,
308
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
309
+ "eos_token_id": tokenizer.eos_token_id,
310
+ "attention_mask": attention_mask,
311
+ "use_cache": True
312
+ }
313
+
314
+ # Add past_key_values if available
315
+ if past_key_values is not None:
316
+ gen_config["past_key_values"] = past_key_values
317
+
318
+ # Generate with error handling
319
+ with torch.no_grad():
320
+ output = model.generate(input_ids, **gen_config)
321
+
322
+ return output
323
+
324
+ except Exception as e:
325
+ logger.error(f"Generation failed: {e}")
326
+ # Return input as fallback
327
+ return input_ids
328
+
329
+
330
  def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
331
  cache_manager: QuantizedKVCache, config: CompressionConfig,
332
  measure_memory: bool = True) -> Dict[str, Any]:
333
  """
334
+ Unified compression pipeline for ALL benchmarks with safety fixes.
335
  Returns compressed cache, metrics, and reconstructed KV pairs.
336
  """
337
  device = input_ids.device
338
 
339
+ # Validate inputs first
340
+ input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
341
+
342
  # Clear GPU cache if requested
343
  if torch.cuda.is_available() and measure_memory:
344
  torch.cuda.empty_cache()
 
350
  torch.cuda.synchronize()
351
  start_time = time.perf_counter()
352
 
353
+ # Prefill phase with error handling
354
+ try:
355
+ with torch.inference_mode():
356
+ outputs = model(
357
+ input_ids,
358
+ attention_mask=attention_mask,
359
+ use_cache=True,
360
+ return_dict=True
361
+ )
362
+ past_key_values = outputs.past_key_values
363
+ logits = outputs.logits
364
+ except Exception as e:
365
+ logger.error(f"Prefill failed: {e}")
366
+ # Return minimal valid result
367
+ return {
368
+ 'past_key_values': None,
369
+ 'prefill_time': 0,
370
+ 'prefill_peak_mem': 0,
371
+ 'prefill_loss': None,
372
+ 'original_cache_size': 0,
373
+ 'compressed_cache_size': 0,
374
+ 'compression_ratio': 1.0,
375
+ 'logits': None
376
+ }
377
 
378
  if torch.cuda.is_available():
379
  torch.cuda.synchronize()
 
385
  if torch.cuda.is_available() and measure_memory:
386
  prefill_peak_mem = _peak_mem_bytes_all_gpus()
387
 
388
+ # Calculate prefill perplexity safely
389
  prefill_loss = None
390
  if logits is not None and input_ids.shape[1] > 1:
391
+ try:
392
+ # Ensure we have valid shapes
393
+ seq_len = min(logits.shape[1], input_ids.shape[1] - 1)
394
+ if seq_len > 0:
395
+ shift_logits = logits[:, :seq_len, :].contiguous()
396
+ shift_labels = input_ids[:, 1:seq_len+1].contiguous()
397
+
398
+ # Calculate loss with ignore_index for padding
399
+ loss = F.cross_entropy(
400
+ shift_logits.view(-1, shift_logits.size(-1)),
401
+ shift_labels.view(-1),
402
+ reduction='mean',
403
+ ignore_index=tokenizer.pad_token_id or -100
404
+ )
405
+ prefill_loss = loss.item()
406
+ except Exception as e:
407
+ logger.warning(f"Could not calculate prefill loss: {e}")
408
 
409
  # Compression phase - same as WikiText
410
  original_cache_size = 0
 
412
  compression_ratio = 1.0
413
 
414
  if past_key_values:
415
+ try:
416
+ # Convert to legacy format for processing
417
+ kv_tuple = past_key_values.to_legacy_cache() if hasattr(past_key_values, 'to_legacy_cache') else past_key_values
 
 
 
 
418
 
419
+ # Calculate original size
420
+ for layer_idx, (keys, values) in enumerate(kv_tuple):
421
+ if keys is not None and values is not None:
422
+ original_cache_size += keys.nelement() * keys.element_size()
423
+ original_cache_size += values.nelement() * values.element_size()
424
+
425
+ # Apply compression if enabled
426
+ if config.compression_type != CompressionType.NONE and cache_manager is not None:
427
+ try:
428
+ cache_manager.compress_and_store(layer_idx, keys, values)
429
+ except Exception as e:
430
+ logger.error(f"Compression failed for layer {layer_idx}: {e}")
431
 
432
+ # Reconstruct compressed cache
433
+ if config.compression_type != CompressionType.NONE and cache_manager is not None:
434
+ reconstructed_kv = []
435
+ for layer_idx in range(len(kv_tuple)):
436
+ try:
437
+ dec_keys, dec_values = cache_manager.get_decompressed(layer_idx)
438
+ if dec_keys is not None and dec_values is not None:
439
+ reconstructed_kv.append((dec_keys, dec_values))
440
+ else:
441
+ # Use original if decompression fails
442
+ logger.warning(f"Decompression returned None for layer {layer_idx}, using original")
443
+ reconstructed_kv.append(kv_tuple[layer_idx])
444
+ except Exception as e:
445
+ logger.error(f"Decompression failed for layer {layer_idx}: {e}")
446
+ reconstructed_kv.append(kv_tuple[layer_idx])
447
+
448
+ # Convert back to DynamicCache format
449
+ if hasattr(DynamicCache, 'from_legacy_cache'):
450
+ past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
451
+ else:
452
+ past_key_values = tuple(reconstructed_kv)
453
+
454
+ # Measure compressed size
455
+ try:
456
+ compressed_cache_size = cache_manager.get_memory_footprint()
457
+ except:
458
+ compressed_cache_size = original_cache_size
459
  else:
460
+ compressed_cache_size = original_cache_size
461
 
462
+ # Calculate compression ratio
463
+ compression_ratio = original_cache_size / compressed_cache_size if compressed_cache_size > 0 else 1.0
464
+
465
+ except Exception as e:
466
+ logger.error(f"Cache processing failed: {e}")
467
  compressed_cache_size = original_cache_size
468
+ compression_ratio = 1.0
 
 
469
 
470
  return {
471
  'past_key_values': past_key_values,
 
511
 
512
  prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:"
513
 
514
+ # Use safe tokenization
515
+ inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
516
  input_ids = inputs.input_ids.to(model.device)
517
  attention_mask = inputs.attention_mask.to(model.device)
518
 
 
521
  model, tokenizer, input_ids, attention_mask, cache_manager, config
522
  )
523
 
524
+ # Generate with compressed cache using safe generation
525
+ gen_start = time.perf_counter()
526
+ output = safe_generate(model, tokenizer, input_ids, attention_mask,
527
+ compression_result['past_key_values'], max_new_tokens=20)
528
+ gen_time = time.perf_counter() - gen_start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
  generated_text = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
531
 
 
548
  def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
549
  """Evaluate RULER with SAME compression pipeline as WikiText."""
550
  # Create synthetic RULER-like task
551
+ seq_len = min(config.ruler_max_seq_length, config.prefill_length, 1024) # Cap at GPT-2 limit
552
 
553
  # Create a retrieval task with multiple facts
554
  facts = []
 
561
  query_idx = random.randint(0, 9)
562
  prompt = f"{context}\n\nWhat is the capital of Country{query_idx}?"
563
 
564
+ inputs = safe_tokenize(tokenizer, prompt, max_length=seq_len)
565
  input_ids = inputs.input_ids.to(model.device)
566
  attention_mask = inputs.attention_mask.to(model.device)
567
 
 
571
  )
572
 
573
  # Generate with compressed cache
574
+ gen_start = time.perf_counter()
575
+ output = safe_generate(model, tokenizer, input_ids, attention_mask,
576
+ compression_result['past_key_values'], max_new_tokens=10)
577
+ gen_time = time.perf_counter() - gen_start
 
 
 
 
 
 
 
 
 
 
 
 
 
578
 
579
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
580
 
 
618
 
619
  full_conversation = "\n".join(conversation) + "\nAssistant:"
620
 
621
+ inputs = safe_tokenize(tokenizer, full_conversation, max_length=min(config.prefill_length, 1024))
 
622
  input_ids = inputs.input_ids.to(model.device)
623
  attention_mask = inputs.attention_mask.to(model.device)
624
 
 
628
  )
629
 
630
  # Generate with compressed cache
631
+ gen_start = time.perf_counter()
632
+ output = safe_generate(model, tokenizer, input_ids, attention_mask,
633
+ compression_result['past_key_values'], max_new_tokens=20)
634
+ gen_time = time.perf_counter() - gen_start
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
637
 
 
678
 
679
  prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
680
 
681
+ inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
 
682
  input_ids = inputs.input_ids.to(model.device)
683
  attention_mask = inputs.attention_mask.to(model.device)
684
 
 
689
  )
690
 
691
  # Generate with compressed cache
692
+ gen_start = time.perf_counter()
693
+ output = safe_generate(model, tokenizer, input_ids, attention_mask,
694
+ compression_result['past_key_values'], max_new_tokens=50)
695
+ gen_time = time.perf_counter() - gen_start
 
 
 
 
 
 
 
 
 
 
 
 
 
696
 
697
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
698
 
 
858
  logger.info(f"Benchmark type: {config.benchmark_type}")
859
  logger.info(f"Config hash: {config.get_hash()}")
860
 
861
+ # Enable synchronous CUDA for debugging
862
+ if torch.cuda.is_available():
863
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
864
+
865
  constants = ResearchConstants()
866
  start_time = datetime.now().isoformat()
867
  per_sample_records = []
 
905
  for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
906
  config.niah_depth_percent = depth
907
  for idx in range(min(config.eval_samples, 10)):
908
+ # Create cache manager for compression types
909
+ if config.compression_type != CompressionType.NONE:
910
+ cache_manager = QuantizedKVCache(config)
911
+ cache_manager.n_layers = n_layers
912
+ else:
913
+ cache_manager = None
914
 
915
  result = evaluate_niah(model, tokenizer, config, cache_manager)
916
 
 
937
  elif config.benchmark_type == "ruler":
938
  # RULER evaluation with unified compression
939
  for idx in range(config.eval_samples):
940
+ if config.compression_type != CompressionType.NONE:
941
+ cache_manager = QuantizedKVCache(config)
942
+ cache_manager.n_layers = n_layers
943
+ else:
944
+ cache_manager = None
945
 
946
  result = evaluate_ruler(model, tokenizer, config, cache_manager)
947
 
 
966
  elif config.benchmark_type == "scbench":
967
  # SCBench evaluation with unified compression
968
  for idx in range(config.eval_samples):
969
+ if config.compression_type != CompressionType.NONE:
970
+ cache_manager = QuantizedKVCache(config)
971
+ cache_manager.n_layers = n_layers
972
+ else:
973
+ cache_manager = None
974
 
975
  result = evaluate_scbench(model, tokenizer, config, cache_manager)
976
 
 
995
  elif config.benchmark_type == "longbench":
996
  # LongBench evaluation with unified compression
997
  if config.benchmark_subset:
998
+ if config.compression_type != CompressionType.NONE:
999
+ cache_manager = QuantizedKVCache(config)
1000
+ cache_manager.n_layers = n_layers
1001
+ else:
1002
+ cache_manager = None
1003
 
1004
  result = evaluate_longbench_task(model, tokenizer, config,
1005
  config.benchmark_subset, cache_manager)
 
1029
  text_idx = (idx + seed * config.eval_samples) % len(dataset_texts)
1030
  text = dataset_texts[text_idx]
1031
 
1032
+ if config.compression_type != CompressionType.NONE:
1033
+ cache_manager = QuantizedKVCache(config)
1034
+ cache_manager.n_layers = n_layers
1035
+ cache_manager.update_position(config.prefill_length + idx)
1036
+ else:
1037
+ cache_manager = None
1038
 
1039
+ # Use safe tokenization
1040
+ inputs = safe_tokenize(tokenizer, text, max_length=min(config.prefill_length, 1024))
 
 
 
 
 
1041
  input_ids = inputs.input_ids.to(device)
1042
  attention_mask = inputs.attention_mask.to(device)
1043