#!/usr/bin/env python3 """ Complete test script for Mamba Graph implementation Tests training, evaluation, and visualization """ import torch import os import time from core.graph_mamba import GraphMamba from core.trainer import GraphMambaTrainer from data.loader import GraphDataLoader from utils.metrics import GraphMetrics from utils.visualization import GraphVisualizer def main(): print("🧠 Mamba Graph Neural Network - Complete Test") print("=" * 60) # Configuration config = { 'model': { 'd_model': 128, 'd_state': 8, 'd_conv': 4, 'expand': 2, 'n_layers': 3, 'dropout': 0.1 }, 'data': { 'batch_size': 16, 'test_split': 0.2 }, 'training': { 'learning_rate': 0.01, 'weight_decay': 0.0005, 'epochs': 50, # Quick test 'patience': 10, 'warmup_epochs': 5, 'min_lr': 1e-6 }, 'ordering': { 'strategy': 'bfs', 'preserve_locality': True } } # Setup device if os.getenv('SPACE_ID'): device = torch.device('cpu') else: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"šŸ’¾ Device: {device}") # Load dataset print("\nšŸ“Š Loading Cora dataset...") try: data_loader = GraphDataLoader() dataset = data_loader.load_node_classification_data('Cora') data = dataset[0].to(device) info = data_loader.get_dataset_info(dataset) print(f"āœ… Dataset loaded successfully!") print(f" Nodes: {data.num_nodes:,}") print(f" Edges: {data.num_edges:,}") print(f" Features: {info['num_features']}") print(f" Classes: {info['num_classes']}") print(f" Train nodes: {data.train_mask.sum()}") print(f" Val nodes: {data.val_mask.sum()}") print(f" Test nodes: {data.test_mask.sum()}") except Exception as e: print(f"āŒ Error loading dataset: {e}") return # Initialize model print("\nšŸ—ļø Initializing GraphMamba...") try: model = GraphMamba(config).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"āœ… Model initialized!") print(f" Parameters: {total_params:,}") print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB") except Exception as e: print(f"āŒ Error initializing model: {e}") return # Test forward pass print("\nšŸš€ Testing forward pass...") try: model.eval() with torch.no_grad(): h = model(data.x, data.edge_index) print(f"āœ… Forward pass successful!") print(f" Input shape: {data.x.shape}") print(f" Output shape: {h.shape}") print(f" Output range: [{h.min():.3f}, {h.max():.3f}]") except Exception as e: print(f"āŒ Forward pass failed: {e}") return # Test ordering strategies print("\nšŸ”„ Testing ordering strategies...") strategies = ['bfs', 'spectral', 'degree', 'community'] for strategy in strategies: try: config['ordering']['strategy'] = strategy test_model = GraphMamba(config).to(device) test_model.eval() start_time = time.time() with torch.no_grad(): h = test_model(data.x, data.edge_index) end_time = time.time() print(f"āœ… {strategy:12} | Shape: {h.shape} | Time: {(end_time-start_time)*1000:.2f}ms") except Exception as e: print(f"āŒ {strategy:12} | Failed: {str(e)}") # Initialize trainer print("\nšŸ‹ļø Testing training system...") try: trainer = GraphMambaTrainer(model, config, device) print(f"āœ… Trainer initialized!") print(f" Optimizer: {type(trainer.optimizer).__name__}") print(f" Learning rate: {trainer.lr}") print(f" Epochs: {trainer.epochs}") except Exception as e: print(f"āŒ Trainer initialization failed: {e}") return # Run training print("\nšŸŽÆ Running training...") try: start_time = time.time() history = trainer.train_node_classification(data, verbose=True) training_time = time.time() - start_time print(f"āœ… Training completed!") print(f" Training time: {training_time:.2f}s") print(f" Epochs trained: {len(history['train_loss'])}") print(f" Best val accuracy: {trainer.best_val_acc:.4f}") except Exception as e: print(f"āŒ Training failed: {e}") return # Test evaluation print("\nšŸ“Š Testing evaluation...") try: test_results = trainer.test(data) print(f"āœ… Evaluation completed!") print(f" Test accuracy: {test_results['test_acc']:.4f}") print(f" Test loss: {test_results['test_loss']:.4f}") # Per-class results class_accs = test_results['class_acc'] print(f" Per-class accuracy:") for i, acc in enumerate(class_accs): print(f" Class {i}: {acc:.4f}") except Exception as e: print(f"āŒ Evaluation failed: {e}") return # Test visualization print("\nšŸŽØ Testing visualization...") try: # Create visualizations graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200) metrics_fig = GraphVisualizer.create_metrics_plot(test_results) training_fig = GraphVisualizer.create_training_history_plot(history) print(f"āœ… Visualizations created!") print(f" Graph plot: {type(graph_fig).__name__}") print(f" Metrics plot: {type(metrics_fig).__name__}") print(f" Training plot: {type(training_fig).__name__}") # Save plots graph_fig.write_html("graph_visualization.html") metrics_fig.write_html("metrics_plot.html") training_fig.write_html("training_history.html") print(f" Plots saved as HTML files") except Exception as e: print(f"āŒ Visualization failed: {e}") # Performance summary print("\nšŸ† Performance Summary") print("=" * 40) print(f"šŸ“Š Dataset: Cora ({data.num_nodes:,} nodes)") print(f"🧠 Model: {total_params:,} parameters") print(f"ā±ļø Training: {training_time:.2f}s ({len(history['train_loss'])} epochs)") print(f"šŸŽÆ Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)") print(f"šŸ… Best Val Accuracy: {trainer.best_val_acc:.4f} ({trainer.best_val_acc*100:.2f}%)") # Compare with baselines cora_baselines = { 'GCN': 0.815, 'GAT': 0.830, 'GraphSAGE': 0.824, 'GIN': 0.800 } print(f"\nšŸ“ˆ Comparison with Baselines:") test_acc = test_results['test_acc'] for model_name, baseline in cora_baselines.items(): diff = test_acc - baseline status = "🟢" if diff > 0 else "🟔" if diff > -0.05 else "šŸ”“" print(f" {status} {model_name:12}: {baseline:.3f} (diff: {diff:+.3f})") print(f"\n✨ Test completed successfully!") print(f"šŸš€ Ready for production deployment!") if __name__ == "__main__": main()