|
|
|
""" |
|
Complete test script for Mamba Graph implementation |
|
Tests training, evaluation, and visualization |
|
""" |
|
|
|
import torch |
|
import os |
|
import time |
|
from core.graph_mamba import GraphMamba |
|
from core.trainer import GraphMambaTrainer |
|
from data.loader import GraphDataLoader |
|
from utils.metrics import GraphMetrics |
|
from utils.visualization import GraphVisualizer |
|
|
|
def main(): |
|
print("π§ Mamba Graph Neural Network - Complete Test") |
|
print("=" * 60) |
|
|
|
|
|
config = { |
|
'model': { |
|
'd_model': 128, |
|
'd_state': 8, |
|
'd_conv': 4, |
|
'expand': 2, |
|
'n_layers': 3, |
|
'dropout': 0.1 |
|
}, |
|
'data': { |
|
'batch_size': 16, |
|
'test_split': 0.2 |
|
}, |
|
'training': { |
|
'learning_rate': 0.01, |
|
'weight_decay': 0.0005, |
|
'epochs': 50, |
|
'patience': 10, |
|
'warmup_epochs': 5, |
|
'min_lr': 1e-6 |
|
}, |
|
'ordering': { |
|
'strategy': 'bfs', |
|
'preserve_locality': True |
|
} |
|
} |
|
|
|
|
|
if os.getenv('SPACE_ID'): |
|
device = torch.device('cpu') |
|
else: |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"πΎ Device: {device}") |
|
|
|
|
|
print("\nπ Loading Cora dataset...") |
|
try: |
|
data_loader = GraphDataLoader() |
|
dataset = data_loader.load_node_classification_data('Cora') |
|
data = dataset[0].to(device) |
|
|
|
info = data_loader.get_dataset_info(dataset) |
|
print(f"β
Dataset loaded successfully!") |
|
print(f" Nodes: {data.num_nodes:,}") |
|
print(f" Edges: {data.num_edges:,}") |
|
print(f" Features: {info['num_features']}") |
|
print(f" Classes: {info['num_classes']}") |
|
print(f" Train nodes: {data.train_mask.sum()}") |
|
print(f" Val nodes: {data.val_mask.sum()}") |
|
print(f" Test nodes: {data.test_mask.sum()}") |
|
|
|
except Exception as e: |
|
print(f"β Error loading dataset: {e}") |
|
return |
|
|
|
|
|
print("\nποΈ Initializing GraphMamba...") |
|
try: |
|
model = GraphMamba(config).to(device) |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
print(f"β
Model initialized!") |
|
print(f" Parameters: {total_params:,}") |
|
print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB") |
|
|
|
except Exception as e: |
|
print(f"β Error initializing model: {e}") |
|
return |
|
|
|
|
|
print("\nπ Testing forward pass...") |
|
try: |
|
model.eval() |
|
with torch.no_grad(): |
|
h = model(data.x, data.edge_index) |
|
print(f"β
Forward pass successful!") |
|
print(f" Input shape: {data.x.shape}") |
|
print(f" Output shape: {h.shape}") |
|
print(f" Output range: [{h.min():.3f}, {h.max():.3f}]") |
|
|
|
except Exception as e: |
|
print(f"β Forward pass failed: {e}") |
|
return |
|
|
|
|
|
print("\nπ Testing ordering strategies...") |
|
strategies = ['bfs', 'spectral', 'degree', 'community'] |
|
|
|
for strategy in strategies: |
|
try: |
|
config['ordering']['strategy'] = strategy |
|
test_model = GraphMamba(config).to(device) |
|
test_model.eval() |
|
|
|
start_time = time.time() |
|
with torch.no_grad(): |
|
h = test_model(data.x, data.edge_index) |
|
end_time = time.time() |
|
|
|
print(f"β
{strategy:12} | Shape: {h.shape} | Time: {(end_time-start_time)*1000:.2f}ms") |
|
|
|
except Exception as e: |
|
print(f"β {strategy:12} | Failed: {str(e)}") |
|
|
|
|
|
print("\nποΈ Testing training system...") |
|
try: |
|
trainer = GraphMambaTrainer(model, config, device) |
|
print(f"β
Trainer initialized!") |
|
print(f" Optimizer: {type(trainer.optimizer).__name__}") |
|
print(f" Learning rate: {trainer.lr}") |
|
print(f" Epochs: {trainer.epochs}") |
|
|
|
except Exception as e: |
|
print(f"β Trainer initialization failed: {e}") |
|
return |
|
|
|
|
|
print("\nπ― Running training...") |
|
try: |
|
start_time = time.time() |
|
history = trainer.train_node_classification(data, verbose=True) |
|
training_time = time.time() - start_time |
|
|
|
print(f"β
Training completed!") |
|
print(f" Training time: {training_time:.2f}s") |
|
print(f" Epochs trained: {len(history['train_loss'])}") |
|
print(f" Best val accuracy: {trainer.best_val_acc:.4f}") |
|
|
|
except Exception as e: |
|
print(f"β Training failed: {e}") |
|
return |
|
|
|
|
|
print("\nπ Testing evaluation...") |
|
try: |
|
test_results = trainer.test(data) |
|
print(f"β
Evaluation completed!") |
|
print(f" Test accuracy: {test_results['test_acc']:.4f}") |
|
print(f" Test loss: {test_results['test_loss']:.4f}") |
|
|
|
|
|
class_accs = test_results['class_acc'] |
|
print(f" Per-class accuracy:") |
|
for i, acc in enumerate(class_accs): |
|
print(f" Class {i}: {acc:.4f}") |
|
|
|
except Exception as e: |
|
print(f"β Evaluation failed: {e}") |
|
return |
|
|
|
|
|
print("\nπ¨ Testing visualization...") |
|
try: |
|
|
|
graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200) |
|
metrics_fig = GraphVisualizer.create_metrics_plot(test_results) |
|
training_fig = GraphVisualizer.create_training_history_plot(history) |
|
|
|
print(f"β
Visualizations created!") |
|
print(f" Graph plot: {type(graph_fig).__name__}") |
|
print(f" Metrics plot: {type(metrics_fig).__name__}") |
|
print(f" Training plot: {type(training_fig).__name__}") |
|
|
|
|
|
graph_fig.write_html("graph_visualization.html") |
|
metrics_fig.write_html("metrics_plot.html") |
|
training_fig.write_html("training_history.html") |
|
print(f" Plots saved as HTML files") |
|
|
|
except Exception as e: |
|
print(f"β Visualization failed: {e}") |
|
|
|
|
|
print("\nπ Performance Summary") |
|
print("=" * 40) |
|
print(f"π Dataset: Cora ({data.num_nodes:,} nodes)") |
|
print(f"π§ Model: {total_params:,} parameters") |
|
print(f"β±οΈ Training: {training_time:.2f}s ({len(history['train_loss'])} epochs)") |
|
print(f"π― Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)") |
|
print(f"π
Best Val Accuracy: {trainer.best_val_acc:.4f} ({trainer.best_val_acc*100:.2f}%)") |
|
|
|
|
|
cora_baselines = { |
|
'GCN': 0.815, |
|
'GAT': 0.830, |
|
'GraphSAGE': 0.824, |
|
'GIN': 0.800 |
|
} |
|
|
|
print(f"\nπ Comparison with Baselines:") |
|
test_acc = test_results['test_acc'] |
|
for model_name, baseline in cora_baselines.items(): |
|
diff = test_acc - baseline |
|
status = "π’" if diff > 0 else "π‘" if diff > -0.05 else "π΄" |
|
print(f" {status} {model_name:12}: {baseline:.3f} (diff: {diff:+.3f})") |
|
|
|
print(f"\n⨠Test completed successfully!") |
|
print(f"π Ready for production deployment!") |
|
|
|
if __name__ == "__main__": |
|
main() |