serpent / app.py
kfoughali's picture
Update app.py
4f8aa53 verified
raw
history blame
9.94 kB
import gradio as gr
import torch
import yaml
import os
from core.graph_mamba import GraphMamba
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
from utils.visualization import GraphVisualizer
import warnings
warnings.filterwarnings('ignore')
# Force CPU for HuggingFace Spaces
if os.getenv('SPACE_ID') or os.getenv('GRADIO_SERVER_NAME'):
device = torch.device('cpu')
print("Running on HuggingFace Spaces - using CPU")
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running locally - using {device}")
# Load configuration
config = {
'model': {
'd_model': 128, # Smaller for demo
'd_state': 8,
'd_conv': 4,
'expand': 2,
'n_layers': 3, # Fewer layers for speed
'dropout': 0.1
},
'data': {
'batch_size': 16,
'test_split': 0.2
},
'ordering': {
'strategy': 'bfs',
'preserve_locality': True
}
}
# Global model holder
model = None
current_dataset = None
def load_and_evaluate(dataset_name, ordering_strategy, num_layers):
"""Load dataset, configure model, return results"""
global model, config, current_dataset
try:
# Update config
config['ordering']['strategy'] = ordering_strategy
config['model']['n_layers'] = num_layers
print(f"Loading dataset: {dataset_name}")
# Load data
data_loader = GraphDataLoader()
if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
dataset = data_loader.load_node_classification_data(dataset_name)
data = dataset[0].to(device)
task_type = 'node_classification'
current_dataset = data
print(f"Loaded {dataset_name}: {data.num_nodes} nodes, {data.num_edges} edges")
else:
dataset = data_loader.load_graph_classification_data(dataset_name)
task_type = 'graph_classification'
print(f"Loaded {dataset_name}: {len(dataset)} graphs")
# Get dataset info
dataset_info = data_loader.get_dataset_info(dataset)
print(f"Dataset info: {dataset_info}")
# Initialize model
print("Initializing GraphMamba model...")
model = GraphMamba(config).to(device)
# Initialize classifier for evaluation
num_classes = dataset_info['num_classes']
model._init_classifier(num_classes, device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
# Quick evaluation (random weights for demo)
print("Running evaluation...")
if task_type == 'node_classification':
# Use test mask for evaluation
if hasattr(data, 'test_mask'):
mask = data.test_mask
else:
# Create a test mask if not available
num_nodes = data.num_nodes
mask = torch.zeros(num_nodes, dtype=torch.bool)
mask[num_nodes//2:] = True
metrics = GraphMetrics.evaluate_node_classification(
model, data, mask, device
)
# Create visualization
print("Creating visualization...")
fig = GraphVisualizer.create_graph_plot(data)
else:
# Graph classification
train_loader, val_loader, test_loader = data_loader.create_dataloaders(
dataset, 'graph_classification'
)
metrics = GraphMetrics.evaluate_graph_classification(
model, test_loader, device
)
fig = GraphVisualizer.create_metrics_plot(metrics)
# Format results
results_text = f"""
## ๐Ÿง  Mamba Graph Neural Network Results
### Dataset: {dataset_name}
**Dataset Statistics:**
- ๐Ÿ“Š Features: {dataset_info['num_features']}
- ๐Ÿท๏ธ Classes: {dataset_info['num_classes']}
- ๐Ÿ“ˆ Graphs: {dataset_info['num_graphs']}
- ๐Ÿ”— Avg Nodes: {dataset_info['avg_nodes']:.1f}
- ๐ŸŒ Avg Edges: {dataset_info['avg_edges']:.1f}
**Model Configuration:**
- ๐Ÿ”„ Ordering Strategy: {ordering_strategy}
- ๐Ÿ—๏ธ Layers: {num_layers}
- โš™๏ธ Parameters: {sum(p.numel() for p in model.parameters()):,}
- ๐Ÿ’พ Device: {device}
**Performance Metrics:**
"""
for metric, value in metrics.items():
if isinstance(value, float) and metric != 'error':
results_text += f"- ๐Ÿ“ˆ {metric.replace('_', ' ').title()}: {value:.4f}\n"
elif metric == 'error':
results_text += f"- โš ๏ธ Error: {value}\n"
results_text += f"""
**Status:** โœ… Model successfully loaded and evaluated!
*Note: This is a demo with random weights. In production, the model would be trained on the dataset.*
"""
print("Evaluation completed successfully!")
return results_text, fig
except Exception as e:
error_msg = f"""
## โŒ Error Loading Model
**Error:** {str(e)}
**Troubleshooting:**
- Check dataset availability
- Verify device compatibility
- Try different ordering strategy
**Debug Info:**
- Device: {device}
- Dataset: {dataset_name}
- Strategy: {ordering_strategy}
"""
print(f"Error: {e}")
# Return empty plot on error
import plotly.graph_objects as go
fig = go.Figure()
fig.add_annotation(
text=f"Error: {str(e)}",
x=0.5, y=0.5,
xref="paper", yref="paper",
showarrow=False
)
return error_msg, fig
# Gradio interface
with gr.Blocks(
title="๐Ÿง  Mamba Graph Neural Network",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
"""
) as demo:
gr.Markdown("""
# ๐Ÿง  Mamba Graph Neural Network
**Real-time evaluation of Graph-Mamba on standard benchmarks.**
This demonstrates the revolutionary combination of Mamba's linear complexity with graph neural networks.
Uses actual datasets and real model architectures - no synthetic data.
๐Ÿš€ **Features:**
- Linear O(n) complexity for massive graphs
- Multiple graph ordering strategies
- Real benchmark datasets (Cora, CiteSeer, etc.)
- Interactive visualizations
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ๐ŸŽฎ Model Configuration")
dataset_choice = gr.Dropdown(
choices=['Cora', 'CiteSeer', 'PubMed', 'MUTAG', 'ENZYMES'],
value='Cora',
label="๐Ÿ“Š Dataset",
info="Choose a graph dataset for evaluation"
)
ordering_choice = gr.Dropdown(
choices=['bfs', 'spectral', 'degree', 'community'],
value='bfs',
label="๐Ÿ”„ Graph Ordering Strategy",
info="How to convert graph to sequence"
)
layers_slider = gr.Slider(
minimum=2, maximum=6, value=3, step=1,
label="๐Ÿ—๏ธ Number of Mamba Layers",
info="More layers = more capacity"
)
evaluate_btn = gr.Button(
"๐Ÿš€ Evaluate Model",
variant="primary",
size="lg"
)
gr.Markdown("""
### ๐Ÿ“– Ordering Strategies:
- **BFS**: Breadth-first traversal
- **Spectral**: Eigenvalue-based ordering
- **Degree**: High-degree nodes first
- **Community**: Cluster-aware ordering
""")
with gr.Column(scale=2):
results_text = gr.Markdown("""
### ๐Ÿ‘‹ Welcome!
Select your parameters and click **'๐Ÿš€ Evaluate Model'** to see Mamba Graph in action.
The model will:
1. ๐Ÿ“ฅ Load the selected dataset
2. ๐Ÿ”„ Apply graph ordering strategy
3. ๐Ÿง  Process through Mamba layers
4. ๐Ÿ“Š Evaluate performance
5. ๐Ÿ“ˆ Show results and visualization
""")
with gr.Row():
with gr.Column():
visualization = gr.Plot(
label="๐Ÿ“ˆ Graph Visualization",
container=True
)
# Event handlers
evaluate_btn.click(
fn=load_and_evaluate,
inputs=[dataset_choice, ordering_choice, layers_slider],
outputs=[results_text, visualization],
show_progress=True
)
# Example section
gr.Markdown("""
---
### ๐ŸŽฏ What Makes This Special?
**Traditional GNNs:** O(nยฒ) complexity limits them to small graphs
**Mamba Graph:** O(n) complexity enables processing of massive graphs
**Key Innovation:** Smart graph-to-sequence conversion preserves structural information while enabling linear-time processing.
### ๐Ÿ”ฌ Technical Details:
- **Selective State Space Models** for sequence processing
- **Structure-preserving ordering** algorithms
- **Position encoding** to maintain graph relationships
- **Multi-scale processing** for different graph properties
### ๐Ÿ“š References:
- Mamba: Linear-Time Sequence Modeling (Gu & Dao, 2023)
- Graph Neural Networks (Kipf & Welling, 2017)
- Spectral Graph Theory applications
""")
if __name__ == "__main__":
print("๐Ÿง  Starting Mamba Graph Demo...")
print(f"Device: {device}")
print("Loading Gradio interface...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False # Set to False for HuggingFace Spaces
)