serpent / app.py
kfoughali's picture
Update app.py
617d132 verified
raw
history blame
11.1 kB
#!/usr/bin/env python3
"""
Production test script for Mamba Graph implementation
Comprehensive testing with real data and enterprise validation
"""
import torch
import os
import time
import logging
from pathlib import Path
from core.graph_mamba import GraphMamba
from core.trainer import GraphMambaTrainer
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
from utils.visualization import GraphVisualizer
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def get_device():
"""Get the best available device - GPU preferred"""
if torch.cuda.is_available():
device = torch.device('cuda')
logger.info(f"πŸš€ CUDA available - using GPU: {torch.cuda.get_device_name()}")
else:
device = torch.device('cpu')
logger.info("πŸ’» Using CPU")
return device
def run_comprehensive_test():
"""Run comprehensive test suite"""
print("🧠 Mamba Graph Neural Network - Complete Test")
print("=" * 60)
# Test configuration
config = {
'model': {
'd_model': 128,
'd_state': 8,
'd_conv': 4,
'expand': 2,
'n_layers': 3,
'dropout': 0.1
},
'data': {
'batch_size': 16,
'test_split': 0.2
},
'training': {
'learning_rate': 0.01,
'weight_decay': 0.0005,
'epochs': 50,
'patience': 10,
'warmup_epochs': 5,
'min_lr': 1e-6
},
'ordering': {
'strategy': 'bfs',
'preserve_locality': True
}
}
# Setup device
device = get_device()
start_time = time.time()
# Test results
test_results = {
'data_loading': False,
'model_initialization': False,
'forward_pass': False,
'ordering_strategies': {},
'training': False,
'evaluation': False,
'visualization': False
}
try:
# Test 1: Data Loading
print("\nπŸ“Š Loading Cora dataset...")
data_loader = GraphDataLoader()
dataset = data_loader.load_node_classification_data('Cora')
data = dataset[0].to(device)
info = data_loader.get_dataset_info(dataset)
print(f"βœ… Dataset loaded successfully!")
print(f" Nodes: {data.num_nodes:,}")
print(f" Edges: {data.num_edges:,}")
print(f" Features: {info['num_features']}")
print(f" Classes: {info['num_classes']}")
print(f" Train nodes: {data.train_mask.sum()}")
print(f" Val nodes: {data.val_mask.sum()}")
print(f" Test nodes: {data.test_mask.sum()}")
test_results['data_loading'] = True
except Exception as e:
print(f"❌ Data loading failed: {e}")
return test_results
try:
# Test 2: Model Initialization
print("\nπŸ—οΈ Initializing GraphMamba...")
model = GraphMamba(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"βœ… Model initialized!")
print(f" Parameters: {total_params:,}")
print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB")
print(f" Device: {device}")
print(f" dtype: {next(model.parameters()).dtype}")
test_results['model_initialization'] = True
except Exception as e:
print(f"❌ Model initialization failed: {e}")
return test_results
try:
# Test 3: Forward Pass
print("\nπŸš€ Testing forward pass...")
model.eval()
with torch.no_grad():
forward_start = time.time()
h = model(data.x, data.edge_index)
forward_time = time.time() - forward_start
print(f"βœ… Forward pass successful!")
print(f" Input shape: {data.x.shape}")
print(f" Output shape: {h.shape}")
print(f" Forward time: {forward_time*1000:.2f}ms")
print(f" Output range: [{h.min():.3f}, {h.max():.3f}]")
test_results['forward_pass'] = True
except Exception as e:
print(f"❌ Forward pass failed: {e}")
return test_results
# Test 4: Ordering Strategies
print("\nπŸ”„ Testing ordering strategies...")
strategies = ['bfs', 'spectral', 'degree', 'community']
for strategy in strategies:
try:
config['ordering']['strategy'] = strategy
test_model = GraphMamba(config).to(device)
test_model.eval()
strategy_start = time.time()
with torch.no_grad():
h = test_model(data.x, data.edge_index)
strategy_time = time.time() - strategy_start
print(f"βœ… {strategy:12} | Shape: {h.shape} | Time: {strategy_time*1000:.2f}ms")
test_results['ordering_strategies'][strategy] = True
except Exception as e:
print(f"❌ {strategy:12} | Failed: {str(e)}")
test_results['ordering_strategies'][strategy] = False
try:
# Test 5: Training
print("\nπŸ‹οΈ Testing training system...")
# Reset to BFS for training
config['ordering']['strategy'] = 'bfs'
model = GraphMamba(config).to(device)
trainer = GraphMambaTrainer(model, config, device)
print(f"βœ… Trainer initialized!")
print(f" Optimizer: {type(trainer.optimizer).__name__}")
print(f" Learning rate: {trainer.lr}")
print(f" Epochs: {trainer.epochs}")
# Run training
print(f"\n🎯 Running training...")
training_start = time.time()
history = trainer.train_node_classification(data, verbose=True)
training_time = time.time() - training_start
print(f"βœ… Training completed!")
print(f" Training time: {training_time:.2f}s")
print(f" Epochs trained: {len(history['train_loss'])}")
print(f" Best val accuracy: {trainer.best_val_acc:.4f}")
print(f" Final train accuracy: {history['train_acc'][-1]:.4f}")
test_results['training'] = True
except Exception as e:
print(f"❌ Training failed: {e}")
return test_results
try:
# Test 6: Evaluation
print("\nπŸ“Š Testing evaluation...")
test_metrics = trainer.test(data)
print(f"βœ… Evaluation completed!")
print(f" Test accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
print(f" Test loss: {test_metrics['test_loss']:.4f}")
print(f" F1 macro: {test_metrics.get('f1_macro', 0):.4f}")
print(f" F1 micro: {test_metrics.get('f1_micro', 0):.4f}")
print(f" Precision: {test_metrics.get('precision', 0):.4f}")
print(f" Recall: {test_metrics.get('recall', 0):.4f}")
test_results['evaluation'] = True
except Exception as e:
print(f"❌ Evaluation failed: {e}")
return test_results
try:
# Test 7: Visualization
print("\n🎨 Testing visualization...")
# Create all visualizations
graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200)
metrics_fig = GraphVisualizer.create_metrics_plot(test_metrics)
training_fig = GraphVisualizer.create_training_history_plot(history)
print(f"βœ… Visualizations created!")
print(f" Graph plot: {type(graph_fig).__name__}")
print(f" Metrics plot: {type(metrics_fig).__name__}")
print(f" Training plot: {type(training_fig).__name__}")
test_results['visualization'] = True
except Exception as e:
print(f"❌ Visualization failed: {e}")
return test_results
# Final Summary
print("\n" + "=" * 60)
print("πŸ† TEST SUMMARY")
print("=" * 60)
# Count passed tests correctly
main_tests_passed = sum(1 for k, v in test_results.items() if k != 'ordering_strategies' and v)
ordering_tests_passed = sum(test_results['ordering_strategies'].values())
total_passed = main_tests_passed + ordering_tests_passed
main_tests_total = len(test_results) - 1 # Exclude ordering_strategies
ordering_tests_total = len(test_results['ordering_strategies'])
total_tests = main_tests_total + ordering_tests_total
print(f"πŸ“Š Overall: {total_passed}/{total_tests} tests passed")
print(f"πŸ’Ύ Device: {device}")
print(f"⏱️ Total time: {time.time() - start_time:.2f}s")
# Detailed results
for test_name, result in test_results.items():
if test_name == 'ordering_strategies':
print(f"πŸ”„ Ordering strategies:")
for strategy, strategy_result in result.items():
status = "βœ…" if strategy_result else "❌"
print(f" {status} {strategy}")
else:
status = "βœ…" if result else "❌"
print(f"{status} {test_name.replace('_', ' ').title()}")
# Performance summary
if test_results['evaluation']:
print(f"\n🎯 Final Performance:")
print(f" Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
print(f" Training Time: {training_time:.2f}s")
print(f" Model Size: {total_params:,} parameters")
# Compare with baselines
cora_baselines = {
'Random': 0.143,
'GCN': 0.815,
'GAT': 0.830,
'GraphSAGE': 0.824
}
print(f"\nπŸ“ˆ Baseline Comparison (Cora):")
for model_name, baseline in cora_baselines.items():
diff = test_metrics['test_acc'] - baseline
status = "🟒" if diff > 0 else "🟑" if diff > -0.05 else "πŸ”΄"
print(f" {status} {model_name:12}: {baseline:.3f} (diff: {diff:+.3f})")
print(f"\n✨ All tests completed!")
if total_passed == total_tests:
print(f"πŸŽ‰ Perfect score! System is production-ready!")
elif total_passed >= total_tests * 0.8:
print(f"πŸ‘ Great! System is mostly functional.")
else:
print(f"⚠️ Some issues detected.")
return test_results
if __name__ == "__main__":
results = run_comprehensive_test()
# Exit with appropriate code
main_tests_passed = sum(1 for k, v in results.items() if k != 'ordering_strategies' and v)
ordering_tests_passed = sum(results['ordering_strategies'].values())
total_passed = main_tests_passed + ordering_tests_passed
main_tests_total = len(results) - 1
ordering_tests_total = len(results['ordering_strategies'])
total_tests = main_tests_total + ordering_tests_total
if total_passed == total_tests:
exit(0)
else:
exit(1)