kfoughali commited on
Commit
ba4e201
Β·
verified Β·
1 Parent(s): 4f8aa53

Create demo.py

Browse files
Files changed (1) hide show
  1. demo.py +135 -0
demo.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick demo script to test Mamba Graph implementation
4
+ Device-safe version
5
+ """
6
+
7
+ import torch
8
+ import os
9
+ from core.graph_mamba import GraphMamba
10
+ from data.loader import GraphDataLoader
11
+ from utils.metrics import GraphMetrics
12
+
13
+ def main():
14
+ print("🧠 Testing Mamba Graph Neural Network")
15
+ print("=" * 50)
16
+
17
+ # Configuration
18
+ config = {
19
+ 'model': {
20
+ 'd_model': 128,
21
+ 'd_state': 8,
22
+ 'd_conv': 4,
23
+ 'expand': 2,
24
+ 'n_layers': 3,
25
+ 'dropout': 0.1
26
+ },
27
+ 'data': {
28
+ 'batch_size': 16,
29
+ 'test_split': 0.2
30
+ },
31
+ 'ordering': {
32
+ 'strategy': 'bfs',
33
+ 'preserve_locality': True
34
+ }
35
+ }
36
+
37
+ # Setup device
38
+ if os.getenv('SPACE_ID'):
39
+ device = torch.device('cpu')
40
+ else:
41
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
+ print(f"Device: {device}")
43
+
44
+ # Load dataset
45
+ print("\nπŸ“Š Loading Cora dataset...")
46
+ try:
47
+ data_loader = GraphDataLoader()
48
+ dataset = data_loader.load_node_classification_data('Cora')
49
+ data = dataset[0].to(device)
50
+
51
+ # Dataset info
52
+ info = data_loader.get_dataset_info(dataset)
53
+ print(f"βœ… Success!")
54
+ print(f"Nodes: {data.num_nodes}")
55
+ print(f"Edges: {data.num_edges}")
56
+ print(f"Features: {info['num_features']}")
57
+ print(f"Classes: {info['num_classes']}")
58
+
59
+ except Exception as e:
60
+ print(f"❌ Error loading dataset: {e}")
61
+ return
62
+
63
+ # Initialize model
64
+ print("\nπŸ—οΈ Initializing GraphMamba...")
65
+ try:
66
+ model = GraphMamba(config).to(device)
67
+ total_params = sum(p.numel() for p in model.parameters())
68
+ print(f"βœ… Model initialized!")
69
+ print(f"Parameters: {total_params:,}")
70
+
71
+ except Exception as e:
72
+ print(f"❌ Error initializing model: {e}")
73
+ return
74
+
75
+ # Forward pass test
76
+ print("\nπŸš€ Testing forward pass...")
77
+ try:
78
+ model.eval()
79
+ with torch.no_grad():
80
+ h = model(data.x, data.edge_index)
81
+ print(f"βœ… Forward pass successful!")
82
+ print(f"Input shape: {data.x.shape}")
83
+ print(f"Output shape: {h.shape}")
84
+ print(f"Output range: [{h.min():.3f}, {h.max():.3f}]")
85
+
86
+ except Exception as e:
87
+ print(f"❌ Forward pass failed: {e}")
88
+ return
89
+
90
+ # Test different ordering strategies
91
+ print("\nπŸ”„ Testing ordering strategies...")
92
+
93
+ strategies = ['bfs', 'spectral', 'degree', 'community']
94
+
95
+ for strategy in strategies:
96
+ try:
97
+ config['ordering']['strategy'] = strategy
98
+ model_test = GraphMamba(config).to(device)
99
+ model_test.eval()
100
+
101
+ with torch.no_grad():
102
+ h = model_test(data.x, data.edge_index)
103
+ print(f"βœ… {strategy}: Success - Shape {h.shape}")
104
+
105
+ except Exception as e:
106
+ print(f"❌ {strategy}: Failed - {str(e)}")
107
+
108
+ # Test evaluation
109
+ print("\nπŸ“ˆ Testing evaluation...")
110
+ try:
111
+ # Initialize classifier
112
+ num_classes = info['num_classes']
113
+ model._init_classifier(num_classes, device)
114
+
115
+ # Create test mask if not available
116
+ if hasattr(data, 'test_mask'):
117
+ mask = data.test_mask
118
+ else:
119
+ mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device)
120
+ mask[data.num_nodes//2:] = True
121
+
122
+ metrics = GraphMetrics.evaluate_node_classification(model, data, mask, device)
123
+ print("βœ… Evaluation successful!")
124
+ for metric, value in metrics.items():
125
+ if isinstance(value, float):
126
+ print(f" {metric}: {value:.4f}")
127
+
128
+ except Exception as e:
129
+ print(f"❌ Evaluation failed: {e}")
130
+
131
+ print("\n✨ Demo completed!")
132
+ print("πŸš€ Ready for production deployment!")
133
+
134
+ if __name__ == "__main__":
135
+ main()