File size: 5,149 Bytes
c6e11c4
 
 
 
 
 
 
 
8aa0616
c6e11c4
a7b5172
c6e11c4
 
 
 
 
 
cf02254
c6e11c4
a7b5172
 
c6e11c4
 
 
 
 
 
 
 
 
 
 
 
 
cf02254
c6e11c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fb1716
c6e11c4
 
 
 
 
3fb1716
c6e11c4
 
 
3fb1716
c6e11c4
 
 
3fb1716
c6e11c4
 
454d2b9
c6e11c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454d2b9
c6e11c4
 
 
 
 
 
 
454d2b9
c6e11c4
 
 
 
 
 
 
 
454d2b9
c6e11c4
 
 
 
a7b5172
c6e11c4
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
"""
Enhanced Mamba Graph with structure preservation and interface fix
"""

import os
os.environ['OMP_NUM_THREADS'] = '4'

import torch
import time
import logging
import threading
import signal
from core.graph_mamba import GraphMamba, HybridGraphMamba, create_regularized_config
from core.trainer import GraphMambaTrainer
from data.loader import GraphDataLoader
from utils.visualization import GraphVisualizer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda')
        logger.info(f"πŸš€ CUDA available - using GPU: {torch.cuda.get_device_name()}")
    else:
        device = torch.device('cpu')
        logger.info("πŸ’» Using CPU")
    return device

def run_comprehensive_test():
    """Enhanced test with structure preservation"""
    print("🧠 Enhanced Mamba Graph Neural Network")
    print("=" * 60)
    
    config = create_regularized_config()
    device = get_device()
    
    try:
        # Data loading
        print("\nπŸ“Š Loading Cora dataset...")
        data_loader = GraphDataLoader()
        dataset = data_loader.load_node_classification_data('Cora')
        data = dataset[0].to(device)
        info = data_loader.get_dataset_info(dataset)
        
        print(f"βœ… Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges")
        
        # Test both models
        models_to_test = [
            ("Enhanced GraphMamba", GraphMamba),
            ("Hybrid GraphMamba", HybridGraphMamba)
        ]
        
        results = {}
        
        for model_name, model_class in models_to_test:
            print(f"\nπŸ—οΈ Testing {model_name}...")
            
            model = model_class(config).to(device)
            total_params = sum(p.numel() for p in model.parameters())
            train_samples = data.train_mask.sum().item()
            
            print(f"   Parameters: {total_params:,} ({total_params/train_samples:.1f} per sample)")
            
            # Training
            trainer = GraphMambaTrainer(model, config, device)
            print(f"   Strategy: {config['ordering']['strategy']}")
            
            start_time = time.time()
            history = trainer.train_node_classification(data, verbose=False)
            training_time = time.time() - start_time
            
            # Evaluation
            test_metrics = trainer.test(data)
            
            results[model_name] = {
                'test_acc': test_metrics['test_acc'],
                'val_acc': trainer.best_val_acc,
                'gap': trainer.best_gap,
                'params': total_params,
                'time': training_time
            }
            
            print(f"   βœ… Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
            print(f"   πŸ“Š Validation: {trainer.best_val_acc:.4f}")
            print(f"   🎯 Gap: {trainer.best_gap:.4f}")
            print(f"   ⏱️ Time: {training_time:.1f}s")
        
        # Comparison
        print(f"\nπŸ“ˆ Model Comparison:")
        print(f"{'Model':<20} {'Test Acc':<10} {'Val Acc':<10} {'Gap':<8} {'Params':<8}")
        print("-" * 60)
        
        for name, result in results.items():
            print(f"{name:<20} {result['test_acc']:.4f}     {result['val_acc']:.4f}     "
                  f"{result['gap']:>6.3f}   {result['params']/1000:.0f}K")
        
        # Best model
        best_model = max(results.items(), key=lambda x: x[1]['test_acc'])
        print(f"\nπŸ† Best: {best_model[0]} - {best_model[1]['test_acc']*100:.2f}% accuracy")
        
        # Baseline comparison
        baselines = {'Random': 0.143, 'GCN': 0.815, 'GAT': 0.830}
        best_acc = best_model[1]['test_acc']
        
        print(f"\nπŸ“Š vs Baselines:")
        for baseline, acc in baselines.items():
            diff = best_acc - acc
            status = "🟒" if diff > 0 else "πŸ”΄"
            print(f"   {status} {baseline}: {acc:.3f} (diff: {diff:+.3f})")
        
        print(f"\n✨ Testing complete! Process staying alive for interface...")
        
    except Exception as e:
        print(f"❌ Error: {e}")
        print("Process staying alive despite error...")

def keep_alive():
    """Keep process running for interface"""
    try:
        while True:
            time.sleep(60)
    except KeyboardInterrupt:
        print("\nπŸ‘‹ Shutting down gracefully...")

def run_background():
    """Run test in background thread"""
    try:
        run_comprehensive_test()
    except Exception as e:
        print(f"Background test error: {e}")
    finally:
        print("Background test complete, keeping alive...")

if __name__ == "__main__":
    # Start test in background thread
    test_thread = threading.Thread(target=run_background, daemon=True)
    test_thread.start()
    
    # Keep main thread alive for interface
    try:
        keep_alive()
    except KeyboardInterrupt:
        print("\nExiting...")
    except Exception as e:
        print(f"Main thread error: {e}")
        keep_alive()  # Still try to keep alive