File size: 4,052 Bytes
ba4e201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
"""
Quick demo script to test Mamba Graph implementation
Device-safe version
"""

import torch
import os
from core.graph_mamba import GraphMamba
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics

def main():
    print("🧠 Testing Mamba Graph Neural Network")
    print("=" * 50)
    
    # Configuration
    config = {
        'model': {
            'd_model': 128,
            'd_state': 8,
            'd_conv': 4,
            'expand': 2,
            'n_layers': 3,
            'dropout': 0.1
        },
        'data': {
            'batch_size': 16,
            'test_split': 0.2
        },
        'ordering': {
            'strategy': 'bfs',
            'preserve_locality': True
        }
    }
    
    # Setup device
    if os.getenv('SPACE_ID'):
        device = torch.device('cpu')
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    
    # Load dataset
    print("\nπŸ“Š Loading Cora dataset...")
    try:
        data_loader = GraphDataLoader()
        dataset = data_loader.load_node_classification_data('Cora')
        data = dataset[0].to(device)
        
        # Dataset info
        info = data_loader.get_dataset_info(dataset)
        print(f"βœ… Success!")
        print(f"Nodes: {data.num_nodes}")
        print(f"Edges: {data.num_edges}")
        print(f"Features: {info['num_features']}")
        print(f"Classes: {info['num_classes']}")
        
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return
    
    # Initialize model
    print("\nπŸ—οΈ Initializing GraphMamba...")
    try:
        model = GraphMamba(config).to(device)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"βœ… Model initialized!")
        print(f"Parameters: {total_params:,}")
        
    except Exception as e:
        print(f"❌ Error initializing model: {e}")
        return
    
    # Forward pass test
    print("\nπŸš€ Testing forward pass...")
    try:
        model.eval()
        with torch.no_grad():
            h = model(data.x, data.edge_index)
            print(f"βœ… Forward pass successful!")
            print(f"Input shape: {data.x.shape}")
            print(f"Output shape: {h.shape}")
            print(f"Output range: [{h.min():.3f}, {h.max():.3f}]")
            
    except Exception as e:
        print(f"❌ Forward pass failed: {e}")
        return
    
    # Test different ordering strategies
    print("\nπŸ”„ Testing ordering strategies...")
    
    strategies = ['bfs', 'spectral', 'degree', 'community']
    
    for strategy in strategies:
        try:
            config['ordering']['strategy'] = strategy
            model_test = GraphMamba(config).to(device)
            model_test.eval()
            
            with torch.no_grad():
                h = model_test(data.x, data.edge_index)
            print(f"βœ… {strategy}: Success - Shape {h.shape}")
            
        except Exception as e:
            print(f"❌ {strategy}: Failed - {str(e)}")
    
    # Test evaluation
    print("\nπŸ“ˆ Testing evaluation...")
    try:
        # Initialize classifier
        num_classes = info['num_classes']
        model._init_classifier(num_classes, device)
        
        # Create test mask if not available
        if hasattr(data, 'test_mask'):
            mask = data.test_mask
        else:
            mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device)
            mask[data.num_nodes//2:] = True
        
        metrics = GraphMetrics.evaluate_node_classification(model, data, mask, device)
        print("βœ… Evaluation successful!")
        for metric, value in metrics.items():
            if isinstance(value, float):
                print(f"  {metric}: {value:.4f}")
            
    except Exception as e:
        print(f"❌ Evaluation failed: {e}")
    
    print("\n✨ Demo completed!")
    print("πŸš€ Ready for production deployment!")

if __name__ == "__main__":
    main()