kfoughali commited on
Commit
4f8aa53
ยท
verified ยท
1 Parent(s): 58746a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -160
app.py CHANGED
@@ -1,253 +1,310 @@
1
  import gradio as gr
2
  import torch
3
  import yaml
4
- import plotly.graph_objects as go
5
- import plotly.express as px
6
  from core.graph_mamba import GraphMamba
7
  from data.loader import GraphDataLoader
8
  from utils.metrics import GraphMetrics
9
- import networkx as nx
10
- import numpy as np
 
 
 
 
 
 
 
 
 
11
 
12
  # Load configuration
13
- with open('config.yaml', 'r') as f:
14
- config = yaml.safe_load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Initialize model (will be loaded dynamically based on dataset)
17
  model = None
18
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
  def load_and_evaluate(dataset_name, ordering_strategy, num_layers):
21
- """Load dataset, train/evaluate model, return results"""
22
- global model, config
23
 
24
  try:
25
  # Update config
26
  config['ordering']['strategy'] = ordering_strategy
27
  config['model']['n_layers'] = num_layers
28
 
 
 
29
  # Load data
30
  data_loader = GraphDataLoader()
31
 
32
- if dataset_name in ['Cora', 'CiteSeer', 'PubMed', 'Reddit', 'Flickr']:
33
  dataset = data_loader.load_node_classification_data(dataset_name)
34
  data = dataset[0].to(device)
35
  task_type = 'node_classification'
 
 
36
  else:
37
  dataset = data_loader.load_graph_classification_data(dataset_name)
38
- train_loader, val_loader, test_loader = data_loader.create_dataloaders(
39
- dataset, 'graph_classification'
40
- )
41
  task_type = 'graph_classification'
 
42
 
43
  # Get dataset info
44
  dataset_info = data_loader.get_dataset_info(dataset)
45
 
 
 
46
  # Initialize model
 
47
  model = GraphMamba(config).to(device)
48
 
49
- # Quick evaluation (in production, you'd load pre-trained weights)
 
 
 
 
 
 
 
 
50
  if task_type == 'node_classification':
51
  # Use test mask for evaluation
 
 
 
 
 
 
 
 
52
  metrics = GraphMetrics.evaluate_node_classification(
53
- model, data, data.test_mask, device
54
  )
55
 
56
  # Create visualization
57
- fig = create_graph_visualization(data)
 
58
 
59
  else:
60
  # Graph classification
 
 
 
61
  metrics = GraphMetrics.evaluate_graph_classification(
62
  model, test_loader, device
63
  )
64
- fig = create_dataset_stats_plot(dataset_info)
65
 
66
  # Format results
67
  results_text = f"""
68
- ## Dataset: {dataset_name}
69
-
70
- **Dataset Statistics:**
71
- - Features: {dataset_info['num_features']}
72
- - Classes: {dataset_info['num_classes']}
73
- - Graphs: {dataset_info['num_graphs']}
74
- - Avg Nodes: {dataset_info['avg_nodes']:.1f}
75
- - Avg Edges: {dataset_info['avg_edges']:.1f}
76
-
77
- **Model Configuration:**
78
- - Ordering Strategy: {ordering_strategy}
79
- - Layers: {num_layers}
80
- - Model Parameters: {sum(p.numel() for p in model.parameters()):,}
81
-
82
- **Performance Metrics:**
 
 
 
83
  """
84
 
85
  for metric, value in metrics.items():
86
- if isinstance(value, float):
87
- results_text += f"- {metric.replace('_', ' ').title()}: {value:.4f}\n"
 
 
 
 
 
 
 
 
 
88
 
 
89
  return results_text, fig
90
 
91
  except Exception as e:
92
- return f"Error: {str(e)}", None
 
93
 
