kfoughali commited on
Commit
c6e11c4
·
verified ·
1 Parent(s): 454d2b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -281
app.py CHANGED
@@ -1,296 +1,148 @@
 
 
 
 
 
 
 
 
1
  import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch_geometric.utils import degree, to_dense_adj
5
- from torch_geometric.nn import GCNConv
6
- import networkx as nx
7
  import logging
 
 
 
 
 
 
8
 
 
9
  logger = logging.getLogger(__name__)
10
 
11
- class MambaBlock(nn.Module):
12
- def __init__(self, d_model, d_state=4, d_conv=4, expand=2):
13
- super().__init__()
14
- self.d_model = d_model
15
- self.d_inner = int(expand * d_model)
16
- self.d_state = d_state
17
-
18
- self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
19
- self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1)
20
- self.act = nn.SiLU()
21
- self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
22
- self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
23
-
24
- A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1)
25
- self.A_log = nn.Parameter(torch.log(A))
26
- self.D = nn.Parameter(torch.ones(self.d_inner))
27
- self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
28
- self.dropout = nn.Dropout(0.3)
29
-
30
- def forward(self, x):
31
- batch, length, d_model = x.shape
32
- xz = self.in_proj(x)
33
- x, z = xz.chunk(2, dim=-1)
34
-
35
- x = x.transpose(1, 2)
36
- x = self.conv1d(x)[:, :, :length]
37
- x = x.transpose(1, 2)
38
- x = self.act(x)
39
- x = self.dropout(x)
40
-
41
- y = self.selective_scan(x)
42
- y = y * self.act(z)
43
- return self.dropout(self.out_proj(y))
44
 
45
- def selective_scan(self, x):
46
- batch, length, d_inner = x.shape
47
- deltaBC = self.x_proj(x)
48
- delta, B, C = torch.split(deltaBC, [1, self.d_state, self.d_state], dim=-1)
49
- delta = F.softplus(self.dt_proj(delta))
50
-
51
- deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log)))
52
- deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
53
-
54
- states = torch.zeros(batch, d_inner, self.d_state, device=x.device)
55
- outputs = []
56
-
57
- for i in range(length):
58
- states = deltaA[:, i] * states + deltaB[:, i] * x[:, i, :, None]
59
- y = (states @ C[:, i, :, None]).squeeze(-1) + self.D * x[:, i]
60
- outputs.append(y)
 
 
 
 
 
 
 
61
 
62
- return torch.stack(outputs, dim=1)
63
-
64
-
65
- class GraphStructureEncoder(nn.Module):
66
- """Encode graph structure to preserve in sequential processing"""
67
- def __init__(self, d_model):
68
- super().__init__()
69
- self.adjacency_proj = nn.Linear(1, d_model)
70
- self.structure_attention = nn.MultiheadAttention(d_model, num_heads=4, batch_first=True)
71
- self.norm = nn.LayerNorm(d_model)
72
-
73
- def forward(self, x, edge_index):
74
- # Create adjacency features
75
- adj = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
76
-
77
- # Add self-connections and normalize
78
- adj = adj + torch.eye(adj.size(0), device=adj.device)
79
- deg = adj.sum(dim=1, keepdim=True)
80
- adj_norm = adj / (deg + 1e-8)
81
-
82
- # Project adjacency to feature space
83
- adj_features = self.adjacency_proj(adj_norm.unsqueeze(-1))
84
-
85
- # Attention over structure
86
- x_with_structure = x.unsqueeze(0) # Add batch dim
87
- adj_features = adj_features.unsqueeze(0)
88
-
89
- attended, _ = self.structure_attention(x_with_structure, adj_features, adj_features)
90
-
91
- return self.norm(x + attended.squeeze(0))
92
-
93
-
94
- class SpectralOrdering:
95
- """Spectral graph ordering to preserve structure"""
96
- @staticmethod
97
- def compute_ordering(edge_index, num_nodes):
98
- try:
99
- # Create adjacency matrix
100
- adj = to_dense_adj(edge_index, max_num_nodes=num_nodes).squeeze(0)
101
 
102
- # Add self-loops
103
- adj = adj + torch.eye(num_nodes, device=adj.device)
 
104
 
