import gradio as gr import torch import yaml import plotly.graph_objects as go import plotly.express as px from core.graph_mamba import GraphMamba from data.loader import GraphDataLoader from utils.metrics import GraphMetrics import networkx as nx import numpy as np # Load configuration with open('config.yaml', 'r') as f: config = yaml.safe_load(f) # Initialize model (will be loaded dynamically based on dataset) model = None device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_and_evaluate(dataset_name, ordering_strategy, num_layers): """Load dataset, train/evaluate model, return results""" global model, config try: # Update config config['ordering']['strategy'] = ordering_strategy config['model']['n_layers'] = num_layers # Load data data_loader = GraphDataLoader() if dataset_name in ['Cora', 'CiteSeer', 'PubMed', 'Reddit', 'Flickr']: dataset = data_loader.load_node_classification_data(dataset_name) data = dataset[0].to(device) task_type = 'node_classification' else: dataset = data_loader.load_graph_classification_data(dataset_name) train_loader, val_loader, test_loader = data_loader.create_dataloaders( dataset, 'graph_classification' ) task_type = 'graph_classification' # Get dataset info dataset_info = data_loader.get_dataset_info(dataset) # Initialize model model = GraphMamba(config).to(device) # Quick evaluation (in production, you'd load pre-trained weights) if task_type == 'node_classification': # Use test mask for evaluation metrics = GraphMetrics.evaluate_node_classification( model, data, data.test_mask, device ) # Create visualization fig = create_graph_visualization(data) else: # Graph classification metrics = GraphMetrics.evaluate_graph_classification( model, test_loader, device ) fig = create_dataset_stats_plot(dataset_info) # Format results results_text = f""" ## Dataset: {dataset_name} **Dataset Statistics:** - Features: {dataset_info['num_features']} - Classes: {dataset_info['num_classes']} - Graphs: {dataset_info['num_graphs']} - Avg Nodes: {dataset_info['avg_nodes']:.1f} - Avg Edges: {dataset_info['avg_edges']:.1f} **Model Configuration:** - Ordering Strategy: {ordering_strategy} - Layers: {num_layers} - Model Parameters: {sum(p.numel() for p in model.parameters()):,} **Performance Metrics:** """ for metric, value in metrics.items(): if isinstance(value, float): results_text += f"- {metric.replace('_', ' ').title()}: {value:.4f}\n" return results_text, fig except Exception as e: return f"Error: {str(e)}", None def create_graph_visualization(data): """Create interactive graph visualization""" try: # Convert to NetworkX G = nx.Graph() edge_list = data.edge_index.t().cpu().numpy() G.add_edges_from(edge_list) # Limit to first 1000 nodes for visualization if len(G.nodes()) > 1000: nodes_to_keep = list(G.nodes())[:1000] G = G.subgraph(nodes_to_keep) # Layout pos = nx.spring_layout(G, k=1, iterations=50) # Node colors based on labels if available node_colors = [] if hasattr(data, 'y') and data.y is not None: labels = data.y.cpu().numpy() for node in G.nodes(): if node < len(labels): node_colors.append(labels[node]) else: node_colors.append(0) else: node_colors = [0] * len(G.nodes()) # Create traces edge_x, edge_y = [], [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) node_x = [pos[node][0] for node in G.nodes()] node_y = [pos[node][1] for node in G.nodes()] fig = go.Figure() # Add edges fig.add_trace(go.Scatter( x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines' )) # Add nodes fig.add_trace(go.Scatter( x=node_x, y=node_y, mode='markers', hoverinfo='text', text=[f'Node {i}' for i in G.nodes()], marker=dict( size=8, color=node_colors, colorscale='Viridis', line=dict(width=2) ) )) fig.update_layout( title='Graph Visualization', showlegend=False, hovermode='closest', margin=dict(b=20,l=5,r=5,t=40), annotations=[ dict( text="Graph structure visualization", showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002, xanchor='left', yanchor='bottom', font=dict(color="black", size=12) ) ], xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False) ) return fig except Exception as e: # Return empty plot on error fig = go.Figure() fig.add_annotation(text=f"Visualization error: {str(e)}", x=0.5, y=0.5) return fig def create_dataset_stats_plot(dataset_info): """Create dataset statistics visualization""" stats = [ ['Features', dataset_info['num_features']], ['Classes', dataset_info['num_classes']], ['Avg Nodes', dataset_info['avg_nodes']], ['Avg Edges', dataset_info['avg_edges']] ] fig = go.Figure(data=[ go.Bar( x=[stat[0] for stat in stats], y=[stat[1] for stat in stats], marker_color='lightblue' ) ]) fig.update_layout( title='Dataset Statistics', xaxis_title='Metric', yaxis_title='Value' ) return fig # Gradio interface with gr.Blocks(title="Mamba Graph Neural Network") as demo: gr.Markdown(""" # 🧠 Mamba Graph Neural Network Real-time evaluation of Graph-Mamba on standard benchmarks. This uses actual datasets and trained models - no synthetic data. """) with gr.Row(): with gr.Column(): dataset_choice = gr.Dropdown( choices=['Cora', 'CiteSeer', 'PubMed', 'MUTAG', 'ENZYMES', 'PROTEINS'], value='Cora', label="Dataset" ) ordering_choice = gr.Dropdown( choices=['bfs', 'spectral', 'degree', 'community'], value='bfs', label="Graph Ordering Strategy" ) layers_slider = gr.Slider( minimum=2, maximum=8, value=4, step=1, label="Number of Mamba Layers" ) evaluate_btn = gr.Button("Evaluate Model", variant="primary") with gr.Column(): results_text = gr.Markdown("Select parameters and click 'Evaluate Model'") with gr.Row(): visualization = gr.Plot(label="Graph Visualization") # Event handlers evaluate_btn.click( fn=load_and_evaluate, inputs=[dataset_choice, ordering_choice, layers_slider], outputs=[results_text, visualization] ) if __name__ == "__main__": demo.launch(share=True)