eyad-silx commited on
Commit
8c261f1
·
verified ·
1 Parent(s): 5556482

Update config/baseline_config.py

Browse files
Files changed (1) hide show
  1. config/baseline_config.py +26 -13
config/baseline_config.py CHANGED
@@ -1,39 +1,43 @@
1
  """
2
- Configuration for baseline transformer on enwik8.
3
- Matches DTAT config exactly for fair comparison.
4
  """
5
 
6
  class BaselineConfig:
7
  def __init__(self):
8
  # Model architecture (exactly matching DTAT)
9
  self.n_layer = 12
10
- self.n_head = 8
11
- self.n_embd = 512
12
  self.dropout = 0.1
13
  self.bias = True
14
 
15
  # Sequence parameters
16
- self.block_size = 1024
17
  self.vocab_size = 256 # For character-level model
18
 
19
- # Training parameters
20
  self.learning_rate = 6e-4
21
- self.min_lr = 6e-5
22
- self.warmup_iters = 733 # 5% of 14,667 iterations
23
- self.max_iters = 14667 # Exactly 4 epochs with batch_size=24
24
- self.weight_decay = 1e-1
25
  self.beta1 = 0.9
26
  self.beta2 = 0.95
27
  self.grad_clip = 1.0
28
 
29
  # Learning rate schedule
30
  self.decay_lr = True
31
- self.lr_decay_iters = self.max_iters # Full schedule
 
 
 
 
 
 
32
 
33
  # Logging
34
  self.log_interval = 10
35
- self.eval_interval = 500
36
- self.eval_iters = 200
37
 
38
  # Mixed precision training
39
  self.mixed_precision = True
@@ -46,9 +50,18 @@ class BaselineConfig:
46
  # System
47
  self.device = 'cuda'
48
  self.compile = True
 
 
 
 
 
 
 
 
49
 
50
  def get_config(self):
51
  return self
52
 
53
  def get_config():
 
54
  return BaselineConfig()
 
1
  """
2
+ Configuration for Baseline Transformer on enwik8.
3
+ Matches DTAT's training setup for fair comparison.
4
  """
5
 
6
  class BaselineConfig:
7
  def __init__(self):
8
  # Model architecture (exactly matching DTAT)
9
  self.n_layer = 12
10
+ self.n_head = 8 # Same as DTAT
11
+ self.n_embd = 512 # Same as DTAT
12
  self.dropout = 0.1
13
  self.bias = True
14
 
15
  # Sequence parameters
16
+ self.block_size = 1024 # Same as DTAT
17
  self.vocab_size = 256 # For character-level model
18
 
19
+ # Training parameters (matched with DTAT)
20
  self.learning_rate = 6e-4
21
+ self.min_lr = 1e-5 # Lower minimum to allow fine-tuning
22
+ self.warmup_iters = 367 # 5% of total iterations
23
+ self.max_iters = 7334 # Exactly 4 epochs with batch_size=24
24
+ self.weight_decay = 0.1 # Same as DTAT
25
  self.beta1 = 0.9
26
  self.beta2 = 0.95
27
  self.grad_clip = 1.0
28
 
29
  # Learning rate schedule
30
  self.decay_lr = True
31
+ self.lr_decay_iters = 5000 # Same as DTAT
32
+
33
+ # Early stopping
34
+ self.patience = 15 # Same as DTAT
35
+ self.min_delta = 0.005 # Same as DTAT
36
+ self.eval_interval = 250 # Same as DTAT
37
+ self.eval_iters = 200 # Same as DTAT
38
 
39
  # Logging
40
  self.log_interval = 10
 
 
41
 
42
  # Mixed precision training
43
  self.mixed_precision = True
 
50
  # System
51
  self.device = 'cuda'
52
  self.compile = True
53
+
54
+ # Performance optimization
55
+ self.compile_model = True
56
+ self.cudnn_benchmark = True
57
+
58
+ # Git config for model versioning
59
+ self.git_name = "Your Name"
60
+ self.git_email = "[email protected]"
61
 
62
  def get_config(self):
63
  return self
64
 
65
  def get_config():
66
+ """Helper function to get config instance."""
67
  return BaselineConfig()