File size: 12,205 Bytes
3fb1716
 
a7b5172
8aa0616
3fb1716
 
cf02254
8aa0616
 
 
cf02254
a7b5172
 
8aa0616
cf02254
 
 
 
 
a7b5172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aa0616
3fb1716
 
 
8aa0616
 
3fb1716
 
a7b5172
efb812d
a7b5172
 
 
 
 
 
 
 
 
 
 
cf02254
 
a7b5172
 
 
cf02254
3fb1716
cf02254
 
3fb1716
a7b5172
3fb1716
 
 
 
 
 
 
 
cf02254
a7b5172
 
3fb1716
a7b5172
 
3fb1716
 
8aa0616
 
a7b5172
cf02254
3fb1716
a7b5172
3fb1716
 
 
617d132
8aa0616
 
 
 
 
 
 
 
 
 
 
 
 
a7b5172
 
 
3fb1716
a7b5172
 
3fb1716
 
a7b5172
 
 
3fb1716
 
efb812d
3fb1716
efb812d
3fb1716
a7b5172
 
 
 
 
 
 
 
3fb1716
 
a7b5172
3fb1716
8aa0616
3fb1716
a7b5172
8aa0616
 
3fb1716
 
 
 
 
 
 
efb812d
3fb1716
 
efb812d
3fb1716
a7b5172
 
3fb1716
 
 
a7b5172
3fb1716
 
8aa0616
 
a7b5172
 
 
 
cf02254
a7b5172
3fb1716
 
 
 
8aa0616
 
cf02254
a7b5172
8aa0616
efb812d
3fb1716
efb812d
cf02254
3fb1716
 
 
 
a7b5172
8aa0616
a7b5172
 
cf02254
3fb1716
 
a7b5172
3fb1716
 
a7b5172
 
 
 
 
3fb1716
a7b5172
 
 
 
 
 
3fb1716
a7b5172
 
3fb1716
 
a7b5172
3fb1716
 
a7b5172
 
 
 
3fb1716
a7b5172
3fb1716
cf02254
3fb1716
 
 
 
cf02254
a7b5172
cf02254
 
3fb1716
a7b5172
 
 
 
 
 
 
efb812d
 
 
 
a7b5172
8aa0616
efb812d
 
 
 
a7b5172
 
 
 
 
 
 
 
 
 
 
 
 
cf02254
3fb1716
a7b5172
 
 
 
 
8aa0616
a7b5172
 
 
efb812d
8aa0616
a7b5172
8aa0616
a7b5172
 
 
 
 
8aa0616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf02254
a7b5172
 
efb812d
8aa0616
efb812d
 
a7b5172
efb812d
cf02254
a7b5172
cf02254
3fb1716
a7b5172
 
 
efb812d
 
 
 
 
 
 
a7b5172
efb812d
 
a7b5172
efb812d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#!/usr/bin/env python3
"""
Production test script for Mamba Graph implementation
Fixed for overfitting with regularized configuration
"""

import os
os.environ['OMP_NUM_THREADS'] = '4'  # Fix warning

import torch
import time
import logging
from pathlib import Path
from core.graph_mamba import GraphMamba, create_regularized_config
from core.trainer import GraphMambaTrainer
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
from utils.visualization import GraphVisualizer

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def get_device():
    """Get the best available device - GPU preferred"""
    if torch.cuda.is_available():
        device = torch.device('cuda')
        logger.info(f"πŸš€ CUDA available - using GPU: {torch.cuda.get_device_name()}")
    else:
        device = torch.device('cpu')
        logger.info("πŸ’» Using CPU")
    return device

