serpent / app.py
kfoughali's picture
Update app.py
c6e11c4 verified
raw
history blame
5.15 kB
#!/usr/bin/env python3
"""
Enhanced Mamba Graph with structure preservation and interface fix
"""
import os
os.environ['OMP_NUM_THREADS'] = '4'
import torch
import time
import logging
import threading
import signal
from core.graph_mamba import GraphMamba, HybridGraphMamba, create_regularized_config
from core.trainer import GraphMambaTrainer
from data.loader import GraphDataLoader
from utils.visualization import GraphVisualizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_device():
if torch.cuda.is_available():
device = torch.device('cuda')
logger.info(f"πŸš€ CUDA available - using GPU: {torch.cuda.get_device_name()}")
else:
device = torch.device('cpu')
logger.info("πŸ’» Using CPU")
return device
def run_comprehensive_test():
"""Enhanced test with structure preservation"""
print("🧠 Enhanced Mamba Graph Neural Network")
print("=" * 60)
config = create_regularized_config()
device = get_device()
try:
# Data loading
print("\nπŸ“Š Loading Cora dataset...")
data_loader = GraphDataLoader()
dataset = data_loader.load_node_classification_data('Cora')
data = dataset[0].to(device)
info = data_loader.get_dataset_info(dataset)
print(f"βœ… Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges")
# Test both models
models_to_test = [
("Enhanced GraphMamba", GraphMamba),
("Hybrid GraphMamba", HybridGraphMamba)
]
results = {}
for model_name, model_class in models_to_test:
print(f"\nπŸ—οΈ Testing {model_name}...")
model = model_class(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
train_samples = data.train_mask.sum().item()
print(f" Parameters: {total_params:,} ({total_params/train_samples:.1f} per sample)")
# Training
trainer = GraphMambaTrainer(model, config, device)
print(f" Strategy: {config['ordering']['strategy']}")
start_time = time.time()
history = trainer.train_node_classification(data, verbose=False)
training_time = time.time() - start_time
# Evaluation
test_metrics = trainer.test(data)
results[model_name] = {
'test_acc': test_metrics['test_acc'],
'val_acc': trainer.best_val_acc,
'gap': trainer.best_gap,
'params': total_params,
'time': training_time
}
print(f" βœ… Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
print(f" πŸ“Š Validation: {trainer.best_val_acc:.4f}")
print(f" 🎯 Gap: {trainer.best_gap:.4f}")
print(f" ⏱️ Time: {training_time:.1f}s")
# Comparison
print(f"\nπŸ“ˆ Model Comparison:")
print(f"{'Model':<20} {'Test Acc':<10} {'Val Acc':<10} {'Gap':<8} {'Params':<8}")
print("-" * 60)
for name, result in results.items():
print(f"{name:<20} {result['test_acc']:.4f} {result['val_acc']:.4f} "
f"{result['gap']:>6.3f} {result['params']/1000:.0f}K")
# Best model
best_model = max(results.items(), key=lambda x: x[1]['test_acc'])
print(f"\nπŸ† Best: {best_model[0]} - {best_model[1]['test_acc']*100:.2f}% accuracy")
# Baseline comparison
baselines = {'Random': 0.143, 'GCN': 0.815, 'GAT': 0.830}
best_acc = best_model[1]['test_acc']
print(f"\nπŸ“Š vs Baselines:")
for baseline, acc in baselines.items():
diff = best_acc - acc
status = "🟒" if diff > 0 else "πŸ”΄"
print(f" {status} {baseline}: {acc:.3f} (diff: {diff:+.3f})")
print(f"\n✨ Testing complete! Process staying alive for interface...")
except Exception as e:
print(f"❌ Error: {e}")
print("Process staying alive despite error...")
def keep_alive():
"""Keep process running for interface"""
try:
while True:
time.sleep(60)
except KeyboardInterrupt:
print("\nπŸ‘‹ Shutting down gracefully...")
def run_background():
"""Run test in background thread"""
try:
run_comprehensive_test()
except Exception as e:
print(f"Background test error: {e}")
finally:
print("Background test complete, keeping alive...")
if __name__ == "__main__":
# Start test in background thread
test_thread = threading.Thread(target=run_background, daemon=True)
test_thread.start()
# Keep main thread alive for interface
try:
keep_alive()
except KeyboardInterrupt:
print("\nExiting...")
except Exception as e:
print(f"Main thread error: {e}")
keep_alive() # Still try to keep alive