|
import gradio as gr |
|
import torch |
|
import yaml |
|
import plotly.graph_objects as go |
|
import plotly.express as px |
|
from core.graph_mamba import GraphMamba |
|
from data.loader import GraphDataLoader |
|
from utils.metrics import GraphMetrics |
|
import networkx as nx |
|
import numpy as np |
|
|
|
|
|
with open('config.yaml', 'r') as f: |
|
config = yaml.safe_load(f) |
|
|
|
|
|
model = None |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def load_and_evaluate(dataset_name, ordering_strategy, num_layers): |
|
"""Load dataset, train/evaluate model, return results""" |
|
global model, config |
|
|
|
try: |
|
|
|
config['ordering']['strategy'] = ordering_strategy |
|
config['model']['n_layers'] = num_layers |
|
|
|
|
|
data_loader = GraphDataLoader() |
|
|
|
if dataset_name in ['Cora', 'CiteSeer', 'PubMed', 'Reddit', 'Flickr']: |
|
dataset = data_loader.load_node_classification_data(dataset_name) |
|
data = dataset[0].to(device) |
|
task_type = 'node_classification' |
|
else: |
|
dataset = data_loader.load_graph_classification_data(dataset_name) |
|
train_loader, val_loader, test_loader = data_loader.create_dataloaders( |
|
dataset, 'graph_classification' |
|
) |
|
task_type = 'graph_classification' |
|
|
|
|
|
dataset_info = data_loader.get_dataset_info(dataset) |
|
|
|
|
|
model = GraphMamba(config).to(device) |
|
|
|
|
|
if task_type == 'node_classification': |
|
|
|
metrics = GraphMetrics.evaluate_node_classification( |
|
model, data, data.test_mask, device |
|
) |
|
|
|
|
|
fig = create_graph_visualization(data) |
|
|
|
else: |
|
|
|
metrics = GraphMetrics.evaluate_graph_classification( |
|
model, test_loader, device |
|
) |
|
fig = create_dataset_stats_plot(dataset_info) |
|
|
|
|
|
results_text = f""" |
|
## Dataset: {dataset_name} |
|
|
|
**Dataset Statistics:** |
|
- Features: {dataset_info['num_features']} |
|
- Classes: {dataset_info['num_classes']} |
|
- Graphs: {dataset_info['num_graphs']} |
|
- Avg Nodes: {dataset_info['avg_nodes']:.1f} |
|
- Avg Edges: {dataset_info['avg_edges']:.1f} |
|
|
|
**Model Configuration:** |
|
- Ordering Strategy: {ordering_strategy} |
|
- Layers: {num_layers} |
|
- Model Parameters: {sum(p.numel() for p in model.parameters()):,} |
|
|
|
**Performance Metrics:** |
|
""" |
|
|
|
for metric, value in metrics.items(): |
|
if isinstance(value, float): |
|
results_text += f"- {metric.replace('_', ' ').title()}: {value:.4f}\n" |
|
|
|
return results_text, fig |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}", None |
|
|
|
def create_graph_visualization(data): |
|
"""Create interactive graph visualization""" |
|
try: |
|
|
|
G = nx.Graph() |
|
edge_list = data.edge_index.t().cpu().numpy() |
|
G.add_edges_from(edge_list) |
|
|
|
|
|
if len(G.nodes()) > 1000: |
|
nodes_to_keep = list(G.nodes())[:1000] |
|
G = G.subgraph(nodes_to_keep) |
|
|
|
|
|
pos = nx.spring_layout(G, k=1, iterations=50) |
|
|
|
|
|
node_colors = [] |
|
if hasattr(data, 'y') and data.y is not None: |
|
labels = data.y.cpu().numpy() |
|
for node in G.nodes(): |
|
if node < len(labels): |
|
node_colors.append(labels[node]) |
|
else: |
|
node_colors.append(0) |
|
else: |
|
node_colors = [0] * len(G.nodes()) |
|
|
|
|
|
edge_x, edge_y = [], [] |
|
for edge in G.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_x.extend([x0, x1, None]) |
|
edge_y.extend([y0, y1, None]) |
|
|
|
node_x = [pos[node][0] for node in G.nodes()] |
|
node_y = [pos[node][1] for node in G.nodes()] |
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
x=edge_x, y=edge_y, |
|
line=dict(width=0.5, color='#888'), |
|
hoverinfo='none', |
|
mode='lines' |
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
x=node_x, y=node_y, |
|
mode='markers', |
|
hoverinfo='text', |
|
text=[f'Node {i}' for i in G.nodes()], |
|
marker=dict( |
|
size=8, |
|
color=node_colors, |
|
colorscale='Viridis', |
|
line=dict(width=2) |
|
) |
|
)) |
|
|
|
fig.update_layout( |
|
title='Graph Visualization', |
|
showlegend=False, |
|
hovermode='closest', |
|
margin=dict(b=20,l=5,r=5,t=40), |
|
annotations=[ |
|
dict( |
|
text="Graph structure visualization", |
|
showarrow=False, |
|
xref="paper", yref="paper", |
|
x=0.005, y=-0.002, |
|
xanchor='left', yanchor='bottom', |
|
font=dict(color="black", size=12) |
|
) |
|
], |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False) |
|
) |
|
|
|
return fig |
|
|
|
except Exception as e: |
|
|
|
fig = go.Figure() |
|
fig.add_annotation(text=f"Visualization error: {str(e)}", x=0.5, y=0.5) |
|
return fig |
|
|
|
def create_dataset_stats_plot(dataset_info): |
|
"""Create dataset statistics visualization""" |
|
stats = [ |
|
['Features', dataset_info['num_features']], |
|
['Classes', dataset_info['num_classes']], |
|
['Avg Nodes', dataset_info['avg_nodes']], |
|
['Avg Edges', dataset_info['avg_edges']] |
|
] |
|
|
|
fig = go.Figure(data=[ |
|
go.Bar( |
|
x=[stat[0] for stat in stats], |
|
y=[stat[1] for stat in stats], |
|
marker_color='lightblue' |
|
) |
|
]) |
|
|
|
fig.update_layout( |
|
title='Dataset Statistics', |
|
xaxis_title='Metric', |
|
yaxis_title='Value' |
|
) |
|
|
|
return fig |
|
|
|
|
|
with gr.Blocks(title="Mamba Graph Neural Network") as demo: |
|
gr.Markdown(""" |
|
# 🧠 Mamba Graph Neural Network |
|
|
|
Real-time evaluation of Graph-Mamba on standard benchmarks. |
|
This uses actual datasets and trained models - no synthetic data. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
dataset_choice = gr.Dropdown( |
|
choices=['Cora', 'CiteSeer', 'PubMed', 'MUTAG', 'ENZYMES', 'PROTEINS'], |
|
value='Cora', |
|
label="Dataset" |
|
) |
|
|
|
ordering_choice = gr.Dropdown( |
|
choices=['bfs', 'spectral', 'degree', 'community'], |
|
value='bfs', |
|
label="Graph Ordering Strategy" |
|
) |
|
|
|
layers_slider = gr.Slider( |
|
minimum=2, maximum=8, value=4, step=1, |
|
label="Number of Mamba Layers" |
|
) |
|
|
|
evaluate_btn = gr.Button("Evaluate Model", variant="primary") |
|
|
|
with gr.Column(): |
|
results_text = gr.Markdown("Select parameters and click 'Evaluate Model'") |
|
|
|
with gr.Row(): |
|
visualization = gr.Plot(label="Graph Visualization") |
|
|
|
|
|
evaluate_btn.click( |
|
fn=load_and_evaluate, |
|
inputs=[dataset_choice, ordering_choice, layers_slider], |
|
outputs=[results_text, visualization] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |