kfoughali commited on
Commit
93db32e
Β·
verified Β·
1 Parent(s): b09f924

Update core/trainer.py

Browse files
Files changed (1) hide show
  1. core/trainer.py +175 -203
core/trainer.py CHANGED
@@ -1,299 +1,271 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
4
- from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
5
  import numpy as np
6
- from tqdm import tqdm
7
  import time
8
- import matplotlib.pyplot as plt
 
 
 
9
 
10
  class GraphMambaTrainer:
11
- """
12
- Production-ready trainer for GraphMamba
13
- Includes advanced training techniques
14
- """
15
 
16
- def __init__(self, model, config, device='cpu'):
17
- self.model = model.to(device)
18
  self.config = config
19
  self.device = device
20
 
21
- # Training parameters
22
- self.lr = config['training']['learning_rate']
23
- self.weight_decay = config['training']['weight_decay']
24
  self.epochs = config['training']['epochs']
25
- self.patience = config['training']['patience']
26
- self.warmup_epochs = config['training']['warmup_epochs']
27
- self.min_lr = config['training']['min_lr']
28
 
29
- # Initialize optimizer
30
  self.optimizer = optim.AdamW(
31
- self.model.parameters(),
32
  lr=self.lr,
33
- weight_decay=self.weight_decay,
34
  betas=(0.9, 0.999),
35
  eps=1e-8
36
  )
37
 
38
- # Loss function
39
  self.criterion = nn.CrossEntropyLoss()
40
 
41
- # Scheduler
42
  self.scheduler = None
43
 
44
- # Training history
45
- self.history = {
46
- 'train_loss': [],
47
- 'train_acc': [],
48
- 'val_loss': [],
49
- 'val_acc': [],
50
- 'lr': []
51
- }
52
-
53
- # Best model tracking
54
  self.best_val_acc = 0.0
55
- self.best_model_state = None
56
  self.patience_counter = 0
57
-
58
- def train_node_classification(self, data, verbose=True):
59
- """
60
- Train model for node classification
61
- """
62
- # Initialize classifier
63
- num_classes = len(torch.unique(data.y))
64
- self.model._init_classifier(num_classes, self.device)
65
-
66
- # Update optimizer to include new parameters
67
- self.optimizer = optim.AdamW(
68
- self.model.parameters(),
69
- lr=self.lr,
70
- weight_decay=self.weight_decay,
71
- betas=(0.9, 0.999)
72
- )
73
-
74
- # Initialize scheduler
75
- self.scheduler = CosineAnnealingLR(
76
  self.optimizer,
77
- T_max=self.epochs - self.warmup_epochs,
78
- eta_min=self.min_lr
 
 
 
 
79
  )
 
 
 
80
 
81
  if verbose:
82
  print(f"πŸ‹οΈ Training GraphMamba for {self.epochs} epochs")
