serpent / core /trainer.py
kfoughali's picture
Create core/trainer.py
6aa4c8c verified
raw
history blame
10.6 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
import numpy as np
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
class GraphMambaTrainer:
"""
Production-ready trainer for GraphMamba
Includes advanced training techniques
"""
def __init__(self, model, config, device='cpu'):
self.model = model.to(device)
self.config = config
self.device = device
# Training parameters
self.lr = config['training']['learning_rate']
self.weight_decay = config['training']['weight_decay']
self.epochs = config['training']['epochs']
self.patience = config['training']['patience']
self.warmup_epochs = config['training']['warmup_epochs']
self.min_lr = config['training']['min_lr']
# Initialize optimizer
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
betas=(0.9, 0.999),
eps=1e-8
)
# Loss function
self.criterion = nn.CrossEntropyLoss()
# Scheduler
self.scheduler = None
# Training history
self.history = {
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': [],
'lr': []
}
# Best model tracking
self.best_val_acc = 0.0
self.best_model_state = None
self.patience_counter = 0
def train_node_classification(self, data, verbose=True):
"""
Train model for node classification
"""
# Initialize classifier
num_classes = len(torch.unique(data.y))
self.model._init_classifier(num_classes, self.device)
# Update optimizer to include new parameters
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
betas=(0.9, 0.999)
)
# Initialize scheduler
self.scheduler = CosineAnnealingLR(
self.optimizer,
T_max=self.epochs - self.warmup_epochs,
eta_min=self.min_lr
)
if verbose:
print(f"🏋️ Training GraphMamba for {self.epochs} epochs")
print(f"📊 Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
print(f"🎯 Classes: {num_classes}")
print(f"💾 Device: {self.device}")
print(f"⚙️ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
# Training loop
for epoch in range(self.epochs):
# Training phase
train_loss, train_acc = self._train_epoch(data, epoch)
# Validation phase
val_loss, val_acc = self._validate_epoch(data)
# Learning rate scheduling
if epoch >= self.warmup_epochs:
self.scheduler.step()
else:
# Warmup
warmup_lr = self.lr * (epoch + 1) / self.warmup_epochs
for param_group in self.optimizer.param_groups:
param_group['lr'] = warmup_lr
# Record history
current_lr = self.optimizer.param_groups[0]['lr']
self.history['train_loss'].append(train_loss)
self.history['train_acc'].append(train_acc)
self.history['val_loss'].append(val_loss)
self.history['val_acc'].append(val_acc)
self.history['lr'].append(current_lr)
# Check for improvement
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
self.best_model_state = self.model.state_dict().copy()
self.patience_counter = 0
if verbose and epoch % 10 == 0:
print(f"🎉 New best validation accuracy: {val_acc:.4f}")
else:
self.patience_counter += 1
# Early stopping
if self.patience_counter >= self.patience:
if verbose:
print(f"⏹️ Early stopping at epoch {epoch}")
break
# Progress reporting
if verbose and epoch % 20 == 0:
print(f"Epoch {epoch:3d} | "
f"Train: {train_loss:.4f} ({train_acc:.4f}) | "
f"Val: {val_loss:.4f} ({val_acc:.4f}) | "
f"LR: {current_lr:.6f}")
# Load best model
if self.best_model_state is not None:
self.model.load_state_dict(self.best_model_state)
if verbose:
print(f"✅ Training completed!")
print(f"🏆 Best validation accuracy: {self.best_val_acc:.4f}")
return self.history
def _train_epoch(self, data, epoch):
"""Single training epoch"""
self.model.train()
# Forward pass
self.optimizer.zero_grad()
h = self.model(data.x, data.edge_index)
pred = self.model.classifier(h)
# Loss only on training nodes
loss = self.criterion(pred[data.train_mask], data.y[data.train_mask])
# Backward pass
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Calculate accuracy
with torch.no_grad():
pred_labels = pred[data.train_mask].argmax(dim=1)
acc = (pred_labels == data.y[data.train_mask]).float().mean()
return loss.item(), acc.item()
def _validate_epoch(self, data):
"""Single validation epoch"""
self.model.eval()
with torch.no_grad():
h = self.model(data.x, data.edge_index)
pred = self.model.classifier(h)
# Loss on validation nodes
val_loss = self.criterion(pred[data.val_mask], data.y[data.val_mask])
# Accuracy
pred_labels = pred[data.val_mask].argmax(dim=1)
val_acc = (pred_labels == data.y[data.val_mask]).float().mean()
return val_loss.item(), val_acc.item()
def test(self, data):
"""Test the model"""
self.model.eval()
with torch.no_grad():
h = self.model(data.x, data.edge_index)
pred = self.model.classifier(h)
# Test metrics
test_loss = self.criterion(pred[data.test_mask], data.y[data.test_mask])
pred_labels = pred[data.test_mask].argmax(dim=1)
test_acc = (pred_labels == data.y[data.test_mask]).float().mean()
# Per-class accuracy
num_classes = len(torch.unique(data.y))
class_acc = []
for c in range(num_classes):
class_mask = data.y[data.test_mask] == c
if class_mask.any():
class_correct = (pred_labels[class_mask] == c).float().mean()
class_acc.append(class_correct.item())
else:
class_acc.append(0.0)
return {
'test_loss': test_loss.item(),
'test_acc': test_acc.item(),
'class_acc': class_acc
}
def plot_training_history(self, save_path=None):
"""Plot training history"""
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
epochs = range(len(self.history['train_loss']))
# Loss plot
ax1.plot(epochs, self.history['train_loss'], label='Train', color='blue')
ax1.plot(epochs, self.history['val_loss'], label='Validation', color='red')
ax1.set_title('Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Accuracy plot
ax2.plot(epochs, self.history['train_acc'], label='Train', color='blue')
ax2.plot(epochs, self.history['val_acc'], label='Validation', color='red')
ax2.set_title('Training Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)
# Learning rate plot
ax3.plot(epochs, self.history['lr'], color='green')
ax3.set_title('Learning Rate')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Learning Rate')
ax3.set_yscale('log')
ax3.grid(True, alpha=0.3)
# Best metrics
best_train_acc = max(self.history['train_acc'])
best_val_acc = max(self.history['val_acc'])
ax4.bar(['Best Train Acc', 'Best Val Acc'], [best_train_acc, best_val_acc],
color=['blue', 'red'], alpha=0.7)
ax4.set_title('Best Accuracies')
ax4.set_ylabel('Accuracy')
ax4.set_ylim(0, 1)
for i, v in enumerate([best_train_acc, best_val_acc]):
ax4.text(i, v + 0.01, f'{v:.4f}', ha='center', va='bottom')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
return fig
def save_model(self, path):
"""Save model and training state"""
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'best_val_acc': self.best_val_acc,
'history': self.history,
'config': self.config
}, path)
def load_model(self, path):
"""Load model and training state"""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if checkpoint['scheduler_state_dict'] and self.scheduler:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.best_val_acc = checkpoint['best_val_acc']
self.history = checkpoint['history']
return checkpoint['config']