kfoughali commited on
Commit
d7cde9b
·
verified ·
1 Parent(s): e7b895b

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +98 -1
config.py CHANGED
@@ -7,7 +7,7 @@ import json
7
  import hashlib
8
  from dataclasses import dataclass, field, asdict
9
  from enum import Enum
10
- from typing import List, Optional, NamedTuple
11
  from datetime import datetime
12
  import torch
13
  import transformers
@@ -17,6 +17,69 @@ import logging
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class CompressionType(Enum):
21
  """RocketKV-enhanced SPG methods with explicit validation."""
22
  NONE = "none"
@@ -184,6 +247,9 @@ class EnhancedSPGConfig:
184
  stage_compression_min: float = 2.0 # Minimum stage compression ratio
185
  stage_compression_max: float = 500.0 # Maximum stage compression ratio (INCREASED for 450x)
186
 
 
 
 
187
  def __post_init__(self):
188
  """Validate all parameters - fail fast on invalid config."""
189
  constants = ResearchConstants()
@@ -304,6 +370,10 @@ class CompressionConfig:
304
  compression_type: CompressionType = CompressionType.ENHANCED_SPG
305
  seed: int = 42
306
 
 
 
 
 
307
  # Enhanced SPG configuration
308
  enhanced_spg_config: EnhancedSPGConfig = field(default_factory=EnhancedSPGConfig)
309
 
@@ -327,10 +397,25 @@ class CompressionConfig:
327
  dataset_config: str = "wikitext-2-raw-v1"
328
  dataset_split: str = "test"
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  # Memory and system settings
331
  clear_cache_between_runs: bool = True
332
  use_memory_snapshot: bool = True
333
  fail_on_cpu_fallback: bool = True # CHANGED: Default to True for strict compliance
 
334
 
335
  # Output settings
336
  generate_latex: bool = True
@@ -347,6 +432,15 @@ class CompressionConfig:
347
  """Comprehensive validation - fail fast on any invalid parameter."""
348
  constants = ResearchConstants()
349
 
 
 
 
 
 
 
 
 
 
350
  # Validate core parameters
351
  if not isinstance(self.seed, int) or self.seed < 0:
352
  raise ValueError(f"seed must be non-negative integer, got {self.seed}")
@@ -371,6 +465,9 @@ class CompressionConfig:
371
  if not 100 <= self.n_bootstrap <= 10000:
372
  logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
373
 
 
 
 
374
  logger.info("RocketKV-enhanced SPG config validated successfully")
375
 
376
  def to_json(self) -> str:
 
7
  import hashlib
8
  from dataclasses import dataclass, field, asdict
9
  from enum import Enum
10
+ from typing import List, Optional, NamedTuple, Dict, Any
11
  from datetime import datetime
12
  import torch
13
  import transformers
 
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Model configurations - NO HARDCODING
21
+ SUPPORTED_MODELS: Dict[str, Dict[str, Any]] = {
22
+ "gpt2": {
23
+ "name": "gpt2",
24
+ "requires_auth": False,
25
+ "max_context": 1024,
26
+ "default_dtype": "float16"
27
+ },
28
+ "llama2-7b": {
29
+ "name": "meta-llama/Llama-2-7b-hf",
30
+ "requires_auth": True,
31
+ "max_context": 4096,
32
+ "default_dtype": "float16"
33
+ },
34
+ "mistral-7b": {
35
+ "name": "mistralai/Mistral-7B-v0.1",
36
+ "requires_auth": False,
37
+ "max_context": 8192,
38
+ "default_dtype": "float16"
39
+ },
40
+ "opt-1.3b": {
41
+ "name": "facebook/opt-1.3b",
42
+ "requires_auth": False,
43
+ "max_context": 2048,
44
+ "default_dtype": "float16"
45
+ }
46
+ }
47
+
48
+ # Benchmark configurations - NO HARDCODING
49
+ BENCHMARK_CONFIGS: Dict[str, Dict[str, Any]] = {
50
+ "perplexity": {
51
+ "type": "perplexity",
52
+ "default_samples": 50,
53
+ "default_prefill": 512,
54
+ "default_generation": 64
55
+ },
56
+ "niah": {
57
+ "type": "needle_in_haystack",
58
+ "depths": [10, 25, 50, 75, 90], # Percentage depths
59
+ "needle": "The secret password is BANANA",
60
+ "default_samples": 10,
61
+ "default_context": 4096
62
+ },
63
+ "ruler": {
64
+ "type": "ruler",
65
+ "max_seq_lengths": [1024, 2048, 4096, 8192],
66
+ "default_samples": 10,
67
+ "default_n_facts": 10
68
+ },
69
+ "scbench": {
70
+ "type": "shared_context",
71
+ "num_turns": [5, 10, 20],
72
+ "default_samples": 10,
73
+ "default_context": 2048
74
+ },
75
+ "longbench": {
76
+ "type": "longbench",
77
+ "subsets": ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa"],
78
+ "default_samples": 20,
79
+ "max_context": 8192
80
+ }
81
+ }
82
+
83
  class CompressionType(Enum):
