kfoughali commited on
Commit
5677fec
Β·
verified Β·
1 Parent(s): cf47595

Update core/trainer.py

Browse files
Files changed (1) hide show
  1. core/trainer.py +8 -9
core/trainer.py CHANGED
@@ -18,7 +18,7 @@ class GraphMambaTrainer:
18
  self.device = device
19
 
20
  # Conservative learning rate
21
- self.lr = config['training']['learning_rate'] # Should be 0.0005
22
  self.epochs = config['training']['epochs']
23
  self.patience = config['training'].get('patience', 10)
24
  self.min_lr = config['training'].get('min_lr', 1e-6)
@@ -27,7 +27,7 @@ class GraphMambaTrainer:
27
  self.optimizer = optim.AdamW(
28
  model.parameters(),
29
  lr=self.lr,
30
- weight_decay=config['training']['weight_decay'], # Should be 0.01
31
  betas=(0.9, 0.999),
32
  eps=1e-8
33
  )
@@ -35,14 +35,13 @@ class GraphMambaTrainer:
35
  # Proper loss function with label smoothing
36
  self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
37
 
38
- # Conservative scheduler
39
  self.scheduler = ReduceLROnPlateau(
40
  self.optimizer,
41
  mode='max',
42
  factor=0.5,
43
  patience=5,
44
- min_lr=self.min_lr,
45
- verbose=True
46
  )
47
 
48
  # Training state
@@ -56,7 +55,7 @@ class GraphMambaTrainer:
56
 
57
  # Track overfitting
58
  self.best_gap = float('inf')
59
- self.overfitting_threshold = 0.3 # Stop if train-val gap > 30%
60
 
61
  def train_node_classification(self, data, verbose=True):
62
  """Anti-overfitting training"""
@@ -121,7 +120,7 @@ class GraphMambaTrainer:
121
  print(f"🚨 OVERFITTING detected: {acc_gap:.3f} gap")
122
  print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
123
 
124
- # Progress logging with overfitting monitoring
125
  if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
126
  elapsed = time.time() - start_time
127
  gap_indicator = "🚨" if acc_gap > 0.2 else "⚠️" if acc_gap > 0.1 else "βœ…"
@@ -171,11 +170,11 @@ class GraphMambaTrainer:
171
  # Compute loss on training nodes only
172
  train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
173
 
174
- # Add L2 regularization manually if needed
175
  l2_reg = 0.0
176
  for param in self.model.parameters():
177
  l2_reg += torch.norm(param, p=2)
178
- train_loss += 1e-5 * l2_reg # Small additional L2
179
 
180
  # Backward pass with gradient clipping
181
  train_loss.backward()
 
18
  self.device = device
19
 
20
  # Conservative learning rate
21
+ self.lr = config['training']['learning_rate']
22
  self.epochs = config['training']['epochs']
23
  self.patience = config['training'].get('patience', 10)
24
  self.min_lr = config['training'].get('min_lr', 1e-6)
 
27
  self.optimizer = optim.AdamW(
28
  model.parameters(),
29
  lr=self.lr,
30
+ weight_decay=config['training']['weight_decay'],
31
  betas=(0.9, 0.999),
32
  eps=1e-8
33
  )
 
35
  # Proper loss function with label smoothing
36
  self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
37
 
38
+ # Conservative scheduler - FIXED: removed verbose parameter
39
  self.scheduler = ReduceLROnPlateau(
40
  self.optimizer,
41
  mode='max',
42
  factor=0.5,
43
  patience=5,
44
+ min_lr=self.min_lr
 
45
  )
46
 
47
  # Training state
 
55
 
56
  # Track overfitting
57
  self.best_gap = float('inf')
58
+ self.overfitting_threshold = 0.3
59
 
60
  def train_node_classification(self, data, verbose=True):
61
  """Anti-overfitting training"""
 
120
  print(f"🚨 OVERFITTING detected: {acc_gap:.3f} gap")
121
  print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
122
 
123
+ # Progress logging
124
  if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
125
  elapsed = time.time() - start_time
126
  gap_indicator = "🚨" if acc_gap > 0.2 else "⚠️" if acc_gap > 0.1 else "βœ…"
 
170
  # Compute loss on training nodes only
171
  train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
172
 
173
+ # Add L2 regularization manually
174
  l2_reg = 0.0
175
  for param in self.model.parameters():
176
  l2_reg += torch.norm(param, p=2)
177
+ train_loss += 1e-5 * l2_reg
178
 
179
  # Backward pass with gradient clipping
180
  train_loss.backward()