|
|
|
""" |
|
Production test script for Mamba Graph implementation |
|
Fixed for overfitting with regularized configuration |
|
""" |
|
|
|
import os |
|
os.environ['OMP_NUM_THREADS'] = '4' |
|
|
|
import torch |
|
import time |
|
import logging |
|
from pathlib import Path |
|
from core.graph_mamba import GraphMamba, create_regularized_config |
|
from core.trainer import GraphMambaTrainer |
|
from data.loader import GraphDataLoader |
|
from utils.metrics import GraphMetrics |
|
from utils.visualization import GraphVisualizer |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def get_device(): |
|
"""Get the best available device - GPU preferred""" |
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
logger.info(f"π CUDA available - using GPU: {torch.cuda.get_device_name()}") |
|
else: |
|
device = torch.device('cpu') |
|
logger.info("π» Using CPU") |
|
return device |
|
|
|
def run_comprehensive_test(): |
|
"""Run comprehensive test suite with overfitting fixes""" |
|
print("π§ Mamba Graph Neural Network - Complete Test") |
|
print("=" * 60) |
|
|
|
|
|
config = create_regularized_config() |
|
|
|
|
|
device = get_device() |
|
start_time = time.time() |
|
|
|
|
|
test_results = { |
|
'data_loading': False, |
|
'model_initialization': False, |
|
'forward_pass': False, |
|
'ordering_strategies': {}, |
|
'training': False, |
|
'evaluation': False, |
|
'visualization': False |
|
} |
|
|
|
try: |
|
|
|
print("\nπ Loading Cora dataset...") |
|
|
|
data_loader = GraphDataLoader() |
|
dataset = data_loader.load_node_classification_data('Cora') |
|
data = dataset[0].to(device) |
|
|
|
info = data_loader.get_dataset_info(dataset) |
|
|
|
print(f"β
Dataset loaded successfully!") |
|
print(f" Nodes: {data.num_nodes:,}") |
|
print(f" Edges: {data.num_edges:,}") |
|
print(f" Features: {info['num_features']}") |
|
print(f" Classes: {info['num_classes']}") |
|
print(f" Train nodes: {data.train_mask.sum()}") |
|
print(f" Val nodes: {data.val_mask.sum()}") |
|
print(f" Test nodes: {data.test_mask.sum()}") |
|
|
|
test_results['data_loading'] = True |
|
|
|
except Exception as e: |
|
print(f"β Data loading failed: {e}") |
|
return test_results |
|
|
|
try: |
|
|
|
print("\nποΈ Initializing GraphMamba (Regularized)...") |
|
|
|
model = GraphMamba(config).to(device) |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
|
print(f"β
Model initialized!") |
|
print(f" Parameters: {total_params:,}") |
|
print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB") |
|
print(f" Device: {device}") |
|
print(f" Model type: Regularized (Anti-overfitting)") |
|
|
|
|
|
train_samples = data.train_mask.sum().item() |
|
params_per_sample = total_params / train_samples |
|
print(f" Params per training sample: {params_per_sample:.1f}") |
|
|
|
if params_per_sample < 500: |
|
print(" β
Good parameter ratio - low overfitting risk") |
|
elif params_per_sample < 1000: |
|
print(" β οΈ Moderate parameter ratio - watch for overfitting") |
|
else: |
|
print(" π¨ High parameter ratio - high overfitting risk") |
|
|
|
test_results['model_initialization'] = True |
|
|
|
except Exception as e: |
|
print(f"β Model initialization failed: {e}") |
|
return test_results |
|
|
|
try: |
|
|
|
print("\nπ Testing forward pass...") |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
forward_start = time.time() |
|
h = model(data.x, data.edge_index) |
|
forward_time = time.time() - forward_start |
|
|
|
print(f"β
Forward pass successful!") |
|
print(f" Input shape: {data.x.shape}") |
|
print(f" Output shape: {h.shape}") |
|
print(f" Forward time: {forward_time*1000:.2f}ms") |
|
print(f" Output range: [{h.min():.3f}, {h.max():.3f}]") |
|
|
|
test_results['forward_pass'] = True |
|
|
|
except Exception as e: |
|
print(f"β Forward pass failed: {e}") |
|
return test_results |
|
|
|
|
|
print("\nπ Testing ordering strategies...") |
|
|
|
|
|
strategies = ['bfs'] |
|
|
|
for strategy in strategies: |
|
try: |
|
config['ordering']['strategy'] = strategy |
|
test_model = GraphMamba(config).to(device) |
|
test_model.eval() |
|
|
|
strategy_start = time.time() |
|
with torch.no_grad(): |
|
h = test_model(data.x, data.edge_index) |
|
strategy_time = time.time() - strategy_start |
|
|
|
print(f"β
{strategy:12} | Shape: {h.shape} | Time: {strategy_time*1000:.2f}ms") |
|
test_results['ordering_strategies'][strategy] = True |
|
|
|
except Exception as e: |
|
print(f"β {strategy:12} | Failed: {str(e)}") |
|
test_results['ordering_strategies'][strategy] = False |
|
|
|
try: |
|
|
|
print("\nποΈ Testing regularized training system...") |
|
|
|
|
|
config['ordering']['strategy'] = 'bfs' |
|
model = GraphMamba(config).to(device) |
|
trainer = GraphMambaTrainer(model, config, device) |
|
|
|
print(f"β
Trainer initialized!") |
|
print(f" Optimizer: {type(trainer.optimizer).__name__}") |
|
print(f" Learning rate: {trainer.lr}") |
|
print(f" Epochs: {trainer.epochs}") |
|
print(f" Weight decay: {config['training']['weight_decay']}") |
|
print(f" Anti-overfitting: Enabled") |
|
|
|
|
|
print(f"\nπ― Running regularized training...") |
|
training_start = time.time() |
|
history = trainer.train_node_classification(data, verbose=True) |
|
training_time = time.time() - training_start |
|
|
|
print(f"β
Training completed!") |
|
print(f" Training time: {training_time:.2f}s") |
|
print(f" Epochs trained: {len(history['train_loss'])}") |
|
print(f" Best val accuracy: {trainer.best_val_acc:.4f}") |
|
print(f" Final train accuracy: {history['train_acc'][-1]:.4f}") |
|
print(f" Overfitting gap: {trainer.best_gap:.4f}") |
|
|
|
test_results['training'] = True |
|
|
|
except Exception as e: |
|
print(f"β Training failed: {e}") |
|
return test_results |
|
|
|
try: |
|
|
|
print("\nπ Testing evaluation...") |
|
|
|
test_metrics = trainer.test(data) |
|
|
|
print(f"β
Evaluation completed!") |
|
print(f" Test accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)") |
|
print(f" Test loss: {test_metrics['test_loss']:.4f}") |
|
print(f" F1 macro: {test_metrics.get('f1_macro', 0):.4f}") |
|
print(f" F1 micro: {test_metrics.get('f1_micro', 0):.4f}") |
|
print(f" Precision: {test_metrics.get('precision', 0):.4f}") |
|
print(f" Recall: {test_metrics.get('recall', 0):.4f}") |
|
|
|
test_results['evaluation'] = True |
|
|
|
except Exception as e: |
|
print(f"β Evaluation failed: {e}") |
|
return test_results |
|
|
|
try: |
|
|
|
print("\nπ¨ Testing visualization...") |
|
|
|
|
|
graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200) |
|
metrics_fig = GraphVisualizer.create_metrics_plot(test_metrics) |
|
training_fig = GraphVisualizer.create_training_history_plot(history) |
|
|
|
print(f"β
Visualizations created!") |
|
print(f" Graph plot: {type(graph_fig).__name__}") |
|
print(f" Metrics plot: {type(metrics_fig).__name__}") |
|
print(f" Training plot: {type(training_fig).__name__}") |
|
|
|
test_results['visualization'] = True |
|
|
|
except Exception as e: |
|
print(f"β Visualization failed: {e}") |
|
return test_results |
|
|
|
|
|
print("\n" + "=" * 60) |
|
print("π TEST SUMMARY") |
|
print("=" * 60) |
|
|
|
|
|
main_tests_passed = sum(1 for k, v in test_results.items() if k != 'ordering_strategies' and v) |
|
ordering_tests_passed = sum(test_results['ordering_strategies'].values()) |
|
total_passed = main_tests_passed + ordering_tests_passed |
|
|
|
main_tests_total = len(test_results) - 1 |
|
ordering_tests_total = len(test_results['ordering_strategies']) |
|
total_tests = main_tests_total + ordering_tests_total |
|
|
|
print(f"π Overall: {total_passed}/{total_tests} tests passed") |
|
print(f"πΎ Device: {device}") |
|
print(f"β±οΈ Total time: {time.time() - start_time:.2f}s") |
|
|
|
|
|
for test_name, result in test_results.items(): |
|
if test_name == 'ordering_strategies': |
|
print(f"π Ordering strategies:") |
|
for strategy, strategy_result in result.items(): |
|
status = "β
" if strategy_result else "β" |
|
print(f" {status} {strategy}") |
|
else: |
|
status = "β
" if result else "β" |
|
print(f"{status} {test_name.replace('_', ' ').title()}") |
|
|
|
|
|
if test_results['evaluation']: |
|
print(f"\nπ― Final Performance:") |
|
print(f" Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)") |
|
print(f" Training Time: {training_time:.2f}s") |
|
print(f" Model Size: {total_params:,} parameters") |
|
print(f" Params per sample: {params_per_sample:.1f}") |
|
|
|
|
|
cora_baselines = { |
|
'Random': 0.143, |
|
'Simple': 0.300, |
|
'GCN': 0.815, |
|
'GAT': 0.830 |
|
} |
|
|
|
print(f"\nπ Baseline Comparison (Cora):") |
|
for model_name, baseline in cora_baselines.items(): |
|
diff = test_metrics['test_acc'] - baseline |
|
if diff > 0: |
|
status = "π’" |
|
desc = f"(+{diff:.3f} better)" |
|
elif diff > -0.1: |
|
status = "π‘" |
|
desc = f"({diff:.3f} competitive)" |
|
else: |
|
status = "π΄" |
|
desc = f"({diff:.3f} gap)" |
|
print(f" {status} {model_name:12}: {baseline:.3f} {desc}") |
|
|
|
|
|
if trainer.best_gap < 0.1: |
|
print(f"\nπ Excellent generalization! (gap: {trainer.best_gap:.3f})") |
|
elif trainer.best_gap < 0.2: |
|
print(f"\nπ Good generalization (gap: {trainer.best_gap:.3f})") |
|
else: |
|
print(f"\nβ οΈ Some overfitting detected (gap: {trainer.best_gap:.3f})") |
|
|
|
print(f"\n⨠All tests completed!") |
|
|
|
if total_passed == total_tests: |
|
print(f"π Perfect score! Regularized system working well!") |
|
elif total_passed >= total_tests * 0.8: |
|
print(f"π Great! System is mostly functional.") |
|
else: |
|
print(f"β οΈ Some issues detected.") |
|
|
|
return test_results |
|
|
|
if __name__ == "__main__": |
|
results = run_comprehensive_test() |
|
|
|
|
|
main_tests_passed = sum(1 for k, v in results.items() if k != 'ordering_strategies' and v) |
|
ordering_tests_passed = sum(results['ordering_strategies'].values()) |
|
total_passed = main_tests_passed + ordering_tests_passed |
|
|
|
main_tests_total = len(results) - 1 |
|
ordering_tests_total = len(results['ordering_strategies']) |
|
total_tests = main_tests_total + ordering_tests_total |
|
|
|
if total_passed == total_tests: |
|
exit(0) |
|
else: |
|
exit(1) |