import gradio as gr import torch import yaml import os from core.graph_mamba import GraphMamba from data.loader import GraphDataLoader from utils.metrics import GraphMetrics from utils.visualization import GraphVisualizer import warnings warnings.filterwarnings('ignore') # Force CPU for HuggingFace Spaces if os.getenv('SPACE_ID') or os.getenv('GRADIO_SERVER_NAME'): device = torch.device('cpu') print("Running on HuggingFace Spaces - using CPU") else: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Running locally - using {device}") # Load configuration config = { 'model': { 'd_model': 128, # Smaller for demo 'd_state': 8, 'd_conv': 4, 'expand': 2, 'n_layers': 3, # Fewer layers for speed 'dropout': 0.1 }, 'data': { 'batch_size': 16, 'test_split': 0.2 }, 'ordering': { 'strategy': 'bfs', 'preserve_locality': True } } # Global model holder model = None current_dataset = None def load_and_evaluate(dataset_name, ordering_strategy, num_layers): """Load dataset, configure model, return results""" global model, config, current_dataset try: # Update config config['ordering']['strategy'] = ordering_strategy config['model']['n_layers'] = num_layers print(f"Loading dataset: {dataset_name}") # Load data data_loader = GraphDataLoader() if dataset_name in ['Cora', 'CiteSeer', 'PubMed']: dataset = data_loader.load_node_classification_data(dataset_name) data = dataset[0].to(device) task_type = 'node_classification' current_dataset = data print(f"Loaded {dataset_name}: {data.num_nodes} nodes, {data.num_edges} edges") else: dataset = data_loader.load_graph_classification_data(dataset_name) task_type = 'graph_classification' print(f"Loaded {dataset_name}: {len(dataset)} graphs") # Get dataset info dataset_info = data_loader.get_dataset_info(dataset) print(f"Dataset info: {dataset_info}") # Initialize model print("Initializing GraphMamba model...") model = GraphMamba(config).to(device) # Initialize classifier for evaluation num_classes = dataset_info['num_classes'] model._init_classifier(num_classes, device) total_params = sum(p.numel() for p in model.parameters()) print(f"Model parameters: {total_params:,}") # Quick evaluation (random weights for demo) print("Running evaluation...") if task_type == 'node_classification': # Use test mask for evaluation if hasattr(data, 'test_mask'): mask = data.test_mask else: # Create a test mask if not available num_nodes = data.num_nodes mask = torch.zeros(num_nodes, dtype=torch.bool) mask[num_nodes//2:] = True metrics = GraphMetrics.evaluate_node_classification( model, data, mask, device ) # Create visualization print("Creating visualization...") fig = GraphVisualizer.create_graph_plot(data) else: # Graph classification train_loader, val_loader, test_loader = data_loader.create_dataloaders( dataset, 'graph_classification' ) metrics = GraphMetrics.evaluate_graph_classification( model, test_loader, device ) fig = GraphVisualizer.create_metrics_plot(metrics) # Format results results_text = f""" ## 🧠 Mamba Graph Neural Network Results ### Dataset: {dataset_name} **Dataset Statistics:** - šŸ“Š Features: {dataset_info['num_features']} - šŸ·ļø Classes: {dataset_info['num_classes']} - šŸ“ˆ Graphs: {dataset_info['num_graphs']} - šŸ”— Avg Nodes: {dataset_info['avg_nodes']:.1f} - 🌐 Avg Edges: {dataset_info['avg_edges']:.1f} **Model Configuration:** - šŸ”„ Ordering Strategy: {ordering_strategy} - šŸ—ļø Layers: {num_layers} - āš™ļø Parameters: {sum(p.numel() for p in model.parameters()):,} - šŸ’¾ Device: {device} **Performance Metrics:** """ for metric, value in metrics.items(): if isinstance(value, float) and metric != 'error': results_text += f"- šŸ“ˆ {metric.replace('_', ' ').title()}: {value:.4f}\n" elif metric == 'error': results_text += f"- āš ļø Error: {value}\n" results_text += f""" **Status:** āœ… Model successfully loaded and evaluated! *Note: This is a demo with random weights. In production, the model would be trained on the dataset.* """ print("Evaluation completed successfully!") return results_text, fig except Exception as e: error_msg = f""" ## āŒ Error Loading Model **Error:** {str(e)} **Troubleshooting:** - Check dataset availability - Verify device compatibility - Try different ordering strategy **Debug Info:** - Device: {device} - Dataset: {dataset_name} - Strategy: {ordering_strategy} """ print(f"Error: {e}") # Return empty plot on error import plotly.graph_objects as go fig = go.Figure() fig.add_annotation( text=f"Error: {str(e)}", x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False ) return error_msg, fig # Gradio interface with gr.Blocks( title="🧠 Mamba Graph Neural Network", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } """ ) as demo: gr.Markdown(""" # 🧠 Mamba Graph Neural Network **Real-time evaluation of Graph-Mamba on standard benchmarks.** This demonstrates the revolutionary combination of Mamba's linear complexity with graph neural networks. Uses actual datasets and real model architectures - no synthetic data. šŸš€ **Features:** - Linear O(n) complexity for massive graphs - Multiple graph ordering strategies - Real benchmark datasets (Cora, CiteSeer, etc.) - Interactive visualizations """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### šŸŽ® Model Configuration") dataset_choice = gr.Dropdown( choices=['Cora', 'CiteSeer', 'PubMed', 'MUTAG', 'ENZYMES'], value='Cora', label="šŸ“Š Dataset", info="Choose a graph dataset for evaluation" ) ordering_choice = gr.Dropdown( choices=['bfs', 'spectral', 'degree', 'community'], value='bfs', label="šŸ”„ Graph Ordering Strategy", info="How to convert graph to sequence" ) layers_slider = gr.Slider( minimum=2, maximum=6, value=3, step=1, label="šŸ—ļø Number of Mamba Layers", info="More layers = more capacity" ) evaluate_btn = gr.Button( "šŸš€ Evaluate Model", variant="primary", size="lg" ) gr.Markdown(""" ### šŸ“– Ordering Strategies: - **BFS**: Breadth-first traversal - **Spectral**: Eigenvalue-based ordering - **Degree**: High-degree nodes first - **Community**: Cluster-aware ordering """) with gr.Column(scale=2): results_text = gr.Markdown(""" ### šŸ‘‹ Welcome! Select your parameters and click **'šŸš€ Evaluate Model'** to see Mamba Graph in action. The model will: 1. šŸ“„ Load the selected dataset 2. šŸ”„ Apply graph ordering strategy 3. 🧠 Process through Mamba layers 4. šŸ“Š Evaluate performance 5. šŸ“ˆ Show results and visualization """) with gr.Row(): with gr.Column(): visualization = gr.Plot( label="šŸ“ˆ Graph Visualization", container=True ) # Event handlers evaluate_btn.click( fn=load_and_evaluate, inputs=[dataset_choice, ordering_choice, layers_slider], outputs=[results_text, visualization], show_progress=True ) # Example section gr.Markdown(""" --- ### šŸŽÆ What Makes This Special? **Traditional GNNs:** O(n²) complexity limits them to small graphs **Mamba Graph:** O(n) complexity enables processing of massive graphs **Key Innovation:** Smart graph-to-sequence conversion preserves structural information while enabling linear-time processing. ### šŸ”¬ Technical Details: - **Selective State Space Models** for sequence processing - **Structure-preserving ordering** algorithms - **Position encoding** to maintain graph relationships - **Multi-scale processing** for different graph properties ### šŸ“š References: - Mamba: Linear-Time Sequence Modeling (Gu & Dao, 2023) - Graph Neural Networks (Kipf & Welling, 2017) - Spectral Graph Theory applications """) if __name__ == "__main__": print("🧠 Starting Mamba Graph Demo...") print(f"Device: {device}") print("Loading Gradio interface...") demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True, share=False # Set to False for HuggingFace Spaces )