kfoughali commited on
Commit
97c533b
Β·
verified Β·
1 Parent(s): 2a31a52

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +127 -41
demo.py CHANGED
@@ -1,18 +1,21 @@
1
  #!/usr/bin/env python3
2
  """
3
- Quick demo script to test Mamba Graph implementation
4
- Device-safe version
5
  """
6
 
7
  import torch
8
  import os
 
9
  from core.graph_mamba import GraphMamba
 
10
  from data.loader import GraphDataLoader
11
  from utils.metrics import GraphMetrics
 
12
 
13
  def main():
14
- print("🧠 Testing Mamba Graph Neural Network")
15
- print("=" * 50)
16
 
17
  # Configuration
18
  config = {
@@ -28,6 +31,14 @@ def main():
28
  'batch_size': 16,
29
  'test_split': 0.2
30
  },
 
 
 
 
 
 
 
 
31
  'ordering': {
32
  'strategy': 'bfs',
33
  'preserve_locality': True
@@ -39,7 +50,7 @@ def main():
39
  device = torch.device('cpu')
40
  else:
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
- print(f"Device: {device}")
43
 
44
  # Load dataset
45
  print("\nπŸ“Š Loading Cora dataset...")
@@ -48,13 +59,15 @@ def main():
48
  dataset = data_loader.load_node_classification_data('Cora')
49
  data = dataset[0].to(device)
50
 
51
- # Dataset info
52
  info = data_loader.get_dataset_info(dataset)
53
- print(f"βœ… Success!")
54
- print(f"Nodes: {data.num_nodes}")
55
- print(f"Edges: {data.num_edges}")
56
- print(f"Features: {info['num_features']}")
57
- print(f"Classes: {info['num_classes']}")
 
 
 
58
 
59
  except Exception as e:
60
  print(f"❌ Error loading dataset: {e}")
@@ -66,70 +79,143 @@ def main():
66
  model = GraphMamba(config).to(device)
67
  total_params = sum(p.numel() for p in model.parameters())
68
  print(f"βœ… Model initialized!")
69
- print(f"Parameters: {total_params:,}")
 
70
 
71
  except Exception as e:
72
  print(f"❌ Error initializing model: {e}")
73
  return
74
 
75
- # Forward pass test
76
  print("\nπŸš€ Testing forward pass...")
77
  try:
78
  model.eval()
79
  with torch.no_grad():
80
  h = model(data.x, data.edge_index)
81
  print(f"βœ… Forward pass successful!")
82
- print(f"Input shape: {data.x.shape}")
83
- print(f"Output shape: {h.shape}")
84
- print(f"Output range: [{h.min():.3f}, {h.max():.3f}]")
85
 
86
  except Exception as e:
87
  print(f"❌ Forward pass failed: {e}")
88
  return
89
 
90
- # Test different ordering strategies
91
  print("\nπŸ”„ Testing ordering strategies...")
92
-
93
  strategies = ['bfs', 'spectral', 'degree', 'community']
94
 
95
  for strategy in strategies:
96
  try:
97
  config['ordering']['strategy'] = strategy
98
- model_test = GraphMamba(config).to(device)
99
- model_test.eval()
100
 
 
101
  with torch.no_grad():
102
- h = model_test(data.x, data.edge_index)
103
- print(f"βœ… {strategy}: Success - Shape {h.shape}")
 
 
104
 
105
  except Exception as e:
106
- print(f"❌ {strategy}: Failed - {str(e)}")
107
 
108
- # Test evaluation
109
- print("\nπŸ“ˆ Testing evaluation...")
110
  try:
111
- # Initialize classifier
112
- num_classes = info['num_classes']
113
- model._init_classifier(num_classes, device)
 
 
114
 
115
- # Create test mask if not available
116
- if hasattr(data, 'test_mask'):
117
- mask = data.test_mask
118
- else:
119
- mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device)
120
- mask[data.num_nodes//2:] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- metrics = GraphMetrics.evaluate_node_classification(model, data, mask, device)
123
- print("βœ… Evaluation successful!")
124
- for metric, value in metrics.items():
125
- if isinstance(value, float):
126
- print(f" {metric}: {value:.4f}")
127
 
128
  except Exception as e:
129
  print(f"❌ Evaluation failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- print("\n✨ Demo completed!")
132
- print("πŸš€ Ready for production deployment!")
133
 
134
  if __name__ == "__main__":
135
  main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Complete test script for Mamba Graph implementation
4
+ Tests training, evaluation, and visualization
5
  """
6
 
7
  import torch
8
  import os
9
+ import time
10
  from core.graph_mamba import GraphMamba
11
+ from core.trainer import GraphMambaTrainer
12
  from data.loader import GraphDataLoader
13
  from utils.metrics import GraphMetrics
14
+ from utils.visualization import GraphVisualizer
15
 
16
  def main():
17
+ print("🧠 Mamba Graph Neural Network - Complete Test")
18
+ print("=" * 60)
19
 
20
  # Configuration
21
  config = {
 
31
  'batch_size': 16,
32
  'test_split': 0.2
33
  },
34
+ 'training': {
35
+ 'learning_rate': 0.01,
36
+ 'weight_decay': 0.0005,
37
+ 'epochs': 50, # Quick test
38
+ 'patience': 10,
39
+ 'warmup_epochs': 5,
40
+ 'min_lr': 1e-6
41
+ },
42
  'ordering': {
43
  'strategy': 'bfs',
44
  'preserve_locality': True
 
50
  device = torch.device('cpu')
51
  else:
52
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
+ print(f"πŸ’Ύ Device: {device}")
54
 
55
  # Load dataset
56
  print("\nπŸ“Š Loading Cora dataset...")
 
59
  dataset = data_loader.load_node_classification_data('Cora')
60
  data = dataset[0].to(device)
61
 
 
62
  info = data_loader.get_dataset_info(dataset)
63
+ print(f"βœ… Dataset loaded successfully!")
64
+ print(f" Nodes: {data.num_nodes:,}")
65
+ print(f" Edges: {data.num_edges:,}")
66
+ print(f" Features: {info['num_features']}")
67
+ print(f" Classes: {info['num_classes']}")
68
+ print(f" Train nodes: {data.train_mask.sum()}")
69
+ print(f" Val nodes: {data.val_mask.sum()}")
70
+ print(f" Test nodes: {data.test_mask.sum()}")
71
 
72
  except Exception as e:
73
  print(f"❌ Error loading dataset: {e}")
 
79
  model = GraphMamba(config).to(device)
80
  total_params = sum(p.numel() for p in model.parameters())
81
  print(f"βœ… Model initialized!")
82
+ print(f" Parameters: {total_params:,}")
83
+ print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB")
84
 
85
  except Exception as e:
86
  print(f"❌ Error initializing model: {e}")
87
  return
88
 
89
+ # Test forward pass
90
  print("\nπŸš€ Testing forward pass...")
91
  try:
92
  model.eval()
93
  with torch.no_grad():
94
  h = model(data.x, data.edge_index)
95
  print(f"βœ… Forward pass successful!")
96
+ print(f" Input shape: {data.x.shape}")
97
+ print(f" Output shape: {h.shape}")
98
+ print(f" Output range: [{h.min():.3f}, {h.max():.3f}]")
99
 
100
  except Exception as e:
101
  print(f"❌ Forward pass failed: {e}")
102
  return
103
 
104
+ # Test ordering strategies
105
  print("\nπŸ”„ Testing ordering strategies...")
 
106
  strategies = ['bfs', 'spectral', 'degree', 'community']
107
 
108
  for strategy in strategies:
109
  try:
110
  config['ordering']['strategy'] = strategy
111
+ test_model = GraphMamba(config).to(device)
112
+ test_model.eval()
113
 
114
+ start_time = time.time()
115
  with torch.no_grad():
116
+ h = test_model(data.x, data.edge_index)
117
+ end_time = time.time()
118
+
119
+ print(f"βœ… {strategy:12} | Shape: {h.shape} | Time: {(end_time-start_time)*1000:.2f}ms")
120
 
121
  except Exception as e:
122
+ print(f"❌ {strategy:12} | Failed: {str(e)}")
123
 
124
+ # Initialize trainer
125
+ print("\nπŸ‹οΈ Testing training system...")
126
  try:
127
+ trainer = GraphMambaTrainer(model, config, device)
128
+ print(f"βœ… Trainer initialized!")
129
+ print(f" Optimizer: {type(trainer.optimizer).__name__}")
130
+ print(f" Learning rate: {trainer.lr}")
131
+ print(f" Epochs: {trainer.epochs}")
132
 
133
+ except Exception as e:
134
+ print(f"❌ Trainer initialization failed: {e}")
135
+ return
136
+
137
+ # Run training
138
+ print("\n🎯 Running training...")
139
+ try:
140
+ start_time = time.time()
141
+ history = trainer.train_node_classification(data, verbose=True)
142
+ training_time = time.time() - start_time
143
+
144
+ print(f"βœ… Training completed!")
145
+ print(f" Training time: {training_time:.2f}s")
146
+ print(f" Epochs trained: {len(history['train_loss'])}")
147
+ print(f" Best val accuracy: {trainer.best_val_acc:.4f}")
148
+
149
+ except Exception as e:
150
+ print(f"❌ Training failed: {e}")
151
+ return
152
+
153
+ # Test evaluation
154
+ print("\nπŸ“Š Testing evaluation...")
155
+ try:
156
+ test_results = trainer.test(data)
157
+ print(f"βœ… Evaluation completed!")
158
+ print(f" Test accuracy: {test_results['test_acc']:.4f}")
159
+ print(f" Test loss: {test_results['test_loss']:.4f}")
160
 
161
+ # Per-class results
162
+ class_accs = test_results['class_acc']
163
+ print(f" Per-class accuracy:")
164
+ for i, acc in enumerate(class_accs):
165
+ print(f" Class {i}: {acc:.4f}")
166
 
167
  except Exception as e:
168
  print(f"❌ Evaluation failed: {e}")
169
+ return
170
+
171
+ # Test visualization
172
+ print("\n🎨 Testing visualization...")
173
+ try:
174
+ # Create visualizations
175
+ graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200)
176
+ metrics_fig = GraphVisualizer.create_metrics_plot(test_results)
177
+ training_fig = GraphVisualizer.create_training_history_plot(history)
178
+
179
+ print(f"βœ… Visualizations created!")
180
+ print(f" Graph plot: {type(graph_fig).__name__}")
181
+ print(f" Metrics plot: {type(metrics_fig).__name__}")
182
+ print(f" Training plot: {type(training_fig).__name__}")
183
+
184
+ # Save plots
185
+ graph_fig.write_html("graph_visualization.html")
186
+ metrics_fig.write_html("metrics_plot.html")
187
+ training_fig.write_html("training_history.html")
188
+ print(f" Plots saved as HTML files")
189
+
190
+ except Exception as e:
191
+ print(f"❌ Visualization failed: {e}")
192
+
193
+ # Performance summary
194
+ print("\nπŸ† Performance Summary")
195
+ print("=" * 40)
196
+ print(f"πŸ“Š Dataset: Cora ({data.num_nodes:,} nodes)")
197
+ print(f"🧠 Model: {total_params:,} parameters")
198
+ print(f"⏱️ Training: {training_time:.2f}s ({len(history['train_loss'])} epochs)")
199
+ print(f"🎯 Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)")
200
+ print(f"πŸ… Best Val Accuracy: {trainer.best_val_acc:.4f} ({trainer.best_val_acc*100:.2f}%)")
201
+
202
+ # Compare with baselines
203
+ cora_baselines = {
204
+ 'GCN': 0.815,
205
+ 'GAT': 0.830,
206
+ 'GraphSAGE': 0.824,
207
+ 'GIN': 0.800
208
+ }
209
+
210
+ print(f"\nπŸ“ˆ Comparison with Baselines:")
211
+ test_acc = test_results['test_acc']
212
+ for model_name, baseline in cora_baselines.items():
213
+ diff = test_acc - baseline
214
+ status = "🟒" if diff > 0 else "🟑" if diff > -0.05 else "πŸ”΄"
215
+ print(f" {status} {model_name:12}: {baseline:.3f} (diff: {diff:+.3f})")
216
 
217
+ print(f"\n✨ Test completed successfully!")
218
+ print(f"πŸš€ Ready for production deployment!")
219
 
220
  if __name__ == "__main__":
221
  main()