94
- def create_graph_visualization(data):
95
- """Create interactive graph visualization"""
96
- try:
97
- # Convert to NetworkX
98
- G = nx.Graph()
99
- edge_list = data.edge_index.t().cpu().numpy()
100
- G.add_edges_from(edge_list)
101
-
102
- # Limit to first 1000 nodes for visualization
103
- if len(G.nodes()) > 1000:
104
- nodes_to_keep = list(G.nodes())[:1000]
105
- G = G.subgraph(nodes_to_keep)
106
-
107
- # Layout
108
- pos = nx.spring_layout(G, k=1, iterations=50)
109
-
110
- # Node colors based on labels if available
111
- node_colors = []
112
- if hasattr(data, 'y') and data.y is not None:
113
- labels = data.y.cpu().numpy()
114
- for node in G.nodes():
115
- if node < len(labels):
116
- node_colors.append(labels[node])
117
- else:
118
- node_colors.append(0)
119
- else:
120
- node_colors = [0] * len(G.nodes())
121
-
122
- # Create traces
123
- edge_x, edge_y = [], []
124
- for edge in G.edges():
125
- x0, y0 = pos[edge[0]]
126
- x1, y1 = pos[edge[1]]
127
- edge_x.extend([x0, x1, None])
128
- edge_y.extend([y0, y1, None])
129
-
130
- node_x = [pos[node][0] for node in G.nodes()]
131
- node_y = [pos[node][1] for node in G.nodes()]
132
-
133
- fig = go.Figure()
134
-
135
- # Add edges
136
- fig.add_trace(go.Scatter(
137
- x=edge_x, y=edge_y,
138
- line=dict(width=0.5, color='#888'),
139
- hoverinfo='none',
140
- mode='lines'
141
- ))
142
-
143
- # Add nodes
144
- fig.add_trace(go.Scatter(
145
- x=node_x, y=node_y,
146
- mode='markers',
147
- hoverinfo='text',
148
- text=[f'Node {i}' for i in G.nodes()],
149
- marker=dict(
150
- size=8,
151
- color=node_colors,
152
- colorscale='Viridis',
153
- line=dict(width=2)
154
- )
155
- ))
156
-
157
- fig.update_layout(
158
- title='Graph Visualization',
159
- showlegend=False,
160
- hovermode='closest',
161
- margin=dict(b=20,l=5,r=5,t=40),
162
- annotations=[
163
- dict(
164
- text="Graph structure visualization",
165
- showarrow=False,
166
- xref="paper", yref="paper",
167
- x=0.005, y=-0.002,
168
- xanchor='left', yanchor='bottom',
169
- font=dict(color="black", size=12)
170
- )
171
- ],
172
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
173
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
174
- )
175
 
176
- return fig
177
 
178
- except Exception as e:
179
  # Return empty plot on error
 
180
  fig = go.Figure()
181
- fig.add_annotation(text=f"Visualization error: {str(e)}", x=0.5, y=0.5)
182
- return fig
183
-
184
- def create_dataset_stats_plot(dataset_info):
185
- """Create dataset statistics visualization"""
186
- stats = [
187
- ['Features', dataset_info['num_features']],
188
- ['Classes', dataset_info['num_classes']],
189
- ['Avg Nodes', dataset_info['avg_nodes']],
190
- ['Avg Edges', dataset_info['avg_edges']]
191
- ]
192
-
193
- fig = go.Figure(data=[
194
- go.Bar(
195
- x=[stat[0] for stat in stats],
196
- y=[stat[1] for stat in stats],
197
- marker_color='lightblue'
198
  )
199
- ])
200
-
201
- fig.update_layout(
202
- title='Dataset Statistics',
203
- xaxis_title='Metric',
204
- yaxis_title='Value'
205
- )
206
-
207
- return fig
208
 
209
  # Gradio interface
210
- with gr.Blocks(title="Mamba Graph Neural Network") as demo:
 
 
 
 
 
 
 
 
 
211
  gr.Markdown("""
212
  # ๐Ÿง  Mamba Graph Neural Network
213
 
214
- Real-time evaluation of Graph-Mamba on standard benchmarks.
215
- This uses actual datasets and trained models - no synthetic data.
 
 
 
 
 
 
 
 
216
  """)
217
 
218
  with gr.Row():
219
- with gr.Column():
 
 
220
  dataset_choice = gr.Dropdown(
221
- choices=['Cora', 'CiteSeer', 'PubMed', 'MUTAG', 'ENZYMES', 'PROTEINS'],
222
  value='Cora',
223
- label="Dataset"
 
224
  )
225
 
226
  ordering_choice = gr.Dropdown(
227
  choices=['bfs', 'spectral', 'degree', 'community'],
228
  value='bfs',
229
- label="Graph Ordering Strategy"
 
230
  )
231
 
