import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts import numpy as np import time import logging from utils.metrics import GraphMetrics logger = logging.getLogger(__name__) class GraphMambaTrainer: """Enhanced trainer with optimized learning rates and schedules""" def __init__(self, model, config, device): self.model = model self.config = config self.device = device # Fixed learning rate (much lower) self.lr = 0.001 # Changed from 0.01 self.epochs = config['training']['epochs'] self.patience = config['training'].get('patience', 15) self.min_lr = config['training'].get('min_lr', 1e-6) # Enhanced optimizer self.optimizer = optim.AdamW( model.parameters(), lr=self.lr, weight_decay=config['training']['weight_decay'], betas=(0.9, 0.999), eps=1e-8 ) # Proper loss function self.criterion = nn.CrossEntropyLoss() # Learning rate scheduler (will be set in training) self.scheduler = None # Training state self.best_val_acc = 0.0 self.best_val_loss = float('inf') self.patience_counter = 0 self.training_history = { 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': [] } def _setup_scheduler(self, total_steps): """Setup learning rate scheduler""" self.scheduler = OneCycleLR( self.optimizer, max_lr=self.lr, total_steps=total_steps, pct_start=0.1, # 10% warmup anneal_strategy='cos', div_factor=10.0, # Start LR = max_lr/10 final_div_factor=100.0 # End LR = max_lr/100 ) def train_node_classification(self, data, verbose=True): """Enhanced training with proper LR scheduling""" if verbose: print(f"🏋️ Training GraphMamba for {self.epochs} epochs") print(f"📊 Dataset: {data.num_nodes} nodes, {data.num_edges} edges") print(f"🎯 Classes: {len(torch.unique(data.y))}") print(f"💾 Device: {self.device}") print(f"⚙️ Parameters: {sum(p.numel() for p in self.model.parameters()):,}") # Initialize classifier num_classes = len(torch.unique(data.y)) self.model._init_classifier(num_classes, self.device) # Setup scheduler self._setup_scheduler(self.epochs) self.model.train() start_time = time.time() for epoch in range(self.epochs): # Training step train_metrics = self._train_epoch(data, epoch) # Validation step val_metrics = self._validate_epoch(data, epoch) # Update history self.training_history['train_loss'].append(train_metrics['loss']) self.training_history['train_acc'].append(train_metrics['acc']) self.training_history['val_loss'].append(val_metrics['loss']) self.training_history['val_acc'].append(val_metrics['acc']) self.training_history['lr'].append(self.optimizer.param_groups[0]['lr']) # Check for improvement if val_metrics['acc'] > self.best_val_acc: self.best_val_acc = val_metrics['acc'] self.best_val_loss = val_metrics['loss'] self.patience_counter = 0 if verbose: print(f"🎉 New best validation accuracy: {self.best_val_acc:.4f}") else: self.patience_counter += 1 # Progress logging if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1): elapsed = time.time() - start_time print(f"Epoch {epoch:3d} | " f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | " f"Val: {val_metrics['loss']:.4f} ({val_metrics['acc']:.4f}) | " f"LR: {self.optimizer.param_groups[0]['lr']:.6f} | " f"Time: {elapsed:.1f}s") # Early stopping if self.patience_counter >= self.patience: if verbose: print(f"🛑 Early stopping at epoch {epoch}") break # Step scheduler self.scheduler.step() if verbose: total_time = time.time() - start_time print(f"✅ Training completed in {total_time:.2f}s") print(f"🏆 Best validation accuracy: {self.best_val_acc:.4f}") return self.training_history def _train_epoch(self, data, epoch): """Single training epoch""" self.model.train() self.optimizer.zero_grad() # Forward pass h = self.model(data.x, data.edge_index) logits = self.model.classifier(h) # Compute loss on training nodes train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask]) # Backward pass train_loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() # Compute accuracy with torch.no_grad(): train_pred = logits[data.train_mask].argmax(dim=1) train_acc = (train_pred == data.y[data.train_mask]).float().mean().item() return {'loss': train_loss.item(), 'acc': train_acc} def _validate_epoch(self, data, epoch): """Single validation epoch""" self.model.eval() with torch.no_grad(): h = self.model(data.x, data.edge_index) logits = self.model.classifier(h) # Validation loss and accuracy val_loss = self.criterion(logits[data.val_mask], data.y[data.val_mask]) val_pred = logits[data.val_mask].argmax(dim=1) val_acc = (val_pred == data.y[data.val_mask]).float().mean().item() return {'loss': val_loss.item(), 'acc': val_acc} def test(self, data): """Comprehensive test evaluation""" self.model.eval() with torch.no_grad(): h = self.model(data.x, data.edge_index) # Ensure classifier exists if self.model.classifier is None: num_classes = len(torch.unique(data.y)) self.model._init_classifier(num_classes, self.device) logits = self.model.classifier(h) # Test metrics test_loss = self.criterion(logits[data.test_mask], data.y[data.test_mask]) test_pred = logits[data.test_mask] test_target = data.y[data.test_mask] # Comprehensive metrics metrics = { 'test_loss': test_loss.item(), 'test_acc': GraphMetrics.accuracy(test_pred, test_target), 'f1_macro': GraphMetrics.f1_score_macro(test_pred, test_target), 'f1_micro': GraphMetrics.f1_score_micro(test_pred, test_target), } # Additional metrics precision, recall = GraphMetrics.precision_recall(test_pred, test_target) metrics['precision'] = precision metrics['recall'] = recall return metrics def get_embeddings(self, data): """Get node embeddings""" self.model.eval() with torch.no_grad(): return self.model(data.x, data.edge_index) class EnhancedGraphMambaTrainer(GraphMambaTrainer): """Enhanced trainer with additional optimizations""" def __init__(self, model, config, device): super().__init__(model, config, device) # Even more conservative learning rate for complex architectures if hasattr(model, 'multi_scale') or 'Hybrid' in model.__class__.__name__: self.lr = 0.0005 # Lower for complex models self.optimizer = optim.AdamW( model.parameters(), lr=self.lr, weight_decay=config['training']['weight_decay'], betas=(0.9, 0.99), # More stable eps=1e-8 ) def _setup_scheduler(self, total_steps): """Enhanced scheduler for complex models""" # Cosine annealing with warm restarts self.scheduler = CosineAnnealingWarmRestarts( self.optimizer, T_0=20, # Restart every 20 epochs T_mult=2, # Double period after restart eta_min=self.min_lr ) def train_node_classification(self, data, verbose=True): """Training with enhanced monitoring""" if verbose: model_type = self.model.__class__.__name__ print(f"🏋️ Training {model_type} for {self.epochs} epochs") print(f"📊 Dataset: {data.num_nodes} nodes, {data.num_edges} edges") print(f"🎯 Classes: {len(torch.unique(data.y))}") print(f"💾 Device: {self.device}") print(f"⚙️ Parameters: {sum(p.numel() for p in self.model.parameters()):,}") print(f"📈 Learning Rate: {self.lr} (enhanced schedule)") # Call parent method with enhancements history = super().train_node_classification(data, verbose) # Additional analysis if verbose: final_acc = history['val_acc'][-1] if history['val_acc'] else 0 improvement = final_acc - (history['val_acc'][0] if history['val_acc'] else 0) print(f"📊 Final validation accuracy: {final_acc:.4f}") print(f"📈 Total improvement: {improvement:.4f} ({improvement*100:.1f}%)") if final_acc > 0.6: print("🎉 Excellent performance! Model converged well.") elif final_acc > 0.4: print("👍 Good progress! Consider more epochs or tuning.") else: print("⚠️ Low accuracy. Check model architecture or data.") return history