83
  print(f"πŸ“Š Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
84
- print(f"🎯 Classes: {num_classes}")
85
  print(f"πŸ’Ύ Device: {self.device}")
86
  print(f"βš™οΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
87
 
88
- # Training loop
 
 
 
 
 
 
 
 
 
89
  for epoch in range(self.epochs):
90
- # Training phase
91
- train_loss, train_acc = self._train_epoch(data, epoch)
92
 
93
- # Validation phase
94
- val_loss, val_acc = self._validate_epoch(data)
95
 
96
- # Learning rate scheduling
97
- if epoch >= self.warmup_epochs:
98
- self.scheduler.step()
99
- else:
100
- # Warmup
101
- warmup_lr = self.lr * (epoch + 1) / self.warmup_epochs
102
- for param_group in self.optimizer.param_groups:
103
- param_group['lr'] = warmup_lr
104
-
105
- # Record history
106
- current_lr = self.optimizer.param_groups[0]['lr']
107
- self.history['train_loss'].append(train_loss)
108
- self.history['train_acc'].append(train_acc)
109
- self.history['val_loss'].append(val_loss)
110
- self.history['val_acc'].append(val_acc)
111
- self.history['lr'].append(current_lr)
112
 
113
  # Check for improvement
114
- if val_acc > self.best_val_acc:
115
- self.best_val_acc = val_acc
116
- self.best_model_state = self.model.state_dict().copy()
117
  self.patience_counter = 0
118
-
119
- if verbose and epoch % 10 == 0:
120
- print(f"πŸŽ‰ New best validation accuracy: {val_acc:.4f}")
121
  else:
122
  self.patience_counter += 1
123
 
 
 
 
 
 
 
 
 
 
124
  # Early stopping
125
  if self.patience_counter >= self.patience:
126
  if verbose:
127
- print(f"⏹️ Early stopping at epoch {epoch}")
128
  break
129
 
130
- # Progress reporting
131
- if verbose and epoch % 20 == 0:
132
- print(f"Epoch {epoch:3d} | "
133
- f"Train: {train_loss:.4f} ({train_acc:.4f}) | "
134
- f"Val: {val_loss:.4f} ({val_acc:.4f}) | "
135
- f"LR: {current_lr:.6f}")
136
 
137
- # Load best model
138
- if self.best_model_state is not None:
139
- self.model.load_state_dict(self.best_model_state)
140
-
141
  if verbose:
142
- print(f"βœ… Training completed!")
 
143
  print(f"πŸ† Best validation accuracy: {self.best_val_acc:.4f}")
144
 
145
- return self.history
146
 
147
  def _train_epoch(self, data, epoch):
148
  """Single training epoch"""
149
  self.model.train()
150
-
151
- # Forward pass
152
  self.optimizer.zero_grad()
153
 
 
154
  h = self.model(data.x, data.edge_index)
155
- pred = self.model.classifier(h)
156
 
157
- # Loss only on training nodes
158
- loss = self.criterion(pred[data.train_mask], data.y[data.train_mask])
159
 
160
  # Backward pass
161
- loss.backward()
162
 
163
  # Gradient clipping
164
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
165
 
166
  self.optimizer.step()
167
 
168
- # Calculate accuracy
169
  with torch.no_grad():
170
- pred_labels = pred[data.train_mask].argmax(dim=1)
171
- acc = (pred_labels == data.y[data.train_mask]).float().mean()
172
 
173
- return loss.item(), acc.item()
174
 
175
- def _validate_epoch(self, data):
176
  """Single validation epoch"""
177
  self.model.eval()
178
 
179
  with torch.no_grad():
180
  h = self.model(data.x, data.edge_index)
181
- pred = self.model.classifier(h)
182
-
183
- # Loss on validation nodes
184
- val_loss = self.criterion(pred[data.val_mask], data.y[data.val_mask])
185
 
186
- # Accuracy
187
- pred_labels = pred[data.val_mask].argmax(dim=1)
188
- val_acc = (pred_labels == data.y[data.val_mask]).float().mean()
 
189
 
190
- return val_loss.item(), val_acc.item()
191
 
192
  def test(self, data):
193
- """Test the model"""
194
  self.model.eval()
195
 
196
  with torch.no_grad():
197
  h = self.model(data.x, data.edge_index)
198
- pred = self.model.classifier(h)
 
 
 
 
 
 
199
 
200
  # Test metrics
201
- test_loss = self.criterion(pred[data.test_mask], data.y[data.test_mask])
202
- pred_labels = pred[data.test_mask].argmax(dim=1)
203
- test_acc = (pred_labels == data.y[data.test_mask]).float().mean()
204
 
205
- # Per-class accuracy
206
- num_classes = len(torch.unique(data.y))
207
- class_acc = []
 
 
 
 
208
 
209
- for c in range(num_classes):
210
- class_mask = data.y[data.test_mask] == c
211
- if class_mask.any():
212
- class_correct = (pred_labels[class_mask] == c).float().mean()
213
- class_acc.append(class_correct.item())
214
- else:
215
- class_acc.append(0.0)
216
-
217
- return {
218
- 'test_loss': test_loss.item(),
219
- 'test_acc': test_acc.item(),
220
- 'class_acc': class_acc
221
- }
222
 
223
- def plot_training_history(self, save_path=None):
224
- """Plot training history"""
225
- fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
226
-
227
- epochs = range(len(self.history['train_loss']))
228
-
229
- # Loss plot
230
- ax1.plot(epochs, self.history['train_loss'], label='Train', color='blue')
231
- ax1.plot(epochs, self.history['val_loss'], label='Validation', color='red')
232
- ax1.set_title('Training Loss')
233
- ax1.set_xlabel('Epoch')
234
- ax1.set_ylabel('Loss')
235
- ax1.legend()
236
- ax1.grid(True, alpha=0.3)
237
-
238
- # Accuracy plot
239
- ax2.plot(epochs, self.history['train_acc'], label='Train', color='blue')
240
- ax2.plot(epochs, self.history['val_acc'], label='Validation', color='red')
241
- ax2.set_title('Training Accuracy')
242
- ax2.set_xlabel('Epoch')
243
- ax2.set_ylabel('Accuracy')
244
- ax2.legend()
245
- ax2.grid(True, alpha=0.3)
246
-
247
- # Learning rate plot
248
- ax3.plot(epochs, self.history['lr'], color='green')
249
- ax3.set_title('Learning Rate')
250
- ax3.set_xlabel('Epoch')
251
- ax3.set_ylabel('Learning Rate')
252
- ax3.set_yscale('log')
253
- ax3.grid(True, alpha=0.3)
254
-
255
- # Best metrics
256
- best_train_acc = max(self.history['train_acc'])
257
- best_val_acc = max(self.history['val_acc'])
258
-
259
- ax4.bar(['Best Train Acc', 'Best Val Acc'], [best_train_acc, best_val_acc],
260
- color=['blue', 'red'], alpha=0.7)
261
- ax4.set_title('Best Accuracies')
262
- ax4.set_ylabel('Accuracy')
263
- ax4.set_ylim(0, 1)
264
-
265
- for i, v in enumerate([best_train_acc, best_val_acc]):
266
- ax4.text(i, v + 0.01, f'{v:.4f}', ha='center', va='bottom')
267
-
268
- plt.tight_layout()
269
-
270
- if save_path:
271
- plt.savefig(save_path, dpi=300, bbox_inches='tight')
272
 
273
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- def save_model(self, path):
276
- """Save model and training state"""
277
- torch.save({
278
- 'model_state_dict': self.model.state_dict(),
279
- 'optimizer_state_dict': self.optimizer.state_dict(),
280
- 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
281
- 'best_val_acc': self.best_val_acc,
282
- 'history': self.history,
283
- 'config': self.config
284
- }, path)
285
 
286
- def load_model(self, path):
287
- """Load model and training state"""
288
- checkpoint = torch.load(path, map_location=self.device)
 
 
 
 
 
289
 
290
- self.model.load_state_dict(checkpoint['model_state_dict'])
291
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
292
 
293
- if checkpoint['scheduler_state_dict'] and self.scheduler:
294
- self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
 
 
 
 
295
 
296
- self.best_val_acc = checkpoint['best_val_acc']
297
- self.history = checkpoint['history']
 
 
 
 
298
 
299
- return checkpoint['config']
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
4
+ from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts
5
  import numpy as np
 
6
  import time
7
+ import logging
8
+ from utils.metrics import GraphMetrics
9
+
10
+ logger = logging.getLogger(__name__)
11
 
12
  class GraphMambaTrainer:
13
+ """Enhanced trainer with optimized learning rates and schedules"""
 
 
 
14
 
15
+ def __init__(self, model, config, device):
16
+ self.model = model
17
  self.config = config
18
  self.device = device
19
 
20
+ # Fixed learning rate (much lower)
21
+ self.lr = 0.001 # Changed from 0.01
 
22
  self.epochs = config['training']['epochs']
23
+ self.patience = config['training'].get('patience', 15)
24
+ self.min_lr = config['training'].get('min_lr', 1e-6)
 
25
 
26
+ # Enhanced optimizer
27
  self.optimizer = optim.AdamW(
28
+ model.parameters(),
29
  lr=self.lr,
30
+ weight_decay=config['training']['weight_decay'],
31
  betas=(0.9, 0.999),
32
  eps=1e-8
33
  )
34
 
35
+ # Proper loss function
36
  self.criterion = nn.CrossEntropyLoss()
37
 
38
+ # Learning rate scheduler (will be set in training)
39
  self.scheduler = None
40
 
41
+ # Training state
 
 
 
 
 
 
 
 
 
42
  self.best_val_acc = 0.0
43
+ self.best_val_loss = float('inf')
44
  self.patience_counter = 0
45
+ self.training_history = {
46
+ 'train_loss': [], 'train_acc': [],
47
+ 'val_loss': [], 'val_acc': [], 'lr': []
48
+ }
49
+
50
+ def _setup_scheduler(self, total_steps):
51
+ """Setup learning rate scheduler"""
52
+ self.scheduler = OneCycleLR(
 
 
 
 
 
 
 
 
 
 
 
53
  self.optimizer,
54
+ max_lr=self.lr,
55
+ total_steps=total_steps,
56
+ pct_start=0.1, # 10% warmup
57
+ anneal_strategy='cos',
58
+ div_factor=10.0, # Start LR = max_lr/10
59
+ final_div_factor=100.0 # End LR = max_lr/100
60
  )
61
+
62
+ def train_node_classification(self, data, verbose=True):
63
+ """Enhanced training with proper LR scheduling"""
64
 
65
  if verbose:
66
  print(f"πŸ‹οΈ Training GraphMamba for {self.epochs} epochs")
67
  print(f"πŸ“Š Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
68
+ print(f"🎯 Classes: {len(torch.unique(data.y))}")
69
  print(f"πŸ’Ύ Device: {self.device}")
70
  print(f"βš™οΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
71
 
72
+ # Initialize classifier
73
+ num_classes = len(torch.unique(data.y))
74
+ self.model._init_classifier(num_classes, self.device)
75
+
76
+ # Setup scheduler
77
+ self._setup_scheduler(self.epochs)
78
+
79
+ self.model.train()
80
+ start_time = time.time()
81
+
82
  for epoch in range(self.epochs):
83
+ # Training step
84
+ train_metrics = self._train_epoch(data, epoch)
85
 
86
+ # Validation step
87
+ val_metrics = self._validate_epoch(data, epoch)
88
 
89
+ # Update history
90
+ self.training_history['train_loss'].append(train_metrics['loss'])
91
+ self.training_history['train_acc'].append(train_metrics['acc'])
92
+ self.training_history['val_loss'].append(val_metrics['loss'])
93
+ self.training_history['val_acc'].append(val_metrics['acc'])
94
+ self.training_history['lr'].append(self.optimizer.param_groups[0]['lr'])
 
 
 
 
 
 
 
 
 
 
95
 
96
  # Check for improvement
97
+ if val_metrics['acc'] > self.best_val_acc:
98
+ self.best_val_acc = val_metrics['acc']
99
+ self.best_val_loss = val_metrics['loss']
100
  self.patience_counter = 0
101
+ if verbose:
102
+ print(f"πŸŽ‰ New best validation accuracy: {self.best_val_acc:.4f}")
 
103
  else:
104
  self.patience_counter += 1
105
 
106
+ # Progress logging
107
+ if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
108
+ elapsed = time.time() - start_time
109
+ print(f"Epoch {epoch:3d} | "
110
+ f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
111
+ f"Val: {val_metrics['loss']:.4f} ({val_metrics['acc']:.4f}) | "
112
+ f"LR: {self.optimizer.param_groups[0]['lr']:.6f} | "
113
+ f"Time: {elapsed:.1f}s")
114
+
115
  # Early stopping
116
  if self.patience_counter >= self.patience:
117
  if verbose:
118
+ print(f"πŸ›‘ Early stopping at epoch {epoch}")
119
  break
120
 
121
+ # Step scheduler
122
+ self.scheduler.step()
 
 
 
 
123
 
 
 
 
 
124
  if verbose:
125
+ total_time = time.time() - start_time
126
+ print(f"βœ… Training completed in {total_time:.2f}s")
127
  print(f"πŸ† Best validation accuracy: {self.best_val_acc:.4f}")
128
 
129
+ return self.training_history
130
 
131
  def _train_epoch(self, data, epoch):
132
  """Single training epoch"""
133
  self.model.train()
 
 
134
  self.optimizer.zero_grad()
135
 
136
+ # Forward pass
137
  h = self.model(data.x, data.edge_index)
138
+ logits = self.model.classifier(h)
139
 
140
+ # Compute loss on training nodes
141
+ train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
142
 
143
  # Backward pass
144
+ train_loss.backward()
145
 
146
  # Gradient clipping
147
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
148
 
149
  self.optimizer.step()
150
 
151
+ # Compute accuracy
152
  with torch.no_grad():
153
+ train_pred = logits[data.train_mask].argmax(dim=1)
154
+ train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
155
 
156
+ return {'loss': train_loss.item(), 'acc': train_acc}
157
 
158
+ def _validate_epoch(self, data, epoch):
159
  """Single validation epoch"""
160
  self.model.eval()
161
 
162
  with torch.no_grad():
163
  h = self.model(data.x, data.edge_index)
164
+ logits = self.model.classifier(h)
 
 
 
165
 
166
+ # Validation loss and accuracy
167
+ val_loss = self.criterion(logits[data.val_mask], data.y[data.val_mask])
168
+ val_pred = logits[data.val_mask].argmax(dim=1)
169
+ val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
170
 
171
+ return {'loss': val_loss.item(), 'acc': val_acc}
172
 
173
  def test(self, data):
174
+ """Comprehensive test evaluation"""
175
  self.model.eval()
176
 
177
  with torch.no_grad():
178
  h = self.model(data.x, data.edge_index)
179
+
180
+ # Ensure classifier exists
181
+ if self.model.classifier is None:
182
+ num_classes = len(torch.unique(data.y))
183
+ self.model._init_classifier(num_classes, self.device)
184
+
185
+ logits = self.model.classifier(h)
186
 
187
  # Test metrics
188
+ test_loss = self.criterion(logits[data.test_mask], data.y[data.test_mask])
189
+ test_pred = logits[data.test_mask]
190
+ test_target = data.y[data.test_mask]
191
 
192
+ # Comprehensive metrics
193
+ metrics = {
194
+ 'test_loss': test_loss.item(),
195
+ 'test_acc': GraphMetrics.accuracy(test_pred, test_target),
196
+ 'f1_macro': GraphMetrics.f1_score_macro(test_pred, test_target),
197
+ 'f1_micro': GraphMetrics.f1_score_micro(test_pred, test_target),
198
+ }
199
 
200
+ # Additional metrics
201
+ precision, recall = GraphMetrics.precision_recall(test_pred, test_target)
202
+ metrics['precision'] = precision
203
+ metrics['recall'] = recall
204
+
205
+ return metrics
 
 
 
 
 
 
 
206
 
207
+ def get_embeddings(self, data):
208
+ """Get node embeddings"""
209
+ self.model.eval()
210
+ with torch.no_grad():
211
+ return self.model(data.x, data.edge_index)
212
+
213
+
214
+ class EnhancedGraphMambaTrainer(GraphMambaTrainer):
215
+ """Enhanced trainer with additional optimizations"""
216
+
217
+ def __init__(self, model, config, device):
218
+ super().__init__(model, config, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ # Even more conservative learning rate for complex architectures
221
+ if hasattr(model, 'multi_scale') or 'Hybrid' in model.__class__.__name__:
222
+ self.lr = 0.0005 # Lower for complex models
223
+
224
+ self.optimizer = optim.AdamW(
225
+ model.parameters(),
226
+ lr=self.lr,
227
+ weight_decay=config['training']['weight_decay'],
228
+ betas=(0.9, 0.99), # More stable
229
+ eps=1e-8
230
+ )
231
+
232
+ def _setup_scheduler(self, total_steps):
233
+ """Enhanced scheduler for complex models"""
234
+ # Cosine annealing with warm restarts
235
+ self.scheduler = CosineAnnealingWarmRestarts(
236
+ self.optimizer,
237
+ T_0=20, # Restart every 20 epochs
238
+ T_mult=2, # Double period after restart
239
+ eta_min=self.min_lr
240
+ )
241
 
242
+ def train_node_classification(self, data, verbose=True):
243
+ """Training with enhanced monitoring"""
 
 
 
 
 
 
 
 
244
 
245
+ if verbose:
246
+ model_type = self.model.__class__.__name__
247
+ print(f"πŸ‹οΈ Training {model_type} for {self.epochs} epochs")
248
+ print(f"πŸ“Š Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
249
+ print(f"🎯 Classes: {len(torch.unique(data.y))}")
250
+ print(f"πŸ’Ύ Device: {self.device}")
251
+ print(f"βš™οΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
252
+ print(f"πŸ“ˆ Learning Rate: {self.lr} (enhanced schedule)")
253
 
254
+ # Call parent method with enhancements
255
+ history = super().train_node_classification(data, verbose)
256
 
257
+ # Additional analysis
258
+ if verbose:
259
+ final_acc = history['val_acc'][-1] if history['val_acc'] else 0
260
+ improvement = final_acc - (history['val_acc'][0] if history['val_acc'] else 0)
261
+ print(f"πŸ“Š Final validation accuracy: {final_acc:.4f}")
262
+ print(f"πŸ“ˆ Total improvement: {improvement:.4f} ({improvement*100:.1f}%)")
263
 
264
+ if final_acc > 0.6:
265
+ print("πŸŽ‰ Excellent performance! Model converged well.")
266
+ elif final_acc > 0.4:
267
+ print("πŸ‘ Good progress! Consider more epochs or tuning.")
268
+ else:
269
+ print("⚠️ Low accuracy. Check model architecture or data.")
270
 
271
+ return history