def run_comprehensive_test():
    """Run comprehensive test suite with overfitting fixes"""
    print("🧠 Mamba Graph Neural Network - Complete Test")
    print("=" * 60)
    
    # Use regularized configuration to prevent overfitting
    config = create_regularized_config()
    
    # Setup device
    device = get_device()
    start_time = time.time()
    
    # Test results
    test_results = {
        'data_loading': False,
        'model_initialization': False,
        'forward_pass': False,
        'ordering_strategies': {},
        'training': False,
        'evaluation': False,
        'visualization': False
    }
    
    try:
        # Test 1: Data Loading
        print("\nπŸ“Š Loading Cora dataset...")
        
        data_loader = GraphDataLoader()
        dataset = data_loader.load_node_classification_data('Cora')
        data = dataset[0].to(device)
        
        info = data_loader.get_dataset_info(dataset)
        
        print(f"βœ… Dataset loaded successfully!")
        print(f"   Nodes: {data.num_nodes:,}")
        print(f"   Edges: {data.num_edges:,}")
        print(f"   Features: {info['num_features']}")
        print(f"   Classes: {info['num_classes']}")
        print(f"   Train nodes: {data.train_mask.sum()}")
        print(f"   Val nodes: {data.val_mask.sum()}")
        print(f"   Test nodes: {data.test_mask.sum()}")
        
        test_results['data_loading'] = True
        
    except Exception as e:
        print(f"❌ Data loading failed: {e}")
        return test_results
    
    try:
        # Test 2: Model Initialization with regularized config
        print("\nπŸ—οΈ Initializing GraphMamba (Regularized)...")
        
        model = GraphMamba(config).to(device)
        total_params = sum(p.numel() for p in model.parameters())
        
        print(f"βœ… Model initialized!")
        print(f"   Parameters: {total_params:,}")
        print(f"   Memory usage: ~{total_params * 4 / 1024**2:.1f} MB")
        print(f"   Device: {device}")
        print(f"   Model type: Regularized (Anti-overfitting)")
        
        # Check if parameter count is reasonable for small training set
        train_samples = data.train_mask.sum().item()
        params_per_sample = total_params / train_samples
        print(f"   Params per training sample: {params_per_sample:.1f}")
        
        if params_per_sample < 500:
            print("   βœ… Good parameter ratio - low overfitting risk")
        elif params_per_sample < 1000:
            print("   ⚠️ Moderate parameter ratio - watch for overfitting")
        else:
            print("   🚨 High parameter ratio - high overfitting risk")
        
        test_results['model_initialization'] = True
        
    except Exception as e:
        print(f"❌ Model initialization failed: {e}")
        return test_results
    
    try:
        # Test 3: Forward Pass
        print("\nπŸš€ Testing forward pass...")
        
        model.eval()
        with torch.no_grad():
            forward_start = time.time()
            h = model(data.x, data.edge_index)
            forward_time = time.time() - forward_start
            
        print(f"βœ… Forward pass successful!")
        print(f"   Input shape: {data.x.shape}")
        print(f"   Output shape: {h.shape}")
        print(f"   Forward time: {forward_time*1000:.2f}ms")
        print(f"   Output range: [{h.min():.3f}, {h.max():.3f}]")
        
        test_results['forward_pass'] = True
        
    except Exception as e:
        print(f"❌ Forward pass failed: {e}")
        return test_results
    
    # Test 4: Ordering Strategies (simplified for regularized model)
    print("\nπŸ”„ Testing ordering strategies...")
    
    # Only test BFS for regularized model to avoid complexity
    strategies = ['bfs']
    
    for strategy in strategies:
        try:
            config['ordering']['strategy'] = strategy
            test_model = GraphMamba(config).to(device)
            test_model.eval()
            
            strategy_start = time.time()
            with torch.no_grad():
                h = test_model(data.x, data.edge_index)
            strategy_time = time.time() - strategy_start
            
            print(f"βœ… {strategy:12} | Shape: {h.shape} | Time: {strategy_time*1000:.2f}ms")
            test_results['ordering_strategies'][strategy] = True
            
        except Exception as e:
            print(f"❌ {strategy:12} | Failed: {str(e)}")
            test_results['ordering_strategies'][strategy] = False
    
    try:
        # Test 5: Regularized Training
        print("\nπŸ‹οΈ Testing regularized training system...")
        
        # Reset to BFS for training
        config['ordering']['strategy'] = 'bfs'
        model = GraphMamba(config).to(device)
        trainer = GraphMambaTrainer(model, config, device)
        
        print(f"βœ… Trainer initialized!")
        print(f"   Optimizer: {type(trainer.optimizer).__name__}")
        print(f"   Learning rate: {trainer.lr}")
        print(f"   Epochs: {trainer.epochs}")
        print(f"   Weight decay: {config['training']['weight_decay']}")
        print(f"   Anti-overfitting: Enabled")
        
        # Run training
        print(f"\n🎯 Running regularized training...")
        training_start = time.time()
        history = trainer.train_node_classification(data, verbose=True)
        training_time = time.time() - training_start
        
        print(f"βœ… Training completed!")
        print(f"   Training time: {training_time:.2f}s")
        print(f"   Epochs trained: {len(history['train_loss'])}")
        print(f"   Best val accuracy: {trainer.best_val_acc:.4f}")
        print(f"   Final train accuracy: {history['train_acc'][-1]:.4f}")
        print(f"   Overfitting gap: {trainer.best_gap:.4f}")
        
        test_results['training'] = True
        
    except Exception as e:
        print(f"❌ Training failed: {e}")
        return test_results
    
    try:
        # Test 6: Evaluation
        print("\nπŸ“Š Testing evaluation...")
        
        test_metrics = trainer.test(data)
        
        print(f"βœ… Evaluation completed!")
        print(f"   Test accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
        print(f"   Test loss: {test_metrics['test_loss']:.4f}")
        print(f"   F1 macro: {test_metrics.get('f1_macro', 0):.4f}")
        print(f"   F1 micro: {test_metrics.get('f1_micro', 0):.4f}")
        print(f"   Precision: {test_metrics.get('precision', 0):.4f}")
        print(f"   Recall: {test_metrics.get('recall', 0):.4f}")
        
        test_results['evaluation'] = True
        
    except Exception as e:
        print(f"❌ Evaluation failed: {e}")
        return test_results
    
    try:
        # Test 7: Visualization
        print("\n🎨 Testing visualization...")
        
        # Create all visualizations
        graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200)
        metrics_fig = GraphVisualizer.create_metrics_plot(test_metrics)
        training_fig = GraphVisualizer.create_training_history_plot(history)
        
        print(f"βœ… Visualizations created!")
        print(f"   Graph plot: {type(graph_fig).__name__}")
        print(f"   Metrics plot: {type(metrics_fig).__name__}")
        print(f"   Training plot: {type(training_fig).__name__}")
        
        test_results['visualization'] = True
        
    except Exception as e:
        print(f"❌ Visualization failed: {e}")
        return test_results
    
    # Final Summary
    print("\n" + "=" * 60)
    print("πŸ† TEST SUMMARY")
    print("=" * 60)
    
    # Count passed tests correctly
    main_tests_passed = sum(1 for k, v in test_results.items() if k != 'ordering_strategies' and v)
    ordering_tests_passed = sum(test_results['ordering_strategies'].values())
    total_passed = main_tests_passed + ordering_tests_passed
    
    main_tests_total = len(test_results) - 1
    ordering_tests_total = len(test_results['ordering_strategies'])
    total_tests = main_tests_total + ordering_tests_total
    
    print(f"πŸ“Š Overall: {total_passed}/{total_tests} tests passed")
    print(f"πŸ’Ύ Device: {device}")
    print(f"⏱️ Total time: {time.time() - start_time:.2f}s")
    
    # Detailed results
    for test_name, result in test_results.items():
        if test_name == 'ordering_strategies':
            print(f"πŸ”„ Ordering strategies:")
            for strategy, strategy_result in result.items():
                status = "βœ…" if strategy_result else "❌"
                print(f"   {status} {strategy}")
        else:
            status = "βœ…" if result else "❌"
            print(f"{status} {test_name.replace('_', ' ').title()}")
    
    # Performance summary
    if test_results['evaluation']:
        print(f"\n🎯 Final Performance:")
        print(f"   Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
        print(f"   Training Time: {training_time:.2f}s")
        print(f"   Model Size: {total_params:,} parameters")
        print(f"   Params per sample: {params_per_sample:.1f}")
        
        # Compare with baselines
        cora_baselines = {
            'Random': 0.143,
            'Simple': 0.300,
            'GCN': 0.815,
            'GAT': 0.830
        }
        
        print(f"\nπŸ“ˆ Baseline Comparison (Cora):")
        for model_name, baseline in cora_baselines.items():
            diff = test_metrics['test_acc'] - baseline
            if diff > 0:
                status = "🟒"
                desc = f"(+{diff:.3f} better)"
            elif diff > -0.1:
                status = "🟑"
                desc = f"({diff:.3f} competitive)"
            else:
                status = "πŸ”΄"
                desc = f"({diff:.3f} gap)"
            print(f"   {status} {model_name:12}: {baseline:.3f} {desc}")
        
        # Overfitting analysis
        if trainer.best_gap < 0.1:
            print(f"\nπŸŽ‰ Excellent generalization! (gap: {trainer.best_gap:.3f})")
        elif trainer.best_gap < 0.2:
            print(f"\nπŸ‘ Good generalization (gap: {trainer.best_gap:.3f})")
        else:
            print(f"\n⚠️ Some overfitting detected (gap: {trainer.best_gap:.3f})")
    
    print(f"\n✨ All tests completed!")
    
    if total_passed == total_tests:
        print(f"πŸŽ‰ Perfect score! Regularized system working well!")
    elif total_passed >= total_tests * 0.8:
        print(f"πŸ‘ Great! System is mostly functional.")
    else:
        print(f"⚠️ Some issues detected.")
    
    return test_results

if __name__ == "__main__":
    results = run_comprehensive_test()
    
    # Exit with appropriate code
    main_tests_passed = sum(1 for k, v in results.items() if k != 'ordering_strategies' and v)
    ordering_tests_passed = sum(results['ordering_strategies'].values())
    total_passed = main_tests_passed + ordering_tests_passed
    
    main_tests_total = len(results) - 1
    ordering_tests_total = len(results['ordering_strategies'])
    total_tests = main_tests_total + ordering_tests_total
    
    if total_passed == total_tests:
        exit(0)
    else:
        exit(1)