kfoughali commited on
Commit
b09f924
·
verified ·
1 Parent(s): 617d132

Update core/graph_mamba.py

Browse files
Files changed (1) hide show
  1. core/graph_mamba.py +277 -209
core/graph_mamba.py CHANGED
@@ -1,247 +1,315 @@
1
  import torch
2
  import torch.nn as nn
3
- from .mamba_block import MambaBlock
4
- from .graph_sequencer import GraphSequencer, PositionalEncoder
 
 
 
5
 
6
- class GraphMamba(nn.Module):
7
- """
8
- Production Graph-Mamba model with training optimizations
9
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def __init__(self, config):
12
  super().__init__()
13
 
14
  self.config = config
15
- self.d_model = config['model']['d_model']
16
- self.n_layers = config['model']['n_layers']
17
- self.dropout = config['model']['dropout']
18
  self.ordering_strategy = config['ordering']['strategy']
19
 
20
- # Input projection (dynamic input dimension)
21
- self.input_proj = None
 
 
 
22
 
23
- # Positional encoding
24
- self.pos_encoder = PositionalEncoder()
25
- self.pos_embed = nn.Linear(11, self.d_model)
26
 
27
- # Mamba layers with residual connections
28
  self.mamba_layers = nn.ModuleList([
29
- MambaBlock(
30
- d_model=self.d_model,
31
- d_state=config['model']['d_state'],
32
- d_conv=config['model']['d_conv'],
33
- expand=config['model']['expand']
34
- )
35
- for _ in range(self.n_layers)
36
  ])
37
 
38
  # Layer norms
39
  self.layer_norms = nn.ModuleList([
40
- nn.LayerNorm(self.d_model)
41
- for _ in range(self.n_layers)
42
  ])
43
 
44
- # Dropout
45
- self.dropout_layer = nn.Dropout(self.dropout)
46
-
47
- # Graph sequencer
48
- self.sequencer = GraphSequencer()
49
 
50
- # Classification head (initialized dynamically)
51
  self.classifier = None
52
 
53
- # Cache for efficiency
54
- self._cache = {}
55
-
56
- def _init_input_proj(self, input_dim, device):
57
- """Initialize input projection dynamically"""
58
- if self.input_proj is None:
59
- self.input_proj = nn.Sequential(
60
- nn.Linear(input_dim, self.d_model),
61
- nn.LayerNorm(self.d_model),
62
- nn.ReLU(),
63
- nn.Dropout(self.dropout * 0.5)
64
- ).to(device)
65
-
66
- def _init_classifier(self, num_classes, device):
67
- """Initialize classifier dynamically"""
68
- if self.classifier is None:
69
- self.classifier = nn.Sequential(
70
- nn.Linear(self.d_model, self.d_model // 2),
71
- nn.LayerNorm(self.d_model // 2),
72
- nn.ReLU(),
73
- nn.Dropout(self.dropout),
74
- nn.Linear(self.d_model // 2, num_classes)
75
- ).to(device)
76
-
77
- def forward(self, x, edge_index, batch=None):
78
- """
79
- Forward pass with training optimizations
80
- """
81
- num_nodes = x.size(0)
82
- input_dim = x.size(1)
83
- device = x.device
84
-
85
- # Move all components to correct device
86
- self.to(device)
87
-
88
- # Initialize input projection if needed
89
- self._init_input_proj(input_dim, device)
90
-
91
- # Project input features
92
- h = self.input_proj(x) # (num_nodes, d_model)
93
-
94
- if batch is None:
95
- # Single graph processing
96
- h = self._process_single_graph(h, edge_index)
97
- else:
98
- # Batch processing
99
- h = self._process_batch(h, edge_index, batch)
100
-
101
- return h
102
 
103
- def _process_single_graph(self, h, edge_index):
104
- """Process a single graph with caching"""
105
- num_nodes = h.size(0)
106
- device = h.device
107
-
108
- # Ensure edge_index is on correct device
109
- edge_index = edge_index.to(device)
110
-
111
- # Cache key for ordering
112
- cache_key = f"{self.ordering_strategy}_{num_nodes}_{edge_index.shape[1]}"
113
-
114
- # Get ordering (with caching during training)
115
- if cache_key not in self._cache or not self.training:
116
- if self.ordering_strategy == "spectral":
117
- order = self.sequencer.spectral_ordering(edge_index, num_nodes)
118
- elif self.ordering_strategy == "degree":
119
- order = self.sequencer.degree_ordering(edge_index, num_nodes)
120
- elif self.ordering_strategy == "community":
121
- order = self.sequencer.community_ordering(edge_index, num_nodes)
122
- else: # default to BFS
123
- order = self.sequencer.bfs_ordering(edge_index, num_nodes)
124
 
125
- if self.training:
126
- self._cache[cache_key] = order
127
- else:
128
- order = self._cache[cache_key]
129
-
130
- # Ensure order is on correct device
131
- order = order.to(device)
 
 
 
132
 
133
- # Add positional encoding
134
- seq_pos, distances = self.pos_encoder.encode_positions(h, edge_index, order)
135
- seq_pos = seq_pos.to(device)
136
- distances = distances.to(device)
137
 
138
- pos_features = torch.cat([seq_pos, distances], dim=1) # (num_nodes, 11)
139
- pos_embed = self.pos_embed(pos_features)
140
 
141
- # Reorder nodes for sequential processing
142
- h_ordered = h[order] + pos_embed[order] # Add positional encoding
143
- h_ordered = h_ordered.unsqueeze(0) # (1, num_nodes, d_model)
144
 
145
- # Process through Mamba layers with residual connections
146
- for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)):
147
- # Pre-norm residual connection with gradient scaling
148
  residual = h_ordered
149
  h_ordered = ln(h_ordered)
150
- h_ordered = mamba(h_ordered)
151
- h_ordered = residual + self.dropout_layer(h_ordered)
152
-
153
- # Layer-wise learning rate scaling
154
- if self.training:
155
- h_ordered = h_ordered * (1.0 - 0.1 * i / self.n_layers)
156
 
157
  # Restore original order
158
- h_out = h_ordered.squeeze(0) # (num_nodes, d_model)
 
159
 
160
- # Create inverse mapping
161
- inverse_order = torch.argsort(order)
162
- h_final = h_out[inverse_order]
163
-
164
- return h_final
 
165
 
166
- def _process_batch(self, h, edge_index, batch):
167
- """Process batched graphs efficiently"""
168
- device = h.device
169
- batch = batch.to(device)
170
- edge_index = edge_index.to(device)
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- batch_size = batch.max().item() + 1
173
- outputs = []
 
 
 
174
 
175
- for b in range(batch_size):
176
- # Extract subgraph
177
- mask = batch == b
178
- batch_h = h[mask]
179
-
180
- # Get edges for this graph
181
- edge_mask = mask[edge_index[0]] & mask[edge_index[1]]
182
- batch_edges = edge_index[:, edge_mask]
183
-
184
- if batch_edges.shape[1] > 0:
185
- # Reindex edges to local indices
186
- node_indices = torch.where(mask)[0]
187
- node_map = torch.zeros(h.size(0), dtype=torch.long, device=device)
188
- node_map[node_indices] = torch.arange(batch_h.size(0), device=device)
189
- batch_edges_local = node_map[batch_edges]
190
- else:
191
- # Empty graph
192
- batch_edges_local = torch.empty((2, 0), dtype=torch.long, device=device)
193
-
194
- # Process subgraph
195
- batch_output = self._process_single_graph(batch_h, batch_edges_local)
196
- outputs.append(batch_output)
197
-
198
- # Reconstruct full batch
199
- h_out = torch.zeros_like(h)
200
- for b, output in enumerate(outputs):
201
- mask = batch == b
202
- h_out[mask] = output
203
-
204
- return h_out
205
 
206
- def get_graph_embedding(self, h, batch=None):
207
- """Get graph-level representation with multiple pooling"""
208
- if batch is None:
209
- # Single graph - multiple pooling strategies
210
- mean_pool = h.mean(dim=0, keepdim=True)
211
- max_pool = h.max(dim=0)[0].unsqueeze(0)
212
-
213
- # Attention pooling
214
- attn_weights = torch.softmax(h.sum(dim=1), dim=0)
215
- attn_pool = (h * attn_weights.unsqueeze(1)).sum(dim=0, keepdim=True)
216
-
217
- return torch.cat([mean_pool, max_pool, attn_pool], dim=1)
218
- else:
219
- # Batched graphs
220
- device = h.device
221
- batch = batch.to(device)
222
- batch_size = batch.max().item() + 1
223
-
224
- graph_embeddings = []
225
- for b in range(batch_size):
226
- mask = batch == b
227
- if mask.any():
228
- batch_h = h[mask]
229
-
230
- # Multiple pooling for this graph
231
- mean_pool = batch_h.mean(dim=0)
232
- max_pool = batch_h.max(dim=0)[0]
233
-
234
- attn_weights = torch.softmax(batch_h.sum(dim=1), dim=0)
235
- attn_pool = (batch_h * attn_weights.unsqueeze(1)).sum(dim=0)
236
-
237
- graph_emb = torch.cat([mean_pool, max_pool, attn_pool])
238
- graph_embeddings.append(graph_emb)
239
- else:
240
- # Empty graph
241
- graph_embeddings.append(torch.zeros(h.size(1) * 3, device=device))
242
-
243
- return torch.stack(graph_embeddings)
244
 
245
- def clear_cache(self):
246
- """Clear ordering cache"""
247
- self._cache.clear()
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch_geometric.utils import degree, to_dense_batch
5
+ import networkx as nx
6
+ import numpy as np
7
+ import logging
8
 
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class MambaBlock(nn.Module):
12
+ """Enhanced Mamba block with optimizations"""
13
+ def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
14
+ super().__init__()
15
+ self.d_model = d_model
16
+ self.d_inner = int(expand * d_model)
17
+ self.d_state = d_state
18
+
19
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
20
+ self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1)
21
+ self.act = nn.SiLU()
22
+ self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
23
+ self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
24
+
25
+ A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1)
26
+ self.A_log = nn.Parameter(torch.log(A))
27
+ self.D = nn.Parameter(torch.ones(self.d_inner))
28
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
29
+
30
+ def forward(self, x):
31
+ batch, length, d_model = x.shape
32
+ xz = self.in_proj(x)
33
+ x, z = xz.chunk(2, dim=-1)
34
+
35
+ x = x.transpose(1, 2)
36
+ x = self.conv1d(x)[:, :, :length]
37
+ x = x.transpose(1, 2)
38
+ x = self.act(x)
39
+
40
+ y = self.selective_scan(x)
41
+ y = y * self.act(z)
42
+ return self.out_proj(y)
43
+
44
+ def selective_scan(self, x):
45
+ batch, length, d_inner = x.shape
46
+ deltaBC = self.x_proj(x)
47
+ delta, B, C = torch.split(deltaBC, [1, self.d_state, self.d_state], dim=-1)
48
+ delta = F.softplus(self.dt_proj(delta))
49
+
50
+ deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log)))
51
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
52
+
53
+ states = torch.zeros(batch, d_inner, self.d_state, device=x.device)
54
+ outputs = []
55
+
56
+ for i in range(length):
57
+ states = deltaA[:, i] * states + deltaB[:, i] * x[:, i, :, None]
58
+ y = (states @ C[:, i, :, None]).squeeze(-1) + self.D * x[:, i]
59
+ outputs.append(y)
60
+
61
+ return torch.stack(outputs, dim=1)
62
+
63
+
64
+ class EnhancedGraphOrdering:
65
+ """Advanced graph ordering strategies"""
66
 
67
+ @staticmethod
68
+ def pagerank_ordering(edge_index, num_nodes):
69
+ """PageRank-based ordering preserving importance"""
70
+ try:
71
+ G = nx.Graph()
72
+ if edge_index.size(1) > 0:
73
+ edges = edge_index.t().cpu().numpy()
74
+ G.add_edges_from(edges)
75
+ G.add_nodes_from(range(num_nodes))
76
+
77
+ pagerank = nx.pagerank(G, max_iter=50)
78
+ order = sorted(range(num_nodes), key=lambda x: pagerank.get(x, 0), reverse=True)
79
+ return torch.tensor(order, dtype=torch.long)
80
+ except:
81
+ return torch.arange(num_nodes, dtype=torch.long)
82
+
83
+ @staticmethod
84
+ def community_aware_ordering(edge_index, num_nodes):
85
+ """Community-preserving ordering"""
86
+ try:
87
+ G = nx.Graph()
88
+ if edge_index.size(1) > 0:
89
+ edges = edge_index.t().cpu().numpy()
90
+ G.add_edges_from(edges)
91
+ G.add_nodes_from(range(num_nodes))
92
+
93
+ communities = nx.community.greedy_modularity_communities(G)
94
+ order = []
95
+ for community in communities:
96
+ # Sort within community by degree
97
+ community_list = list(community)
98
+ degrees = {node: G.degree(node) for node in community_list}
99
+ community_sorted = sorted(community_list, key=lambda x: degrees[x], reverse=True)
100
+ order.extend(community_sorted)
101
+
102
+ return torch.tensor(order, dtype=torch.long)
103
+ except:
104
+ return torch.arange(num_nodes, dtype=torch.long)
105
+
106
+
107
+ class StructuralEncoding(nn.Module):
108
+ """Multi-faceted structural encoding"""
109
+ def __init__(self, d_model, max_nodes=5000, max_degree=100):
110
+ super().__init__()
111
+ self.pos_encoding = nn.Embedding(max_nodes, d_model)
112
+ self.degree_encoding = nn.Embedding(max_degree, d_model)
113
+ self.centrality_proj = nn.Linear(1, d_model)
114
+ self.layer_norm = nn.LayerNorm(d_model)
115
+
116
+ def forward(self, x, edge_index, node_order=None):
117
+ num_nodes = x.size(0)
118
+ device = x.device
119
+
120
+ # Position encoding
121
+ positions = torch.arange(num_nodes, device=device).clamp(max=self.pos_encoding.num_embeddings-1)
122
+ pos_emb = self.pos_encoding(positions)
123
+
124
+ # Degree encoding
125
+ degrees = degree(edge_index[0], num_nodes).long().clamp(max=self.degree_encoding.num_embeddings-1)
126
+ degree_emb = self.degree_encoding(degrees)
127
+
128
+ # Simple centrality (normalized degree)
129
+ centrality = degrees.float() / max(degrees.max().item(), 1.0)
130
+ centrality_emb = self.centrality_proj(centrality.unsqueeze(-1))
131
+
132
+ # Combine encodings
133
+ structural_emb = pos_emb + degree_emb + centrality_emb
134
+ return self.layer_norm(x + structural_emb)
135
+
136
+
137
+ class MultiScaleGraphMamba(nn.Module):
138
+ """Multi-scale processing with different orderings"""
139
+ def __init__(self, d_model, n_layers=3):
140
+ super().__init__()
141
+ self.d_model = d_model
142
+
143
+ # Different scale processors
144
+ self.local_mamba = nn.ModuleList([MambaBlock(d_model) for _ in range(n_layers//2)])
145
+ self.global_mamba = nn.ModuleList([MambaBlock(d_model) for _ in range(n_layers//2)])
146
+
147
+ # Fusion layers
148
+ self.scale_fusion = nn.Linear(d_model * 2, d_model)
149
+ self.layer_norm = nn.LayerNorm(d_model)
150
+
151
+ def forward(self, x, edge_index):
152
+ num_nodes = x.size(0)
153
+
154
+ # Different orderings
155
+ local_order = torch.arange(num_nodes) # BFS equivalent
156
+ global_order = EnhancedGraphOrdering.pagerank_ordering(edge_index, num_nodes)
157
+
158
+ # Process local scale
159
+ x_local = x[local_order].unsqueeze(0)
160
+ for layer in self.local_mamba:
161
+ x_local = x_local + layer(x_local)
162
+ x_local = x_local.squeeze(0)
163
+
164
+ # Process global scale
165
+ x_global = x[global_order].unsqueeze(0)
166
+ for layer in self.global_mamba:
167
+ x_global = x_global + layer(x_global)
168
+ x_global = x_global.squeeze(0)
169
+
170
+ # Restore original order
171
+ local_restored = torch.zeros_like(x_local)
172
+ global_restored = torch.zeros_like(x_global)
173
+
174
+ local_restored[local_order] = x_local
175
+ global_restored[global_order] = x_global
176
+
177
+ # Fuse scales
178
+ fused = torch.cat([local_restored, global_restored], dim=-1)
179
+ return self.layer_norm(self.scale_fusion(fused))
180
+
181
+
182
+ class GraphMamba(nn.Module):
183
+ """Enhanced GraphMamba with accuracy improvements"""
184
  def __init__(self, config):
185
  super().__init__()
186
 
187
  self.config = config
188
+ d_model = config['model']['d_model']
189
+ n_layers = config['model']['n_layers']
 
190
  self.ordering_strategy = config['ordering']['strategy']
191
 
192
+ # Input projection
193
+ self.input_proj = nn.Linear(config.get('input_dim', 1433), d_model)
194
+
195
+ # Structural encoding
196
+ self.structural_encoding = StructuralEncoding(d_model)
197
 
198
+ # Multi-scale processing
199
+ self.multi_scale = MultiScaleGraphMamba(d_model, n_layers)
 
200
 
201
+ # Additional Mamba layers
202
  self.mamba_layers = nn.ModuleList([
203
+ MambaBlock(d_model) for _ in range(max(1, n_layers - 2))
 
 
 
 
 
 
204
  ])
205
 
206
  # Layer norms
207
  self.layer_norms = nn.ModuleList([
208
+ nn.LayerNorm(d_model) for _ in range(len(self.mamba_layers))
 
209
  ])
210
 
211
+ # Output projection
212
+ self.output_proj = nn.Linear(d_model, d_model)
213
+ self.dropout = nn.Dropout(config['model']['dropout'])
 
 
214
 
215
+ # For node classification
216
  self.classifier = None
217
 
218
+ def _get_ordering(self, edge_index, num_nodes):
219
+ """Get node ordering based on strategy"""
220
+ if self.ordering_strategy == 'pagerank':
221
+ return EnhancedGraphOrdering.pagerank_ordering(edge_index, num_nodes)
222
+ elif self.ordering_strategy == 'community':
223
+ return EnhancedGraphOrdering.community_aware_ordering(edge_index, num_nodes)
224
+ elif self.ordering_strategy == 'spectral':
225
+ return self._spectral_ordering(edge_index, num_nodes)
226
+ else: # BFS default
227
+ return torch.arange(num_nodes, dtype=torch.long)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ def _spectral_ordering(self, edge_index, num_nodes):
230
+ """Spectral ordering with fallback"""
231
+ try:
232
+ from torch_geometric.utils import get_laplacian
233
+ edge_index_lap, edge_weight = get_laplacian(edge_index, num_nodes=num_nodes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ # Simple degree-based approximation
236
+ degrees = degree(edge_index[0], num_nodes)
237
+ return torch.argsort(degrees, descending=True)
238
+ except:
239
+ return torch.arange(num_nodes, dtype=torch.long)
240
+
241
+ def forward(self, x, edge_index, batch=None):
242
+ """Enhanced forward pass"""
243
+ # Input projection
244
+ h = self.input_proj(x)
245
 
246
+ # Add structural information
247
+ h = self.structural_encoding(h, edge_index)
 
 
248
 
249
+ # Multi-scale processing
250
+ h = self.multi_scale(h, edge_index)
251
 
252
+ # Additional sequential processing
253
+ order = self._get_ordering(edge_index, h.size(0))
254
+ h_ordered = h[order].unsqueeze(0)
255
 
256
+ for mamba, ln in zip(self.mamba_layers, self.layer_norms):
 
 
257
  residual = h_ordered
258
  h_ordered = ln(h_ordered)
259
+ h_ordered = residual + self.dropout(mamba(h_ordered))
 
 
 
 
 
260
 
261
  # Restore original order
262
+ h_restored = torch.zeros_like(h_ordered.squeeze(0))
263
+ h_restored[order] = h_ordered.squeeze(0)
264
 
265
+ return self.output_proj(h_restored)
266
+
267
+ def _init_classifier(self, num_classes, device):
268
+ """Initialize classifier head"""
269
+ if self.classifier is None:
270
+ self.classifier = nn.Linear(self.config['model']['d_model'], num_classes).to(device)
271
 
272
+ def get_performance_stats(self):
273
+ """Get model performance statistics"""
274
+ total_params = sum(p.numel() for p in self.parameters())
275
+ return {
276
+ 'total_params': total_params,
277
+ 'device': next(self.parameters()).device,
278
+ 'dtype': next(self.parameters()).dtype,
279
+ 'ordering_strategy': self.ordering_strategy
280
+ }
281
+
282
+
283
+ class HybridGraphMamba(nn.Module):
284
+ """Hybrid approach with minimal GNN overhead"""
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ from torch_geometric.nn import GCNConv
288
 
289
+ d_model = config['model']['d_model']
290
+ self.graph_mamba = GraphMamba(config)
291
+ self.gcn = GCNConv(d_model, d_model)
292
+ self.gate = nn.Linear(d_model, 1)
293
+ self.fusion = nn.Linear(d_model * 2, d_model)
294
 
295
+ def forward(self, x, edge_index, batch=None):
296
+ # Get both representations
297
+ mamba_out = self.graph_mamba(x, edge_index, batch)
298
+ gcn_out = self.gcn(mamba_out, edge_index)
299
+
300
+ # Learned fusion
301
+ gate_weight = torch.sigmoid(self.gate(mamba_out))
302
+ weighted = gate_weight * mamba_out + (1 - gate_weight) * gcn_out
303
+
304
+ # Final fusion
305
+ combined = torch.cat([mamba_out, weighted], dim=-1)
306
+ return self.fusion(combined)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
+ def _init_classifier(self, num_classes, device):
309
+ """Initialize classifier for hybrid model"""
310
+ if not hasattr(self, 'classifier') or self.classifier is None:
311
+ self.classifier = nn.Linear(self.config['model']['d_model'], num_classes).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
+ def get_performance_stats(self):
314
+ """Get hybrid model stats"""
315
+ return self.graph_mamba.get_performance_stats()