kfoughali commited on
Commit
6aa4c8c
·
verified ·
1 Parent(s): e4d5cc2

Create core/trainer.py

Browse files
Files changed (1) hide show
  1. core/trainer.py +299 -0
core/trainer.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import time
8
+ import matplotlib.pyplot as plt
9
+
10
+ class GraphMambaTrainer:
11
+ """
12
+ Production-ready trainer for GraphMamba
13
+ Includes advanced training techniques
14
+ """
15
+
16
+ def __init__(self, model, config, device='cpu'):
17
+ self.model = model.to(device)
18
+ self.config = config
19
+ self.device = device
20
+
21
+ # Training parameters
22
+ self.lr = config['training']['learning_rate']
23
+ self.weight_decay = config['training']['weight_decay']
24
+ self.epochs = config['training']['epochs']
25
+ self.patience = config['training']['patience']
26
+ self.warmup_epochs = config['training']['warmup_epochs']
27
+ self.min_lr = config['training']['min_lr']
28
+
29
+ # Initialize optimizer
30
+ self.optimizer = optim.AdamW(
31
+ self.model.parameters(),
32
+ lr=self.lr,
33
+ weight_decay=self.weight_decay,
34
+ betas=(0.9, 0.999),
35
+ eps=1e-8
36
+ )
37
+
38
+ # Loss function
39
+ self.criterion = nn.CrossEntropyLoss()
40
+
41
+ # Scheduler
42
+ self.scheduler = None
43
+
44
+ # Training history
45
+ self.history = {
46
+ 'train_loss': [],
47
+ 'train_acc': [],
48
+ 'val_loss': [],
49
+ 'val_acc': [],
50
+ 'lr': []
51
+ }
52
+
53
+ # Best model tracking
54
+ self.best_val_acc = 0.0
55
+ self.best_model_state = None
56
+ self.patience_counter = 0
57
+
58
+ def train_node_classification(self, data, verbose=True):
59
+ """
60
+ Train model for node classification
61
+ """
62
+ # Initialize classifier
63
+ num_classes = len(torch.unique(data.y))
64
+ self.model._init_classifier(num_classes, self.device)
65
+
66
+ # Update optimizer to include new parameters
67
+ self.optimizer = optim.AdamW(
68
+ self.model.parameters(),
69
+ lr=self.lr,
70
+ weight_decay=self.weight_decay,
71
+ betas=(0.9, 0.999)
72
+ )
73
+
74
+ # Initialize scheduler
75
+ self.scheduler = CosineAnnealingLR(
76
+ self.optimizer,
77
+ T_max=self.epochs - self.warmup_epochs,
78
+ eta_min=self.min_lr
79
+ )
80
+
81
+ if verbose:
82
+ print(f"🏋️ Training GraphMamba for {self.epochs} epochs")
83
+ print(f"📊 Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
84
+ print(f"🎯 Classes: {num_classes}")
85
+ print(f"💾 Device: {self.device}")
86
+ print(f"⚙️ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
87
+
88
+ # Training loop
89
+ for epoch in range(self.epochs):
90
+ # Training phase
91
+ train_loss, train_acc = self._train_epoch(data, epoch)
92
+
93
+ # Validation phase
94
+ val_loss, val_acc = self._validate_epoch(data)
95
+
96
+ # Learning rate scheduling
97
+ if epoch >= self.warmup_epochs:
98
+ self.scheduler.step()
99
+ else:
100
+ # Warmup
101
+ warmup_lr = self.lr * (epoch + 1) / self.warmup_epochs
102
+ for param_group in self.optimizer.param_groups:
103
+ param_group['lr'] = warmup_lr
104
+
105
+ # Record history
106
+ current_lr = self.optimizer.param_groups[0]['lr']
107
+ self.history['train_loss'].append(train_loss)
108
+ self.history['train_acc'].append(train_acc)
109
+ self.history['val_loss'].append(val_loss)
110
+ self.history['val_acc'].append(val_acc)
111
+ self.history['lr'].append(current_lr)
112
+
113
+ # Check for improvement
114
+ if val_acc > self.best_val_acc:
115
+ self.best_val_acc = val_acc
116
+ self.best_model_state = self.model.state_dict().copy()
117
+ self.patience_counter = 0
118
+
119
+ if verbose and epoch % 10 == 0:
120
+ print(f"🎉 New best validation accuracy: {val_acc:.4f}")
121
+ else:
122
+ self.patience_counter += 1
123
+
124
+ # Early stopping
125
+ if self.patience_counter >= self.patience:
126
+ if verbose:
127
+ print(f"⏹️ Early stopping at epoch {epoch}")
128
+ break
129
+
130
+ # Progress reporting
131
+ if verbose and epoch % 20 == 0:
132
+ print(f"Epoch {epoch:3d} | "
133
+ f"Train: {train_loss:.4f} ({train_acc:.4f}) | "
134
+ f"Val: {val_loss:.4f} ({val_acc:.4f}) | "
135
+ f"LR: {current_lr:.6f}")
136
+
137
+ # Load best model
138
+ if self.best_model_state is not None:
139
+ self.model.load_state_dict(self.best_model_state)
140
+
141
+ if verbose:
142
+ print(f"✅ Training completed!")
143
+ print(f"🏆 Best validation accuracy: {self.best_val_acc:.4f}")
144
+
145
+ return self.history
146
+
147
+ def _train_epoch(self, data, epoch):
148
+ """Single training epoch"""
149
+ self.model.train()
150
+
151
+ # Forward pass
152
+ self.optimizer.zero_grad()
153
+
154
+ h = self.model(data.x, data.edge_index)
155
+ pred = self.model.classifier(h)
156
+
157
+ # Loss only on training nodes
158
+ loss = self.criterion(pred[data.train_mask], data.y[data.train_mask])
159
+
160
+ # Backward pass
161
+ loss.backward()
162
+
163
+ # Gradient clipping
164
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
165
+
166
+ self.optimizer.step()
167
+
168
+ # Calculate accuracy
169
+ with torch.no_grad():
170
+ pred_labels = pred[data.train_mask].argmax(dim=1)
171
+ acc = (pred_labels == data.y[data.train_mask]).float().mean()
172
+
173
+ return loss.item(), acc.item()
174
+
175
+ def _validate_epoch(self, data):
176
+ """Single validation epoch"""
177
+ self.model.eval()
178
+
179
+ with torch.no_grad():
180
+ h = self.model(data.x, data.edge_index)
181
+ pred = self.model.classifier(h)
182
+
183
+ # Loss on validation nodes
184
+ val_loss = self.criterion(pred[data.val_mask], data.y[data.val_mask])
185
+
186
+ # Accuracy
187
+ pred_labels = pred[data.val_mask].argmax(dim=1)
188
+ val_acc = (pred_labels == data.y[data.val_mask]).float().mean()
189
+
190
+ return val_loss.item(), val_acc.item()
191
+
192
+ def test(self, data):
193
+ """Test the model"""
194
+ self.model.eval()
195
+
196
+ with torch.no_grad():
197
+ h = self.model(data.x, data.edge_index)
198
+ pred = self.model.classifier(h)
199
+
200
+ # Test metrics
201
+ test_loss = self.criterion(pred[data.test_mask], data.y[data.test_mask])
202
+ pred_labels = pred[data.test_mask].argmax(dim=1)
203
+ test_acc = (pred_labels == data.y[data.test_mask]).float().mean()
204
+
205
+ # Per-class accuracy
206
+ num_classes = len(torch.unique(data.y))
207
+ class_acc = []
208
+
209
+ for c in range(num_classes):
210
+ class_mask = data.y[data.test_mask] == c
211
+ if class_mask.any():
212
+ class_correct = (pred_labels[class_mask] == c).float().mean()
213
+ class_acc.append(class_correct.item())
214
+ else:
215
+ class_acc.append(0.0)
216
+
217
+ return {
218
+ 'test_loss': test_loss.item(),
219
+ 'test_acc': test_acc.item(),
220
+ 'class_acc': class_acc
221
+ }
222
+
223
+ def plot_training_history(self, save_path=None):
224
+ """Plot training history"""
225
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
226
+
227
+ epochs = range(len(self.history['train_loss']))
228
+
229
+ # Loss plot
230
+ ax1.plot(epochs, self.history['train_loss'], label='Train', color='blue')
231
+ ax1.plot(epochs, self.history['val_loss'], label='Validation', color='red')
232
+ ax1.set_title('Training Loss')
233
+ ax1.set_xlabel('Epoch')
234
+ ax1.set_ylabel('Loss')
235
+ ax1.legend()
236
+ ax1.grid(True, alpha=0.3)
237
+
238
+ # Accuracy plot
239
+ ax2.plot(epochs, self.history['train_acc'], label='Train', color='blue')
240
+ ax2.plot(epochs, self.history['val_acc'], label='Validation', color='red')
241
+ ax2.set_title('Training Accuracy')
242
+ ax2.set_xlabel('Epoch')
243
+ ax2.set_ylabel('Accuracy')
244
+ ax2.legend()
245
+ ax2.grid(True, alpha=0.3)
246
+
247
+ # Learning rate plot
248
+ ax3.plot(epochs, self.history['lr'], color='green')
249
+ ax3.set_title('Learning Rate')
250
+ ax3.set_xlabel('Epoch')
251
+ ax3.set_ylabel('Learning Rate')
252
+ ax3.set_yscale('log')
253
+ ax3.grid(True, alpha=0.3)
254
+
255
+ # Best metrics
256
+ best_train_acc = max(self.history['train_acc'])
257
+ best_val_acc = max(self.history['val_acc'])
258
+
259
+ ax4.bar(['Best Train Acc', 'Best Val Acc'], [best_train_acc, best_val_acc],
260
+ color=['blue', 'red'], alpha=0.7)
261
+ ax4.set_title('Best Accuracies')
262
+ ax4.set_ylabel('Accuracy')
263
+ ax4.set_ylim(0, 1)
264
+
265
+ for i, v in enumerate([best_train_acc, best_val_acc]):
266
+ ax4.text(i, v + 0.01, f'{v:.4f}', ha='center', va='bottom')
267
+
268
+ plt.tight_layout()
269
+
270
+ if save_path:
271
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
272
+
273
+ return fig
274
+
275
+ def save_model(self, path):
276
+ """Save model and training state"""
277
+ torch.save({
278
+ 'model_state_dict': self.model.state_dict(),
279
+ 'optimizer_state_dict': self.optimizer.state_dict(),
280
+ 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
281
+ 'best_val_acc': self.best_val_acc,
282
+ 'history': self.history,
283
+ 'config': self.config
284
+ }, path)
285
+
286
+ def load_model(self, path):
287
+ """Load model and training state"""
288
+ checkpoint = torch.load(path, map_location=self.device)
289
+
290
+ self.model.load_state_dict(checkpoint['model_state_dict'])
291
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
292
+
293
+ if checkpoint['scheduler_state_dict'] and self.scheduler:
294
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
295
+
296
+ self.best_val_acc = checkpoint['best_val_acc']
297
+ self.history = checkpoint['history']
298
+
299
+ return checkpoint['config']