File size: 7,528 Bytes
ba4e201
 
97c533b
 
ba4e201
 
 
 
97c533b
ba4e201
97c533b
ba4e201
 
97c533b
ba4e201
 
97c533b
 
ba4e201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97c533b
 
 
 
 
 
 
 
ba4e201
 
 
 
 
 
 
 
 
 
 
97c533b
ba4e201
 
 
 
 
 
 
 
 
97c533b
 
 
 
 
 
 
 
ba4e201
 
 
 
 
 
 
 
 
 
 
97c533b
 
ba4e201
 
 
 
 
97c533b
ba4e201
 
 
 
 
 
97c533b
 
 
ba4e201
 
 
 
 
97c533b
ba4e201
 
 
 
 
 
97c533b
 
ba4e201
97c533b
ba4e201
97c533b
 
 
 
ba4e201
 
97c533b
ba4e201
97c533b
 
ba4e201
97c533b
 
 
 
 
ba4e201
97c533b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba4e201
97c533b
 
 
 
 
ba4e201
 
 
97c533b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba4e201
97c533b
 
ba4e201
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
#!/usr/bin/env python3
"""
Complete test script for Mamba Graph implementation
Tests training, evaluation, and visualization
"""

import torch
import os
import time
from core.graph_mamba import GraphMamba
from core.trainer import GraphMambaTrainer
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
from utils.visualization import GraphVisualizer

