|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import time |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class GraphMambaTrainer: |
|
|
""" |
|
|
Production-ready trainer for GraphMamba |
|
|
Includes advanced training techniques |
|
|
""" |
|
|
|
|
|
def __init__(self, model, config, device='cpu'): |
|
|
self.model = model.to(device) |
|
|
self.config = config |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.lr = config['training']['learning_rate'] |
|
|
self.weight_decay = config['training']['weight_decay'] |
|
|
self.epochs = config['training']['epochs'] |
|
|
self.patience = config['training']['patience'] |
|
|
self.warmup_epochs = config['training']['warmup_epochs'] |
|
|
self.min_lr = config['training']['min_lr'] |
|
|
|
|
|
|
|
|
self.optimizer = optim.AdamW( |
|
|
self.model.parameters(), |
|
|
lr=self.lr, |
|
|
weight_decay=self.weight_decay, |
|
|
betas=(0.9, 0.999), |
|
|
eps=1e-8 |
|
|
) |
|
|
|
|
|
|
|
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
self.scheduler = None |
|
|
|
|
|
|
|
|
self.history = { |
|
|
'train_loss': [], |
|
|
'train_acc': [], |
|
|
'val_loss': [], |
|
|
'val_acc': [], |
|
|
'lr': [] |
|
|
} |
|
|
|
|
|
|
|
|
self.best_val_acc = 0.0 |
|
|
self.best_model_state = None |
|
|
self.patience_counter = 0 |
|
|
|
|
|
def train_node_classification(self, data, verbose=True): |
|
|
""" |
|
|
Train model for node classification |
|
|
""" |
|
|
|
|
|
num_classes = len(torch.unique(data.y)) |
|
|
self.model._init_classifier(num_classes, self.device) |
|
|
|
|
|
|
|
|
self.optimizer = optim.AdamW( |
|
|
self.model.parameters(), |
|
|
lr=self.lr, |
|
|
weight_decay=self.weight_decay, |
|
|
betas=(0.9, 0.999) |
|
|
) |
|
|
|
|
|
|
|
|
self.scheduler = CosineAnnealingLR( |
|
|
self.optimizer, |
|
|
T_max=self.epochs - self.warmup_epochs, |
|
|
eta_min=self.min_lr |
|
|
) |
|
|
|
|
|
if verbose: |
|
|
print(f"🏋️ Training GraphMamba for {self.epochs} epochs") |
|
|
print(f"📊 Dataset: {data.num_nodes} nodes, {data.num_edges} edges") |
|
|
print(f"🎯 Classes: {num_classes}") |
|
|
print(f"💾 Device: {self.device}") |
|
|
print(f"⚙️ Parameters: {sum(p.numel() for p in self.model.parameters()):,}") |
|
|
|
|
|
|
|
|
for epoch in range(self.epochs): |
|
|
|
|
|
train_loss, train_acc = self._train_epoch(data, epoch) |
|
|
|
|
|
|
|
|
val_loss, val_acc = self._validate_epoch(data) |
|
|
|
|
|
|
|
|
if epoch >= self.warmup_epochs: |
|
|
self.scheduler.step() |
|
|
else: |
|
|
|
|
|
warmup_lr = self.lr * (epoch + 1) / self.warmup_epochs |
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group['lr'] = warmup_lr |
|
|
|
|
|
|
|
|
current_lr = self.optimizer.param_groups[0]['lr'] |
|
|
self.history['train_loss'].append(train_loss) |
|
|
self.history['train_acc'].append(train_acc) |
|
|
self.history['val_loss'].append(val_loss) |
|
|
self.history['val_acc'].append(val_acc) |
|
|
self.history['lr'].append(current_lr) |
|
|
|
|
|
|
|
|
if val_acc > self.best_val_acc: |
|
|
self.best_val_acc = val_acc |
|
|
self.best_model_state = self.model.state_dict().copy() |
|
|
self.patience_counter = 0 |
|
|
|
|
|
if verbose and epoch % 10 == 0: |
|
|
print(f"🎉 New best validation accuracy: {val_acc:.4f}") |
|
|
else: |
|
|
self.patience_counter += 1 |
|
|
|
|
|
|
|
|
if self.patience_counter >= self.patience: |
|
|
if verbose: |
|
|
print(f"⏹️ Early stopping at epoch {epoch}") |
|
|
break |
|
|
|
|
|
|
|
|
if verbose and epoch % 20 == 0: |
|
|
print(f"Epoch {epoch:3d} | " |
|
|
f"Train: {train_loss:.4f} ({train_acc:.4f}) | " |
|
|
f"Val: {val_loss:.4f} ({val_acc:.4f}) | " |
|
|
f"LR: {current_lr:.6f}") |
|
|
|
|
|
|
|
|
if self.best_model_state is not None: |
|
|
self.model.load_state_dict(self.best_model_state) |
|
|
|
|
|
if verbose: |
|
|
print(f"✅ Training completed!") |
|
|
print(f"🏆 Best validation accuracy: {self.best_val_acc:.4f}") |
|
|
|
|
|
return self.history |
|
|
|
|
|
def _train_epoch(self, data, epoch): |
|
|
"""Single training epoch""" |
|
|
self.model.train() |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
h = self.model(data.x, data.edge_index) |
|
|
pred = self.model.classifier(h) |
|
|
|
|
|
|
|
|
loss = self.criterion(pred[data.train_mask], data.y[data.train_mask]) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
pred_labels = pred[data.train_mask].argmax(dim=1) |
|
|
acc = (pred_labels == data.y[data.train_mask]).float().mean() |
|
|
|
|
|
return loss.item(), acc.item() |
|
|
|
|
|
def _validate_epoch(self, data): |
|
|
"""Single validation epoch""" |
|
|
self.model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
h = self.model(data.x, data.edge_index) |
|
|
pred = self.model.classifier(h) |
|
|
|
|
|
|
|
|
val_loss = self.criterion(pred[data.val_mask], data.y[data.val_mask]) |
|
|
|
|
|
|
|
|
pred_labels = pred[data.val_mask].argmax(dim=1) |
|
|
val_acc = (pred_labels == data.y[data.val_mask]).float().mean() |
|
|
|
|
|
return val_loss.item(), val_acc.item() |
|
|
|
|
|
def test(self, data): |
|
|
"""Test the model""" |
|
|
self.model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
h = self.model(data.x, data.edge_index) |
|
|
pred = self.model.classifier(h) |
|
|
|
|
|
|
|
|
test_loss = self.criterion(pred[data.test_mask], data.y[data.test_mask]) |
|
|
pred_labels = pred[data.test_mask].argmax(dim=1) |
|
|
test_acc = (pred_labels == data.y[data.test_mask]).float().mean() |
|
|
|
|
|
|
|
|
num_classes = len(torch.unique(data.y)) |
|
|
class_acc = [] |
|
|
|
|
|
for c in range(num_classes): |
|
|
class_mask = data.y[data.test_mask] == c |
|
|
if class_mask.any(): |
|
|
class_correct = (pred_labels[class_mask] == c).float().mean() |
|
|
class_acc.append(class_correct.item()) |
|
|
else: |
|
|
class_acc.append(0.0) |
|
|
|
|
|
return { |
|
|
'test_loss': test_loss.item(), |
|
|
'test_acc': test_acc.item(), |
|
|
'class_acc': class_acc |
|
|
} |
|
|
|
|
|
def plot_training_history(self, save_path=None): |
|
|
"""Plot training history""" |
|
|
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8)) |
|
|
|
|
|
epochs = range(len(self.history['train_loss'])) |
|
|
|
|
|
|
|
|
ax1.plot(epochs, self.history['train_loss'], label='Train', color='blue') |
|
|
ax1.plot(epochs, self.history['val_loss'], label='Validation', color='red') |
|
|
ax1.set_title('Training Loss') |
|
|
ax1.set_xlabel('Epoch') |
|
|
ax1.set_ylabel('Loss') |
|
|
ax1.legend() |
|
|
ax1.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax2.plot(epochs, self.history['train_acc'], label='Train', color='blue') |
|
|
ax2.plot(epochs, self.history['val_acc'], label='Validation', color='red') |
|
|
ax2.set_title('Training Accuracy') |
|
|
ax2.set_xlabel('Epoch') |
|
|
ax2.set_ylabel('Accuracy') |
|
|
ax2.legend() |
|
|
ax2.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax3.plot(epochs, self.history['lr'], color='green') |
|
|
ax3.set_title('Learning Rate') |
|
|
ax3.set_xlabel('Epoch') |
|
|
ax3.set_ylabel('Learning Rate') |
|
|
ax3.set_yscale('log') |
|
|
ax3.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
best_train_acc = max(self.history['train_acc']) |
|
|
best_val_acc = max(self.history['val_acc']) |
|
|
|
|
|
ax4.bar(['Best Train Acc', 'Best Val Acc'], [best_train_acc, best_val_acc], |
|
|
color=['blue', 'red'], alpha=0.7) |
|
|
ax4.set_title('Best Accuracies') |
|
|
ax4.set_ylabel('Accuracy') |
|
|
ax4.set_ylim(0, 1) |
|
|
|
|
|
for i, v in enumerate([best_train_acc, best_val_acc]): |
|
|
ax4.text(i, v + 0.01, f'{v:.4f}', ha='center', va='bottom') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
|
|
|
return fig |
|
|
|
|
|
def save_model(self, path): |
|
|
"""Save model and training state""" |
|
|
torch.save({ |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, |
|
|
'best_val_acc': self.best_val_acc, |
|
|
'history': self.history, |
|
|
'config': self.config |
|
|
}, path) |
|
|
|
|
|
def load_model(self, path): |
|
|
"""Load model and training state""" |
|
|
checkpoint = torch.load(path, map_location=self.device) |
|
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
|
|
if checkpoint['scheduler_state_dict'] and self.scheduler: |
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
|
|
|
self.best_val_acc = checkpoint['best_val_acc'] |
|
|
self.history = checkpoint['history'] |
|
|
|
|
|
return checkpoint['config'] |