serpent / core /trainer.py
kfoughali's picture
Update core/trainer.py
93db32e verified
raw
history blame
10.5 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts
import numpy as np
import time
import logging
from utils.metrics import GraphMetrics
logger = logging.getLogger(__name__)
class GraphMambaTrainer:
"""Enhanced trainer with optimized learning rates and schedules"""
def __init__(self, model, config, device):
self.model = model
self.config = config
self.device = device
# Fixed learning rate (much lower)
self.lr = 0.001 # Changed from 0.01
self.epochs = config['training']['epochs']
self.patience = config['training'].get('patience', 15)
self.min_lr = config['training'].get('min_lr', 1e-6)
# Enhanced optimizer
self.optimizer = optim.AdamW(
model.parameters(),
lr=self.lr,
weight_decay=config['training']['weight_decay'],
betas=(0.9, 0.999),
eps=1e-8
)
# Proper loss function
self.criterion = nn.CrossEntropyLoss()
# Learning rate scheduler (will be set in training)
self.scheduler = None
# Training state
self.best_val_acc = 0.0
self.best_val_loss = float('inf')
self.patience_counter = 0
self.training_history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': [], 'lr': []
}
def _setup_scheduler(self, total_steps):
"""Setup learning rate scheduler"""
self.scheduler = OneCycleLR(
self.optimizer,
max_lr=self.lr,
total_steps=total_steps,
pct_start=0.1, # 10% warmup
anneal_strategy='cos',
div_factor=10.0, # Start LR = max_lr/10
final_div_factor=100.0 # End LR = max_lr/100
)
def train_node_classification(self, data, verbose=True):
"""Enhanced training with proper LR scheduling"""
if verbose:
print(f"πŸ‹οΈ Training GraphMamba for {self.epochs} epochs")
print(f"πŸ“Š Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
print(f"🎯 Classes: {len(torch.unique(data.y))}")
print(f"πŸ’Ύ Device: {self.device}")
print(f"βš™οΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
# Initialize classifier
num_classes = len(torch.unique(data.y))
self.model._init_classifier(num_classes, self.device)
# Setup scheduler
self._setup_scheduler(self.epochs)
self.model.train()
start_time = time.time()
for epoch in range(self.epochs):
# Training step
train_metrics = self._train_epoch(data, epoch)
# Validation step
val_metrics = self._validate_epoch(data, epoch)
# Update history
self.training_history['train_loss'].append(train_metrics['loss'])
self.training_history['train_acc'].append(train_metrics['acc'])
self.training_history['val_loss'].append(val_metrics['loss'])
self.training_history['val_acc'].append(val_metrics['acc'])
self.training_history['lr'].append(self.optimizer.param_groups[0]['lr'])
# Check for improvement
if val_metrics['acc'] > self.best_val_acc:
self.best_val_acc = val_metrics['acc']
self.best_val_loss = val_metrics['loss']
self.patience_counter = 0
if verbose:
print(f"πŸŽ‰ New best validation accuracy: {self.best_val_acc:.4f}")
else:
self.patience_counter += 1
# Progress logging
if verbose and (epoch == 0 or (epoch + 1) % 10 == 0 or epoch == self.epochs - 1):
elapsed = time.time() - start_time
print(f"Epoch {epoch:3d} | "
f"Train: {train_metrics['loss']:.4f} ({train_metrics['acc']:.4f}) | "
f"Val: {val_metrics['loss']:.4f} ({val_metrics['acc']:.4f}) | "
f"LR: {self.optimizer.param_groups[0]['lr']:.6f} | "
f"Time: {elapsed:.1f}s")
# Early stopping
if self.patience_counter >= self.patience:
if verbose:
print(f"πŸ›‘ Early stopping at epoch {epoch}")
break
# Step scheduler
self.scheduler.step()
if verbose:
total_time = time.time() - start_time
print(f"βœ… Training completed in {total_time:.2f}s")
print(f"πŸ† Best validation accuracy: {self.best_val_acc:.4f}")
return self.training_history
def _train_epoch(self, data, epoch):
"""Single training epoch"""
self.model.train()
self.optimizer.zero_grad()
# Forward pass
h = self.model(data.x, data.edge_index)
logits = self.model.classifier(h)
# Compute loss on training nodes
train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
# Backward pass
train_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Compute accuracy
with torch.no_grad():
train_pred = logits[data.train_mask].argmax(dim=1)
train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
return {'loss': train_loss.item(), 'acc': train_acc}
def _validate_epoch(self, data, epoch):
"""Single validation epoch"""
self.model.eval()
with torch.no_grad():
h = self.model(data.x, data.edge_index)
logits = self.model.classifier(h)
# Validation loss and accuracy
val_loss = self.criterion(logits[data.val_mask], data.y[data.val_mask])
val_pred = logits[data.val_mask].argmax(dim=1)
val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
return {'loss': val_loss.item(), 'acc': val_acc}
def test(self, data):
"""Comprehensive test evaluation"""
self.model.eval()
with torch.no_grad():
h = self.model(data.x, data.edge_index)
# Ensure classifier exists
if self.model.classifier is None:
num_classes = len(torch.unique(data.y))
self.model._init_classifier(num_classes, self.device)
logits = self.model.classifier(h)
# Test metrics
test_loss = self.criterion(logits[data.test_mask], data.y[data.test_mask])
test_pred = logits[data.test_mask]
test_target = data.y[data.test_mask]
# Comprehensive metrics
metrics = {
'test_loss': test_loss.item(),
'test_acc': GraphMetrics.accuracy(test_pred, test_target),
'f1_macro': GraphMetrics.f1_score_macro(test_pred, test_target),
'f1_micro': GraphMetrics.f1_score_micro(test_pred, test_target),
}
# Additional metrics
precision, recall = GraphMetrics.precision_recall(test_pred, test_target)
metrics['precision'] = precision
metrics['recall'] = recall
return metrics
def get_embeddings(self, data):
"""Get node embeddings"""
self.model.eval()
with torch.no_grad():
return self.model(data.x, data.edge_index)
class EnhancedGraphMambaTrainer(GraphMambaTrainer):
"""Enhanced trainer with additional optimizations"""
def __init__(self, model, config, device):
super().__init__(model, config, device)
# Even more conservative learning rate for complex architectures
if hasattr(model, 'multi_scale') or 'Hybrid' in model.__class__.__name__:
self.lr = 0.0005 # Lower for complex models
self.optimizer = optim.AdamW(
model.parameters(),
lr=self.lr,
weight_decay=config['training']['weight_decay'],
betas=(0.9, 0.99), # More stable
eps=1e-8
)
def _setup_scheduler(self, total_steps):
"""Enhanced scheduler for complex models"""
# Cosine annealing with warm restarts
self.scheduler = CosineAnnealingWarmRestarts(
self.optimizer,
T_0=20, # Restart every 20 epochs
T_mult=2, # Double period after restart
eta_min=self.min_lr
)
def train_node_classification(self, data, verbose=True):
"""Training with enhanced monitoring"""
if verbose:
model_type = self.model.__class__.__name__
print(f"πŸ‹οΈ Training {model_type} for {self.epochs} epochs")
print(f"πŸ“Š Dataset: {data.num_nodes} nodes, {data.num_edges} edges")
print(f"🎯 Classes: {len(torch.unique(data.y))}")
print(f"πŸ’Ύ Device: {self.device}")
print(f"βš™οΈ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
print(f"πŸ“ˆ Learning Rate: {self.lr} (enhanced schedule)")
# Call parent method with enhancements
history = super().train_node_classification(data, verbose)
# Additional analysis
if verbose:
final_acc = history['val_acc'][-1] if history['val_acc'] else 0
improvement = final_acc - (history['val_acc'][0] if history['val_acc'] else 0)
print(f"πŸ“Š Final validation accuracy: {final_acc:.4f}")
print(f"πŸ“ˆ Total improvement: {improvement:.4f} ({improvement*100:.1f}%)")
if final_acc > 0.6:
print("πŸŽ‰ Excellent performance! Model converged well.")
elif final_acc > 0.4:
print("πŸ‘ Good progress! Consider more epochs or tuning.")
else:
print("⚠️ Low accuracy. Check model architecture or data.")
return history