serpent / app.py
kfoughali's picture
Update app.py
8aa0616 verified
raw
history blame
12.2 kB
#!/usr/bin/env python3
"""
Production test script for Mamba Graph implementation
Fixed for overfitting with regularized configuration
"""
import os
os.environ['OMP_NUM_THREADS'] = '4' # Fix warning
import torch
import time
import logging
from pathlib import Path
from core.graph_mamba import GraphMamba, create_regularized_config
from core.trainer import GraphMambaTrainer
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
from utils.visualization import GraphVisualizer
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def get_device():
"""Get the best available device - GPU preferred"""
if torch.cuda.is_available():
device = torch.device('cuda')
logger.info(f"πŸš€ CUDA available - using GPU: {torch.cuda.get_device_name()}")
else:
device = torch.device('cpu')
logger.info("πŸ’» Using CPU")
return device
def run_comprehensive_test():
"""Run comprehensive test suite with overfitting fixes"""
print("🧠 Mamba Graph Neural Network - Complete Test")
print("=" * 60)
# Use regularized configuration to prevent overfitting
config = create_regularized_config()
# Setup device
device = get_device()
start_time = time.time()
# Test results
test_results = {
'data_loading': False,
'model_initialization': False,
'forward_pass': False,
'ordering_strategies': {},
'training': False,
'evaluation': False,
'visualization': False
}
try:
# Test 1: Data Loading
print("\nπŸ“Š Loading Cora dataset...")
data_loader = GraphDataLoader()
dataset = data_loader.load_node_classification_data('Cora')
data = dataset[0].to(device)
info = data_loader.get_dataset_info(dataset)
print(f"βœ… Dataset loaded successfully!")
print(f" Nodes: {data.num_nodes:,}")
print(f" Edges: {data.num_edges:,}")
print(f" Features: {info['num_features']}")
print(f" Classes: {info['num_classes']}")
print(f" Train nodes: {data.train_mask.sum()}")
print(f" Val nodes: {data.val_mask.sum()}")
print(f" Test nodes: {data.test_mask.sum()}")
test_results['data_loading'] = True
except Exception as e:
print(f"❌ Data loading failed: {e}")
return test_results
try:
# Test 2: Model Initialization with regularized config
print("\nπŸ—οΈ Initializing GraphMamba (Regularized)...")
model = GraphMamba(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"βœ… Model initialized!")
print(f" Parameters: {total_params:,}")
print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB")
print(f" Device: {device}")
print(f" Model type: Regularized (Anti-overfitting)")
# Check if parameter count is reasonable for small training set
train_samples = data.train_mask.sum().item()
params_per_sample = total_params / train_samples
print(f" Params per training sample: {params_per_sample:.1f}")
if params_per_sample < 500:
print(" βœ… Good parameter ratio - low overfitting risk")
elif params_per_sample < 1000:
print(" ⚠️ Moderate parameter ratio - watch for overfitting")
else:
print(" 🚨 High parameter ratio - high overfitting risk")
test_results['model_initialization'] = True
except Exception as e:
print(f"❌ Model initialization failed: {e}")
return test_results
try:
# Test 3: Forward Pass
print("\nπŸš€ Testing forward pass...")
model.eval()
with torch.no_grad():
forward_start = time.time()
h = model(data.x, data.edge_index)
forward_time = time.time() - forward_start
print(f"βœ… Forward pass successful!")
print(f" Input shape: {data.x.shape}")
print(f" Output shape: {h.shape}")
print(f" Forward time: {forward_time*1000:.2f}ms")
print(f" Output range: [{h.min():.3f}, {h.max():.3f}]")
test_results['forward_pass'] = True
except Exception as e:
print(f"❌ Forward pass failed: {e}")
return test_results
# Test 4: Ordering Strategies (simplified for regularized model)
print("\nπŸ”„ Testing ordering strategies...")
# Only test BFS for regularized model to avoid complexity
strategies = ['bfs']
for strategy in strategies:
try:
config['ordering']['strategy'] = strategy
test_model = GraphMamba(config).to(device)
test_model.eval()
strategy_start = time.time()
with torch.no_grad():
h = test_model(data.x, data.edge_index)
strategy_time = time.time() - strategy_start
print(f"βœ… {strategy:12} | Shape: {h.shape} | Time: {strategy_time*1000:.2f}ms")
test_results['ordering_strategies'][strategy] = True
except Exception as e:
print(f"❌ {strategy:12} | Failed: {str(e)}")
test_results['ordering_strategies'][strategy] = False
try:
# Test 5: Regularized Training
print("\nπŸ‹οΈ Testing regularized training system...")
# Reset to BFS for training
config['ordering']['strategy'] = 'bfs'
model = GraphMamba(config).to(device)
trainer = GraphMambaTrainer(model, config, device)
print(f"βœ… Trainer initialized!")
print(f" Optimizer: {type(trainer.optimizer).__name__}")
print(f" Learning rate: {trainer.lr}")
print(f" Epochs: {trainer.epochs}")
print(f" Weight decay: {config['training']['weight_decay']}")
print(f" Anti-overfitting: Enabled")
# Run training
print(f"\n🎯 Running regularized training...")
training_start = time.time()
history = trainer.train_node_classification(data, verbose=True)
training_time = time.time() - training_start
print(f"βœ… Training completed!")
print(f" Training time: {training_time:.2f}s")
print(f" Epochs trained: {len(history['train_loss'])}")
print(f" Best val accuracy: {trainer.best_val_acc:.4f}")
print(f" Final train accuracy: {history['train_acc'][-1]:.4f}")
print(f" Overfitting gap: {trainer.best_gap:.4f}")
test_results['training'] = True
except Exception as e:
print(f"❌ Training failed: {e}")
return test_results
try:
# Test 6: Evaluation
print("\nπŸ“Š Testing evaluation...")
test_metrics = trainer.test(data)
print(f"βœ… Evaluation completed!")
print(f" Test accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
print(f" Test loss: {test_metrics['test_loss']:.4f}")
print(f" F1 macro: {test_metrics.get('f1_macro', 0):.4f}")
print(f" F1 micro: {test_metrics.get('f1_micro', 0):.4f}")
print(f" Precision: {test_metrics.get('precision', 0):.4f}")
print(f" Recall: {test_metrics.get('recall', 0):.4f}")
test_results['evaluation'] = True
except Exception as e:
print(f"❌ Evaluation failed: {e}")
return test_results
try:
# Test 7: Visualization
print("\n🎨 Testing visualization...")
# Create all visualizations
graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200)
metrics_fig = GraphVisualizer.create_metrics_plot(test_metrics)
training_fig = GraphVisualizer.create_training_history_plot(history)
print(f"βœ… Visualizations created!")
print(f" Graph plot: {type(graph_fig).__name__}")
print(f" Metrics plot: {type(metrics_fig).__name__}")
print(f" Training plot: {type(training_fig).__name__}")
test_results['visualization'] = True
except Exception as e:
print(f"❌ Visualization failed: {e}")
return test_results
# Final Summary
print("\n" + "=" * 60)
print("πŸ† TEST SUMMARY")
print("=" * 60)
# Count passed tests correctly
main_tests_passed = sum(1 for k, v in test_results.items() if k != 'ordering_strategies' and v)
ordering_tests_passed = sum(test_results['ordering_strategies'].values())
total_passed = main_tests_passed + ordering_tests_passed
main_tests_total = len(test_results) - 1
ordering_tests_total = len(test_results['ordering_strategies'])
total_tests = main_tests_total + ordering_tests_total
print(f"πŸ“Š Overall: {total_passed}/{total_tests} tests passed")
print(f"πŸ’Ύ Device: {device}")
print(f"⏱️ Total time: {time.time() - start_time:.2f}s")
# Detailed results
for test_name, result in test_results.items():
if test_name == 'ordering_strategies':
print(f"πŸ”„ Ordering strategies:")
for strategy, strategy_result in result.items():
status = "βœ…" if strategy_result else "❌"
print(f" {status} {strategy}")
else:
status = "βœ…" if result else "❌"
print(f"{status} {test_name.replace('_', ' ').title()}")
# Performance summary
if test_results['evaluation']:
print(f"\n🎯 Final Performance:")
print(f" Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
print(f" Training Time: {training_time:.2f}s")
print(f" Model Size: {total_params:,} parameters")
print(f" Params per sample: {params_per_sample:.1f}")
# Compare with baselines
cora_baselines = {
'Random': 0.143,
'Simple': 0.300,
'GCN': 0.815,
'GAT': 0.830
}
print(f"\nπŸ“ˆ Baseline Comparison (Cora):")
for model_name, baseline in cora_baselines.items():
diff = test_metrics['test_acc'] - baseline
if diff > 0:
status = "🟒"
desc = f"(+{diff:.3f} better)"
elif diff > -0.1:
status = "🟑"
desc = f"({diff:.3f} competitive)"
else:
status = "πŸ”΄"
desc = f"({diff:.3f} gap)"
print(f" {status} {model_name:12}: {baseline:.3f} {desc}")
# Overfitting analysis
if trainer.best_gap < 0.1:
print(f"\nπŸŽ‰ Excellent generalization! (gap: {trainer.best_gap:.3f})")
elif trainer.best_gap < 0.2:
print(f"\nπŸ‘ Good generalization (gap: {trainer.best_gap:.3f})")
else:
print(f"\n⚠️ Some overfitting detected (gap: {trainer.best_gap:.3f})")
print(f"\n✨ All tests completed!")
if total_passed == total_tests:
print(f"πŸŽ‰ Perfect score! Regularized system working well!")
elif total_passed >= total_tests * 0.8:
print(f"πŸ‘ Great! System is mostly functional.")
else:
print(f"⚠️ Some issues detected.")
return test_results
if __name__ == "__main__":
results = run_comprehensive_test()
# Exit with appropriate code
main_tests_passed = sum(1 for k, v in results.items() if k != 'ordering_strategies' and v)
ordering_tests_passed = sum(results['ordering_strategies'].values())
total_passed = main_tests_passed + ordering_tests_passed
main_tests_total = len(results) - 1
ordering_tests_total = len(results['ordering_strategies'])
total_tests = main_tests_total + ordering_tests_total
if total_passed == total_tests:
exit(0)
else:
exit(1)