File size: 7,528 Bytes
ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 97c533b ba4e201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
#!/usr/bin/env python3
"""
Complete test script for Mamba Graph implementation
Tests training, evaluation, and visualization
"""
import torch
import os
import time
from core.graph_mamba import GraphMamba
from core.trainer import GraphMambaTrainer
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
from utils.visualization import GraphVisualizer
def main():
print("π§ Mamba Graph Neural Network - Complete Test")
print("=" * 60)
# Configuration
config = {
'model': {
'd_model': 128,
'd_state': 8,
'd_conv': 4,
'expand': 2,
'n_layers': 3,
'dropout': 0.1
},
'data': {
'batch_size': 16,
'test_split': 0.2
},
'training': {
'learning_rate': 0.01,
'weight_decay': 0.0005,
'epochs': 50, # Quick test
'patience': 10,
'warmup_epochs': 5,
'min_lr': 1e-6
},
'ordering': {
'strategy': 'bfs',
'preserve_locality': True
}
}
# Setup device
if os.getenv('SPACE_ID'):
device = torch.device('cpu')
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"πΎ Device: {device}")
# Load dataset
print("\nπ Loading Cora dataset...")
try:
data_loader = GraphDataLoader()
dataset = data_loader.load_node_classification_data('Cora')
data = dataset[0].to(device)
info = data_loader.get_dataset_info(dataset)
print(f"β
Dataset loaded successfully!")
print(f" Nodes: {data.num_nodes:,}")
print(f" Edges: {data.num_edges:,}")
print(f" Features: {info['num_features']}")
print(f" Classes: {info['num_classes']}")
print(f" Train nodes: {data.train_mask.sum()}")
print(f" Val nodes: {data.val_mask.sum()}")
print(f" Test nodes: {data.test_mask.sum()}")
except Exception as e:
print(f"β Error loading dataset: {e}")
return
# Initialize model
print("\nποΈ Initializing GraphMamba...")
try:
model = GraphMamba(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"β
Model initialized!")
print(f" Parameters: {total_params:,}")
print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB")
except Exception as e:
print(f"β Error initializing model: {e}")
return
# Test forward pass
print("\nπ Testing forward pass...")
try:
model.eval()
with torch.no_grad():
h = model(data.x, data.edge_index)
print(f"β
Forward pass successful!")
print(f" Input shape: {data.x.shape}")
print(f" Output shape: {h.shape}")
print(f" Output range: [{h.min():.3f}, {h.max():.3f}]")
except Exception as e:
print(f"β Forward pass failed: {e}")
return
# Test ordering strategies
print("\nπ Testing ordering strategies...")
strategies = ['bfs', 'spectral', 'degree', 'community']
for strategy in strategies:
try:
config['ordering']['strategy'] = strategy
test_model = GraphMamba(config).to(device)
test_model.eval()
start_time = time.time()
with torch.no_grad():
h = test_model(data.x, data.edge_index)
end_time = time.time()
print(f"β
{strategy:12} | Shape: {h.shape} | Time: {(end_time-start_time)*1000:.2f}ms")
except Exception as e:
print(f"β {strategy:12} | Failed: {str(e)}")
# Initialize trainer
print("\nποΈ Testing training system...")
try:
trainer = GraphMambaTrainer(model, config, device)
print(f"β
Trainer initialized!")
print(f" Optimizer: {type(trainer.optimizer).__name__}")
print(f" Learning rate: {trainer.lr}")
print(f" Epochs: {trainer.epochs}")
except Exception as e:
print(f"β Trainer initialization failed: {e}")
return
# Run training
print("\nπ― Running training...")
try:
start_time = time.time()
history = trainer.train_node_classification(data, verbose=True)
training_time = time.time() - start_time
print(f"β
Training completed!")
print(f" Training time: {training_time:.2f}s")
print(f" Epochs trained: {len(history['train_loss'])}")
print(f" Best val accuracy: {trainer.best_val_acc:.4f}")
except Exception as e:
print(f"β Training failed: {e}")
return
# Test evaluation
print("\nπ Testing evaluation...")
try:
test_results = trainer.test(data)
print(f"β
Evaluation completed!")
print(f" Test accuracy: {test_results['test_acc']:.4f}")
print(f" Test loss: {test_results['test_loss']:.4f}")
# Per-class results
class_accs = test_results['class_acc']
print(f" Per-class accuracy:")
for i, acc in enumerate(class_accs):
print(f" Class {i}: {acc:.4f}")
except Exception as e:
print(f"β Evaluation failed: {e}")
return
# Test visualization
print("\nπ¨ Testing visualization...")
try:
# Create visualizations
graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200)
metrics_fig = GraphVisualizer.create_metrics_plot(test_results)
training_fig = GraphVisualizer.create_training_history_plot(history)
print(f"β
Visualizations created!")
print(f" Graph plot: {type(graph_fig).__name__}")
print(f" Metrics plot: {type(metrics_fig).__name__}")
print(f" Training plot: {type(training_fig).__name__}")
# Save plots
graph_fig.write_html("graph_visualization.html")
metrics_fig.write_html("metrics_plot.html")
training_fig.write_html("training_history.html")
print(f" Plots saved as HTML files")
except Exception as e:
print(f"β Visualization failed: {e}")
# Performance summary
print("\nπ Performance Summary")
print("=" * 40)
print(f"π Dataset: Cora ({data.num_nodes:,} nodes)")
print(f"π§ Model: {total_params:,} parameters")
print(f"β±οΈ Training: {training_time:.2f}s ({len(history['train_loss'])} epochs)")
print(f"π― Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)")
print(f"π
Best Val Accuracy: {trainer.best_val_acc:.4f} ({trainer.best_val_acc*100:.2f}%)")
# Compare with baselines
cora_baselines = {
'GCN': 0.815,
'GAT': 0.830,
'GraphSAGE': 0.824,
'GIN': 0.800
}
print(f"\nπ Comparison with Baselines:")
test_acc = test_results['test_acc']
for model_name, baseline in cora_baselines.items():
diff = test_acc - baseline
status = "π’" if diff > 0 else "π‘" if diff > -0.05 else "π΄"
print(f" {status} {model_name:12}: {baseline:.3f} (diff: {diff:+.3f})")
print(f"\n⨠Test completed successfully!")
print(f"π Ready for production deployment!")
if __name__ == "__main__":
main() |