Update core/trainer.py
Browse files- core/trainer.py +8 -9
core/trainer.py
CHANGED
|
@@ -18,7 +18,7 @@ class GraphMambaTrainer:
|
|
| 18 |
self.device = device
|
| 19 |
|
| 20 |
# Conservative learning rate
|
| 21 |
-
self.lr = config['training']['learning_rate']
|
| 22 |
self.epochs = config['training']['epochs']
|
| 23 |
self.patience = config['training'].get('patience', 10)
|
| 24 |
self.min_lr = config['training'].get('min_lr', 1e-6)
|
|
@@ -27,7 +27,7 @@ class GraphMambaTrainer:
|
|
| 27 |
self.optimizer = optim.AdamW(
|
| 28 |
model.parameters(),
|
| 29 |
lr=self.lr,
|
| 30 |
-
weight_decay=config['training']['weight_decay'],
|
| 31 |
betas=(0.9, 0.999),
|
| 32 |
eps=1e-8
|
| 33 |
)
|
|
@@ -35,14 +35,13 @@ class GraphMambaTrainer:
|
|
| 35 |
# Proper loss function with label smoothing
|
| 36 |
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| 37 |
|
| 38 |
-
# Conservative scheduler
|
| 39 |
self.scheduler = ReduceLROnPlateau(
|
| 40 |
self.optimizer,
|
| 41 |
mode='max',
|
| 42 |
factor=0.5,
|
| 43 |
patience=5,
|
| 44 |
-
min_lr=self.min_lr
|
| 45 |
-
verbose=True
|
| 46 |
)
|
| 47 |
|
| 48 |
# Training state
|
|
@@ -56,7 +55,7 @@ class GraphMambaTrainer:
|
|
| 56 |
|
| 57 |
# Track overfitting
|
| 58 |
self.best_gap = float('inf')
|
| 59 |
-
self.overfitting_threshold = 0.3
|
| 60 |
|
| 61 |
def train_node_classification(self, data, verbose=True):
|
| 62 |
"""Anti-overfitting training"""
|
|
@@ -121,7 +120,7 @@ class GraphMambaTrainer:
|
|
| 121 |
print(f"π¨ OVERFITTING detected: {acc_gap:.3f} gap")
|
| 122 |
print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
|
| 123 |
|
| 124 |
-
# Progress logging
|
| 125 |
if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
|
| 126 |
elapsed = time.time() - start_time
|
| 127 |
gap_indicator = "π¨" if acc_gap > 0.2 else "β οΈ" if acc_gap > 0.1 else "β
"
|
|
@@ -171,11 +170,11 @@ class GraphMambaTrainer:
|
|
| 171 |
# Compute loss on training nodes only
|
| 172 |
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
|
| 173 |
|
| 174 |
-
# Add L2 regularization manually
|
| 175 |
l2_reg = 0.0
|
| 176 |
for param in self.model.parameters():
|
| 177 |
l2_reg += torch.norm(param, p=2)
|
| 178 |
-
train_loss += 1e-5 * l2_reg
|
| 179 |
|
| 180 |
# Backward pass with gradient clipping
|
| 181 |
train_loss.backward()
|
|
|
|
| 18 |
self.device = device
|
| 19 |
|
| 20 |
# Conservative learning rate
|
| 21 |
+
self.lr = config['training']['learning_rate']
|
| 22 |
self.epochs = config['training']['epochs']
|
| 23 |
self.patience = config['training'].get('patience', 10)
|
| 24 |
self.min_lr = config['training'].get('min_lr', 1e-6)
|
|
|
|
| 27 |
self.optimizer = optim.AdamW(
|
| 28 |
model.parameters(),
|
| 29 |
lr=self.lr,
|
| 30 |
+
weight_decay=config['training']['weight_decay'],
|
| 31 |
betas=(0.9, 0.999),
|
| 32 |
eps=1e-8
|
| 33 |
)
|
|
|
|
| 35 |
# Proper loss function with label smoothing
|
| 36 |
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| 37 |
|
| 38 |
+
# Conservative scheduler - FIXED: removed verbose parameter
|
| 39 |
self.scheduler = ReduceLROnPlateau(
|
| 40 |
self.optimizer,
|
| 41 |
mode='max',
|
| 42 |
factor=0.5,
|
| 43 |
patience=5,
|
| 44 |
+
min_lr=self.min_lr
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
# Training state
|
|
|
|
| 55 |
|
| 56 |
# Track overfitting
|
| 57 |
self.best_gap = float('inf')
|
| 58 |
+
self.overfitting_threshold = 0.3
|
| 59 |
|
| 60 |
def train_node_classification(self, data, verbose=True):
|
| 61 |
"""Anti-overfitting training"""
|
|
|
|
| 120 |
print(f"π¨ OVERFITTING detected: {acc_gap:.3f} gap")
|
| 121 |
print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
|
| 122 |
|
| 123 |
+
# Progress logging
|
| 124 |
if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
|
| 125 |
elapsed = time.time() - start_time
|
| 126 |
gap_indicator = "π¨" if acc_gap > 0.2 else "β οΈ" if acc_gap > 0.1 else "β
"
|
|
|
|
| 170 |
# Compute loss on training nodes only
|
| 171 |
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
|
| 172 |
|
| 173 |
+
# Add L2 regularization manually
|
| 174 |
l2_reg = 0.0
|
| 175 |
for param in self.model.parameters():
|
| 176 |
l2_reg += torch.norm(param, p=2)
|
| 177 |
+
train_loss += 1e-5 * l2_reg
|
| 178 |
|
| 179 |
# Backward pass with gradient clipping
|
| 180 |
train_loss.backward()
|