Create core/trainer.py
Browse files- core/trainer.py +299 -0
core/trainer.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
import time
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
class GraphMambaTrainer:
|
11 |
+
"""
|
12 |
+
Production-ready trainer for GraphMamba
|
13 |
+
Includes advanced training techniques
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, model, config, device='cpu'):
|
17 |
+
self.model = model.to(device)
|
18 |
+
self.config = config
|
19 |
+
self.device = device
|
20 |
+
|
21 |
+
# Training parameters
|
22 |
+
self.lr = config['training']['learning_rate']
|
23 |
+
self.weight_decay = config['training']['weight_decay']
|
24 |
+
self.epochs = config['training']['epochs']
|
25 |
+
self.patience = config['training']['patience']
|
26 |
+
self.warmup_epochs = config['training']['warmup_epochs']
|
27 |
+
self.min_lr = config['training']['min_lr']
|
28 |
+
|
29 |
+
# Initialize optimizer
|
30 |
+
self.optimizer = optim.AdamW(
|
31 |
+
self.model.parameters(),
|
32 |
+
lr=self.lr,
|
33 |
+
weight_decay=self.weight_decay,
|
34 |
+
betas=(0.9, 0.999),
|
35 |
+
eps=1e-8
|
36 |
+
)
|
37 |
+
|
38 |
+
# Loss function
|
39 |
+
self.criterion = nn.CrossEntropyLoss()
|
40 |
+
|
41 |
+
# Scheduler
|
42 |
+
self.scheduler = None
|
43 |
+
|
44 |
+
# Training history
|
45 |
+
self.history = {
|
46 |
+
'train_loss': [],
|
47 |
+
'train_acc': [],
|
48 |
+
'val_loss': [],
|
49 |
+
'val_acc': [],
|
50 |
+
'lr': []
|
51 |
+
}
|
52 |
+
|
53 |
+
# Best model tracking
|
54 |
+
self.best_val_acc = 0.0
|
55 |
+
self.best_model_state = None
|
56 |
+
self.patience_counter = 0
|
57 |
+
|
58 |
+
def train_node_classification(self, data, verbose=True):
|
59 |
+
"""
|
60 |
+
Train model for node classification
|
61 |
+
"""
|
62 |
+
# Initialize classifier
|
63 |
+
num_classes = len(torch.unique(data.y))
|
64 |
+
self.model._init_classifier(num_classes, self.device)
|
65 |
+
|
66 |
+
# Update optimizer to include new parameters
|
67 |
+
self.optimizer = optim.AdamW(
|
68 |
+
self.model.parameters(),
|
69 |
+
lr=self.lr,
|
70 |
+
weight_decay=self.weight_decay,
|
71 |
+
betas=(0.9, 0.999)
|
72 |
+
)
|
73 |
+
|
74 |
+
# Initialize scheduler
|
75 |
+
self.scheduler = CosineAnnealingLR(
|
76 |
+
self.optimizer,
|
77 |
+
T_max=self.epochs - self.warmup_epochs,
|
78 |
+
eta_min=self.min_lr
|
79 |
+
)
|
80 |
+
|
81 |
+
if verbose:
|
82 |
+
print(f"🏋️ Training GraphMamba for {self.epochs} epochs")
|
83 |
+
print(f"📊 Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
|
84 |
+
print(f"🎯 Classes: {num_classes}")
|
85 |
+
print(f"💾 Device: {self.device}")
|
86 |
+
print(f"⚙️ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
87 |
+
|
88 |
+
# Training loop
|
89 |
+
for epoch in range(self.epochs):
|
90 |
+
# Training phase
|
91 |
+
train_loss, train_acc = self._train_epoch(data, epoch)
|
92 |
+
|
93 |
+
# Validation phase
|
94 |
+
val_loss, val_acc = self._validate_epoch(data)
|
95 |
+
|
96 |
+
# Learning rate scheduling
|
97 |
+
if epoch >= self.warmup_epochs:
|
98 |
+
self.scheduler.step()
|
99 |
+
else:
|
100 |
+
# Warmup
|
101 |
+
warmup_lr = self.lr * (epoch + 1) / self.warmup_epochs
|
102 |
+
for param_group in self.optimizer.param_groups:
|
103 |
+
param_group['lr'] = warmup_lr
|
104 |
+
|
105 |
+
# Record history
|
106 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
107 |
+
self.history['train_loss'].append(train_loss)
|
108 |
+
self.history['train_acc'].append(train_acc)
|
109 |
+
self.history['val_loss'].append(val_loss)
|
110 |
+
self.history['val_acc'].append(val_acc)
|
111 |
+
self.history['lr'].append(current_lr)
|
112 |
+
|
113 |
+
# Check for improvement
|
114 |
+
if val_acc > self.best_val_acc:
|
115 |
+
self.best_val_acc = val_acc
|
116 |
+
self.best_model_state = self.model.state_dict().copy()
|
117 |
+
self.patience_counter = 0
|
118 |
+
|
119 |
+
if verbose and epoch % 10 == 0:
|
120 |
+
print(f"🎉 New best validation accuracy: {val_acc:.4f}")
|
121 |
+
else:
|
122 |
+
self.patience_counter += 1
|
123 |
+
|
124 |
+
# Early stopping
|
125 |
+
if self.patience_counter >= self.patience:
|
126 |
+
if verbose:
|
127 |
+
print(f"⏹️ Early stopping at epoch {epoch}")
|
128 |
+
break
|
129 |
+
|
130 |
+
# Progress reporting
|
131 |
+
if verbose and epoch % 20 == 0:
|
132 |
+
print(f"Epoch {epoch:3d} | "
|
133 |
+
f"Train: {train_loss:.4f} ({train_acc:.4f}) | "
|
134 |
+
f"Val: {val_loss:.4f} ({val_acc:.4f}) | "
|
135 |
+
f"LR: {current_lr:.6f}")
|
136 |
+
|
137 |
+
# Load best model
|
138 |
+
if self.best_model_state is not None:
|
139 |
+
self.model.load_state_dict(self.best_model_state)
|
140 |
+
|
141 |
+
if verbose:
|
142 |
+
print(f"✅ Training completed!")
|
143 |
+
print(f"🏆 Best validation accuracy: {self.best_val_acc:.4f}")
|
144 |
+
|
145 |
+
return self.history
|
146 |
+
|
147 |
+
def _train_epoch(self, data, epoch):
|
148 |
+
"""Single training epoch"""
|
149 |
+
self.model.train()
|
150 |
+
|
151 |
+
# Forward pass
|
152 |
+
self.optimizer.zero_grad()
|
153 |
+
|
154 |
+
h = self.model(data.x, data.edge_index)
|
155 |
+
pred = self.model.classifier(h)
|
156 |
+
|
157 |
+
# Loss only on training nodes
|
158 |
+
loss = self.criterion(pred[data.train_mask], data.y[data.train_mask])
|
159 |
+
|
160 |
+
# Backward pass
|
161 |
+
loss.backward()
|
162 |
+
|
163 |
+
# Gradient clipping
|
164 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
165 |
+
|
166 |
+
self.optimizer.step()
|
167 |
+
|
168 |
+
# Calculate accuracy
|
169 |
+
with torch.no_grad():
|
170 |
+
pred_labels = pred[data.train_mask].argmax(dim=1)
|
171 |
+
acc = (pred_labels == data.y[data.train_mask]).float().mean()
|
172 |
+
|
173 |
+
return loss.item(), acc.item()
|
174 |
+
|
175 |
+
def _validate_epoch(self, data):
|
176 |
+
"""Single validation epoch"""
|
177 |
+
self.model.eval()
|
178 |
+
|
179 |
+
with torch.no_grad():
|
180 |
+
h = self.model(data.x, data.edge_index)
|
181 |
+
pred = self.model.classifier(h)
|
182 |
+
|
183 |
+
# Loss on validation nodes
|
184 |
+
val_loss = self.criterion(pred[data.val_mask], data.y[data.val_mask])
|
185 |
+
|
186 |
+
# Accuracy
|
187 |
+
pred_labels = pred[data.val_mask].argmax(dim=1)
|
188 |
+
val_acc = (pred_labels == data.y[data.val_mask]).float().mean()
|
189 |
+
|
190 |
+
return val_loss.item(), val_acc.item()
|
191 |
+
|
192 |
+
def test(self, data):
|
193 |
+
"""Test the model"""
|
194 |
+
self.model.eval()
|
195 |
+
|
196 |
+
with torch.no_grad():
|
197 |
+
h = self.model(data.x, data.edge_index)
|
198 |
+
pred = self.model.classifier(h)
|
199 |
+
|
200 |
+
# Test metrics
|
201 |
+
test_loss = self.criterion(pred[data.test_mask], data.y[data.test_mask])
|
202 |
+
pred_labels = pred[data.test_mask].argmax(dim=1)
|
203 |
+
test_acc = (pred_labels == data.y[data.test_mask]).float().mean()
|
204 |
+
|
205 |
+
# Per-class accuracy
|
206 |
+
num_classes = len(torch.unique(data.y))
|
207 |
+
class_acc = []
|
208 |
+
|
209 |
+
for c in range(num_classes):
|
210 |
+
class_mask = data.y[data.test_mask] == c
|
211 |
+
if class_mask.any():
|
212 |
+
class_correct = (pred_labels[class_mask] == c).float().mean()
|
213 |
+
class_acc.append(class_correct.item())
|
214 |
+
else:
|
215 |
+
class_acc.append(0.0)
|
216 |
+
|
217 |
+
return {
|
218 |
+
'test_loss': test_loss.item(),
|
219 |
+
'test_acc': test_acc.item(),
|
220 |
+
'class_acc': class_acc
|
221 |
+
}
|
222 |
+
|
223 |
+
def plot_training_history(self, save_path=None):
|
224 |
+
"""Plot training history"""
|
225 |
+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
|
226 |
+
|
227 |
+
epochs = range(len(self.history['train_loss']))
|
228 |
+
|
229 |
+
# Loss plot
|
230 |
+
ax1.plot(epochs, self.history['train_loss'], label='Train', color='blue')
|
231 |
+
ax1.plot(epochs, self.history['val_loss'], label='Validation', color='red')
|
232 |
+
ax1.set_title('Training Loss')
|
233 |
+
ax1.set_xlabel('Epoch')
|
234 |
+
ax1.set_ylabel('Loss')
|
235 |
+
ax1.legend()
|
236 |
+
ax1.grid(True, alpha=0.3)
|
237 |
+
|
238 |
+
# Accuracy plot
|
239 |
+
ax2.plot(epochs, self.history['train_acc'], label='Train', color='blue')
|
240 |
+
ax2.plot(epochs, self.history['val_acc'], label='Validation', color='red')
|
241 |
+
ax2.set_title('Training Accuracy')
|
242 |
+
ax2.set_xlabel('Epoch')
|
243 |
+
ax2.set_ylabel('Accuracy')
|
244 |
+
ax2.legend()
|
245 |
+
ax2.grid(True, alpha=0.3)
|
246 |
+
|
247 |
+
# Learning rate plot
|
248 |
+
ax3.plot(epochs, self.history['lr'], color='green')
|
249 |
+
ax3.set_title('Learning Rate')
|
250 |
+
ax3.set_xlabel('Epoch')
|
251 |
+
ax3.set_ylabel('Learning Rate')
|
252 |
+
ax3.set_yscale('log')
|
253 |
+
ax3.grid(True, alpha=0.3)
|
254 |
+
|
255 |
+
# Best metrics
|
256 |
+
best_train_acc = max(self.history['train_acc'])
|
257 |
+
best_val_acc = max(self.history['val_acc'])
|
258 |
+
|
259 |
+
ax4.bar(['Best Train Acc', 'Best Val Acc'], [best_train_acc, best_val_acc],
|
260 |
+
color=['blue', 'red'], alpha=0.7)
|
261 |
+
ax4.set_title('Best Accuracies')
|
262 |
+
ax4.set_ylabel('Accuracy')
|
263 |
+
ax4.set_ylim(0, 1)
|
264 |
+
|
265 |
+
for i, v in enumerate([best_train_acc, best_val_acc]):
|
266 |
+
ax4.text(i, v + 0.01, f'{v:.4f}', ha='center', va='bottom')
|
267 |
+
|
268 |
+
plt.tight_layout()
|
269 |
+
|
270 |
+
if save_path:
|
271 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
272 |
+
|
273 |
+
return fig
|
274 |
+
|
275 |
+
def save_model(self, path):
|
276 |
+
"""Save model and training state"""
|
277 |
+
torch.save({
|
278 |
+
'model_state_dict': self.model.state_dict(),
|
279 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
280 |
+
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
281 |
+
'best_val_acc': self.best_val_acc,
|
282 |
+
'history': self.history,
|
283 |
+
'config': self.config
|
284 |
+
}, path)
|
285 |
+
|
286 |
+
def load_model(self, path):
|
287 |
+
"""Load model and training state"""
|
288 |
+
checkpoint = torch.load(path, map_location=self.device)
|
289 |
+
|
290 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
291 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
292 |
+
|
293 |
+
if checkpoint['scheduler_state_dict'] and self.scheduler:
|
294 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
295 |
+
|
296 |
+
self.best_val_acc = checkpoint['best_val_acc']
|
297 |
+
self.history = checkpoint['history']
|
298 |
+
|
299 |
+
return checkpoint['config']
|