|
import gradio as gr |
|
import torch |
|
import yaml |
|
import os |
|
from core.graph_mamba import GraphMamba |
|
from data.loader import GraphDataLoader |
|
from utils.metrics import GraphMetrics |
|
from utils.visualization import GraphVisualizer |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
if os.getenv('SPACE_ID') or os.getenv('GRADIO_SERVER_NAME'): |
|
device = torch.device('cpu') |
|
print("Running on HuggingFace Spaces - using CPU") |
|
else: |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Running locally - using {device}") |
|
|
|
|
|
config = { |
|
'model': { |
|
'd_model': 128, |
|
'd_state': 8, |
|
'd_conv': 4, |
|
'expand': 2, |
|
'n_layers': 3, |
|
'dropout': 0.1 |
|
}, |
|
'data': { |
|
'batch_size': 16, |
|
'test_split': 0.2 |
|
}, |
|
'ordering': { |
|
'strategy': 'bfs', |
|
'preserve_locality': True |
|
} |
|
} |
|
|
|
|
|
model = None |
|
current_dataset = None |
|
|
|
def load_and_evaluate(dataset_name, ordering_strategy, num_layers): |
|
"""Load dataset, configure model, return results""" |
|
global model, config, current_dataset |
|
|
|
try: |
|
|
|
config['ordering']['strategy'] = ordering_strategy |
|
config['model']['n_layers'] = num_layers |
|
|
|
print(f"Loading dataset: {dataset_name}") |
|
|
|
|
|
data_loader = GraphDataLoader() |
|
|
|
if dataset_name in ['Cora', 'CiteSeer', 'PubMed']: |
|
dataset = data_loader.load_node_classification_data(dataset_name) |
|
data = dataset[0].to(device) |
|
task_type = 'node_classification' |
|
current_dataset = data |
|
print(f"Loaded {dataset_name}: {data.num_nodes} nodes, {data.num_edges} edges") |
|
else: |
|
dataset = data_loader.load_graph_classification_data(dataset_name) |
|
task_type = 'graph_classification' |
|
print(f"Loaded {dataset_name}: {len(dataset)} graphs") |
|
|
|
|
|
dataset_info = data_loader.get_dataset_info(dataset) |
|
|
|
print(f"Dataset info: {dataset_info}") |
|
|
|
|
|
print("Initializing GraphMamba model...") |
|
model = GraphMamba(config).to(device) |
|
|
|
|
|
num_classes = dataset_info['num_classes'] |
|
model._init_classifier(num_classes, device) |
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
print(f"Model parameters: {total_params:,}") |
|
|
|
|
|
print("Running evaluation...") |
|
if task_type == 'node_classification': |
|
|
|
if hasattr(data, 'test_mask'): |
|
mask = data.test_mask |
|
else: |
|
|
|
num_nodes = data.num_nodes |
|
mask = torch.zeros(num_nodes, dtype=torch.bool) |
|
mask[num_nodes//2:] = True |
|
|
|
metrics = GraphMetrics.evaluate_node_classification( |
|
model, data, mask, device |
|
) |
|
|
|
|
|
print("Creating visualization...") |
|
fig = GraphVisualizer.create_graph_plot(data) |
|
|
|
else: |
|
|
|
train_loader, val_loader, test_loader = data_loader.create_dataloaders( |
|
dataset, 'graph_classification' |
|
) |
|
metrics = GraphMetrics.evaluate_graph_classification( |
|
model, test_loader, device |
|
) |
|
fig = GraphVisualizer.create_metrics_plot(metrics) |
|
|
|
|
|
results_text = f""" |
|
## ๐ง Mamba Graph Neural Network Results |
|
|
|
### Dataset: {dataset_name} |
|
|
|
**Dataset Statistics:** |
|
- ๐ Features: {dataset_info['num_features']} |
|
- ๐ท๏ธ Classes: {dataset_info['num_classes']} |
|
- ๐ Graphs: {dataset_info['num_graphs']} |
|
- ๐ Avg Nodes: {dataset_info['avg_nodes']:.1f} |
|
- ๐ Avg Edges: {dataset_info['avg_edges']:.1f} |
|
|
|
**Model Configuration:** |
|
- ๐ Ordering Strategy: {ordering_strategy} |
|
- ๐๏ธ Layers: {num_layers} |
|
- โ๏ธ Parameters: {sum(p.numel() for p in model.parameters()):,} |
|
- ๐พ Device: {device} |
|
|
|
**Performance Metrics:** |
|
""" |
|
|
|
for metric, value in metrics.items(): |
|
if isinstance(value, float) and metric != 'error': |
|
results_text += f"- ๐ {metric.replace('_', ' ').title()}: {value:.4f}\n" |
|
elif metric == 'error': |
|
results_text += f"- โ ๏ธ Error: {value}\n" |
|
|
|
results_text += f""" |
|
|
|
**Status:** โ
Model successfully loaded and evaluated! |
|
|
|
*Note: This is a demo with random weights. In production, the model would be trained on the dataset.* |
|
""" |
|
|
|
print("Evaluation completed successfully!") |
|
return results_text, fig |
|
|
|
except Exception as e: |
|
error_msg = f""" |
|
## โ Error Loading Model |
|
|
|
**Error:** {str(e)} |
|
|
|
**Troubleshooting:** |
|
- Check dataset availability |
|
- Verify device compatibility |
|
- Try different ordering strategy |
|
|
|
**Debug Info:** |
|
- Device: {device} |
|
- Dataset: {dataset_name} |
|
- Strategy: {ordering_strategy} |
|
""" |
|
|
|
print(f"Error: {e}") |
|
|
|
|
|
import plotly.graph_objects as go |
|
fig = go.Figure() |
|
fig.add_annotation( |
|
text=f"Error: {str(e)}", |
|
x=0.5, y=0.5, |
|
xref="paper", yref="paper", |
|
showarrow=False |
|
) |
|
|
|
return error_msg, fig |
|
|
|
|
|
with gr.Blocks( |
|
title="๐ง Mamba Graph Neural Network", |
|
theme=gr.themes.Soft(), |
|
css=""" |
|
.gradio-container { |
|
max-width: 1200px !important; |
|
} |
|
""" |
|
) as demo: |
|
|
|
gr.Markdown(""" |
|
# ๐ง Mamba Graph Neural Network |
|
|
|
**Real-time evaluation of Graph-Mamba on standard benchmarks.** |
|
|
|
This demonstrates the revolutionary combination of Mamba's linear complexity with graph neural networks. |
|
Uses actual datasets and real model architectures - no synthetic data. |
|
|
|
๐ **Features:** |
|
- Linear O(n) complexity for massive graphs |
|
- Multiple graph ordering strategies |
|
- Real benchmark datasets (Cora, CiteSeer, etc.) |
|
- Interactive visualizations |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### ๐ฎ Model Configuration") |
|
|
|
dataset_choice = gr.Dropdown( |
|
choices=['Cora', 'CiteSeer', 'PubMed', 'MUTAG', 'ENZYMES'], |
|
value='Cora', |
|
label="๐ Dataset", |
|
info="Choose a graph dataset for evaluation" |
|
) |
|
|
|
ordering_choice = gr.Dropdown( |
|
choices=['bfs', 'spectral', 'degree', 'community'], |
|
value='bfs', |
|
label="๐ Graph Ordering Strategy", |
|
info="How to convert graph to sequence" |
|
) |
|
|
|
layers_slider = gr.Slider( |
|
minimum=2, maximum=6, value=3, step=1, |
|
label="๐๏ธ Number of Mamba Layers", |
|
info="More layers = more capacity" |
|
) |
|
|
|
evaluate_btn = gr.Button( |
|
"๐ Evaluate Model", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
gr.Markdown(""" |
|
### ๐ Ordering Strategies: |
|
- **BFS**: Breadth-first traversal |
|
- **Spectral**: Eigenvalue-based ordering |
|
- **Degree**: High-degree nodes first |
|
- **Community**: Cluster-aware ordering |
|
""") |
|
|
|
with gr.Column(scale=2): |
|
results_text = gr.Markdown(""" |
|
### ๐ Welcome! |
|
|
|
Select your parameters and click **'๐ Evaluate Model'** to see Mamba Graph in action. |
|
|
|
The model will: |
|
1. ๐ฅ Load the selected dataset |
|
2. ๐ Apply graph ordering strategy |
|
3. ๐ง Process through Mamba layers |
|
4. ๐ Evaluate performance |
|
5. ๐ Show results and visualization |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
visualization = gr.Plot( |
|
label="๐ Graph Visualization", |
|
container=True |
|
) |
|
|
|
|
|
evaluate_btn.click( |
|
fn=load_and_evaluate, |
|
inputs=[dataset_choice, ordering_choice, layers_slider], |
|
outputs=[results_text, visualization], |
|
show_progress=True |
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
--- |
|
### ๐ฏ What Makes This Special? |
|
|
|
**Traditional GNNs:** O(nยฒ) complexity limits them to small graphs |
|
|
|
**Mamba Graph:** O(n) complexity enables processing of massive graphs |
|
|
|
**Key Innovation:** Smart graph-to-sequence conversion preserves structural information while enabling linear-time processing. |
|
|
|
### ๐ฌ Technical Details: |
|
- **Selective State Space Models** for sequence processing |
|
- **Structure-preserving ordering** algorithms |
|
- **Position encoding** to maintain graph relationships |
|
- **Multi-scale processing** for different graph properties |
|
|
|
### ๐ References: |
|
- Mamba: Linear-Time Sequence Modeling (Gu & Dao, 2023) |
|
- Graph Neural Networks (Kipf & Welling, 2017) |
|
- Spectral Graph Theory applications |
|
""") |
|
|
|
if __name__ == "__main__": |
|
print("๐ง Starting Mamba Graph Demo...") |
|
print(f"Device: {device}") |
|
print("Loading Gradio interface...") |
|
|
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
share=False |
|
) |