Update core/trainer.py
Browse files- core/trainer.py +8 -9
core/trainer.py
CHANGED
@@ -18,7 +18,7 @@ class GraphMambaTrainer:
|
|
18 |
self.device = device
|
19 |
|
20 |
# Conservative learning rate
|
21 |
-
self.lr = config['training']['learning_rate']
|
22 |
self.epochs = config['training']['epochs']
|
23 |
self.patience = config['training'].get('patience', 10)
|
24 |
self.min_lr = config['training'].get('min_lr', 1e-6)
|
@@ -27,7 +27,7 @@ class GraphMambaTrainer:
|
|
27 |
self.optimizer = optim.AdamW(
|
28 |
model.parameters(),
|
29 |
lr=self.lr,
|
30 |
-
weight_decay=config['training']['weight_decay'],
|
31 |
betas=(0.9, 0.999),
|
32 |
eps=1e-8
|
33 |
)
|
@@ -35,14 +35,13 @@ class GraphMambaTrainer:
|
|
35 |
# Proper loss function with label smoothing
|
36 |
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
37 |
|
38 |
-
# Conservative scheduler
|
39 |
self.scheduler = ReduceLROnPlateau(
|
40 |
self.optimizer,
|
41 |
mode='max',
|
42 |
factor=0.5,
|
43 |
patience=5,
|
44 |
-
min_lr=self.min_lr
|
45 |
-
verbose=True
|
46 |
)
|
47 |
|
48 |
# Training state
|
@@ -56,7 +55,7 @@ class GraphMambaTrainer:
|
|
56 |
|
57 |
# Track overfitting
|
58 |
self.best_gap = float('inf')
|
59 |
-
self.overfitting_threshold = 0.3
|
60 |
|
61 |
def train_node_classification(self, data, verbose=True):
|
62 |
"""Anti-overfitting training"""
|
@@ -121,7 +120,7 @@ class GraphMambaTrainer:
|
|
121 |
print(f"π¨ OVERFITTING detected: {acc_gap:.3f} gap")
|
122 |
print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
|
123 |
|
124 |
-
# Progress logging
|
125 |
if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
|
126 |
elapsed = time.time() - start_time
|
127 |
gap_indicator = "π¨" if acc_gap > 0.2 else "β οΈ" if acc_gap > 0.1 else "β
"
|
@@ -171,11 +170,11 @@ class GraphMambaTrainer:
|
|
171 |
# Compute loss on training nodes only
|
172 |
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
|
173 |
|
174 |
-
# Add L2 regularization manually
|
175 |
l2_reg = 0.0
|
176 |
for param in self.model.parameters():
|
177 |
l2_reg += torch.norm(param, p=2)
|
178 |
-
train_loss += 1e-5 * l2_reg
|
179 |
|
180 |
# Backward pass with gradient clipping
|
181 |
train_loss.backward()
|
|
|
18 |
self.device = device
|
19 |
|
20 |
# Conservative learning rate
|
21 |
+
self.lr = config['training']['learning_rate']
|
22 |
self.epochs = config['training']['epochs']
|
23 |
self.patience = config['training'].get('patience', 10)
|
24 |
self.min_lr = config['training'].get('min_lr', 1e-6)
|
|
|
27 |
self.optimizer = optim.AdamW(
|
28 |
model.parameters(),
|
29 |
lr=self.lr,
|
30 |
+
weight_decay=config['training']['weight_decay'],
|
31 |
betas=(0.9, 0.999),
|
32 |
eps=1e-8
|
33 |
)
|
|
|
35 |
# Proper loss function with label smoothing
|
36 |
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
37 |
|
38 |
+
# Conservative scheduler - FIXED: removed verbose parameter
|
39 |
self.scheduler = ReduceLROnPlateau(
|
40 |
self.optimizer,
|
41 |
mode='max',
|
42 |
factor=0.5,
|
43 |
patience=5,
|
44 |
+
min_lr=self.min_lr
|
|
|
45 |
)
|
46 |
|
47 |
# Training state
|
|
|
55 |
|
56 |
# Track overfitting
|
57 |
self.best_gap = float('inf')
|
58 |
+
self.overfitting_threshold = 0.3
|
59 |
|
60 |
def train_node_classification(self, data, verbose=True):
|
61 |
"""Anti-overfitting training"""
|
|
|
120 |
print(f"π¨ OVERFITTING detected: {acc_gap:.3f} gap")
|
121 |
print(f" Train: {train_metrics['acc']:.3f}, Val: {val_metrics['acc']:.3f}")
|
122 |
|
123 |
+
# Progress logging
|
124 |
if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
|
125 |
elapsed = time.time() - start_time
|
126 |
gap_indicator = "π¨" if acc_gap > 0.2 else "β οΈ" if acc_gap > 0.1 else "β
"
|
|
|
170 |
# Compute loss on training nodes only
|
171 |
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
|
172 |
|
173 |
+
# Add L2 regularization manually
|
174 |
l2_reg = 0.0
|
175 |
for param in self.model.parameters():
|
176 |
l2_reg += torch.norm(param, p=2)
|
177 |
+
train_loss += 1e-5 * l2_reg
|
178 |
|
179 |
# Backward pass with gradient clipping
|
180 |
train_loss.backward()
|