kfoughali commited on
Commit
a7a0326
·
verified ·
1 Parent(s): 93db32e

Update core/graph_mamba.py

Browse files
Files changed (1) hide show
  1. core/graph_mamba.py +113 -200
core/graph_mamba.py CHANGED
@@ -1,16 +1,15 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- from torch_geometric.utils import degree, to_dense_batch
5
  import networkx as nx
6
- import numpy as np
7
  import logging
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
  class MambaBlock(nn.Module):
12
- """Enhanced Mamba block with optimizations"""
13
- def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
14
  super().__init__()
15
  self.d_model = d_model
16
  self.d_inner = int(expand * d_model)
@@ -27,6 +26,9 @@ class MambaBlock(nn.Module):
27
  self.D = nn.Parameter(torch.ones(self.d_inner))
28
  self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
29
 
 
 
 
30
  def forward(self, x):
31
  batch, length, d_model = x.shape
32
  xz = self.in_proj(x)
@@ -36,10 +38,11 @@ class MambaBlock(nn.Module):
36
  x = self.conv1d(x)[:, :, :length]
37
  x = x.transpose(1, 2)
38
  x = self.act(x)
 
39
 
40
  y = self.selective_scan(x)
41
  y = y * self.act(z)
42
- return self.out_proj(y)
43
 
44
  def selective_scan(self, x):
45
  batch, length, d_inner = x.shape
@@ -61,255 +64,165 @@ class MambaBlock(nn.Module):
61
  return torch.stack(outputs, dim=1)
62
 
63
 
64
- class EnhancedGraphOrdering:
65
- """Advanced graph ordering strategies"""
66
 
67
  @staticmethod
68
- def pagerank_ordering(edge_index, num_nodes):
69
- """PageRank-based ordering preserving importance"""
70
- try:
71
- G = nx.Graph()
72
- if edge_index.size(1) > 0:
73
- edges = edge_index.t().cpu().numpy()
74
- G.add_edges_from(edges)
75
- G.add_nodes_from(range(num_nodes))
76
-
77
- pagerank = nx.pagerank(G, max_iter=50)
78
- order = sorted(range(num_nodes), key=lambda x: pagerank.get(x, 0), reverse=True)
79
- return torch.tensor(order, dtype=torch.long)
80
- except:
81
- return torch.arange(num_nodes, dtype=torch.long)
82
 
83
  @staticmethod
84
- def community_aware_ordering(edge_index, num_nodes):
85
- """Community-preserving ordering"""
86
- try:
87
- G = nx.Graph()
88
- if edge_index.size(1) > 0:
89
- edges = edge_index.t().cpu().numpy()
90
- G.add_edges_from(edges)
91
- G.add_nodes_from(range(num_nodes))
92
-
93
- communities = nx.community.greedy_modularity_communities(G)
94
- order = []
95
- for community in communities:
96
- # Sort within community by degree
97
- community_list = list(community)
98
- degrees = {node: G.degree(node) for node in community_list}
99
- community_sorted = sorted(community_list, key=lambda x: degrees[x], reverse=True)
100
- order.extend(community_sorted)
101
-
102
- return torch.tensor(order, dtype=torch.long)
103
- except:
104
- return torch.arange(num_nodes, dtype=torch.long)
105
 
106
 
107
- class StructuralEncoding(nn.Module):
108
- """Multi-faceted structural encoding"""
109
- def __init__(self, d_model, max_nodes=5000, max_degree=100):
110
  super().__init__()
111
- self.pos_encoding = nn.Embedding(max_nodes, d_model)
112
  self.degree_encoding = nn.Embedding(max_degree, d_model)
113
- self.centrality_proj = nn.Linear(1, d_model)
114
- self.layer_norm = nn.LayerNorm(d_model)
115
-
116
- def forward(self, x, edge_index, node_order=None):
117
- num_nodes = x.size(0)
118
- device = x.device
119
-
120
- # Position encoding
121
- positions = torch.arange(num_nodes, device=device).clamp(max=self.pos_encoding.num_embeddings-1)
122
- pos_emb = self.pos_encoding(positions)
123
-
124
- # Degree encoding
125
- degrees = degree(edge_index[0], num_nodes).long().clamp(max=self.degree_encoding.num_embeddings-1)
126
- degree_emb = self.degree_encoding(degrees)
127
-
128
- # Simple centrality (normalized degree)
129
- centrality = degrees.float() / max(degrees.max().item(), 1.0)
130
- centrality_emb = self.centrality_proj(centrality.unsqueeze(-1))
131
-
132
- # Combine encodings
133
- structural_emb = pos_emb + degree_emb + centrality_emb
134
- return self.layer_norm(x + structural_emb)
135
-
136
-
137
- class MultiScaleGraphMamba(nn.Module):
138
- """Multi-scale processing with different orderings"""
139
- def __init__(self, d_model, n_layers=3):
140
- super().__init__()
141
- self.d_model = d_model
142
-
143
- # Different scale processors
144
- self.local_mamba = nn.ModuleList([MambaBlock(d_model) for _ in range(n_layers//2)])
145
- self.global_mamba = nn.ModuleList([MambaBlock(d_model) for _ in range(n_layers//2)])
146
-
147
- # Fusion layers
148
- self.scale_fusion = nn.Linear(d_model * 2, d_model)
149
  self.layer_norm = nn.LayerNorm(d_model)
 
150
 
151
  def forward(self, x, edge_index):
152
  num_nodes = x.size(0)
153
 
154
- # Different orderings
155
- local_order = torch.arange(num_nodes) # BFS equivalent
156
- global_order = EnhancedGraphOrdering.pagerank_ordering(edge_index, num_nodes)
157
-
158
- # Process local scale
159
- x_local = x[local_order].unsqueeze(0)
160
- for layer in self.local_mamba:
161
- x_local = x_local + layer(x_local)
162
- x_local = x_local.squeeze(0)
163
-
164
- # Process global scale
165
- x_global = x[global_order].unsqueeze(0)
166
- for layer in self.global_mamba:
167
- x_global = x_global + layer(x_global)
168
- x_global = x_global.squeeze(0)
169
-
170
- # Restore original order
171
- local_restored = torch.zeros_like(x_local)
172
- global_restored = torch.zeros_like(x_global)
173
-
174
- local_restored[local_order] = x_local
175
- global_restored[global_order] = x_global
176
 
177
- # Fuse scales
178
- fused = torch.cat([local_restored, global_restored], dim=-1)
179
- return self.layer_norm(self.scale_fusion(fused))
180
 
181
 
182
  class GraphMamba(nn.Module):
183
- """Enhanced GraphMamba with accuracy improvements"""
184
  def __init__(self, config):
185
  super().__init__()
186
 
187
  self.config = config
188
- d_model = config['model']['d_model']
189
- n_layers = config['model']['n_layers']
190
- self.ordering_strategy = config['ordering']['strategy']
191
-
192
- # Input projection
193
- self.input_proj = nn.Linear(config.get('input_dim', 1433), d_model)
194
 
195
- # Structural encoding
196
- self.structural_encoding = StructuralEncoding(d_model)
 
197
 
198
- # Multi-scale processing
199
- self.multi_scale = MultiScaleGraphMamba(d_model, n_layers)
200
 
201
- # Additional Mamba layers
202
  self.mamba_layers = nn.ModuleList([
203
- MambaBlock(d_model) for _ in range(max(1, n_layers - 2))
204
  ])
205
 
206
- # Layer norms
207
  self.layer_norms = nn.ModuleList([
208
- nn.LayerNorm(d_model) for _ in range(len(self.mamba_layers))
209
  ])
210
 
211
- # Output projection
 
 
 
212
  self.output_proj = nn.Linear(d_model, d_model)
213
- self.dropout = nn.Dropout(config['model']['dropout'])
214
 
215
- # For node classification
 
 
 
216
  self.classifier = None
217
 
218
- def _get_ordering(self, edge_index, num_nodes):
219
- """Get node ordering based on strategy"""
220
- if self.ordering_strategy == 'pagerank':
221
- return EnhancedGraphOrdering.pagerank_ordering(edge_index, num_nodes)
222
- elif self.ordering_strategy == 'community':
223
- return EnhancedGraphOrdering.community_aware_ordering(edge_index, num_nodes)
224
- elif self.ordering_strategy == 'spectral':
225
- return self._spectral_ordering(edge_index, num_nodes)
226
- else: # BFS default
227
- return torch.arange(num_nodes, dtype=torch.long)
228
-
229
- def _spectral_ordering(self, edge_index, num_nodes):
230
- """Spectral ordering with fallback"""
231
- try:
232
- from torch_geometric.utils import get_laplacian
233
- edge_index_lap, edge_weight = get_laplacian(edge_index, num_nodes=num_nodes)
234
-
235
- # Simple degree-based approximation
236
- degrees = degree(edge_index[0], num_nodes)
237
- return torch.argsort(degrees, descending=True)
238
- except:
239
- return torch.arange(num_nodes, dtype=torch.long)
240
-
241
  def forward(self, x, edge_index, batch=None):
242
- """Enhanced forward pass"""
243
- # Input projection
244
- h = self.input_proj(x)
 
245
 
246
- # Add structural information
247
- h = self.structural_encoding(h, edge_index)
248
 
249
- # Multi-scale processing
250
- h = self.multi_scale(h, edge_index)
251
 
252
- # Additional sequential processing
253
- order = self._get_ordering(edge_index, h.size(0))
254
  h_ordered = h[order].unsqueeze(0)
255
 
256
- for mamba, ln in zip(self.mamba_layers, self.layer_norms):
 
257
  residual = h_ordered
258
  h_ordered = ln(h_ordered)
259
- h_ordered = residual + self.dropout(mamba(h_ordered))
 
260
 
261
- # Restore original order
262
- h_restored = torch.zeros_like(h_ordered.squeeze(0))
263
- h_restored[order] = h_ordered.squeeze(0)
264
 
265
- return self.output_proj(h_restored)
266
 
267
  def _init_classifier(self, num_classes, device):
268
- """Initialize classifier head"""
269
  if self.classifier is None:
270
- self.classifier = nn.Linear(self.config['model']['d_model'], num_classes).to(device)
 
 
 
271
 
272
  def get_performance_stats(self):
273
- """Get model performance statistics"""
274
  total_params = sum(p.numel() for p in self.parameters())
275
  return {
276
  'total_params': total_params,
277
  'device': next(self.parameters()).device,
278
  'dtype': next(self.parameters()).dtype,
279
- 'ordering_strategy': self.ordering_strategy
280
  }
281
 
282
 
283
- class HybridGraphMamba(nn.Module):
284
- """Hybrid approach with minimal GNN overhead"""
285
- def __init__(self, config):
286
- super().__init__()
287
- from torch_geometric.nn import GCNConv
288
-
289
- d_model = config['model']['d_model']
290
- self.graph_mamba = GraphMamba(config)
291
- self.gcn = GCNConv(d_model, d_model)
292
- self.gate = nn.Linear(d_model, 1)
293
- self.fusion = nn.Linear(d_model * 2, d_model)
294
-
295
- def forward(self, x, edge_index, batch=None):
296
- # Get both representations
297
- mamba_out = self.graph_mamba(x, edge_index, batch)
298
- gcn_out = self.gcn(mamba_out, edge_index)
299
-
300
- # Learned fusion
301
- gate_weight = torch.sigmoid(self.gate(mamba_out))
302
- weighted = gate_weight * mamba_out + (1 - gate_weight) * gcn_out
303
-
304
- # Final fusion
305
- combined = torch.cat([mamba_out, weighted], dim=-1)
306
- return self.fusion(combined)
307
-
308
- def _init_classifier(self, num_classes, device):
309
- """Initialize classifier for hybrid model"""
310
- if not hasattr(self, 'classifier') or self.classifier is None:
311
- self.classifier = nn.Linear(self.config['model']['d_model'], num_classes).to(device)
312
-
313
- def get_performance_stats(self):
314
- """Get hybrid model stats"""
315
- return self.graph_mamba.get_performance_stats()
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from torch_geometric.utils import degree
5
  import networkx as nx
 
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
  class MambaBlock(nn.Module):
11
+ """Heavily regularized Mamba block"""
12
+ def __init__(self, d_model, d_state=4, d_conv=4, expand=2):
13
  super().__init__()
14
  self.d_model = d_model
15
  self.d_inner = int(expand * d_model)
 
26
  self.D = nn.Parameter(torch.ones(self.d_inner))
27
  self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
28
 
29
+ # Heavy regularization
30
+ self.dropout = nn.Dropout(0.3)
31
+
32
  def forward(self, x):
33
  batch, length, d_model = x.shape
34
  xz = self.in_proj(x)
 
38
  x = self.conv1d(x)[:, :, :length]
39
  x = x.transpose(1, 2)
40
  x = self.act(x)
41
+ x = self.dropout(x)
42
 
43
  y = self.selective_scan(x)
44
  y = y * self.act(z)
45
+ return self.dropout(self.out_proj(y))
46
 
47
  def selective_scan(self, x):
48
  batch, length, d_inner = x.shape
 
64
  return torch.stack(outputs, dim=1)
65
 
66
 
67
+ class GraphDataAugmentation:
68
+ """Data augmentation to combat overfitting"""
69
 
70
  @staticmethod
71
+ def augment_features(x, noise_level=0.1, dropout_prob=0.2):
72
+ if x.size(0) == 0:
73
+ return x
74
+ # Feature noise
75
+ noise = torch.randn_like(x) * noise_level
76
+ x_aug = x + noise
77
+
78
+ # Feature dropout
79
+ mask = torch.rand(x.shape[0], x.shape[1], device=x.device) > dropout_prob
80
+ x_aug = x_aug * mask.float()
81
+
82
+ return x_aug
 
 
83
 
84
  @staticmethod
85
+ def augment_edges(edge_index, drop_prob=0.1):
86
+ if edge_index.size(1) == 0:
87
+ return edge_index
88
+ # Edge dropout
89
+ edge_mask = torch.rand(edge_index.size(1), device=edge_index.device) > drop_prob
90
+ return edge_index[:, edge_mask]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
+ class LightStructuralEncoding(nn.Module):
94
+ """Lightweight structural encoding"""
95
+ def __init__(self, d_model, max_degree=50):
96
  super().__init__()
 
97
  self.degree_encoding = nn.Embedding(max_degree, d_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  self.layer_norm = nn.LayerNorm(d_model)
99
+ self.dropout = nn.Dropout(0.5)
100
 
101
  def forward(self, x, edge_index):
102
  num_nodes = x.size(0)
103
 
104
+ # Only degree encoding (simpler)
105
+ degrees = degree(edge_index[0], num_nodes).long().clamp(max=49)
106
+ degree_emb = self.degree_encoding(degrees)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # Combine with heavy dropout
109
+ combined = self.layer_norm(x + degree_emb)
110
+ return self.dropout(combined)
111
 
112
 
113
  class GraphMamba(nn.Module):
114
+ """Heavily regularized GraphMamba to prevent overfitting"""
115
  def __init__(self, config):
116
  super().__init__()
117
 
118
  self.config = config
119
+ d_model = config['model']['d_model'] # Should be 64
120
+ n_layers = config['model']['n_layers'] # Should be 2
121
+ input_dim = config.get('input_dim', 1433)
 
 
 
122
 
123
+ # Minimal architecture
124
+ self.input_proj = nn.Linear(input_dim, d_model)
125
+ self.input_dropout = nn.Dropout(0.5)
126
 
127
+ # Light structural encoding
128
+ self.structural_encoding = LightStructuralEncoding(d_model)
129
 
130
+ # Minimal Mamba layers
131
  self.mamba_layers = nn.ModuleList([
132
+ MambaBlock(d_model, d_state=4) for _ in range(n_layers)
133
  ])
134
 
135
+ # Layer norms with dropout
136
  self.layer_norms = nn.ModuleList([
137
+ nn.LayerNorm(d_model) for _ in range(n_layers)
138
  ])
139
 
140
+ self.hidden_dropout = nn.Dropout(0.5)
141
+ self.output_dropout = nn.Dropout(0.3)
142
+
143
+ # Simple output
144
  self.output_proj = nn.Linear(d_model, d_model)
 
145
 
146
+ # Data augmentation
147
+ self.augmentation = GraphDataAugmentation()
148
+
149
+ # Classifier will be added later
150
  self.classifier = None
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def forward(self, x, edge_index, batch=None):
153
+ # Apply data augmentation during training
154
+ if self.training:
155
+ x = self.augmentation.augment_features(x)
156
+ edge_index = self.augmentation.augment_edges(edge_index)
157
 
158
+ # Input projection with dropout
159
+ h = self.input_dropout(self.input_proj(x))
160
 
161
+ # Add minimal structural information
162
+ h = self.structural_encoding(h, edge_index)
163
 
164
+ # Simple BFS ordering only
165
+ order = torch.arange(h.size(0), device=h.device)
166
  h_ordered = h[order].unsqueeze(0)
167
 
168
+ # Process through minimal Mamba layers
169
+ for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)):
170
  residual = h_ordered
171
  h_ordered = ln(h_ordered)
172
+ h_ordered = residual + mamba(h_ordered)
173
+ h_ordered = self.hidden_dropout(h_ordered)
174
 
175
+ # Restore order and final processing
176
+ h_restored = h_ordered.squeeze(0)
177
+ h_out = self.output_dropout(self.output_proj(h_restored))
178
 
179
+ return h_out
180
 
181
  def _init_classifier(self, num_classes, device):
182
+ """Initialize heavily regularized classifier"""
183
  if self.classifier is None:
184
+ self.classifier = nn.Sequential(
185
+ nn.Dropout(0.5),
186
+ nn.Linear(self.config['model']['d_model'], num_classes)
187
+ ).to(device)
188
 
189
  def get_performance_stats(self):
190
+ """Get model statistics"""
191
  total_params = sum(p.numel() for p in self.parameters())
192
  return {
193
  'total_params': total_params,
194
  'device': next(self.parameters()).device,
195
  'dtype': next(self.parameters()).dtype,
196
+ 'model_size': f"{total_params/1000:.1f}K parameters"
197
  }
198
 
199
 
200
+ def create_regularized_config():
201
+ """Create config optimized for small training sets"""
202
+ return {
203
+ 'model': {
204
+ 'd_model': 64, # Reduced from 128
205
+ 'd_state': 4, # Reduced from 8
206
+ 'd_conv': 4,
207
+ 'expand': 2,
208
+ 'n_layers': 2, # Reduced from 3
209
+ 'dropout': 0.5 # Increased from 0.1
210
+ },
211
+ 'data': {
212
+ 'batch_size': 1, # Full batch for small datasets
213
+ 'test_split': 0.2
214
+ },
215
+ 'training': {
216
+ 'learning_rate': 0.0005, # Reduced from 0.001
217
+ 'weight_decay': 0.01, # High regularization
218
+ 'epochs': 200,
219
+ 'patience': 10, # More patient early stopping
220
+ 'warmup_epochs': 10,
221
+ 'min_lr': 1e-6
222
+ },
223
+ 'ordering': {
224
+ 'strategy': 'bfs', # Simple strategy only
225
+ 'preserve_locality': True
226
+ },
227
+ 'input_dim': 1433
228
+ }