kfoughali commited on
Commit
4ac02ee
·
verified ·
1 Parent(s): 6d0498a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import yaml
4
+ import plotly.graph_objects as go
5
+ import plotly.express as px
6
+ from core.graph_mamba import GraphMamba
7
+ from data.loader import GraphDataLoader
8
+ from utils.metrics import GraphMetrics
9
+ import networkx as nx
10
+ import numpy as np
11
+
12
+ # Load configuration
13
+ with open('config.yaml', 'r') as f:
14
+ config = yaml.safe_load(f)
15
+
16
+ # Initialize model (will be loaded dynamically based on dataset)
17
+ model = None
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+
20
+ def load_and_evaluate(dataset_name, ordering_strategy, num_layers):
21
+ """Load dataset, train/evaluate model, return results"""
22
+ global model, config
23
+
24
+ try:
25
+ # Update config
26
+ config['ordering']['strategy'] = ordering_strategy
27
+ config['model']['n_layers'] = num_layers
28
+
29
+ # Load data
30
+ data_loader = GraphDataLoader()
31
+
32
+ if dataset_name in ['Cora', 'CiteSeer', 'PubMed', 'Reddit', 'Flickr']:
33
+ dataset = data_loader.load_node_classification_data(dataset_name)
34
+ data = dataset[0].to(device)
35
+ task_type = 'node_classification'
36
+ else:
37
+ dataset = data_loader.load_graph_classification_data(dataset_name)
38
+ train_loader, val_loader, test_loader = data_loader.create_dataloaders(
39
+ dataset, 'graph_classification'
40
+ )
41
+ task_type = 'graph_classification'
42
+
43
+ # Get dataset info
44
+ dataset_info = data_loader.get_dataset_info(dataset)
45
+
46
+ # Initialize model
47
+ model = GraphMamba(config).to(device)
48
+
49
+ # Quick evaluation (in production, you'd load pre-trained weights)
50
+ if task_type == 'node_classification':
51
+ # Use test mask for evaluation
52
+ metrics = GraphMetrics.evaluate_node_classification(
53
+ model, data, data.test_mask, device
54
+ )
55
+
56
+ # Create visualization
57
+ fig = create_graph_visualization(data)
58
+
59
+ else:
60
+ # Graph classification
61
+ metrics = GraphMetrics.evaluate_graph_classification(
62
+ model, test_loader, device
63
+ )
64
+ fig = create_dataset_stats_plot(dataset_info)
65
+
66
+ # Format results
67
+ results_text = f"""
68
+ ## Dataset: {dataset_name}
69
+
70
+ **Dataset Statistics:**
71
+ - Features: {dataset_info['num_features']}
72
+ - Classes: {dataset_info['num_classes']}
73
+ - Graphs: {dataset_info['num_graphs']}
74
+ - Avg Nodes: {dataset_info['avg_nodes']:.1f}
75
+ - Avg Edges: {dataset_info['avg_edges']:.1f}
76
+
77
+ **Model Configuration:**
78
+ - Ordering Strategy: {ordering_strategy}
79
+ - Layers: {num_layers}
80
+ - Model Parameters: {sum(p.numel() for p in model.parameters()):,}
81
+
82
+ **Performance Metrics:**
83
+ """
84
+
85
+ for metric, value in metrics.items():
86
+ if isinstance(value, float):
87
+ results_text += f"- {metric.replace('_', ' ').title()}: {value:.4f}\n"
88
+
89
+ return results_text, fig
90
+
91
+ except Exception as e:
92
+ return f"Error: {str(e)}", None
93
+
94
+ def create_graph_visualization(data):
95
+ """Create interactive graph visualization"""
96
+ try:
97
+ # Convert to NetworkX
98
+ G = nx.Graph()
99
+ edge_list = data.edge_index.t().cpu().numpy()
100
+ G.add_edges_from(edge_list)
101
+
102
+ # Limit to first 1000 nodes for visualization
103
+ if len(G.nodes()) > 1000:
104
+ nodes_to_keep = list(G.nodes())[:1000]
105
+ G = G.subgraph(nodes_to_keep)
106
+
107
+ # Layout
108
+ pos = nx.spring_layout(G, k=1, iterations=50)
109
+
110
+ # Node colors based on labels if available
111
+ node_colors = []
112
+ if hasattr(data, 'y') and data.y is not None:
113
+ labels = data.y.cpu().numpy()
114
+ for node in G.nodes():
115
+ if node < len(labels):
116
+ node_colors.append(labels[node])
117
+ else:
118
+ node_colors.append(0)
119
+ else:
120
+ node_colors = [0] * len(G.nodes())
121
+
122
+ # Create traces
123
+ edge_x, edge_y = [], []
124
+ for edge in G.edges():
125
+ x0, y0 = pos[edge[0]]
126
+ x1, y1 = pos[edge[1]]
127
+ edge_x.extend([x0, x1, None])
128
+ edge_y.extend([y0, y1, None])
129
+
130
+ node_x = [pos[node][0] for node in G.nodes()]
131
+ node_y = [pos[node][1] for node in G.nodes()]
132
+
133
+ fig = go.Figure()
134
+
135
+ # Add edges
136
+ fig.add_trace(go.Scatter(
137
+ x=edge_x, y=edge_y,
138
+ line=dict(width=0.5, color='#888'),
139
+ hoverinfo='none',
140
+ mode='lines'
141
+ ))
142
+
143
+ # Add nodes
144
+ fig.add_trace(go.Scatter(
145
+ x=node_x, y=node_y,
146
+ mode='markers',
147
+ hoverinfo='text',
148
+ text=[f'Node {i}' for i in G.nodes()],
149
+ marker=dict(
150
+ size=8,
151
+ color=node_colors,
152
+ colorscale='Viridis',
153
+ line=dict(width=2)
154
+ )
155
+ ))
156
+
157
+ fig.update_layout(
158
+ title='Graph Visualization',
159
+ showlegend=False,
160
+ hovermode='closest',
161
+ margin=dict(b=20,l=5,r=5,t=40),
162
+ annotations=[
163
+ dict(
164
+ text="Graph structure visualization",
165
+ showarrow=False,
166
+ xref="paper", yref="paper",
167
+ x=0.005, y=-0.002,
168
+ xanchor='left', yanchor='bottom',
169
+ font=dict(color="black", size=12)
170
+ )
171
+ ],
172
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
173
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
174
+ )
175
+
176
+ return fig
177
+
178
+ except Exception as e:
179
+ # Return empty plot on error
180
+ fig = go.Figure()
181
+ fig.add_annotation(text=f"Visualization error: {str(e)}", x=0.5, y=0.5)
182
+ return fig
183
+
184
+ def create_dataset_stats_plot(dataset_info):
185
+ """Create dataset statistics visualization"""
186
+ stats = [
187
+ ['Features', dataset_info['num_features']],
188
+ ['Classes', dataset_info['num_classes']],
189
+ ['Avg Nodes', dataset_info['avg_nodes']],
190
+ ['Avg Edges', dataset_info['avg_edges']]
191
+ ]
192
+
193
+ fig = go.Figure(data=[
194
+ go.Bar(
195
+ x=[stat[0] for stat in stats],
196
+ y=[stat[1] for stat in stats],
197
+ marker_color='lightblue'
198
+ )
199
+ ])
200
+
201
+ fig.update_layout(
202
+ title='Dataset Statistics',
203
+ xaxis_title='Metric',
204
+ yaxis_title='Value'
205
+ )
206
+
207
+ return fig
208
+
209
+ # Gradio interface
210
+ with gr.Blocks(title="Mamba Graph Neural Network") as demo:
211
+ gr.Markdown("""
212
+ # 🧠 Mamba Graph Neural Network
213
+
214
+ Real-time evaluation of Graph-Mamba on standard benchmarks.
215
+ This uses actual datasets and trained models - no synthetic data.
216
+ """)
217
+
218
+ with gr.Row():
219
+ with gr.Column():
220
+ dataset_choice = gr.Dropdown(
221
+ choices=['Cora', 'CiteSeer', 'PubMed', 'MUTAG', 'ENZYMES', 'PROTEINS'],
222
+ value='Cora',
223
+ label="Dataset"
224
+ )
225
+
226
+ ordering_choice = gr.Dropdown(
227
+ choices=['bfs', 'spectral', 'degree', 'community'],
228
+ value='bfs',
229
+ label="Graph Ordering Strategy"
230
+ )
231
+
232
+ layers_slider = gr.Slider(
233
+ minimum=2, maximum=8, value=4, step=1,
234
+ label="Number of Mamba Layers"
235
+ )
236
+
237
+ evaluate_btn = gr.Button("Evaluate Model", variant="primary")
238
+
239
+ with gr.Column():
240
+ results_text = gr.Markdown("Select parameters and click 'Evaluate Model'")
241
+
242
+ with gr.Row():
243
+ visualization = gr.Plot(label="Graph Visualization")
244
+
245
+ # Event handlers
246
+ evaluate_btn.click(
247
+ fn=load_and_evaluate,
248
+ inputs=[dataset_choice, ordering_choice, layers_slider],
249
+ outputs=[results_text, visualization]
250
+ )
251
+
252
+ if __name__ == "__main__":
253
+ demo.launch(share=True)