File size: 4,052 Bytes
ba4e201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
#!/usr/bin/env python3
"""
Quick demo script to test Mamba Graph implementation
Device-safe version
"""
import torch
import os
from core.graph_mamba import GraphMamba
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
def main():
print("π§ Testing Mamba Graph Neural Network")
print("=" * 50)
# Configuration
config = {
'model': {
'd_model': 128,
'd_state': 8,
'd_conv': 4,
'expand': 2,
'n_layers': 3,
'dropout': 0.1
},
'data': {
'batch_size': 16,
'test_split': 0.2
},
'ordering': {
'strategy': 'bfs',
'preserve_locality': True
}
}
# Setup device
if os.getenv('SPACE_ID'):
device = torch.device('cpu')
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Load dataset
print("\nπ Loading Cora dataset...")
try:
data_loader = GraphDataLoader()
dataset = data_loader.load_node_classification_data('Cora')
data = dataset[0].to(device)
# Dataset info
info = data_loader.get_dataset_info(dataset)
print(f"β
Success!")
print(f"Nodes: {data.num_nodes}")
print(f"Edges: {data.num_edges}")
print(f"Features: {info['num_features']}")
print(f"Classes: {info['num_classes']}")
except Exception as e:
print(f"β Error loading dataset: {e}")
return
# Initialize model
print("\nποΈ Initializing GraphMamba...")
try:
model = GraphMamba(config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"β
Model initialized!")
print(f"Parameters: {total_params:,}")
except Exception as e:
print(f"β Error initializing model: {e}")
return
# Forward pass test
print("\nπ Testing forward pass...")
try:
model.eval()
with torch.no_grad():
h = model(data.x, data.edge_index)
print(f"β
Forward pass successful!")
print(f"Input shape: {data.x.shape}")
print(f"Output shape: {h.shape}")
print(f"Output range: [{h.min():.3f}, {h.max():.3f}]")
except Exception as e:
print(f"β Forward pass failed: {e}")
return
# Test different ordering strategies
print("\nπ Testing ordering strategies...")
strategies = ['bfs', 'spectral', 'degree', 'community']
for strategy in strategies:
try:
config['ordering']['strategy'] = strategy
model_test = GraphMamba(config).to(device)
model_test.eval()
with torch.no_grad():
h = model_test(data.x, data.edge_index)
print(f"β
{strategy}: Success - Shape {h.shape}")
except Exception as e:
print(f"β {strategy}: Failed - {str(e)}")
# Test evaluation
print("\nπ Testing evaluation...")
try:
# Initialize classifier
num_classes = info['num_classes']
model._init_classifier(num_classes, device)
# Create test mask if not available
if hasattr(data, 'test_mask'):
mask = data.test_mask
else:
mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device)
mask[data.num_nodes//2:] = True
metrics = GraphMetrics.evaluate_node_classification(model, data, mask, device)
print("β
Evaluation successful!")
for metric, value in metrics.items():
if isinstance(value, float):
print(f" {metric}: {value:.4f}")
except Exception as e:
print(f"β Evaluation failed: {e}")
print("\n⨠Demo completed!")
print("π Ready for production deployment!")
if __name__ == "__main__":
main() |