Update config.py
Browse files
config.py
CHANGED
@@ -7,7 +7,7 @@ import json
|
|
7 |
import hashlib
|
8 |
from dataclasses import dataclass, field, asdict
|
9 |
from enum import Enum
|
10 |
-
from typing import List, Optional, NamedTuple
|
11 |
from datetime import datetime
|
12 |
import torch
|
13 |
import transformers
|
@@ -17,6 +17,69 @@ import logging
|
|
17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
class CompressionType(Enum):
|
21 |
"""RocketKV-enhanced SPG methods with explicit validation."""
|
22 |
NONE = "none"
|
@@ -184,6 +247,9 @@ class EnhancedSPGConfig:
|
|
184 |
stage_compression_min: float = 2.0 # Minimum stage compression ratio
|
185 |
stage_compression_max: float = 500.0 # Maximum stage compression ratio (INCREASED for 450x)
|
186 |
|
|
|
|
|
|
|
187 |
def __post_init__(self):
|
188 |
"""Validate all parameters - fail fast on invalid config."""
|
189 |
constants = ResearchConstants()
|
@@ -304,6 +370,10 @@ class CompressionConfig:
|
|
304 |
compression_type: CompressionType = CompressionType.ENHANCED_SPG
|
305 |
seed: int = 42
|
306 |
|
|
|
|
|
|
|
|
|
307 |
# Enhanced SPG configuration
|
308 |
enhanced_spg_config: EnhancedSPGConfig = field(default_factory=EnhancedSPGConfig)
|
309 |
|
@@ -327,10 +397,25 @@ class CompressionConfig:
|
|
327 |
dataset_config: str = "wikitext-2-raw-v1"
|
328 |
dataset_split: str = "test"
|
329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
# Memory and system settings
|
331 |
clear_cache_between_runs: bool = True
|
332 |
use_memory_snapshot: bool = True
|
333 |
fail_on_cpu_fallback: bool = True # CHANGED: Default to True for strict compliance
|
|
|
334 |
|
335 |
# Output settings
|
336 |
generate_latex: bool = True
|
@@ -347,6 +432,15 @@ class CompressionConfig:
|
|
347 |
"""Comprehensive validation - fail fast on any invalid parameter."""
|
348 |
constants = ResearchConstants()
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
# Validate core parameters
|
351 |
if not isinstance(self.seed, int) or self.seed < 0:
|
352 |
raise ValueError(f"seed must be non-negative integer, got {self.seed}")
|
@@ -371,6 +465,9 @@ class CompressionConfig:
|
|
371 |
if not 100 <= self.n_bootstrap <= 10000:
|
372 |
logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
|
373 |
|
|
|
|
|
|
|
374 |
logger.info("RocketKV-enhanced SPG config validated successfully")
|
375 |
|
376 |
def to_json(self) -> str:
|
|
|
7 |
import hashlib
|
8 |
from dataclasses import dataclass, field, asdict
|
9 |
from enum import Enum
|
10 |
+
from typing import List, Optional, NamedTuple, Dict, Any
|
11 |
from datetime import datetime
|
12 |
import torch
|
13 |
import transformers
|
|
|
17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
+
# Model configurations - NO HARDCODING
|
21 |
+
SUPPORTED_MODELS: Dict[str, Dict[str, Any]] = {
|
22 |
+
"gpt2": {
|
23 |
+
"name": "gpt2",
|
24 |
+
"requires_auth": False,
|
25 |
+
"max_context": 1024,
|
26 |
+
"default_dtype": "float16"
|
27 |
+
},
|
28 |
+
"llama2-7b": {
|
29 |
+
"name": "meta-llama/Llama-2-7b-hf",
|
30 |
+
"requires_auth": True,
|
31 |
+
"max_context": 4096,
|
32 |
+
"default_dtype": "float16"
|
33 |
+
},
|
34 |
+
"mistral-7b": {
|
35 |
+
"name": "mistralai/Mistral-7B-v0.1",
|
36 |
+
"requires_auth": False,
|
37 |
+
"max_context": 8192,
|
38 |
+
"default_dtype": "float16"
|
39 |
+
},
|
40 |
+
"opt-1.3b": {
|
41 |
+
"name": "facebook/opt-1.3b",
|
42 |
+
"requires_auth": False,
|
43 |
+
"max_context": 2048,
|
44 |
+
"default_dtype": "float16"
|
45 |
+
}
|
46 |
+
}
|
47 |
+
|
48 |
+
# Benchmark configurations - NO HARDCODING
|
49 |
+
BENCHMARK_CONFIGS: Dict[str, Dict[str, Any]] = {
|
50 |
+
"perplexity": {
|
51 |
+
"type": "perplexity",
|
52 |
+
"default_samples": 50,
|
53 |
+
"default_prefill": 512,
|
54 |
+
"default_generation": 64
|
55 |
+
},
|
56 |
+
"niah": {
|
57 |
+
"type": "needle_in_haystack",
|
58 |
+
"depths": [10, 25, 50, 75, 90], # Percentage depths
|
59 |
+
"needle": "The secret password is BANANA",
|
60 |
+
"default_samples": 10,
|
61 |
+
"default_context": 4096
|
62 |
+
},
|
63 |
+
"ruler": {
|
64 |
+
"type": "ruler",
|
65 |
+
"max_seq_lengths": [1024, 2048, 4096, 8192],
|
66 |
+
"default_samples": 10,
|
67 |
+
"default_n_facts": 10
|
68 |
+
},
|
69 |
+
"scbench": {
|
70 |
+
"type": "shared_context",
|
71 |
+
"num_turns": [5, 10, 20],
|
72 |
+
"default_samples": 10,
|
73 |
+
"default_context": 2048
|
74 |
+
},
|
75 |
+
"longbench": {
|
76 |
+
"type": "longbench",
|
77 |
+
"subsets": ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa"],
|
78 |
+
"default_samples": 20,
|
79 |
+
"max_context": 8192
|
80 |
+
}
|
81 |
+
}
|
82 |
+
|
83 |
class CompressionType(Enum):
|
84 |
"""RocketKV-enhanced SPG methods with explicit validation."""
|
85 |
NONE = "none"
|
|
|
247 |
stage_compression_min: float = 2.0 # Minimum stage compression ratio
|
248 |
stage_compression_max: float = 500.0 # Maximum stage compression ratio (INCREASED for 450x)
|
249 |
|
250 |
+
# Flash Attention support
|
251 |
+
use_flash_attention: bool = False # Try to use Flash Attention if available
|
252 |
+
|
253 |
def __post_init__(self):
|
254 |
"""Validate all parameters - fail fast on invalid config."""
|
255 |
constants = ResearchConstants()
|
|
|
370 |
compression_type: CompressionType = CompressionType.ENHANCED_SPG
|
371 |
seed: int = 42
|
372 |
|
373 |
+
# Model selection
|
374 |
+
model_key: str = "gpt2" # Key into SUPPORTED_MODELS
|
375 |
+
model_name: str = field(init=False) # Will be set in __post_init__
|
376 |
+
|
377 |
# Enhanced SPG configuration
|
378 |
enhanced_spg_config: EnhancedSPGConfig = field(default_factory=EnhancedSPGConfig)
|
379 |
|
|
|
397 |
dataset_config: str = "wikitext-2-raw-v1"
|
398 |
dataset_split: str = "test"
|
399 |
|
400 |
+
# Benchmark configuration
|
401 |
+
benchmark_type: str = "perplexity" # perplexity, niah, ruler, scbench, longbench
|
402 |
+
benchmark_subset: Optional[str] = None # For longbench subsets
|
403 |
+
|
404 |
+
# NIAH-specific parameters
|
405 |
+
niah_needle: str = field(default_factory=lambda: BENCHMARK_CONFIGS["niah"]["needle"])
|
406 |
+
niah_depth_percent: float = 50.0
|
407 |
+
|
408 |
+
# RULER-specific parameters
|
409 |
+
ruler_max_seq_length: int = 4096
|
410 |
+
|
411 |
+
# SCBench-specific parameters
|
412 |
+
scbench_num_turns: int = 10
|
413 |
+
|
414 |
# Memory and system settings
|
415 |
clear_cache_between_runs: bool = True
|
416 |
use_memory_snapshot: bool = True
|
417 |
fail_on_cpu_fallback: bool = True # CHANGED: Default to True for strict compliance
|
418 |
+
use_flash_attention: bool = False # Try to use Flash Attention if available
|
419 |
|
420 |
# Output settings
|
421 |
generate_latex: bool = True
|
|
|
432 |
"""Comprehensive validation - fail fast on any invalid parameter."""
|
433 |
constants = ResearchConstants()
|
434 |
|
435 |
+
# Set model name from key
|
436 |
+
if self.model_key not in SUPPORTED_MODELS:
|
437 |
+
raise ValueError(f"model_key {self.model_key} not in SUPPORTED_MODELS: {list(SUPPORTED_MODELS.keys())}")
|
438 |
+
self.model_name = SUPPORTED_MODELS[self.model_key]["name"]
|
439 |
+
|
440 |
+
# Validate benchmark type
|
441 |
+
if self.benchmark_type not in BENCHMARK_CONFIGS:
|
442 |
+
raise ValueError(f"benchmark_type {self.benchmark_type} not in BENCHMARK_CONFIGS: {list(BENCHMARK_CONFIGS.keys())}")
|
443 |
+
|
444 |
# Validate core parameters
|
445 |
if not isinstance(self.seed, int) or self.seed < 0:
|
446 |
raise ValueError(f"seed must be non-negative integer, got {self.seed}")
|
|
|
465 |
if not 100 <= self.n_bootstrap <= 10000:
|
466 |
logger.warning(f"n_bootstrap {self.n_bootstrap} outside recommended range [100, 10000]")
|
467 |
|
468 |
+
# Pass Flash Attention setting to EnhancedSPGConfig
|
469 |
+
self.enhanced_spg_config.use_flash_attention = self.use_flash_attention
|
470 |
+
|
471 |
logger.info("RocketKV-enhanced SPG config validated successfully")
|
472 |
|
473 |
def to_json(self) -> str:
|