232
  layers_slider = gr.Slider(
233
- minimum=2, maximum=8, value=4, step=1,
234
- label="Number of Mamba Layers"
 
235
  )
236
 
237
- evaluate_btn = gr.Button("Evaluate Model", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- with gr.Column():
240
- results_text = gr.Markdown("Select parameters and click 'Evaluate Model'")
 
241
 
 
 
 
 
 
 
 
 
 
 
242
  with gr.Row():
243
- visualization = gr.Plot(label="Graph Visualization")
 
 
 
 
244
 
245
  # Event handlers
246
  evaluate_btn.click(
247
  fn=load_and_evaluate,
248
  inputs=[dataset_choice, ordering_choice, layers_slider],
249
- outputs=[results_text, visualization]
 
250
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  if __name__ == "__main__":
253
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import yaml
4
+ import os
 
5
  from core.graph_mamba import GraphMamba
6
  from data.loader import GraphDataLoader
7
  from utils.metrics import GraphMetrics
8
+ from utils.visualization import GraphVisualizer
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ # Force CPU for HuggingFace Spaces
13
+ if os.getenv('SPACE_ID') or os.getenv('GRADIO_SERVER_NAME'):
14
+ device = torch.device('cpu')
15
+ print("Running on HuggingFace Spaces - using CPU")
16
+ else:
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ print(f"Running locally - using {device}")
19
 
20
  # Load configuration
21
+ config = {
22
+ 'model': {
23
+ 'd_model': 128, # Smaller for demo
24
+ 'd_state': 8,
25
+ 'd_conv': 4,
26
+ 'expand': 2,
27
+ 'n_layers': 3, # Fewer layers for speed
28
+ 'dropout': 0.1
29
+ },
30
+ 'data': {
31
+ 'batch_size': 16,
32
+ 'test_split': 0.2
33
+ },
34
+ 'ordering': {
35
+ 'strategy': 'bfs',
36
+ 'preserve_locality': True
37
+ }
38
+ }
39
 
40
+ # Global model holder
41
  model = None
42
+ current_dataset = None
43
 
44
  def load_and_evaluate(dataset_name, ordering_strategy, num_layers):
45
+ """Load dataset, configure model, return results"""
46
+ global model, config, current_dataset
47
 
48
  try:
49
  # Update config
50
  config['ordering']['strategy'] = ordering_strategy
51
  config['model']['n_layers'] = num_layers
52
 
53
+ print(f"Loading dataset: {dataset_name}")
54
+
55
  # Load data
56
  data_loader = GraphDataLoader()
57
 
58
+ if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
59
  dataset = data_loader.load_node_classification_data(dataset_name)
60
  data = dataset[0].to(device)
61
  task_type = 'node_classification'
62
+ current_dataset = data
63
+ print(f"Loaded {dataset_name}: {data.num_nodes} nodes, {data.num_edges} edges")
64
  else:
65
  dataset = data_loader.load_graph_classification_data(dataset_name)
 
 
 
66
  task_type = 'graph_classification'
67
+ print(f"Loaded {dataset_name}: {len(dataset)} graphs")
68
 
69
  # Get dataset info
70
  dataset_info = data_loader.get_dataset_info(dataset)
71
 
72
+ print(f"Dataset info: {dataset_info}")
73
+
74
  # Initialize model
75
+ print("Initializing GraphMamba model...")
76
  model = GraphMamba(config).to(device)
77
 
78
+ # Initialize classifier for evaluation
79
+ num_classes = dataset_info['num_classes']
80
+ model._init_classifier(num_classes, device)
81
+
82
+ total_params = sum(p.numel() for p in model.parameters())
83
+ print(f"Model parameters: {total_params:,}")
84
+
85
+ # Quick evaluation (random weights for demo)
86
+ print("Running evaluation...")
87
  if task_type == 'node_classification':
88
  # Use test mask for evaluation
89
+ if hasattr(data, 'test_mask'):
90
+ mask = data.test_mask
91
+ else:
92
+ # Create a test mask if not available
93
+ num_nodes = data.num_nodes
94
+ mask = torch.zeros(num_nodes, dtype=torch.bool)
95
+ mask[num_nodes//2:] = True
96
+
97
  metrics = GraphMetrics.evaluate_node_classification(
98
+ model, data, mask, device
99
  )
100
 
101
  # Create visualization
102
+ print("Creating visualization...")
103
+ fig = GraphVisualizer.create_graph_plot(data)
104
 
105
  else:
106
  # Graph classification
107
+ train_loader, val_loader, test_loader = data_loader.create_dataloaders(
108
+ dataset, 'graph_classification'
109
+ )
110
  metrics = GraphMetrics.evaluate_graph_classification(
111
  model, test_loader, device
112
  )
113
+ fig = GraphVisualizer.create_metrics_plot(metrics)
114
 
115
  # Format results
116
  results_text = f"""
117
+ ## ๐Ÿง  Mamba Graph Neural Network Results
118
+
119
+ ### Dataset: {dataset_name}
120
+
121
+ **Dataset Statistics:**
122
+ - ๐Ÿ“Š Features: {dataset_info['num_features']}
123
+ - ๐Ÿท๏ธ Classes: {dataset_info['num_classes']}
124
+ - ๐Ÿ“ˆ Graphs: {dataset_info['num_graphs']}
125
+ - ๐Ÿ”— Avg Nodes: {dataset_info['avg_nodes']:.1f}
126
+ - ๐ŸŒ Avg Edges: {dataset_info['avg_edges']:.1f}
127
+
128
+ **Model Configuration:**
129
+ - ๐Ÿ”„ Ordering Strategy: {ordering_strategy}
130
+ - ๐Ÿ—๏ธ Layers: {num_layers}
131
+ - โš™๏ธ Parameters: {sum(p.numel() for p in model.parameters()):,}
132
+ - ๐Ÿ’พ Device: {device}
133
+
134
+ **Performance Metrics:**
135
  """
136
 
137
  for metric, value in metrics.items():
138
+ if isinstance(value, float) and metric != 'error':
139
+ results_text += f"- ๐Ÿ“ˆ {metric.replace('_', ' ').title()}: {value:.4f}\n"
140
+ elif metric == 'error':
141
+ results_text += f"- โš ๏ธ Error: {value}\n"
142
+
143
+ results_text += f"""
144
+
145
+ **Status:** โœ… Model successfully loaded and evaluated!
146
+
147
+ *Note: This is a demo with random weights. In production, the model would be trained on the dataset.*
148
+ """
149
 
150
+ print("Evaluation completed successfully!")
151
  return results_text, fig
152
 
153
  except Exception as e:
154
+ error_msg = f"""
155
+ ## โŒ Error Loading Model
156
 
157
+ **Error:** {str(e)}
158
+
159
+ **Troubleshooting:**
160
+ - Check dataset availability
161
+ - Verify device compatibility
162
+ - Try different ordering strategy
163
+
164
+ **Debug Info:**
165
+ - Device: {device}
166
+ - Dataset: {dataset_name}
167
+ - Strategy: {ordering_strategy}
168
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ print(f"Error: {e}")
171
 
 
172
  # Return empty plot on error
173
+ import plotly.graph_objects as go
174
  fig = go.Figure()
175
+ fig.add_annotation(
176
+ text=f"Error: {str(e)}",
177
+ x=0.5, y=0.5,
178
+ xref="paper", yref="paper",
179
+ showarrow=False
 
 
 
 
 
 
 
 
 
 
 
 
180
  )
181
+
182
+ return error_msg, fig
 
 
 
 
 
 
 
183
 
184
  # Gradio interface
185
+ with gr.Blocks(
186
+ title="๐Ÿง  Mamba Graph Neural Network",
187
+ theme=gr.themes.Soft(),
188
+ css="""
189
+ .gradio-container {
190
+ max-width: 1200px !important;
191
+ }
192
+ """
193
+ ) as demo:
194
+
195
  gr.Markdown("""
196
  # ๐Ÿง  Mamba Graph Neural Network
197
 
198
+ **Real-time evaluation of Graph-Mamba on standard benchmarks.**
199
+
200
+ This demonstrates the revolutionary combination of Mamba's linear complexity with graph neural networks.
201
+ Uses actual datasets and real model architectures - no synthetic data.
202
+
203
+ ๐Ÿš€ **Features:**
204
+ - Linear O(n) complexity for massive graphs
205
+ - Multiple graph ordering strategies
206
+ - Real benchmark datasets (Cora, CiteSeer, etc.)
207
+ - Interactive visualizations
208
  """)
209
 
210
  with gr.Row():
211
+ with gr.Column(scale=1):
212
+ gr.Markdown("### ๐ŸŽฎ Model Configuration")
213
+
214
  dataset_choice = gr.Dropdown(
215
+ choices=['Cora', 'CiteSeer', 'PubMed', 'MUTAG', 'ENZYMES'],
216
  value='Cora',
217
+ label="๐Ÿ“Š Dataset",
218
+ info="Choose a graph dataset for evaluation"
219
  )
220
 
221
  ordering_choice = gr.Dropdown(
222
  choices=['bfs', 'spectral', 'degree', 'community'],
223
  value='bfs',
224
+ label="๐Ÿ”„ Graph Ordering Strategy",
225
+ info="How to convert graph to sequence"
226
  )
227
 
228
  layers_slider = gr.Slider(
229
+ minimum=2, maximum=6, value=3, step=1,
230
+ label="๐Ÿ—๏ธ Number of Mamba Layers",
231
+ info="More layers = more capacity"
232
  )
233
 
234
+ evaluate_btn = gr.Button(
235
+ "๐Ÿš€ Evaluate Model",
236
+ variant="primary",
237
+ size="lg"
238
+ )
239
+
240
+ gr.Markdown("""
241
+ ### ๐Ÿ“– Ordering Strategies:
242
+ - **BFS**: Breadth-first traversal
243
+ - **Spectral**: Eigenvalue-based ordering
244
+ - **Degree**: High-degree nodes first
245
+ - **Community**: Cluster-aware ordering
246
+ """)
247
 
248
+ with gr.Column(scale=2):
249
+ results_text = gr.Markdown("""
250
+ ### ๐Ÿ‘‹ Welcome!
251
 
252
+ Select your parameters and click **'๐Ÿš€ Evaluate Model'** to see Mamba Graph in action.
253
+
254
+ The model will:
255
+ 1. ๐Ÿ“ฅ Load the selected dataset
256
+ 2. ๐Ÿ”„ Apply graph ordering strategy
257
+ 3. ๐Ÿง  Process through Mamba layers
258
+ 4. ๐Ÿ“Š Evaluate performance
259
+ 5. ๐Ÿ“ˆ Show results and visualization
260
+ """)
261
+
262
  with gr.Row():
263
+ with gr.Column():
264
+ visualization = gr.Plot(
265
+ label="๐Ÿ“ˆ Graph Visualization",
266
+ container=True
267
+ )
268
 
269
  # Event handlers
270
  evaluate_btn.click(
271
  fn=load_and_evaluate,
272
  inputs=[dataset_choice, ordering_choice, layers_slider],
273
+ outputs=[results_text, visualization],
274
+ show_progress=True
275
  )
276
+
277
+ # Example section
278
+ gr.Markdown("""
279
+ ---
280
+ ### ๐ŸŽฏ What Makes This Special?
281
+
282
+ **Traditional GNNs:** O(nยฒ) complexity limits them to small graphs
283
+
284
+ **Mamba Graph:** O(n) complexity enables processing of massive graphs
285
+
286
+ **Key Innovation:** Smart graph-to-sequence conversion preserves structural information while enabling linear-time processing.
287
+
288
+ ### ๐Ÿ”ฌ Technical Details:
289
+ - **Selective State Space Models** for sequence processing
290
+ - **Structure-preserving ordering** algorithms
291
+ - **Position encoding** to maintain graph relationships
292
+ - **Multi-scale processing** for different graph properties
293
+
294
+ ### ๐Ÿ“š References:
295
+ - Mamba: Linear-Time Sequence Modeling (Gu & Dao, 2023)
296
+ - Graph Neural Networks (Kipf & Welling, 2017)
297
+ - Spectral Graph Theory applications
298
+ """)
299
 
300
  if __name__ == "__main__":
301
+ print("๐Ÿง  Starting Mamba Graph Demo...")
302
+ print(f"Device: {device}")
303
+ print("Loading Gradio interface...")
304
+
305
+ demo.launch(
306
+ server_name="0.0.0.0",
307
+ server_port=7860,
308
+ show_error=True,
309
+ share=False # Set to False for HuggingFace Spaces
310
+ )