serpent / app.py
kfoughali's picture
Create app.py
4ac02ee verified
raw
history blame
8.2 kB
import gradio as gr
import torch
import yaml
import plotly.graph_objects as go
import plotly.express as px
from core.graph_mamba import GraphMamba
from data.loader import GraphDataLoader
from utils.metrics import GraphMetrics
import networkx as nx
import numpy as np
# Load configuration
with open('config.yaml', 'r') as f:
config = yaml.safe_load(f)
# Initialize model (will be loaded dynamically based on dataset)
model = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_and_evaluate(dataset_name, ordering_strategy, num_layers):
"""Load dataset, train/evaluate model, return results"""
global model, config
try:
# Update config
config['ordering']['strategy'] = ordering_strategy
config['model']['n_layers'] = num_layers
# Load data
data_loader = GraphDataLoader()
if dataset_name in ['Cora', 'CiteSeer', 'PubMed', 'Reddit', 'Flickr']:
dataset = data_loader.load_node_classification_data(dataset_name)
data = dataset[0].to(device)
task_type = 'node_classification'
else:
dataset = data_loader.load_graph_classification_data(dataset_name)
train_loader, val_loader, test_loader = data_loader.create_dataloaders(
dataset, 'graph_classification'
)
task_type = 'graph_classification'
# Get dataset info
dataset_info = data_loader.get_dataset_info(dataset)
# Initialize model
model = GraphMamba(config).to(device)
# Quick evaluation (in production, you'd load pre-trained weights)
if task_type == 'node_classification':
# Use test mask for evaluation
metrics = GraphMetrics.evaluate_node_classification(
model, data, data.test_mask, device
)
# Create visualization
fig = create_graph_visualization(data)
else:
# Graph classification
metrics = GraphMetrics.evaluate_graph_classification(
model, test_loader, device
)
fig = create_dataset_stats_plot(dataset_info)
# Format results
results_text = f"""
## Dataset: {dataset_name}
**Dataset Statistics:**
- Features: {dataset_info['num_features']}
- Classes: {dataset_info['num_classes']}
- Graphs: {dataset_info['num_graphs']}
- Avg Nodes: {dataset_info['avg_nodes']:.1f}
- Avg Edges: {dataset_info['avg_edges']:.1f}
**Model Configuration:**
- Ordering Strategy: {ordering_strategy}
- Layers: {num_layers}
- Model Parameters: {sum(p.numel() for p in model.parameters()):,}
**Performance Metrics:**
"""
for metric, value in metrics.items():
if isinstance(value, float):
results_text += f"- {metric.replace('_', ' ').title()}: {value:.4f}\n"
return results_text, fig
except Exception as e:
return f"Error: {str(e)}", None
def create_graph_visualization(data):
"""Create interactive graph visualization"""
try:
# Convert to NetworkX
G = nx.Graph()
edge_list = data.edge_index.t().cpu().numpy()
G.add_edges_from(edge_list)
# Limit to first 1000 nodes for visualization
if len(G.nodes()) > 1000:
nodes_to_keep = list(G.nodes())[:1000]
G = G.subgraph(nodes_to_keep)
# Layout
pos = nx.spring_layout(G, k=1, iterations=50)
# Node colors based on labels if available
node_colors = []
if hasattr(data, 'y') and data.y is not None:
labels = data.y.cpu().numpy()
for node in G.nodes():
if node < len(labels):
node_colors.append(labels[node])
else:
node_colors.append(0)
else:
node_colors = [0] * len(G.nodes())
# Create traces
edge_x, edge_y = [], []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
node_x = [pos[node][0] for node in G.nodes()]
node_y = [pos[node][1] for node in G.nodes()]
fig = go.Figure()
# Add edges
fig.add_trace(go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines'
))
# Add nodes
fig.add_trace(go.Scatter(
x=node_x, y=node_y,
mode='markers',
hoverinfo='text',
text=[f'Node {i}' for i in G.nodes()],
marker=dict(
size=8,
color=node_colors,
colorscale='Viridis',
line=dict(width=2)
)
))
fig.update_layout(
title='Graph Visualization',
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
annotations=[
dict(
text="Graph structure visualization",
showarrow=False,
xref="paper", yref="paper",
x=0.005, y=-0.002,
xanchor='left', yanchor='bottom',
font=dict(color="black", size=12)
)
],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
)
return fig
except Exception as e:
# Return empty plot on error
fig = go.Figure()
fig.add_annotation(text=f"Visualization error: {str(e)}", x=0.5, y=0.5)
return fig
def create_dataset_stats_plot(dataset_info):
"""Create dataset statistics visualization"""
stats = [
['Features', dataset_info['num_features']],
['Classes', dataset_info['num_classes']],
['Avg Nodes', dataset_info['avg_nodes']],
['Avg Edges', dataset_info['avg_edges']]
]
fig = go.Figure(data=[
go.Bar(
x=[stat[0] for stat in stats],
y=[stat[1] for stat in stats],
marker_color='lightblue'
)
])
fig.update_layout(
title='Dataset Statistics',
xaxis_title='Metric',
yaxis_title='Value'
)
return fig
# Gradio interface
with gr.Blocks(title="Mamba Graph Neural Network") as demo:
gr.Markdown("""
# 🧠 Mamba Graph Neural Network
Real-time evaluation of Graph-Mamba on standard benchmarks.
This uses actual datasets and trained models - no synthetic data.
""")
with gr.Row():
with gr.Column():
dataset_choice = gr.Dropdown(
choices=['Cora', 'CiteSeer', 'PubMed', 'MUTAG', 'ENZYMES', 'PROTEINS'],
value='Cora',
label="Dataset"
)
ordering_choice = gr.Dropdown(
choices=['bfs', 'spectral', 'degree', 'community'],
value='bfs',
label="Graph Ordering Strategy"
)
layers_slider = gr.Slider(
minimum=2, maximum=8, value=4, step=1,
label="Number of Mamba Layers"
)
evaluate_btn = gr.Button("Evaluate Model", variant="primary")
with gr.Column():
results_text = gr.Markdown("Select parameters and click 'Evaluate Model'")
with gr.Row():
visualization = gr.Plot(label="Graph Visualization")
# Event handlers
evaluate_btn.click(
fn=load_and_evaluate,
inputs=[dataset_choice, ordering_choice, layers_slider],
outputs=[results_text, visualization]
)
if __name__ == "__main__":
demo.launch(share=True)