105
- # Compute degree matrix
106
- deg = torch.diag(adj.sum(dim=1))
 
107
 
108
- # Laplacian
109
- L = deg - adj
110
 
111
- # Eigendecomposition (use only first few eigenvectors)
112
- try:
113
- eigenvals, eigenvecs = torch.linalg.eigh(L)
114
- # Sort by second smallest eigenvalue (Fiedler vector)
115
- fiedler = eigenvecs[:, 1]
116
- order = torch.argsort(fiedler)
117
- return order
118
- except:
119
- # Fallback to degree ordering
120
- degrees = adj.sum(dim=1)
121
- return torch.argsort(degrees, descending=True)
122
-
123
- except:
124
- return torch.arange(num_nodes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
 
 
 
 
 
 
 
126
 
127
- class GraphMamba(nn.Module):
128
- """Enhanced GraphMamba with structure preservation"""
129
- def __init__(self, config):
130
- super().__init__()
131
-
132
- self.config = config
133
- d_model = config['model']['d_model']
134
- n_layers = config['model']['n_layers']
135
- input_dim = config.get('input_dim', 1433)
136
-
137
- # Input processing
138
- self.input_proj = nn.Linear(input_dim, d_model)
139
- self.input_dropout = nn.Dropout(0.5)
140
-
141
- # Graph structure encoding
142
- self.structure_encoder = GraphStructureEncoder(d_model)
143
-
144
- # Positional encoding
145
- self.pos_encoding = nn.Embedding(5000, d_model)
146
- self.degree_encoding = nn.Embedding(100, d_model)
147
-
148
- # Mamba layers
149
- self.mamba_layers = nn.ModuleList([
150
- MambaBlock(d_model, d_state=4) for _ in range(n_layers)
151
- ])
152
- self.layer_norms = nn.ModuleList([
153
- nn.LayerNorm(d_model) for _ in range(n_layers)
154
- ])
155
-
156
- self.hidden_dropout = nn.Dropout(0.5)
157
- self.output_proj = nn.Linear(d_model, d_model)
158
-
159
- # Classifier
160
- self.classifier = None
161
-
162
- def _get_ordering(self, edge_index, num_nodes):
163
- """Get node ordering based on strategy"""
164
- strategy = self.config['ordering']['strategy']
165
-
166
- if strategy == 'spectral':
167
- return SpectralOrdering.compute_ordering(edge_index, num_nodes)
168
- elif strategy == 'degree':
169
- degrees = degree(edge_index[0], num_nodes)
170
- return torch.argsort(degrees, descending=True)
171
- else: # bfs
172
- return torch.arange(num_nodes)
173
-
174
- def forward(self, x, edge_index, batch=None):
175
- # Input projection
176
- h = self.input_dropout(self.input_proj(x))
177
-
178
- # Add structural information
179
- h = self.structure_encoder(h, edge_index)
180
-
181
- # Add positional encodings
182
- positions = torch.arange(h.size(0), device=h.device).clamp(max=4999)
183
- degrees = degree(edge_index[0], h.size(0)).long().clamp(max=99)
184
-
185
- h = h + self.pos_encoding(positions) + self.degree_encoding(degrees)
186
-
187
- # Get ordering
188
- order = self._get_ordering(edge_index, h.size(0))
189
- h_ordered = h[order].unsqueeze(0)
190
-
191
- # Process through Mamba layers
192
- for mamba, ln in zip(self.mamba_layers, self.layer_norms):
193
- residual = h_ordered
194
- h_ordered = ln(h_ordered)
195
- h_ordered = residual + mamba(h_ordered)
196
- h_ordered = self.hidden_dropout(h_ordered)
197
-
198
- # Restore order
199
- h_restored = torch.zeros_like(h_ordered.squeeze(0))
200
- h_restored[order] = h_ordered.squeeze(0)
201
-
202
- return self.output_proj(h_restored)
203
-
204
- def _init_classifier(self, num_classes, device):
205
- if self.classifier is None:
206
- self.classifier = nn.Sequential(
207
- nn.Dropout(0.5),
208
- nn.Linear(self.config['model']['d_model'], num_classes)
209
- ).to(device)
210
-
211
- def get_performance_stats(self):
212
- total_params = sum(p.numel() for p in self.parameters())
213
- return {
214
- 'total_params': total_params,
215
- 'device': next(self.parameters()).device,
216
- 'dtype': next(self.parameters()).dtype,
217
- 'model_size': f"{total_params/1000:.1f}K parameters"
218
- }
219
-
220
 
221
- class HybridGraphMamba(nn.Module):
222
- """Hybrid approach: Mamba + minimal GCN"""
223
- def __init__(self, config):
224
- super().__init__()
225
-
226
- d_model = config['model']['d_model']
227
- input_dim = config.get('input_dim', 1433)
228
-
229
- # Mamba branch
230
- self.mamba = GraphMamba(config)
231
-
232
- # GCN branch (single layer)
233
- self.gcn = GCNConv(input_dim, d_model)
234
-
235
- # Fusion
236
- self.fusion = nn.Sequential(
237
- nn.Linear(d_model * 2, d_model),
238
- nn.ReLU(),
239
- nn.Dropout(0.3),
240
- nn.Linear(d_model, d_model)
241
- )
242
-
243
- self.classifier = None
244
- self.config = config
245
-
246
- def forward(self, x, edge_index, batch=None):
247
- # Mamba branch
248
- mamba_out = self.mamba(x, edge_index, batch)
249
-
250
- # GCN branch
251
- gcn_out = F.dropout(F.relu(self.gcn(x, edge_index)), 0.5, training=self.training)
252
-
253
- # Fuse
254
- combined = torch.cat([mamba_out, gcn_out], dim=-1)
255
- return self.fusion(combined)
256
 
257
- def _init_classifier(self, num_classes, device):
258
- if self.classifier is None:
259
- self.classifier = nn.Sequential(
260
- nn.Dropout(0.5),
261
- nn.Linear(self.config['model']['d_model'], num_classes)
262
- ).to(device)
263
-
264
- def get_performance_stats(self):
265
- return self.mamba.get_performance_stats()
266
-
267
-
268
- def create_regularized_config():
269
- """Optimized config with structure preservation"""
270
- return {
271
- 'model': {
272
- 'd_model': 64,
273
- 'd_state': 4,
274
- 'd_conv': 4,
275
- 'expand': 2,
276
- 'n_layers': 2,
277
- 'dropout': 0.5
278
- },
279
- 'data': {
280
- 'batch_size': 1,
281
- 'test_split': 0.2
282
- },
283
- 'training': {
284
- 'learning_rate': 0.001, # Slightly higher
285
- 'weight_decay': 0.01,
286
- 'epochs': 200,
287
- 'patience': 15,
288
- 'warmup_epochs': 10,
289
- 'min_lr': 1e-6
290
- },
291
- 'ordering': {
292
- 'strategy': 'spectral', # Changed from bfs
293
- 'preserve_locality': True
294
- },
295
- 'input_dim': 1433
296
- }
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Enhanced Mamba Graph with structure preservation and interface fix
4
+ """
5
+
6
+ import os
7
+ os.environ['OMP_NUM_THREADS'] = '4'
8
+
9
  import torch
10
+ import time
 
 
 
 
11
  import logging
12
+ import threading
13
+ import signal
14
+ from core.graph_mamba import GraphMamba, HybridGraphMamba, create_regularized_config
15
+ from core.trainer import GraphMambaTrainer
16
+ from data.loader import GraphDataLoader
17
+ from utils.visualization import GraphVisualizer
18
 
19
+ logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
+ def get_device():
23
+ if torch.cuda.is_available():
24
+ device = torch.device('cuda')
25
+ logger.info(f"🚀 CUDA available - using GPU: {torch.cuda.get_device_name()}")
26
+ else:
27
+ device = torch.device('cpu')
28
+ logger.info("💻 Using CPU")
29
+ return device
30
+
31
+ def run_comprehensive_test():
32
+ """Enhanced test with structure preservation"""
33
+ print("🧠 Enhanced Mamba Graph Neural Network")
34
+ print("=" * 60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ config = create_regularized_config()
37
+ device = get_device()
38
+
39
+ try:
40
+ # Data loading
41
+ print("\n📊 Loading Cora dataset...")
42
+ data_loader = GraphDataLoader()
43
+ dataset = data_loader.load_node_classification_data('Cora')
44
+ data = dataset[0].to(device)
45
+ info = data_loader.get_dataset_info(dataset)
46
+
47
+ print(f"✅ Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges")
48
+
49
+ # Test both models
50
+ models_to_test = [
51
+ ("Enhanced GraphMamba", GraphMamba),
52
+ ("Hybrid GraphMamba", HybridGraphMamba)
53
+ ]
54
+
55
+ results = {}
56
+
57
+ for model_name, model_class in models_to_test:
58
+ print(f"\n🏗️ Testing {model_name}...")
59
 
60
+ model = model_class(config).to(device)
61
+ total_params = sum(p.numel() for p in model.parameters())
62
+ train_samples = data.train_mask.sum().item()
63
+
64
+ print(f" Parameters: {total_params:,} ({total_params/train_samples:.1f} per sample)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # Training
67
+ trainer = GraphMambaTrainer(model, config, device)
68
+ print(f" Strategy: {config['ordering']['strategy']}")
69
 
70
+ start_time = time.time()
71
+ history = trainer.train_node_classification(data, verbose=False)
72
+ training_time = time.time() - start_time
73
 
74
+ # Evaluation
75
+ test_metrics = trainer.test(data)
76
 
77
+ results[model_name] = {
78
+ 'test_acc': test_metrics['test_acc'],
79
+ 'val_acc': trainer.best_val_acc,
80
+ 'gap': trainer.best_gap,
81
+ 'params': total_params,
82
+ 'time': training_time
83
+ }
84
+
85
+ print(f" ✅ Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
86
+ print(f" 📊 Validation: {trainer.best_val_acc:.4f}")
87
+ print(f" 🎯 Gap: {trainer.best_gap:.4f}")
88
+ print(f" ⏱️ Time: {training_time:.1f}s")
89
+
90
+ # Comparison
91
+ print(f"\n📈 Model Comparison:")
92
+ print(f"{'Model':<20} {'Test Acc':<10} {'Val Acc':<10} {'Gap':<8} {'Params':<8}")
93
+ print("-" * 60)
94
+
95
+ for name, result in results.items():
96
+ print(f"{name:<20} {result['test_acc']:.4f} {result['val_acc']:.4f} "
97
+ f"{result['gap']:>6.3f} {result['params']/1000:.0f}K")
98
+
99
+ # Best model
100
+ best_model = max(results.items(), key=lambda x: x[1]['test_acc'])
101
+ print(f"\n🏆 Best: {best_model[0]} - {best_model[1]['test_acc']*100:.2f}% accuracy")
102
+
103
+ # Baseline comparison
104
+ baselines = {'Random': 0.143, 'GCN': 0.815, 'GAT': 0.830}
105
+ best_acc = best_model[1]['test_acc']
106
+
107
+ print(f"\n📊 vs Baselines:")
108
+ for baseline, acc in baselines.items():
109
+ diff = best_acc - acc
110
+ status = "🟢" if diff > 0 else "🔴"
111
+ print(f" {status} {baseline}: {acc:.3f} (diff: {diff:+.3f})")
112
+
113
+ print(f"\n✨ Testing complete! Process staying alive for interface...")
114
+
115
+ except Exception as e:
116
+ print(f"❌ Error: {e}")
117
+ print("Process staying alive despite error...")
118
 
119
+ def keep_alive():
120
+ """Keep process running for interface"""
121
+ try:
122
+ while True:
123
+ time.sleep(60)
124
+ except KeyboardInterrupt:
125
+ print("\n👋 Shutting down gracefully...")
126
 
127
+ def run_background():
128
+ """Run test in background thread"""
129
+ try:
130
+ run_comprehensive_test()
131
+ except Exception as e:
132
+ print(f"Background test error: {e}")
133
+ finally:
134
+ print("Background test complete, keeping alive...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ if __name__ == "__main__":
137
+ # Start test in background thread
138
+ test_thread = threading.Thread(target=run_background, daemon=True)
139
+ test_thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ # Keep main thread alive for interface
142
+ try:
143
+ keep_alive()
144
+ except KeyboardInterrupt:
145
+ print("\nExiting...")
146
+ except Exception as e:
147
+ print(f"Main thread error: {e}")
148
+ keep_alive() # Still try to keep alive