|
|
|
""" |
|
Quick demo script to test Mamba Graph implementation |
|
Device-safe version |
|
""" |
|
|
|
import torch |
|
import os |
|
from core.graph_mamba import GraphMamba |
|
from data.loader import GraphDataLoader |
|
from utils.metrics import GraphMetrics |
|
|
|
def main(): |
|
print("π§ Testing Mamba Graph Neural Network") |
|
print("=" * 50) |
|
|
|
|
|
config = { |
|
'model': { |
|
'd_model': 128, |
|
'd_state': 8, |
|
'd_conv': 4, |
|
'expand': 2, |
|
'n_layers': 3, |
|
'dropout': 0.1 |
|
}, |
|
'data': { |
|
'batch_size': 16, |
|
'test_split': 0.2 |
|
}, |
|
'ordering': { |
|
'strategy': 'bfs', |
|
'preserve_locality': True |
|
} |
|
} |
|
|
|
|
|
if os.getenv('SPACE_ID'): |
|
device = torch.device('cpu') |
|
else: |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Device: {device}") |
|
|
|
|
|
print("\nπ Loading Cora dataset...") |
|
try: |
|
data_loader = GraphDataLoader() |
|
dataset = data_loader.load_node_classification_data('Cora') |
|
data = dataset[0].to(device) |
|
|
|
|
|
info = data_loader.get_dataset_info(dataset) |
|
print(f"β
Success!") |
|
print(f"Nodes: {data.num_nodes}") |
|
print(f"Edges: {data.num_edges}") |
|
print(f"Features: {info['num_features']}") |
|
print(f"Classes: {info['num_classes']}") |
|
|
|
except Exception as e: |
|
print(f"β Error loading dataset: {e}") |
|
return |
|
|
|
|
|
print("\nποΈ Initializing GraphMamba...") |
|
try: |
|
model = GraphMamba(config).to(device) |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
print(f"β
Model initialized!") |
|
print(f"Parameters: {total_params:,}") |
|
|
|
except Exception as e: |
|
print(f"β Error initializing model: {e}") |
|
return |
|
|
|
|
|
print("\nπ Testing forward pass...") |
|
try: |
|
model.eval() |
|
with torch.no_grad(): |
|
h = model(data.x, data.edge_index) |
|
print(f"β
Forward pass successful!") |
|
print(f"Input shape: {data.x.shape}") |
|
print(f"Output shape: {h.shape}") |
|
print(f"Output range: [{h.min():.3f}, {h.max():.3f}]") |
|
|
|
except Exception as e: |
|
print(f"β Forward pass failed: {e}") |
|
return |
|
|
|
|
|
print("\nπ Testing ordering strategies...") |
|
|
|
strategies = ['bfs', 'spectral', 'degree', 'community'] |
|
|
|
for strategy in strategies: |
|
try: |
|
config['ordering']['strategy'] = strategy |
|
model_test = GraphMamba(config).to(device) |
|
model_test.eval() |
|
|
|
with torch.no_grad(): |
|
h = model_test(data.x, data.edge_index) |
|
print(f"β
{strategy}: Success - Shape {h.shape}") |
|
|
|
except Exception as e: |
|
print(f"β {strategy}: Failed - {str(e)}") |
|
|
|
|
|
print("\nπ Testing evaluation...") |
|
try: |
|
|
|
num_classes = info['num_classes'] |
|
model._init_classifier(num_classes, device) |
|
|
|
|
|
if hasattr(data, 'test_mask'): |
|
mask = data.test_mask |
|
else: |
|
mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=device) |
|
mask[data.num_nodes//2:] = True |
|
|
|
metrics = GraphMetrics.evaluate_node_classification(model, data, mask, device) |
|
print("β
Evaluation successful!") |
|
for metric, value in metrics.items(): |
|
if isinstance(value, float): |
|
print(f" {metric}: {value:.4f}") |
|
|
|
except Exception as e: |
|
print(f"β Evaluation failed: {e}") |
|
|
|
print("\n⨠Demo completed!") |
|
print("π Ready for production deployment!") |
|
|
|
if __name__ == "__main__": |
|
main() |