File size: 5,149 Bytes
c6e11c4 8aa0616 c6e11c4 a7b5172 c6e11c4 cf02254 c6e11c4 a7b5172 c6e11c4 cf02254 c6e11c4 3fb1716 c6e11c4 3fb1716 c6e11c4 3fb1716 c6e11c4 3fb1716 c6e11c4 454d2b9 c6e11c4 454d2b9 c6e11c4 454d2b9 c6e11c4 454d2b9 c6e11c4 a7b5172 c6e11c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#!/usr/bin/env python3
"""
Enhanced Mamba Graph with structure preservation and interface fix
"""
import os
os.environ['OMP_NUM_THREADS'] = '4'
import torch
import time
import logging
import threading
import signal
from core.graph_mamba import GraphMamba, HybridGraphMamba, create_regularized_config
from core.trainer import GraphMambaTrainer
from data.loader import GraphDataLoader
from utils.visualization import GraphVisualizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_device():
if torch.cuda.is_available():
device = torch.device('cuda')
logger.info(f"π CUDA available - using GPU: {torch.cuda.get_device_name()}")
else:
device = torch.device('cpu')
logger.info("π» Using CPU")
return device
def run_comprehensive_test():
"""Enhanced test with structure preservation"""
print("π§ Enhanced Mamba Graph Neural Network")
print("=" * 60)
config = create_regularized_config()
device = get_device()
try:
# Data loading
print("\nπ Loading Cora dataset...")
data_loader = GraphDataLoader()
dataset = data_loader.load_node_classification_data('Cora')
data = dataset[0].to(device)
info = data_loader.get_dataset_info(dataset)
print(f"β
Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges")
# Test both models
models_to_test = [
("Enhanced GraphMamba", GraphMamba),
("Hybrid GraphMamba", HybridGraphMamba)
]
results = {}
for model_name, model_class in models_to_test:
print(f"\nποΈ Testing {model_name}...")
model = model_class(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
train_samples = data.train_mask.sum().item()
print(f" Parameters: {total_params:,} ({total_params/train_samples:.1f} per sample)")
# Training
trainer = GraphMambaTrainer(model, config, device)
print(f" Strategy: {config['ordering']['strategy']}")
start_time = time.time()
history = trainer.train_node_classification(data, verbose=False)
training_time = time.time() - start_time
# Evaluation
test_metrics = trainer.test(data)
results[model_name] = {
'test_acc': test_metrics['test_acc'],
'val_acc': trainer.best_val_acc,
'gap': trainer.best_gap,
'params': total_params,
'time': training_time
}
print(f" β
Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
print(f" π Validation: {trainer.best_val_acc:.4f}")
print(f" π― Gap: {trainer.best_gap:.4f}")
print(f" β±οΈ Time: {training_time:.1f}s")
# Comparison
print(f"\nπ Model Comparison:")
print(f"{'Model':<20} {'Test Acc':<10} {'Val Acc':<10} {'Gap':<8} {'Params':<8}")
print("-" * 60)
for name, result in results.items():
print(f"{name:<20} {result['test_acc']:.4f} {result['val_acc']:.4f} "
f"{result['gap']:>6.3f} {result['params']/1000:.0f}K")
# Best model
best_model = max(results.items(), key=lambda x: x[1]['test_acc'])
print(f"\nπ Best: {best_model[0]} - {best_model[1]['test_acc']*100:.2f}% accuracy")
# Baseline comparison
baselines = {'Random': 0.143, 'GCN': 0.815, 'GAT': 0.830}
best_acc = best_model[1]['test_acc']
print(f"\nπ vs Baselines:")
for baseline, acc in baselines.items():
diff = best_acc - acc
status = "π’" if diff > 0 else "π΄"
print(f" {status} {baseline}: {acc:.3f} (diff: {diff:+.3f})")
print(f"\n⨠Testing complete! Process staying alive for interface...")
except Exception as e:
print(f"β Error: {e}")
print("Process staying alive despite error...")
def keep_alive():
"""Keep process running for interface"""
try:
while True:
time.sleep(60)
except KeyboardInterrupt:
print("\nπ Shutting down gracefully...")
def run_background():
"""Run test in background thread"""
try:
run_comprehensive_test()
except Exception as e:
print(f"Background test error: {e}")
finally:
print("Background test complete, keeping alive...")
if __name__ == "__main__":
# Start test in background thread
test_thread = threading.Thread(target=run_background, daemon=True)
test_thread.start()
# Keep main thread alive for interface
try:
keep_alive()
except KeyboardInterrupt:
print("\nExiting...")
except Exception as e:
print(f"Main thread error: {e}")
keep_alive() # Still try to keep alive |