Update core/trainer.py
Browse files- core/trainer.py +175 -203
core/trainer.py
CHANGED
@@ -1,299 +1,271 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.optim as optim
|
4 |
-
from torch.optim.lr_scheduler import
|
5 |
import numpy as np
|
6 |
-
from tqdm import tqdm
|
7 |
import time
|
8 |
-
import
|
|
|
|
|
|
|
9 |
|
10 |
class GraphMambaTrainer:
|
11 |
-
"""
|
12 |
-
Production-ready trainer for GraphMamba
|
13 |
-
Includes advanced training techniques
|
14 |
-
"""
|
15 |
|
16 |
-
def __init__(self, model, config, device
|
17 |
-
self.model = model
|
18 |
self.config = config
|
19 |
self.device = device
|
20 |
|
21 |
-
#
|
22 |
-
self.lr =
|
23 |
-
self.weight_decay = config['training']['weight_decay']
|
24 |
self.epochs = config['training']['epochs']
|
25 |
-
self.patience = config['training']
|
26 |
-
self.
|
27 |
-
self.min_lr = config['training']['min_lr']
|
28 |
|
29 |
-
#
|
30 |
self.optimizer = optim.AdamW(
|
31 |
-
|
32 |
lr=self.lr,
|
33 |
-
weight_decay=
|
34 |
betas=(0.9, 0.999),
|
35 |
eps=1e-8
|
36 |
)
|
37 |
|
38 |
-
#
|
39 |
self.criterion = nn.CrossEntropyLoss()
|
40 |
|
41 |
-
#
|
42 |
self.scheduler = None
|
43 |
|
44 |
-
# Training
|
45 |
-
self.history = {
|
46 |
-
'train_loss': [],
|
47 |
-
'train_acc': [],
|
48 |
-
'val_loss': [],
|
49 |
-
'val_acc': [],
|
50 |
-
'lr': []
|
51 |
-
}
|
52 |
-
|
53 |
-
# Best model tracking
|
54 |
self.best_val_acc = 0.0
|
55 |
-
self.
|
56 |
self.patience_counter = 0
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
self.
|
65 |
-
|
66 |
-
# Update optimizer to include new parameters
|
67 |
-
self.optimizer = optim.AdamW(
|
68 |
-
self.model.parameters(),
|
69 |
-
lr=self.lr,
|
70 |
-
weight_decay=self.weight_decay,
|
71 |
-
betas=(0.9, 0.999)
|
72 |
-
)
|
73 |
-
|
74 |
-
# Initialize scheduler
|
75 |
-
self.scheduler = CosineAnnealingLR(
|
76 |
self.optimizer,
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
)
|
|
|
|
|
|
|
80 |
|
81 |
if verbose:
|
82 |
print(f"ποΈ Training GraphMamba for {self.epochs} epochs")
|
83 |
print(f"π Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
|
84 |
-
print(f"π― Classes: {
|
85 |
print(f"πΎ Device: {self.device}")
|
86 |
print(f"βοΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
87 |
|
88 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
for epoch in range(self.epochs):
|
90 |
-
# Training
|
91 |
-
|
92 |
|
93 |
-
# Validation
|
94 |
-
|
95 |
|
96 |
-
#
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
for param_group in self.optimizer.param_groups:
|
103 |
-
param_group['lr'] = warmup_lr
|
104 |
-
|
105 |
-
# Record history
|
106 |
-
current_lr = self.optimizer.param_groups[0]['lr']
|
107 |
-
self.history['train_loss'].append(train_loss)
|
108 |
-
self.history['train_acc'].append(train_acc)
|
109 |
-
self.history['val_loss'].append(val_loss)
|
110 |
-
self.history['val_acc'].append(val_acc)
|
111 |
-
self.history['lr'].append(current_lr)
|
112 |
|
113 |
# Check for improvement
|
114 |
-
if
|
115 |
-
self.best_val_acc =
|
116 |
-
self.
|
117 |
self.patience_counter = 0
|
118 |
-
|
119 |
-
|
120 |
-
print(f"π New best validation accuracy: {val_acc:.4f}")
|
121 |
else:
|
122 |
self.patience_counter += 1
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
# Early stopping
|
125 |
if self.patience_counter >= self.patience:
|
126 |
if verbose:
|
127 |
-
print(f"
|
128 |
break
|
129 |
|
130 |
-
#
|
131 |
-
|
132 |
-
print(f"Epoch {epoch:3d} | "
|
133 |
-
f"Train: {train_loss:.4f} ({train_acc:.4f}) | "
|
134 |
-
f"Val: {val_loss:.4f} ({val_acc:.4f}) | "
|
135 |
-
f"LR: {current_lr:.6f}")
|
136 |
|
137 |
-
# Load best model
|
138 |
-
if self.best_model_state is not None:
|
139 |
-
self.model.load_state_dict(self.best_model_state)
|
140 |
-
|
141 |
if verbose:
|
142 |
-
|
|
|
143 |
print(f"π Best validation accuracy: {self.best_val_acc:.4f}")
|
144 |
|
145 |
-
return self.
|
146 |
|
147 |
def _train_epoch(self, data, epoch):
|
148 |
"""Single training epoch"""
|
149 |
self.model.train()
|
150 |
-
|
151 |
-
# Forward pass
|
152 |
self.optimizer.zero_grad()
|
153 |
|
|
|
154 |
h = self.model(data.x, data.edge_index)
|
155 |
-
|
156 |
|
157 |
-
#
|
158 |
-
|
159 |
|
160 |
# Backward pass
|
161 |
-
|
162 |
|
163 |
# Gradient clipping
|
164 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
165 |
|
166 |
self.optimizer.step()
|
167 |
|
168 |
-
#
|
169 |
with torch.no_grad():
|
170 |
-
|
171 |
-
|
172 |
|
173 |
-
return loss.item(), acc
|
174 |
|
175 |
-
def _validate_epoch(self, data):
|
176 |
"""Single validation epoch"""
|
177 |
self.model.eval()
|
178 |
|
179 |
with torch.no_grad():
|
180 |
h = self.model(data.x, data.edge_index)
|
181 |
-
|
182 |
-
|
183 |
-
# Loss on validation nodes
|
184 |
-
val_loss = self.criterion(pred[data.val_mask], data.y[data.val_mask])
|
185 |
|
186 |
-
#
|
187 |
-
|
188 |
-
|
|
|
189 |
|
190 |
-
return val_loss.item(), val_acc
|
191 |
|
192 |
def test(self, data):
|
193 |
-
"""
|
194 |
self.model.eval()
|
195 |
|
196 |
with torch.no_grad():
|
197 |
h = self.model(data.x, data.edge_index)
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
# Test metrics
|
201 |
-
test_loss = self.criterion(
|
202 |
-
|
203 |
-
|
204 |
|
205 |
-
#
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
class_acc.append(0.0)
|
216 |
-
|
217 |
-
return {
|
218 |
-
'test_loss': test_loss.item(),
|
219 |
-
'test_acc': test_acc.item(),
|
220 |
-
'class_acc': class_acc
|
221 |
-
}
|
222 |
|
223 |
-
def
|
224 |
-
"""
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
ax1.legend()
|
236 |
-
ax1.grid(True, alpha=0.3)
|
237 |
-
|
238 |
-
# Accuracy plot
|
239 |
-
ax2.plot(epochs, self.history['train_acc'], label='Train', color='blue')
|
240 |
-
ax2.plot(epochs, self.history['val_acc'], label='Validation', color='red')
|
241 |
-
ax2.set_title('Training Accuracy')
|
242 |
-
ax2.set_xlabel('Epoch')
|
243 |
-
ax2.set_ylabel('Accuracy')
|
244 |
-
ax2.legend()
|
245 |
-
ax2.grid(True, alpha=0.3)
|
246 |
-
|
247 |
-
# Learning rate plot
|
248 |
-
ax3.plot(epochs, self.history['lr'], color='green')
|
249 |
-
ax3.set_title('Learning Rate')
|
250 |
-
ax3.set_xlabel('Epoch')
|
251 |
-
ax3.set_ylabel('Learning Rate')
|
252 |
-
ax3.set_yscale('log')
|
253 |
-
ax3.grid(True, alpha=0.3)
|
254 |
-
|
255 |
-
# Best metrics
|
256 |
-
best_train_acc = max(self.history['train_acc'])
|
257 |
-
best_val_acc = max(self.history['val_acc'])
|
258 |
-
|
259 |
-
ax4.bar(['Best Train Acc', 'Best Val Acc'], [best_train_acc, best_val_acc],
|
260 |
-
color=['blue', 'red'], alpha=0.7)
|
261 |
-
ax4.set_title('Best Accuracies')
|
262 |
-
ax4.set_ylabel('Accuracy')
|
263 |
-
ax4.set_ylim(0, 1)
|
264 |
-
|
265 |
-
for i, v in enumerate([best_train_acc, best_val_acc]):
|
266 |
-
ax4.text(i, v + 0.01, f'{v:.4f}', ha='center', va='bottom')
|
267 |
-
|
268 |
-
plt.tight_layout()
|
269 |
-
|
270 |
-
if save_path:
|
271 |
-
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
272 |
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
-
def
|
276 |
-
"""
|
277 |
-
torch.save({
|
278 |
-
'model_state_dict': self.model.state_dict(),
|
279 |
-
'optimizer_state_dict': self.optimizer.state_dict(),
|
280 |
-
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
281 |
-
'best_val_acc': self.best_val_acc,
|
282 |
-
'history': self.history,
|
283 |
-
'config': self.config
|
284 |
-
}, path)
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
-
|
291 |
-
|
292 |
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
295 |
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
|
|
298 |
|
299 |
-
return
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.optim as optim
|
4 |
+
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts
|
5 |
import numpy as np
|
|
|
6 |
import time
|
7 |
+
import logging
|
8 |
+
from utils.metrics import GraphMetrics
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
|
12 |
class GraphMambaTrainer:
|
13 |
+
"""Enhanced trainer with optimized learning rates and schedules"""
|
|
|
|
|
|
|
14 |
|
15 |
+
def __init__(self, model, config, device):
|
16 |
+
self.model = model
|
17 |
self.config = config
|
18 |
self.device = device
|
19 |
|
20 |
+
# Fixed learning rate (much lower)
|
21 |
+
self.lr = 0.001 # Changed from 0.01
|
|
|
22 |
self.epochs = config['training']['epochs']
|
23 |
+
self.patience = config['training'].get('patience', 15)
|
24 |
+
self.min_lr = config['training'].get('min_lr', 1e-6)
|
|
|
25 |
|
26 |
+
# Enhanced optimizer
|
27 |
self.optimizer = optim.AdamW(
|
28 |
+
model.parameters(),
|
29 |
lr=self.lr,
|
30 |
+
weight_decay=config['training']['weight_decay'],
|
31 |
betas=(0.9, 0.999),
|
32 |
eps=1e-8
|
33 |
)
|
34 |
|
35 |
+
# Proper loss function
|
36 |
self.criterion = nn.CrossEntropyLoss()
|
37 |
|
38 |
+
# Learning rate scheduler (will be set in training)
|
39 |
self.scheduler = None
|
40 |
|
41 |
+
# Training state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
self.best_val_acc = 0.0
|
43 |
+
self.best_val_loss = float('inf')
|
44 |
self.patience_counter = 0
|
45 |
+
self.training_history = {
|
46 |
+
'train_loss': [], 'train_acc': [],
|
47 |
+
'val_loss': [], 'val_acc': [], 'lr': []
|
48 |
+
}
|
49 |
+
|
50 |
+
def _setup_scheduler(self, total_steps):
|
51 |
+
"""Setup learning rate scheduler"""
|
52 |
+
self.scheduler = OneCycleLR(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
self.optimizer,
|
54 |
+
max_lr=self.lr,
|
55 |
+
total_steps=total_steps,
|
56 |
+
pct_start=0.1, # 10% warmup
|
57 |
+
anneal_strategy='cos',
|
58 |
+
div_factor=10.0, # Start LR = max_lr/10
|
59 |
+
final_div_factor=100.0 # End LR = max_lr/100
|
60 |
)
|
61 |
+
|
62 |
+
def train_node_classification(self, data, verbose=True):
|
63 |
+
"""Enhanced training with proper LR scheduling"""
|
64 |
|
65 |
if verbose:
|
66 |
print(f"ποΈ Training GraphMamba for {self.epochs} epochs")
|
67 |
print(f"π Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
|
68 |
+
print(f"π― Classes: {len(torch.unique(data.y))}")
|
69 |
print(f"πΎ Device: {self.device}")
|
70 |
print(f"βοΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
71 |
|
72 |
+
# Initialize classifier
|
73 |
+
num_classes = len(torch.unique(data.y))
|
74 |
+
self.model._init_classifier(num_classes, self.device)
|
75 |
+
|
76 |
+
# Setup scheduler
|
77 |
+
self._setup_scheduler(self.epochs)
|
78 |
+
|
79 |
+
self.model.train()
|
80 |
+
start_time = time.time()
|
81 |
+
|
82 |
for epoch in range(self.epochs):
|
83 |
+
# Training step
|
84 |
+
train_metrics = self._train_epoch(data, epoch)
|
85 |
|
86 |
+
# Validation step
|
87 |
+
val_metrics = self._validate_epoch(data, epoch)
|
88 |
|
89 |
+
# Update history
|
90 |
+
self.training_history['train_loss'].append(train_metrics['loss'])
|
91 |
+
self.training_history['train_acc'].append(train_metrics['acc'])
|
92 |
+
self.training_history['val_loss'].append(val_metrics['loss'])
|
93 |
+
self.training_history['val_acc'].append(val_metrics['acc'])
|
94 |
+
self.training_history['lr'].append(self.optimizer.param_groups[0]['lr'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
# Check for improvement
|
97 |
+
if val_metrics['acc'] > self.best_val_acc:
|
98 |
+
self.best_val_acc = val_metrics['acc']
|
99 |
+
self.best_val_loss = val_metrics['loss']
|
100 |
self.patience_counter = 0
|
101 |
+
if verbose:
|
102 |
+
print(f"π New best validation accuracy: {self.best_val_acc:.4f}")
|
|
|
103 |
else:
|
104 |
self.patience_counter += 1
|
105 |
|
106 |
+
# Progress logging
|
107 |
+
if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
|
108 |
+
elapsed = time.time() - start_time
|
109 |
+
print(f"Epoch {epoch:3d} | "
|
110 |
+
f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
|
111 |
+
f"Val: {val_metrics['loss']:.4f} ({val_metrics['acc']:.4f}) | "
|
112 |
+
f"LR: {self.optimizer.param_groups[0]['lr']:.6f} | "
|
113 |
+
f"Time: {elapsed:.1f}s")
|
114 |
+
|
115 |
# Early stopping
|
116 |
if self.patience_counter >= self.patience:
|
117 |
if verbose:
|
118 |
+
print(f"π Early stopping at epoch {epoch}")
|
119 |
break
|
120 |
|
121 |
+
# Step scheduler
|
122 |
+
self.scheduler.step()
|
|
|
|
|
|
|
|
|
123 |
|
|
|
|
|
|
|
|
|
124 |
if verbose:
|
125 |
+
total_time = time.time() - start_time
|
126 |
+
print(f"β
Training completed in {total_time:.2f}s")
|
127 |
print(f"π Best validation accuracy: {self.best_val_acc:.4f}")
|
128 |
|
129 |
+
return self.training_history
|
130 |
|
131 |
def _train_epoch(self, data, epoch):
|
132 |
"""Single training epoch"""
|
133 |
self.model.train()
|
|
|
|
|
134 |
self.optimizer.zero_grad()
|
135 |
|
136 |
+
# Forward pass
|
137 |
h = self.model(data.x, data.edge_index)
|
138 |
+
logits = self.model.classifier(h)
|
139 |
|
140 |
+
# Compute loss on training nodes
|
141 |
+
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
|
142 |
|
143 |
# Backward pass
|
144 |
+
train_loss.backward()
|
145 |
|
146 |
# Gradient clipping
|
147 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
148 |
|
149 |
self.optimizer.step()
|
150 |
|
151 |
+
# Compute accuracy
|
152 |
with torch.no_grad():
|
153 |
+
train_pred = logits[data.train_mask].argmax(dim=1)
|
154 |
+
train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
|
155 |
|
156 |
+
return {'loss': train_loss.item(), 'acc': train_acc}
|
157 |
|
158 |
+
def _validate_epoch(self, data, epoch):
|
159 |
"""Single validation epoch"""
|
160 |
self.model.eval()
|
161 |
|
162 |
with torch.no_grad():
|
163 |
h = self.model(data.x, data.edge_index)
|
164 |
+
logits = self.model.classifier(h)
|
|
|
|
|
|
|
165 |
|
166 |
+
# Validation loss and accuracy
|
167 |
+
val_loss = self.criterion(logits[data.val_mask], data.y[data.val_mask])
|
168 |
+
val_pred = logits[data.val_mask].argmax(dim=1)
|
169 |
+
val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
|
170 |
|
171 |
+
return {'loss': val_loss.item(), 'acc': val_acc}
|
172 |
|
173 |
def test(self, data):
|
174 |
+
"""Comprehensive test evaluation"""
|
175 |
self.model.eval()
|
176 |
|
177 |
with torch.no_grad():
|
178 |
h = self.model(data.x, data.edge_index)
|
179 |
+
|
180 |
+
# Ensure classifier exists
|
181 |
+
if self.model.classifier is None:
|
182 |
+
num_classes = len(torch.unique(data.y))
|
183 |
+
self.model._init_classifier(num_classes, self.device)
|
184 |
+
|
185 |
+
logits = self.model.classifier(h)
|
186 |
|
187 |
# Test metrics
|
188 |
+
test_loss = self.criterion(logits[data.test_mask], data.y[data.test_mask])
|
189 |
+
test_pred = logits[data.test_mask]
|
190 |
+
test_target = data.y[data.test_mask]
|
191 |
|
192 |
+
# Comprehensive metrics
|
193 |
+
metrics = {
|
194 |
+
'test_loss': test_loss.item(),
|
195 |
+
'test_acc': GraphMetrics.accuracy(test_pred, test_target),
|
196 |
+
'f1_macro': GraphMetrics.f1_score_macro(test_pred, test_target),
|
197 |
+
'f1_micro': GraphMetrics.f1_score_micro(test_pred, test_target),
|
198 |
+
}
|
199 |
|
200 |
+
# Additional metrics
|
201 |
+
precision, recall = GraphMetrics.precision_recall(test_pred, test_target)
|
202 |
+
metrics['precision'] = precision
|
203 |
+
metrics['recall'] = recall
|
204 |
+
|
205 |
+
return metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
+
def get_embeddings(self, data):
|
208 |
+
"""Get node embeddings"""
|
209 |
+
self.model.eval()
|
210 |
+
with torch.no_grad():
|
211 |
+
return self.model(data.x, data.edge_index)
|
212 |
+
|
213 |
+
|
214 |
+
class EnhancedGraphMambaTrainer(GraphMambaTrainer):
|
215 |
+
"""Enhanced trainer with additional optimizations"""
|
216 |
+
|
217 |
+
def __init__(self, model, config, device):
|
218 |
+
super().__init__(model, config, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
+
# Even more conservative learning rate for complex architectures
|
221 |
+
if hasattr(model, 'multi_scale') or 'Hybrid' in model.__class__.__name__:
|
222 |
+
self.lr = 0.0005 # Lower for complex models
|
223 |
+
|
224 |
+
self.optimizer = optim.AdamW(
|
225 |
+
model.parameters(),
|
226 |
+
lr=self.lr,
|
227 |
+
weight_decay=config['training']['weight_decay'],
|
228 |
+
betas=(0.9, 0.99), # More stable
|
229 |
+
eps=1e-8
|
230 |
+
)
|
231 |
+
|
232 |
+
def _setup_scheduler(self, total_steps):
|
233 |
+
"""Enhanced scheduler for complex models"""
|
234 |
+
# Cosine annealing with warm restarts
|
235 |
+
self.scheduler = CosineAnnealingWarmRestarts(
|
236 |
+
self.optimizer,
|
237 |
+
T_0=20, # Restart every 20 epochs
|
238 |
+
T_mult=2, # Double period after restart
|
239 |
+
eta_min=self.min_lr
|
240 |
+
)
|
241 |
|
242 |
+
def train_node_classification(self, data, verbose=True):
|
243 |
+
"""Training with enhanced monitoring"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
+
if verbose:
|
246 |
+
model_type = self.model.__class__.__name__
|
247 |
+
print(f"ποΈ Training {model_type} for {self.epochs} epochs")
|
248 |
+
print(f"π Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
|
249 |
+
print(f"π― Classes: {len(torch.unique(data.y))}")
|
250 |
+
print(f"πΎ Device: {self.device}")
|
251 |
+
print(f"βοΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
252 |
+
print(f"π Learning Rate: {self.lr} (enhanced schedule)")
|
253 |
|
254 |
+
# Call parent method with enhancements
|
255 |
+
history = super().train_node_classification(data, verbose)
|
256 |
|
257 |
+
# Additional analysis
|
258 |
+
if verbose:
|
259 |
+
final_acc = history['val_acc'][-1] if history['val_acc'] else 0
|
260 |
+
improvement = final_acc - (history['val_acc'][0] if history['val_acc'] else 0)
|
261 |
+
print(f"π Final validation accuracy: {final_acc:.4f}")
|
262 |
+
print(f"π Total improvement: {improvement:.4f} ({improvement*100:.1f}%)")
|
263 |
|
264 |
+
if final_acc > 0.6:
|
265 |
+
print("π Excellent performance! Model converged well.")
|
266 |
+
elif final_acc > 0.4:
|
267 |
+
print("π Good progress! Consider more epochs or tuning.")
|
268 |
+
else:
|
269 |
+
print("β οΈ Low accuracy. Check model architecture or data.")
|
270 |
|
271 |
+
return history
|