#!/usr/bin/env python3 """ Enhanced Mamba Graph with structure preservation and interface fix """ import os os.environ['OMP_NUM_THREADS'] = '4' import torch import time import logging import threading import signal from core.graph_mamba import GraphMamba, HybridGraphMamba, create_regularized_config from core.trainer import GraphMambaTrainer from data.loader import GraphDataLoader from utils.visualization import GraphVisualizer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def get_device(): if torch.cuda.is_available(): device = torch.device('cuda') logger.info(f"šŸš€ CUDA available - using GPU: {torch.cuda.get_device_name()}") else: device = torch.device('cpu') logger.info("šŸ’» Using CPU") return device def run_comprehensive_test(): """Enhanced test with structure preservation""" print("🧠 Enhanced Mamba Graph Neural Network") print("=" * 60) config = create_regularized_config() device = get_device() try: # Data loading print("\nšŸ“Š Loading Cora dataset...") data_loader = GraphDataLoader() dataset = data_loader.load_node_classification_data('Cora') data = dataset[0].to(device) info = data_loader.get_dataset_info(dataset) print(f"āœ… Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges") # Test both models models_to_test = [ ("Enhanced GraphMamba", GraphMamba), ("Hybrid GraphMamba", HybridGraphMamba) ] results = {} for model_name, model_class in models_to_test: print(f"\nšŸ—ļø Testing {model_name}...") model = model_class(config).to(device) total_params = sum(p.numel() for p in model.parameters()) train_samples = data.train_mask.sum().item() print(f" Parameters: {total_params:,} ({total_params/train_samples:.1f} per sample)") # Training trainer = GraphMambaTrainer(model, config, device) print(f" Strategy: {config['ordering']['strategy']}") start_time = time.time() history = trainer.train_node_classification(data, verbose=False) training_time = time.time() - start_time # Evaluation test_metrics = trainer.test(data) results[model_name] = { 'test_acc': test_metrics['test_acc'], 'val_acc': trainer.best_val_acc, 'gap': trainer.best_gap, 'params': total_params, 'time': training_time } print(f" āœ… Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)") print(f" šŸ“Š Validation: {trainer.best_val_acc:.4f}") print(f" šŸŽÆ Gap: {trainer.best_gap:.4f}") print(f" ā±ļø Time: {training_time:.1f}s") # Comparison print(f"\nšŸ“ˆ Model Comparison:") print(f"{'Model':<20} {'Test Acc':<10} {'Val Acc':<10} {'Gap':<8} {'Params':<8}") print("-" * 60) for name, result in results.items(): print(f"{name:<20} {result['test_acc']:.4f} {result['val_acc']:.4f} " f"{result['gap']:>6.3f} {result['params']/1000:.0f}K") # Best model best_model = max(results.items(), key=lambda x: x[1]['test_acc']) print(f"\nšŸ† Best: {best_model[0]} - {best_model[1]['test_acc']*100:.2f}% accuracy") # Baseline comparison baselines = {'Random': 0.143, 'GCN': 0.815, 'GAT': 0.830} best_acc = best_model[1]['test_acc'] print(f"\nšŸ“Š vs Baselines:") for baseline, acc in baselines.items(): diff = best_acc - acc status = "🟢" if diff > 0 else "šŸ”“" print(f" {status} {baseline}: {acc:.3f} (diff: {diff:+.3f})") print(f"\n✨ Testing complete! Process staying alive for interface...") except Exception as e: print(f"āŒ Error: {e}") print("Process staying alive despite error...") def keep_alive(): """Keep process running for interface""" try: while True: time.sleep(60) except KeyboardInterrupt: print("\nšŸ‘‹ Shutting down gracefully...") def run_background(): """Run test in background thread""" try: run_comprehensive_test() except Exception as e: print(f"Background test error: {e}") finally: print("Background test complete, keeping alive...") if __name__ == "__main__": # Start test in background thread test_thread = threading.Thread(target=run_background, daemon=True) test_thread.start() # Keep main thread alive for interface try: keep_alive() except KeyboardInterrupt: print("\nExiting...") except Exception as e: print(f"Main thread error: {e}") keep_alive() # Still try to keep alive