serpent / demo.py
kfoughali's picture
Create demo.py
ba4e201 verified
raw
history blame
4.05 kB
#!/usr/bin/env python3
"""
Quick demo script to test Mamba Graph implementation
Device-safe version
"""
import torch
import os
from core.graph_mamba import GraphMamba
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
def main():
print("🧠 Testing Mamba Graph Neural Network")
print("=" * 50)
# Configuration
config = {
'model': {
'd_model': 128,
'd_state': 8,
'd_conv': 4,
'expand': 2,
'n_layers': 3,
'dropout': 0.1
},
'data': {
'batch_size': 16,
'test_split': 0.2
},
'ordering': {
'strategy': 'bfs',
'preserve_locality': True
}
}
# Setup device
if os.getenv('SPACE_ID'):
device = torch.device('cpu')
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Load dataset
print("\nπŸ“Š Loading Cora dataset...")
try:
data_loader = GraphDataLoader()
dataset = data_loader.load_node_classification_data('Cora')
data = dataset[0].to(device)
# Dataset info
info = data_loader.get_dataset_info(dataset)
print(f"βœ… Success!")
print(f"Nodes: {data.num_nodes}")
print(f"Edges: {data.num_edges}")
print(f"Features: {info['num_features']}")
print(f"Classes: {info['num_classes']}")
except Exception as e:
print(f"❌ Error loading dataset: {e}")
return
# Initialize model
print("\nπŸ—οΈ Initializing GraphMamba...")
try:
model = GraphMamba(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"βœ… Model initialized!")
print(f"Parameters: {total_params:,}")
except Exception as e:
print(f"❌ Error initializing model: {e}")
return
# Forward pass test
print("\nπŸš€ Testing forward pass...")
try:
model.eval()
with torch.no_grad():
h = model(data.x, data.edge_index)
print(f"βœ… Forward pass successful!")
print(f"Input shape: {data.x.shape}")
print(f"Output shape: {h.shape}")
print(f"Output range: [{h.min():.3f}, {h.max():.3f}]")
except Exception as e:
print(f"❌ Forward pass failed: {e}")
return
# Test different ordering strategies
print("\nπŸ”„ Testing ordering strategies...")
strategies = ['bfs', 'spectral', 'degree', 'community']
for strategy in strategies:
try:
config['ordering']['strategy'] = strategy
model_test = GraphMamba(config).to(device)
model_test.eval()
with torch.no_grad():
h = model_test(data.x, data.edge_index)
print(f"βœ… {strategy}: Success - Shape {h.shape}")
except Exception as e:
print(f"❌ {strategy}: Failed - {str(e)}")
# Test evaluation
print("\nπŸ“ˆ Testing evaluation...")
try:
# Initialize classifier
num_classes = info['num_classes']
model._init_classifier(num_classes, device)
# Create test mask if not available
if hasattr(data, 'test_mask'):
mask = data.test_mask
else:
mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device)
mask[data.num_nodes//2:] = True
metrics = GraphMetrics.evaluate_node_classification(model, data, mask, device)
print("βœ… Evaluation successful!")
for metric, value in metrics.items():
if isinstance(value, float):
print(f" {metric}: {value:.4f}")
except Exception as e:
print(f"❌ Evaluation failed: {e}")
print("\n✨ Demo completed!")
print("πŸš€ Ready for production deployment!")
if __name__ == "__main__":
main()