|
|
|
""" |
|
Enhanced Mamba Graph with structure preservation and interface fix |
|
""" |
|
|
|
import os |
|
os.environ['OMP_NUM_THREADS'] = '4' |
|
|
|
import torch |
|
import time |
|
import logging |
|
import threading |
|
import signal |
|
from core.graph_mamba import GraphMamba, HybridGraphMamba, create_regularized_config |
|
from core.trainer import GraphMambaTrainer |
|
from data.loader import GraphDataLoader |
|
from utils.visualization import GraphVisualizer |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def get_device(): |
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
logger.info(f"π CUDA available - using GPU: {torch.cuda.get_device_name()}") |
|
else: |
|
device = torch.device('cpu') |
|
logger.info("π» Using CPU") |
|
return device |
|
|
|
def run_comprehensive_test(): |
|
"""Enhanced test with structure preservation""" |
|
print("π§ Enhanced Mamba Graph Neural Network") |
|
print("=" * 60) |
|
|
|
config = create_regularized_config() |
|
device = get_device() |
|
|
|
try: |
|
|
|
print("\nπ Loading Cora dataset...") |
|
data_loader = GraphDataLoader() |
|
dataset = data_loader.load_node_classification_data('Cora') |
|
data = dataset[0].to(device) |
|
info = data_loader.get_dataset_info(dataset) |
|
|
|
print(f"β
Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges") |
|
|
|
|
|
models_to_test = [ |
|
("Enhanced GraphMamba", GraphMamba), |
|
("Hybrid GraphMamba", HybridGraphMamba) |
|
] |
|
|
|
results = {} |
|
|
|
for model_name, model_class in models_to_test: |
|
print(f"\nποΈ Testing {model_name}...") |
|
|
|
model = model_class(config).to(device) |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
train_samples = data.train_mask.sum().item() |
|
|
|
print(f" Parameters: {total_params:,} ({total_params/train_samples:.1f} per sample)") |
|
|
|
|
|
trainer = GraphMambaTrainer(model, config, device) |
|
print(f" Strategy: {config['ordering']['strategy']}") |
|
|
|
start_time = time.time() |
|
history = trainer.train_node_classification(data, verbose=False) |
|
training_time = time.time() - start_time |
|
|
|
|
|
test_metrics = trainer.test(data) |
|
|
|
results[model_name] = { |
|
'test_acc': test_metrics['test_acc'], |
|
'val_acc': trainer.best_val_acc, |
|
'gap': trainer.best_gap, |
|
'params': total_params, |
|
'time': training_time |
|
} |
|
|
|
print(f" β
Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)") |
|
print(f" π Validation: {trainer.best_val_acc:.4f}") |
|
print(f" π― Gap: {trainer.best_gap:.4f}") |
|
print(f" β±οΈ Time: {training_time:.1f}s") |
|
|
|
|
|
print(f"\nπ Model Comparison:") |
|
print(f"{'Model':<20} {'Test Acc':<10} {'Val Acc':<10} {'Gap':<8} {'Params':<8}") |
|
print("-" * 60) |
|
|
|
for name, result in results.items(): |
|
print(f"{name:<20} {result['test_acc']:.4f} {result['val_acc']:.4f} " |
|
f"{result['gap']:>6.3f} {result['params']/1000:.0f}K") |
|
|
|
|
|
best_model = max(results.items(), key=lambda x: x[1]['test_acc']) |
|
print(f"\nπ Best: {best_model[0]} - {best_model[1]['test_acc']*100:.2f}% accuracy") |
|
|
|
|
|
baselines = {'Random': 0.143, 'GCN': 0.815, 'GAT': 0.830} |
|
best_acc = best_model[1]['test_acc'] |
|
|
|
print(f"\nπ vs Baselines:") |
|
for baseline, acc in baselines.items(): |
|
diff = best_acc - acc |
|
status = "π’" if diff > 0 else "π΄" |
|
print(f" {status} {baseline}: {acc:.3f} (diff: {diff:+.3f})") |
|
|
|
print(f"\n⨠Testing complete! Process staying alive for interface...") |
|
|
|
except Exception as e: |
|
print(f"β Error: {e}") |
|
print("Process staying alive despite error...") |
|
|
|
def keep_alive(): |
|
"""Keep process running for interface""" |
|
try: |
|
while True: |
|
time.sleep(60) |
|
except KeyboardInterrupt: |
|
print("\nπ Shutting down gracefully...") |
|
|
|
def run_background(): |
|
"""Run test in background thread""" |
|
try: |
|
run_comprehensive_test() |
|
except Exception as e: |
|
print(f"Background test error: {e}") |
|
finally: |
|
print("Background test complete, keeping alive...") |
|
|
|
if __name__ == "__main__": |
|
|
|
test_thread = threading.Thread(target=run_background, daemon=True) |
|
test_thread.start() |
|
|
|
|
|
try: |
|
keep_alive() |
|
except KeyboardInterrupt: |
|
print("\nExiting...") |
|
except Exception as e: |
|
print(f"Main thread error: {e}") |
|
keep_alive() |