serpent / demo.py
kfoughali's picture
Update demo.py
97c533b verified
raw
history blame
7.53 kB
#!/usr/bin/env python3
"""
Complete test script for Mamba Graph implementation
Tests training, evaluation, and visualization
"""
import torch
import os
import time
from core.graph_mamba import GraphMamba
from core.trainer import GraphMambaTrainer
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
from utils.visualization import GraphVisualizer
def main():
print("🧠 Mamba Graph Neural Network - Complete Test")
print("=" * 60)
# Configuration
config = {
'model': {
'd_model': 128,
'd_state': 8,
'd_conv': 4,
'expand': 2,
'n_layers': 3,
'dropout': 0.1
},
'data': {
'batch_size': 16,
'test_split': 0.2
},
'training': {
'learning_rate': 0.01,
'weight_decay': 0.0005,
'epochs': 50, # Quick test
'patience': 10,
'warmup_epochs': 5,
'min_lr': 1e-6
},
'ordering': {
'strategy': 'bfs',
'preserve_locality': True
}
}
# Setup device
if os.getenv('SPACE_ID'):
device = torch.device('cpu')
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"πŸ’Ύ Device: {device}")
# Load dataset
print("\nπŸ“Š Loading Cora dataset...")
try:
data_loader = GraphDataLoader()
dataset = data_loader.load_node_classification_data('Cora')
data = dataset[0].to(device)
info = data_loader.get_dataset_info(dataset)
print(f"βœ… Dataset loaded successfully!")
print(f" Nodes: {data.num_nodes:,}")
print(f" Edges: {data.num_edges:,}")
print(f" Features: {info['num_features']}")
print(f" Classes: {info['num_classes']}")
print(f" Train nodes: {data.train_mask.sum()}")
print(f" Val nodes: {data.val_mask.sum()}")
print(f" Test nodes: {data.test_mask.sum()}")
except Exception as e:
print(f"❌ Error loading dataset: {e}")
return
# Initialize model
print("\nπŸ—οΈ Initializing GraphMamba...")
try:
model = GraphMamba(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"βœ… Model initialized!")
print(f" Parameters: {total_params:,}")
print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB")
except Exception as e:
print(f"❌ Error initializing model: {e}")
return
# Test forward pass
print("\nπŸš€ Testing forward pass...")
try:
model.eval()
with torch.no_grad():
h = model(data.x, data.edge_index)
print(f"βœ… Forward pass successful!")
print(f" Input shape: {data.x.shape}")
print(f" Output shape: {h.shape}")
print(f" Output range: [{h.min():.3f}, {h.max():.3f}]")
except Exception as e:
print(f"❌ Forward pass failed: {e}")
return
# Test ordering strategies
print("\nπŸ”„ Testing ordering strategies...")
strategies = ['bfs', 'spectral', 'degree', 'community']
for strategy in strategies:
try:
config['ordering']['strategy'] = strategy
test_model = GraphMamba(config).to(device)
test_model.eval()
start_time = time.time()
with torch.no_grad():
h = test_model(data.x, data.edge_index)
end_time = time.time()
print(f"βœ… {strategy:12} | Shape: {h.shape} | Time: {(end_time-start_time)*1000:.2f}ms")
except Exception as e:
print(f"❌ {strategy:12} | Failed: {str(e)}")
# Initialize trainer
print("\nπŸ‹οΈ Testing training system...")
try:
trainer = GraphMambaTrainer(model, config, device)
print(f"βœ… Trainer initialized!")
print(f" Optimizer: {type(trainer.optimizer).__name__}")
print(f" Learning rate: {trainer.lr}")
print(f" Epochs: {trainer.epochs}")
except Exception as e:
print(f"❌ Trainer initialization failed: {e}")
return
# Run training
print("\n🎯 Running training...")
try:
start_time = time.time()
history = trainer.train_node_classification(data, verbose=True)
training_time = time.time() - start_time
print(f"βœ… Training completed!")
print(f" Training time: {training_time:.2f}s")
print(f" Epochs trained: {len(history['train_loss'])}")
print(f" Best val accuracy: {trainer.best_val_acc:.4f}")
except Exception as e:
print(f"❌ Training failed: {e}")
return
# Test evaluation
print("\nπŸ“Š Testing evaluation...")
try:
test_results = trainer.test(data)
print(f"βœ… Evaluation completed!")
print(f" Test accuracy: {test_results['test_acc']:.4f}")
print(f" Test loss: {test_results['test_loss']:.4f}")
# Per-class results
class_accs = test_results['class_acc']
print(f" Per-class accuracy:")
for i, acc in enumerate(class_accs):
print(f" Class {i}: {acc:.4f}")
except Exception as e:
print(f"❌ Evaluation failed: {e}")
return
# Test visualization
print("\n🎨 Testing visualization...")
try:
# Create visualizations
graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200)
metrics_fig = GraphVisualizer.create_metrics_plot(test_results)
training_fig = GraphVisualizer.create_training_history_plot(history)
print(f"βœ… Visualizations created!")
print(f" Graph plot: {type(graph_fig).__name__}")
print(f" Metrics plot: {type(metrics_fig).__name__}")
print(f" Training plot: {type(training_fig).__name__}")
# Save plots
graph_fig.write_html("graph_visualization.html")
metrics_fig.write_html("metrics_plot.html")
training_fig.write_html("training_history.html")
print(f" Plots saved as HTML files")
except Exception as e:
print(f"❌ Visualization failed: {e}")
# Performance summary
print("\nπŸ† Performance Summary")
print("=" * 40)
print(f"πŸ“Š Dataset: Cora ({data.num_nodes:,} nodes)")
print(f"🧠 Model: {total_params:,} parameters")
print(f"⏱️ Training: {training_time:.2f}s ({len(history['train_loss'])} epochs)")
print(f"🎯 Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)")
print(f"πŸ… Best Val Accuracy: {trainer.best_val_acc:.4f} ({trainer.best_val_acc*100:.2f}%)")
# Compare with baselines
cora_baselines = {
'GCN': 0.815,
'GAT': 0.830,
'GraphSAGE': 0.824,
'GIN': 0.800
}
print(f"\nπŸ“ˆ Comparison with Baselines:")
test_acc = test_results['test_acc']
for model_name, baseline in cora_baselines.items():
diff = test_acc - baseline
status = "🟒" if diff > 0 else "🟑" if diff > -0.05 else "πŸ”΄"
print(f" {status} {model_name:12}: {baseline:.3f} (diff: {diff:+.3f})")
print(f"\n✨ Test completed successfully!")
print(f"πŸš€ Ready for production deployment!")
if __name__ == "__main__":
main()