def main():
    print("🧠 Mamba Graph Neural Network - Complete Test")
    print("=" * 60)
    
    # Configuration
    config = {
        'model': {
            'd_model': 128,
            'd_state': 8,
            'd_conv': 4,
            'expand': 2,
            'n_layers': 3,
            'dropout': 0.1
        },
        'data': {
            'batch_size': 16,
            'test_split': 0.2
        },
        'training': {
            'learning_rate': 0.01,
            'weight_decay': 0.0005,
            'epochs': 50,  # Quick test
            'patience': 10,
            'warmup_epochs': 5,
            'min_lr': 1e-6
        },
        'ordering': {
            'strategy': 'bfs',
            'preserve_locality': True
        }
    }
    
    # Setup device
    if os.getenv('SPACE_ID'):
        device = torch.device('cpu')
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"πŸ’Ύ Device: {device}")
    
    # Load dataset
    print("\nπŸ“Š Loading Cora dataset...")
    try:
        data_loader = GraphDataLoader()
        dataset = data_loader.load_node_classification_data('Cora')
        data = dataset[0].to(device)
        
        info = data_loader.get_dataset_info(dataset)
        print(f"βœ… Dataset loaded successfully!")
        print(f"   Nodes: {data.num_nodes:,}")
        print(f"   Edges: {data.num_edges:,}")
        print(f"   Features: {info['num_features']}")
        print(f"   Classes: {info['num_classes']}")
        print(f"   Train nodes: {data.train_mask.sum()}")
        print(f"   Val nodes: {data.val_mask.sum()}")
        print(f"   Test nodes: {data.test_mask.sum()}")
        
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return
    
    # Initialize model
    print("\nπŸ—οΈ Initializing GraphMamba...")
    try:
        model = GraphMamba(config).to(device)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"βœ… Model initialized!")
        print(f"   Parameters: {total_params:,}")
        print(f"   Memory usage: ~{total_params * 4 / 1024**2:.1f} MB")
        
    except Exception as e:
        print(f"❌ Error initializing model: {e}")
        return
    
    # Test forward pass
    print("\nπŸš€ Testing forward pass...")
    try:
        model.eval()
        with torch.no_grad():
            h = model(data.x, data.edge_index)
            print(f"βœ… Forward pass successful!")
            print(f"   Input shape: {data.x.shape}")
            print(f"   Output shape: {h.shape}")
            print(f"   Output range: [{h.min():.3f}, {h.max():.3f}]")
            
    except Exception as e:
        print(f"❌ Forward pass failed: {e}")
        return
    
    # Test ordering strategies
    print("\nπŸ”„ Testing ordering strategies...")
    strategies = ['bfs', 'spectral', 'degree', 'community']
    
    for strategy in strategies:
        try:
            config['ordering']['strategy'] = strategy
            test_model = GraphMamba(config).to(device)
            test_model.eval()
            
            start_time = time.time()
            with torch.no_grad():
                h = test_model(data.x, data.edge_index)
            end_time = time.time()
            
            print(f"βœ… {strategy:12} | Shape: {h.shape} | Time: {(end_time-start_time)*1000:.2f}ms")
            
        except Exception as e:
            print(f"❌ {strategy:12} | Failed: {str(e)}")
    
    # Initialize trainer
    print("\nπŸ‹οΈ Testing training system...")
    try:
        trainer = GraphMambaTrainer(model, config, device)
        print(f"βœ… Trainer initialized!")
        print(f"   Optimizer: {type(trainer.optimizer).__name__}")
        print(f"   Learning rate: {trainer.lr}")
        print(f"   Epochs: {trainer.epochs}")
        
    except Exception as e:
        print(f"❌ Trainer initialization failed: {e}")
        return
    
    # Run training
    print("\n🎯 Running training...")
    try:
        start_time = time.time()
        history = trainer.train_node_classification(data, verbose=True)
        training_time = time.time() - start_time
        
        print(f"βœ… Training completed!")
        print(f"   Training time: {training_time:.2f}s")
        print(f"   Epochs trained: {len(history['train_loss'])}")
        print(f"   Best val accuracy: {trainer.best_val_acc:.4f}")
        
    except Exception as e:
        print(f"❌ Training failed: {e}")
        return
    
    # Test evaluation
    print("\nπŸ“Š Testing evaluation...")
    try:
        test_results = trainer.test(data)
        print(f"βœ… Evaluation completed!")
        print(f"   Test accuracy: {test_results['test_acc']:.4f}")
        print(f"   Test loss: {test_results['test_loss']:.4f}")
        
        # Per-class results
        class_accs = test_results['class_acc']
        print(f"   Per-class accuracy:")
        for i, acc in enumerate(class_accs):
            print(f"     Class {i}: {acc:.4f}")
            
    except Exception as e:
        print(f"❌ Evaluation failed: {e}")
        return
    
    # Test visualization
    print("\n🎨 Testing visualization...")
    try:
        # Create visualizations
        graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200)
        metrics_fig = GraphVisualizer.create_metrics_plot(test_results)
        training_fig = GraphVisualizer.create_training_history_plot(history)
        
        print(f"βœ… Visualizations created!")
        print(f"   Graph plot: {type(graph_fig).__name__}")
        print(f"   Metrics plot: {type(metrics_fig).__name__}")
        print(f"   Training plot: {type(training_fig).__name__}")
        
        # Save plots
        graph_fig.write_html("graph_visualization.html")
        metrics_fig.write_html("metrics_plot.html")
        training_fig.write_html("training_history.html")
        print(f"   Plots saved as HTML files")
        
    except Exception as e:
        print(f"❌ Visualization failed: {e}")
    
    # Performance summary
    print("\nπŸ† Performance Summary")
    print("=" * 40)
    print(f"πŸ“Š Dataset: Cora ({data.num_nodes:,} nodes)")
    print(f"🧠 Model: {total_params:,} parameters")
    print(f"⏱️ Training: {training_time:.2f}s ({len(history['train_loss'])} epochs)")
    print(f"🎯 Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)")
    print(f"πŸ… Best Val Accuracy: {trainer.best_val_acc:.4f} ({trainer.best_val_acc*100:.2f}%)")
    
    # Compare with baselines
    cora_baselines = {
        'GCN': 0.815,
        'GAT': 0.830,
        'GraphSAGE': 0.824,
        'GIN': 0.800
    }
    
    print(f"\nπŸ“ˆ Comparison with Baselines:")
    test_acc = test_results['test_acc']
    for model_name, baseline in cora_baselines.items():
        diff = test_acc - baseline
        status = "🟒" if diff > 0 else "🟑" if diff > -0.05 else "πŸ”΄"
        print(f"   {status} {model_name:12}: {baseline:.3f} (diff: {diff:+.3f})")
    
    print(f"\n✨ Test completed successfully!")
    print(f"πŸš€ Ready for production deployment!")

if __name__ == "__main__":
    main()