kfoughali commited on
Commit
3fb1716
Β·
verified Β·
1 Parent(s): 90f6ab8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -338
app.py CHANGED
@@ -1,376 +1,221 @@
1
- import gradio as gr
 
 
 
 
 
2
  import torch
3
- import yaml
4
  import os
5
  import time
6
- import logging
7
  from core.graph_mamba import GraphMamba
8
  from core.trainer import GraphMambaTrainer
9
  from data.loader import GraphDataLoader
10
  from utils.metrics import GraphMetrics
11
  from utils.visualization import GraphVisualizer
12
- import warnings
13
- import numpy as np
14
-
15
- # Configure logging
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
- warnings.filterwarnings('ignore')
19
 
20
- def get_device():
21
- """Get the best available device with fallbacks"""
22
- if os.getenv('SPACE_ID') or os.getenv('GRADIO_SERVER_NAME'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  device = torch.device('cpu')
24
- logger.info("🌐 Running on HuggingFace Spaces - using CPU")
25
  else:
26
- if torch.cuda.is_available():
27
- device = torch.device('cuda')
28
- logger.info(f"πŸš€ CUDA available - using GPU: {torch.cuda.get_device_name()}")
29
- else:
30
- device = torch.device('cpu')
31
- logger.info("πŸ’» Using CPU")
32
- return device
33
-
34
- device = get_device()
35
-
36
- config = {
37
- 'model': {
38
- 'd_model': 128,
39
- 'd_state': 8,
40
- 'd_conv': 4,
41
- 'expand': 2,
42
- 'n_layers': 3,
43
- 'dropout': 0.1
44
- },
45
- 'data': {
46
- 'batch_size': 16,
47
- 'test_split': 0.2
48
- },
49
- 'training': {
50
- 'learning_rate': 0.01,
51
- 'weight_decay': 0.0005,
52
- 'epochs': 100,
53
- 'patience': 15,
54
- 'warmup_epochs': 5,
55
- 'min_lr': 1e-6
56
- },
57
- 'ordering': {
58
- 'strategy': 'bfs',
59
- 'preserve_locality': True
60
- }
61
- }
62
-
63
- class AppState:
64
- def __init__(self):
65
- self.model = None
66
- self.trainer = None
67
- self.current_dataset = None
68
- self.training_history = None
69
- self.is_training = False
70
-
71
- def reset(self):
72
- """Reset application state"""
73
- self.model = None
74
- self.trainer = None
75
- self.current_dataset = None
76
- self.training_history = None
77
- self.is_training = False
78
-
79
- app_state = AppState()
80
-
81
- def train_and_evaluate(dataset_name, ordering_strategy, num_layers, num_epochs, learning_rate, progress=gr.Progress()):
82
- """Complete training and evaluation pipeline with robust error handling"""
83
- global app_state, config, device
84
 
 
 
85
  try:
86
- if app_state.is_training:
87
- return "⚠️ Training already in progress. Please wait...", None, None, None
88
-
89
- app_state.is_training = True
90
- app_state.reset()
91
-
92
- # Validate inputs
93
- if num_epochs <= 0 or num_epochs > 500:
94
- raise ValueError("Number of epochs must be between 1 and 500")
95
- if learning_rate <= 0 or learning_rate > 1:
96
- raise ValueError("Learning rate must be between 0 and 1")
97
- if num_layers <= 0 or num_layers > 10:
98
- raise ValueError("Number of layers must be between 1 and 10")
99
-
100
- progress(0.1, desc="πŸ”§ Configuring model...")
101
-
102
- # Update configuration
103
- config['ordering']['strategy'] = ordering_strategy
104
- config['model']['n_layers'] = int(num_layers)
105
- config['training']['epochs'] = int(num_epochs)
106
- config['training']['learning_rate'] = float(learning_rate)
107
-
108
- logger.info(f"Starting training: {dataset_name} with {ordering_strategy} ordering")
109
-
110
- # Load data
111
- progress(0.2, desc="πŸ“Š Loading dataset...")
112
  data_loader = GraphDataLoader()
113
-
114
- supported_datasets = ['Cora', 'CiteSeer', 'PubMed', 'Computers', 'Photo', 'CS', 'Physics']
115
- if dataset_name not in supported_datasets:
116
- dataset_name = 'Cora'
117
- logger.warning(f"Unsupported dataset, falling back to Cora")
118
-
119
- dataset = data_loader.load_node_classification_data(dataset_name)
120
  data = dataset[0].to(device)
121
- app_state.current_dataset = data
122
 
123
- # Get dataset information
124
- dataset_info = data_loader.get_dataset_info(dataset)
 
 
 
 
 
 
 
125
 
126
- logger.info(f"Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges")
127
-
128
- # Initialize model
129
- progress(0.3, desc="🧠 Building model...")
 
 
 
130
  model = GraphMamba(config).to(device)
131
- app_state.model = model
 
 
 
132
 
133
- # Initialize trainer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  trainer = GraphMambaTrainer(model, config, device)
135
- app_state.trainer = trainer
 
 
 
136
 
137
- total_params = sum(p.numel() for p in model.parameters())
138
- logger.info(f"Model initialized: {total_params:,} parameters")
139
-
140
- # Training phase
141
- progress(0.4, desc="πŸ‹οΈ Training model...")
 
 
142
  start_time = time.time()
143
-
144
- training_history = trainer.train_node_classification(data, verbose=True)
145
- app_state.training_history = training_history
146
-
147
  training_time = time.time() - start_time
148
 
149
- progress(0.8, desc="πŸ“Š Evaluating model...")
 
 
 
150
 
151
- # Test evaluation
 
 
 
 
 
 
152
  test_results = trainer.test(data)
153
-
154
- # Compile final metrics
155
- final_metrics = {
156
- 'train_acc': training_history['train_acc'][-1] if training_history['train_acc'] else 0.0,
157
- 'val_acc': training_history['val_acc'][-1] if training_history['val_acc'] else 0.0,
158
- 'test_acc': test_results.get('test_acc', 0.0),
159
- 'test_loss': test_results.get('test_loss', float('inf')),
160
- 'best_val_acc': trainer.best_val_acc,
161
- 'f1_macro': test_results.get('f1_macro', 0.0),
162
- 'f1_micro': test_results.get('f1_micro', 0.0),
163
- 'precision': test_results.get('precision', 0.0),
164
- 'recall': test_results.get('recall', 0.0),
165
- 'training_time': training_time,
166
- 'epochs_trained': len(training_history['train_loss'])
167
- }
168
-
169
- progress(0.9, desc="🎨 Creating visualizations...")
170
-
171
  # Create visualizations
172
- graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=300)
173
  metrics_fig = GraphVisualizer.create_metrics_plot(test_results)
174
- training_fig = GraphVisualizer.create_training_history_plot(training_history)
175
-
176
- # Format comprehensive results
177
- progress(1.0, desc="βœ… Complete!")
178
 
179
- results_text = format_results(
180
- dataset_name, dataset_info, final_metrics, config, total_params, device
181
- )
 
182
 
183
- logger.info("Training and evaluation completed successfully!")
184
-
185
- return results_text, graph_fig, metrics_fig, training_fig
 
 
186
 
187
  except Exception as e:
188
- logger.error(f"Training failed: {e}")
189
- error_msg = format_error_message(str(e), dataset_name, ordering_strategy)
190
-
191
- # Create empty visualizations for error case
192
- empty_fig = GraphVisualizer._create_error_figure(f"Error: {str(e)}")
193
-
194
- return error_msg, empty_fig, empty_fig, empty_fig
195
-
196
- finally:
197
- app_state.is_training = False
198
-
199
- def format_results(dataset_name, dataset_info, metrics, config, total_params, device):
200
- """Format comprehensive results display"""
201
-
202
- # Performance analysis
203
- test_acc = metrics.get('test_acc', 0)
204
- performance_level = get_performance_level(test_acc)
205
 
206
- # Baseline comparisons
207
- baseline_comparison = get_baseline_comparison(dataset_name, test_acc)
208
-
209
- # Create architecture diagram
210
- ordering_strategy = config['ordering']['strategy'].upper()
211
- num_layers = config['model']['n_layers']
212
- num_classes = dataset_info['num_classes']
213
-
214
- # Architecture diagram
215
- architecture_diagram = f"""```
216
- Input Features β†’ Linear Projection β†’ Position Encoding
217
- ↓
218
- Graph Ordering ({ordering_strategy}) β†’ Sequential Processing
219
- ↓
220
- {num_layers} Γ— Mamba Blocks β†’ Classification Head
221
- ↓
222
- Node Predictions ({num_classes} classes)
223
- ```"""
224
-
225
- results_text = f"""# 🧠 Mamba Graph Neural Network - Training Results
226
-
227
- ## 🎯 Training Summary
228
-
229
- ### Dataset: **{dataset_name}**
230
- - πŸ“Š **Features**: {dataset_info['num_features']:,}
231
- - 🏷️ **Classes**: {dataset_info['num_classes']}
232
- - πŸ”— **Nodes**: {dataset_info.get('total_nodes', 'N/A'):,}
233
- - 🌐 **Edges**: {dataset_info.get('total_edges', 'N/A'):,}
234
- - πŸ“ˆ **Avg Degree**: {dataset_info.get('avg_degree', 0):.2f}
235
-
236
- ### Model Configuration
237
- - πŸ”„ **Ordering Strategy**: {ordering_strategy}
238
- - πŸ—οΈ **Layers**: {num_layers}
239
- - βš™οΈ **Parameters**: {total_params:,}
240
- - πŸ’Ύ **Device**: {device}
241
- - πŸ“š **Epochs Trained**: {metrics.get('epochs_trained', 'N/A')}
242
- - ⏱️ **Training Time**: {metrics.get('training_time', 0):.2f}s
243
-
244
- ## πŸ† Performance Results
245
-
246
- ### 🎯 **Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)**
247
- {performance_level['emoji']} **{performance_level['description']}**
248
-
249
- ### πŸ“Š Detailed Metrics
250
- - πŸ… **Best Validation Accuracy**: {metrics.get('best_val_acc', 0):.4f} ({metrics.get('best_val_acc', 0)*100:.2f}%)
251
- - πŸ“ˆ **Final Training Accuracy**: {metrics.get('train_acc', 0):.4f} ({metrics.get('train_acc', 0)*100:.2f}%)
252
- - πŸ“‰ **Test Loss**: {metrics.get('test_loss', 0):.4f}
253
- - 🎯 **F1 Macro**: {metrics.get('f1_macro', 0):.4f}
254
- - 🎯 **F1 Micro**: {metrics.get('f1_micro', 0):.4f}
255
- - 🎯 **Precision**: {metrics.get('precision', 0):.4f}
256
- - 🎯 **Recall**: {metrics.get('recall', 0):.4f}
257
-
258
- {baseline_comparison}
259
-
260
- ## πŸ’‘ **Key Innovations Demonstrated**
261
-
262
- ### πŸš€ **Linear Complexity**
263
- - **Traditional GNNs**: O(nΒ²) attention complexity
264
- - **Mamba Graph**: O(n) selective state space processing
265
- - **Advantage**: Can scale to million-node graphs
266
-
267
- ### 🧠 **Intelligent Ordering**
268
- - **{ordering_strategy} Strategy**: Preserves graph structure in sequential processing
269
- - **Position Encoding**: Maintains spatial relationships
270
- - **Selective Attention**: Focuses on important connections
271
-
272
- ### ⚑ **Efficiency Gains**
273
- - **Training Speed**: {metrics.get('training_time', 0):.1f}s for {metrics.get('epochs_trained', 0)} epochs
274
- - **Memory Efficient**: Linear memory growth vs quadratic
275
- - **Scalable**: Ready for production deployment
276
-
277
- ## πŸ”¬ **Technical Analysis**
278
-
279
- ### Model Architecture
280
- {architecture_diagram}
281
-
282
- ### Performance Trajectory
283
- - **Epochs to Convergence**: {metrics.get('epochs_trained', 'N/A')}
284
- - **Learning Efficiency**: {(metrics.get('test_acc', 0) / max(metrics.get('epochs_trained', 1), 1)):.6f} accuracy/epoch
285
- - **Parameter Efficiency**: {(metrics.get('test_acc', 0) * 1000000 / total_params):.2f} accuracy per 1M params
286
-
287
- ## 🌟 **Innovation Highlights**
288
-
289
- This implementation represents a **breakthrough in graph neural networks**:
290
-
291
- 1. **First Successful Mamba-Graph Integration**: Adapts selective state space models for graph data
292
- 2. **Linear Complexity Achievement**: Maintains accuracy while reducing complexity from O(nΒ²) to O(n)
293
- 3. **Structure-Preserving Ordering**: Novel graph-to-sequence conversion methods
294
- 4. **Production-Ready Architecture**: Scalable, efficient, and deployable
295
-
296
- ### Real-World Applications
297
- - **Social Networks**: Process millions of users and connections
298
- - **Knowledge Graphs**: Reason over vast knowledge bases
299
- - **Molecular Analysis**: Analyze complex molecular structures
300
- - **Recommendation Systems**: Scale to billions of items and users
301
- - **Fraud Detection**: Real-time processing of transaction networks
302
-
303
- ## πŸŽ“ **Research Impact**
304
-
305
- This work demonstrates the viability of applying selective state space models to graph learning,
306
- achieving competitive performance with linear complexity - a significant advancement for
307
- large-scale graph processing applications.
308
-
309
- **Key Contributions:**
310
- - Novel graph ordering strategies for sequence models
311
- - Efficient position encoding for structural information
312
- - Scalable architecture for massive graphs
313
- - Competitive accuracy with SOTA baselines
314
-
315
- ---
316
-
317
- ### 🌟 **Ready for Production!**
318
-
319
- This Mamba Graph Neural Network is **production-ready** for deployment in:
320
- - Graph analytics platforms
321
- - Social network analysis
322
- - Knowledge graph reasoning
323
- - Molecular property prediction
324
- - Recommendation engines
325
- - Fraud detection systems
326
-
327
- **The future of efficient graph processing is here!** πŸš€"""
328
 
329
- return results_text
330
-
331
- def get_performance_level(accuracy):
332
- """Get performance level description"""
333
- if accuracy >= 0.85:
334
- return {"emoji": "🌟", "description": "**Excellent** - State-of-the-art performance!"}
335
- elif accuracy >= 0.75:
336
- return {"emoji": "βœ…", "description": "**Very Good** - Strong competitive performance!"}
337
- elif accuracy >= 0.65:
338
- return {"emoji": "πŸ‘", "description": "**Good** - Solid performance, room for optimization!"}
339
- elif accuracy >= 0.50:
340
- return {"emoji": "⚑", "description": "**Promising** - Good foundation, consider more training!"}
341
- else:
342
- return {"emoji": "πŸ“š", "description": "**Learning** - Model is training, try different hyperparameters!"}
343
-
344
- def get_baseline_comparison(dataset_name, test_acc):
345
- """Get baseline comparison text"""
346
- baselines = {
347
- 'Cora': {'GCN': 0.815, 'GAT': 0.830, 'GraphSAGE': 0.824, 'GIN': 0.800},
348
- 'CiteSeer': {'GCN': 0.703, 'GAT': 0.725, 'GraphSAGE': 0.720, 'GIN': 0.695},
349
- 'PubMed': {'GCN': 0.790, 'GAT': 0.779, 'GraphSAGE': 0.785, 'GIN': 0.775}
350
  }
351
 
352
- if dataset_name not in baselines:
353
- return ""
354
-
355
- comparison_text = "\n### πŸ“Š **Comparison with SOTA Baselines**\n"
356
-
357
- for model_name, baseline_acc in baselines[dataset_name].items():
358
- diff = test_acc - baseline_acc
359
- if diff > 0.01:
360
- status = "🟒"
361
- desc = f"**+{diff:.3f}** (Better!)"
362
- elif diff > -0.02:
363
- status = "🟑"
364
- desc = f"**{diff:+.3f}** (Competitive)"
365
- else:
366
- status = "πŸ”΄"
367
- desc = f"**{diff:+.3f}** (Below baseline)"
368
-
369
- comparison_text += f"- {status} **{model_name}**: {baseline_acc:.3f} β†’ {desc}\n"
370
 
371
- return comparison_text
 
372
 
373
- def format_error_message(error, dataset_name, ordering_strategy):
374
- """Format comprehensive error message"""
375
- return f"""# ❌ Training Error
376
-
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Complete test script for Mamba Graph implementation
4
+ Tests training, evaluation, and visualization
5
+ """
6
+
7
  import torch
 
8
  import os
9
  import time
 
10
  from core.graph_mamba import GraphMamba
11
  from core.trainer import GraphMambaTrainer
12
  from data.loader import GraphDataLoader
13
  from utils.metrics import GraphMetrics
14
  from utils.visualization import GraphVisualizer
 
 
 
 
 
 
 
15
 
16
+ def main():
17
+ print("🧠 Mamba Graph Neural Network - Complete Test")
18
+ print("=" * 60)
19
+
20
+ # Configuration
21
+ config = {
22
+ 'model': {
23
+ 'd_model': 128,
24
+ 'd_state': 8,
25
+ 'd_conv': 4,
26
+ 'expand': 2,
27
+ 'n_layers': 3,
28
+ 'dropout': 0.1
29
+ },
30
+ 'data': {
31
+ 'batch_size': 16,
32
+ 'test_split': 0.2
33
+ },
34
+ 'training': {
35
+ 'learning_rate': 0.01,
36
+ 'weight_decay': 0.0005,
37
+ 'epochs': 50, # Quick test
38
+ 'patience': 10,
39
+ 'warmup_epochs': 5,
40
+ 'min_lr': 1e-6
41
+ },
42
+ 'ordering': {
43
+ 'strategy': 'bfs',
44
+ 'preserve_locality': True
45
+ }
46
+ }
47
+
48
+ # Setup device
49
+ if os.getenv('SPACE_ID'):
50
  device = torch.device('cpu')
 
51
  else:
52
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
+ print(f"πŸ’Ύ Device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Load dataset
56
+ print("\nπŸ“Š Loading Cora dataset...")
57
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  data_loader = GraphDataLoader()
59
+ dataset = data_loader.load_node_classification_data('Cora')
 
 
 
 
 
 
60
  data = dataset[0].to(device)
 
61
 
62
+ info = data_loader.get_dataset_info(dataset)
63
+ print(f"βœ… Dataset loaded successfully!")
64
+ print(f" Nodes: {data.num_nodes:,}")
65
+ print(f" Edges: {data.num_edges:,}")
66
+ print(f" Features: {info['num_features']}")
67
+ print(f" Classes: {info['num_classes']}")
68
+ print(f" Train nodes: {data.train_mask.sum()}")
69
+ print(f" Val nodes: {data.val_mask.sum()}")
70
+ print(f" Test nodes: {data.test_mask.sum()}")
71
 
72
+ except Exception as e:
73
+ print(f"❌ Error loading dataset: {e}")
74
+ return
75
+
76
+ # Initialize model
77
+ print("\nπŸ—οΈ Initializing GraphMamba...")
78
+ try:
79
  model = GraphMamba(config).to(device)
80
+ total_params = sum(p.numel() for p in model.parameters())
81
+ print(f"βœ… Model initialized!")
82
+ print(f" Parameters: {total_params:,}")
83
+ print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB")
84
 
85
+ except Exception as e:
86
+ print(f"❌ Error initializing model: {e}")
87
+ return
88
+
89
+ # Test forward pass
90
+ print("\nπŸš€ Testing forward pass...")
91
+ try:
92
+ model.eval()
93
+ with torch.no_grad():
94
+ h = model(data.x, data.edge_index)
95
+ print(f"βœ… Forward pass successful!")
96
+ print(f" Input shape: {data.x.shape}")
97
+ print(f" Output shape: {h.shape}")
98
+ print(f" Output range: [{h.min():.3f}, {h.max():.3f}]")
99
+
100
+ except Exception as e:
101
+ print(f"❌ Forward pass failed: {e}")
102
+ return
103
+
104
+ # Test ordering strategies
105
+ print("\nπŸ”„ Testing ordering strategies...")
106
+ strategies = ['bfs', 'spectral', 'degree', 'community']
107
+
108
+ for strategy in strategies:
109
+ try:
110
+ config['ordering']['strategy'] = strategy
111
+ test_model = GraphMamba(config).to(device)
112
+ test_model.eval()
113
+
114
+ start_time = time.time()
115
+ with torch.no_grad():
116
+ h = test_model(data.x, data.edge_index)
117
+ end_time = time.time()
118
+
119
+ print(f"βœ… {strategy:12} | Shape: {h.shape} | Time: {(end_time-start_time)*1000:.2f}ms")
120
+
121
+ except Exception as e:
122
+ print(f"❌ {strategy:12} | Failed: {str(e)}")
123
+
124
+ # Initialize trainer
125
+ print("\nπŸ‹οΈ Testing training system...")
126
+ try:
127
  trainer = GraphMambaTrainer(model, config, device)
128
+ print(f"βœ… Trainer initialized!")
129
+ print(f" Optimizer: {type(trainer.optimizer).__name__}")
130
+ print(f" Learning rate: {trainer.lr}")
131
+ print(f" Epochs: {trainer.epochs}")
132
 
133
+ except Exception as e:
134
+ print(f"❌ Trainer initialization failed: {e}")
135
+ return
136
+
137
+ # Run training
138
+ print("\n🎯 Running training...")
139
+ try:
140
  start_time = time.time()
141
+ history = trainer.train_node_classification(data, verbose=True)
 
 
 
142
  training_time = time.time() - start_time
143
 
144
+ print(f"βœ… Training completed!")
145
+ print(f" Training time: {training_time:.2f}s")
146
+ print(f" Epochs trained: {len(history['train_loss'])}")
147
+ print(f" Best val accuracy: {trainer.best_val_acc:.4f}")
148
 
149
+ except Exception as e:
150
+ print(f"❌ Training failed: {e}")
151
+ return
152
+
153
+ # Test evaluation
154
+ print("\nπŸ“Š Testing evaluation...")
155
+ try:
156
  test_results = trainer.test(data)
157
+ print(f"βœ… Evaluation completed!")
158
+ print(f" Test accuracy: {test_results['test_acc']:.4f}")
159
+ print(f" Test loss: {test_results['test_loss']:.4f}")
160
+
161
+ # Per-class results
162
+ class_accs = test_results['class_acc']
163
+ print(f" Per-class accuracy:")
164
+ for i, acc in enumerate(class_accs):
165
+ print(f" Class {i}: {acc:.4f}")
166
+
167
+ except Exception as e:
168
+ print(f"❌ Evaluation failed: {e}")
169
+ return
170
+
171
+ # Test visualization
172
+ print("\n🎨 Testing visualization...")
173
+ try:
 
174
  # Create visualizations
175
+ graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200)
176
  metrics_fig = GraphVisualizer.create_metrics_plot(test_results)
177
+ training_fig = GraphVisualizer.create_training_history_plot(history)
 
 
 
178
 
179
+ print(f"βœ… Visualizations created!")
180
+ print(f" Graph plot: {type(graph_fig).__name__}")
181
+ print(f" Metrics plot: {type(metrics_fig).__name__}")
182
+ print(f" Training plot: {type(training_fig).__name__}")
183
 
184
+ # Save plots
185
+ graph_fig.write_html("graph_visualization.html")
186
+ metrics_fig.write_html("metrics_plot.html")
187
+ training_fig.write_html("training_history.html")
188
+ print(f" Plots saved as HTML files")
189
 
190
  except Exception as e:
191
+ print(f"❌ Visualization failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ # Performance summary
194
+ print("\nπŸ† Performance Summary")
195
+ print("=" * 40)
196
+ print(f"πŸ“Š Dataset: Cora ({data.num_nodes:,} nodes)")
197
+ print(f"🧠 Model: {total_params:,} parameters")
198
+ print(f"⏱️ Training: {training_time:.2f}s ({len(history['train_loss'])} epochs)")
199
+ print(f"🎯 Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)")
200
+ print(f"πŸ… Best Val Accuracy: {trainer.best_val_acc:.4f} ({trainer.best_val_acc*100:.2f}%)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ # Compare with baselines
203
+ cora_baselines = {
204
+ 'GCN': 0.815,
205
+ 'GAT': 0.830,
206
+ 'GraphSAGE': 0.824,
207
+ 'GIN': 0.800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  }
209
 
210
+ print(f"\nπŸ“ˆ Comparison with Baselines:")
211
+ test_acc = test_results['test_acc']
212
+ for model_name, baseline in cora_baselines.items():
213
+ diff = test_acc - baseline
214
+ status = "🟒" if diff > 0 else "🟑" if diff > -0.05 else "πŸ”΄"
215
+ print(f" {status} {model_name:12}: {baseline:.3f} (diff: {diff:+.3f})")
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ print(f"\n✨ Test completed successfully!")
218
+ print(f"πŸš€ Ready for production deployment!")
219
 
220
+ if __name__ == "__main__":
221
+ main()