#!/usr/bin/env python3 """ Quick demo script to test Mamba Graph implementation Device-safe version """ import torch import os from core.graph_mamba import GraphMamba from data.loader import GraphDataLoader from utils.metrics import GraphMetrics def main(): print("🧠 Testing Mamba Graph Neural Network") print("=" * 50) # Configuration config = { 'model': { 'd_model': 128, 'd_state': 8, 'd_conv': 4, 'expand': 2, 'n_layers': 3, 'dropout': 0.1 }, 'data': { 'batch_size': 16, 'test_split': 0.2 }, 'ordering': { 'strategy': 'bfs', 'preserve_locality': True } } # Setup device if os.getenv('SPACE_ID'): device = torch.device('cpu') else: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # Load dataset print("\nšŸ“Š Loading Cora dataset...") try: data_loader = GraphDataLoader() dataset = data_loader.load_node_classification_data('Cora') data = dataset[0].to(device) # Dataset info info = data_loader.get_dataset_info(dataset) print(f"āœ… Success!") print(f"Nodes: {data.num_nodes}") print(f"Edges: {data.num_edges}") print(f"Features: {info['num_features']}") print(f"Classes: {info['num_classes']}") except Exception as e: print(f"āŒ Error loading dataset: {e}") return # Initialize model print("\nšŸ—ļø Initializing GraphMamba...") try: model = GraphMamba(config).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"āœ… Model initialized!") print(f"Parameters: {total_params:,}") except Exception as e: print(f"āŒ Error initializing model: {e}") return # Forward pass test print("\nšŸš€ Testing forward pass...") try: model.eval() with torch.no_grad(): h = model(data.x, data.edge_index) print(f"āœ… Forward pass successful!") print(f"Input shape: {data.x.shape}") print(f"Output shape: {h.shape}") print(f"Output range: [{h.min():.3f}, {h.max():.3f}]") except Exception as e: print(f"āŒ Forward pass failed: {e}") return # Test different ordering strategies print("\nšŸ”„ Testing ordering strategies...") strategies = ['bfs', 'spectral', 'degree', 'community'] for strategy in strategies: try: config['ordering']['strategy'] = strategy model_test = GraphMamba(config).to(device) model_test.eval() with torch.no_grad(): h = model_test(data.x, data.edge_index) print(f"āœ… {strategy}: Success - Shape {h.shape}") except Exception as e: print(f"āŒ {strategy}: Failed - {str(e)}") # Test evaluation print("\nšŸ“ˆ Testing evaluation...") try: # Initialize classifier num_classes = info['num_classes'] model._init_classifier(num_classes, device) # Create test mask if not available if hasattr(data, 'test_mask'): mask = data.test_mask else: mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device) mask[data.num_nodes//2:] = True metrics = GraphMetrics.evaluate_node_classification(model, data, mask, device) print("āœ… Evaluation successful!") for metric, value in metrics.items(): if isinstance(value, float): print(f" {metric}: {value:.4f}") except Exception as e: print(f"āŒ Evaluation failed: {e}") print("\n✨ Demo completed!") print("šŸš€ Ready for production deployment!") if __name__ == "__main__": main()