84
  """RocketKV-enhanced SPG methods with explicit validation."""
85
  NONE = "none"
 
247
  stage_compression_min: float = 2.0 # Minimum stage compression ratio
248
  stage_compression_max: float = 500.0 # Maximum stage compression ratio (INCREASED for 450x)
249
 
250
+ # Flash Attention support
251
+ use_flash_attention: bool = False # Try to use Flash Attention if available
252
+
253
  def __post_init__(self):
254
  """Validate all parameters - fail fast on invalid config."""
255
  constants = ResearchConstants()
 
370
  compression_type: CompressionType = CompressionType.ENHANCED_SPG
371
  seed: int = 42
372
 
373
+ # Model selection
374
+ model_key: str = "gpt2" # Key into SUPPORTED_MODELS
375
+ model_name: str = field(init=False) # Will be set in __post_init__
376
+
377
  # Enhanced SPG configuration
378
  enhanced_spg_config: EnhancedSPGConfig = field(default_factory=EnhancedSPGConfig)
379
 
 
397
  dataset_config: str = "wikitext-2-raw-v1"
398
  dataset_split: str = "test"
399
 
400
+ # Benchmark configuration
401
+ benchmark_type: str = "perplexity" # perplexity, niah, ruler, scbench, longbench
402
+ benchmark_subset: Optional[str] = None # For longbench subsets
403
+
404
+ # NIAH-specific parameters
405
+ niah_needle: str = field(default_factory=lambda: BENCHMARK_CONFIGS["niah"]["needle"])
406
+ niah_depth_percent: float = 50.0
407
+
408
+ # RULER-specific parameters
409
+ ruler_max_seq_length: int = 4096
410
+
411
+ # SCBench-specific parameters
412
+ scbench_num_turns: int = 10
413
+
414
  # Memory and system settings
415
  clear_cache_between_runs: bool = True
416
  use_memory_snapshot: bool = True
417
  fail_on_cpu_fallback: bool = True # CHANGED: Default to True for strict compliance
418
+ use_flash_attention: bool = False # Try to use Flash Attention if available
419
 
420
  # Output settings
421
  generate_latex: bool = True
 
432
  """Comprehensive validation - fail fast on any invalid parameter."""
433
  constants = ResearchConstants()
434
 
435
+ # Set model name from key
436
+ if self.model_key not in SUPPORTED_MODELS:
437
+ raise ValueError(f"model_key {self.model_key} not in SUPPORTED_MODELS: {list(SUPPORTED_MODELS.keys())}")
438
+ self.model_name = SUPPORTED_MODELS[self.model_key]["name"]
439
+
440
+ # Validate benchmark type
441
+ if self.benchmark_type not in BENCHMARK_CONFIGS:
442
+ raise ValueError(f"benchmark_type {self.benchmark_type} not in BENCHMARK_CONFIGS: {list(BENCHMARK_CONFIGS.keys())}")
443
+
444
  # Validate core parameters
445
  if not isinstance(self.seed, int) or self.seed < 0:
446
  raise ValueError(f"seed must be non-negative integer, got {self.seed}")
 
465
  if not 100 <= self.n_bootstrap <= 10000:
466
  logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
467
 
468
+ # Pass Flash Attention setting to EnhancedSPGConfig
469
+ self.enhanced_spg_config.use_flash_attention = self.use_flash_attention
470
+
471
  logger.info("RocketKV-enhanced SPG config validated successfully")
472
 
473
  def to_json(self) -> str: