kfoughali commited on
Commit
e4d5cc2
·
verified ·
1 Parent(s): 8e24e05

Update core/graph_mamba.py

Browse files
Files changed (1) hide show
  1. core/graph_mamba.py +81 -36
core/graph_mamba.py CHANGED
@@ -5,8 +5,7 @@ from .graph_sequencer import GraphSequencer, PositionalEncoder
5
 
6
  class GraphMamba(nn.Module):
7
  """
8
- Production Graph-Mamba model
9
- Device-safe implementation with dynamic handling
10
  """
11
 
12
  def __init__(self, config):
@@ -19,13 +18,13 @@ class GraphMamba(nn.Module):
19
  self.ordering_strategy = config['ordering']['strategy']
20
 
21
  # Input projection (dynamic input dimension)
22
- self.input_proj = None # Will be initialized on first forward
23
 
24
  # Positional encoding
25
  self.pos_encoder = PositionalEncoder()
26
- self.pos_embed = nn.Linear(11, self.d_model) # 1 + 10 distances
27
 
28
- # Mamba layers
29
  self.mamba_layers = nn.ModuleList([
30
  MambaBlock(
31
  d_model=self.d_model,
@@ -48,27 +47,36 @@ class GraphMamba(nn.Module):
48
  # Graph sequencer
49
  self.sequencer = GraphSequencer()
50
 
51
- # Classification head (for demo)
52
  self.classifier = None
53
 
 
 
 
54
  def _init_input_proj(self, input_dim, device):
55
  """Initialize input projection dynamically"""
56
  if self.input_proj is None:
57
- self.input_proj = nn.Linear(input_dim, self.d_model).to(device)
 
 
 
 
 
58
 
59
  def _init_classifier(self, num_classes, device):
60
  """Initialize classifier dynamically"""
61
  if self.classifier is None:
62
- self.classifier = nn.Linear(self.d_model, num_classes).to(device)
 
 
 
 
 
 
63
 
64
  def forward(self, x, edge_index, batch=None):
65
  """
66
- Forward pass with device-safe handling
67
-
68
- Args:
69
- x: Node features (num_nodes, input_dim)
70
- edge_index: Edge connectivity (2, num_edges)
71
- batch: Batch assignment (num_nodes,) - optional
72
  """
73
  num_nodes = x.size(0)
74
  input_dim = x.size(1)
@@ -93,22 +101,31 @@ class GraphMamba(nn.Module):
93
  return h
94
 
95
  def _process_single_graph(self, h, edge_index):
96
- """Process a single graph - device safe"""
97
  num_nodes = h.size(0)
98
  device = h.device
99
 
100
  # Ensure edge_index is on correct device
101
  edge_index = edge_index.to(device)
102
 
103
- # Get ordering
104
- if self.ordering_strategy == "spectral":
105
- order = self.sequencer.spectral_ordering(edge_index, num_nodes)
106
- elif self.ordering_strategy == "degree":
107
- order = self.sequencer.degree_ordering(edge_index, num_nodes)
108
- elif self.ordering_strategy == "community":
109
- order = self.sequencer.community_ordering(edge_index, num_nodes)
110
- else: # default to BFS
111
- order = self.sequencer.bfs_ordering(edge_index, num_nodes)
 
 
 
 
 
 
 
 
 
112
 
113
  # Ensure order is on correct device
114
  order = order.to(device)
@@ -125,10 +142,17 @@ class GraphMamba(nn.Module):
125
  h_ordered = h[order] + pos_embed[order] # Add positional encoding
126
  h_ordered = h_ordered.unsqueeze(0) # (1, num_nodes, d_model)
127
 
128
- # Process through Mamba layers
129
- for mamba, ln in zip(self.mamba_layers, self.layer_norms):
130
- # Pre-norm residual connection
131
- h_ordered = h_ordered + self.dropout_layer(mamba(ln(h_ordered)))
 
 
 
 
 
 
 
132
 
133
  # Restore original order
134
  h_out = h_ordered.squeeze(0) # (num_nodes, d_model)
@@ -140,7 +164,7 @@ class GraphMamba(nn.Module):
140
  return h_final
141
 
142
  def _process_batch(self, h, edge_index, batch):
143
- """Process batched graphs - device safe"""
144
  device = h.device
145
  batch = batch.to(device)
146
  edge_index = edge_index.to(device)
@@ -180,12 +204,19 @@ class GraphMamba(nn.Module):
180
  return h_out
181
 
182
  def get_graph_embedding(self, h, batch=None):
183
- """Get graph-level representation"""
184
  if batch is None:
185
- # Single graph - mean pooling
186
- return h.mean(dim=0, keepdim=True)
 
 
 
 
 
 
 
187
  else:
188
- # Batched graphs - manual pooling to avoid dependencies
189
  device = h.device
190
  batch = batch.to(device)
191
  batch_size = batch.max().item() + 1
@@ -194,9 +225,23 @@ class GraphMamba(nn.Module):
194
  for b in range(batch_size):
195
  mask = batch == b
196
  if mask.any():
197
- graph_emb = h[mask].mean(dim=0)
 
 
 
 
 
 
 
 
 
198
  graph_embeddings.append(graph_emb)
199
  else:
200
- graph_embeddings.append(torch.zeros(h.size(1), device=device))
 
201
 
202
- return torch.stack(graph_embeddings)
 
 
 
 
 
5
 
6
  class GraphMamba(nn.Module):
7
  """
8
+ Production Graph-Mamba model with training optimizations
 
9
  """
10
 
11
  def __init__(self, config):
 
18
  self.ordering_strategy = config['ordering']['strategy']
19
 
20
  # Input projection (dynamic input dimension)
21
+ self.input_proj = None
22
 
23
  # Positional encoding
24
  self.pos_encoder = PositionalEncoder()
25
+ self.pos_embed = nn.Linear(11, self.d_model)
26
 
27
+ # Mamba layers with residual connections
28
  self.mamba_layers = nn.ModuleList([
29
  MambaBlock(
30
  d_model=self.d_model,
 
47
  # Graph sequencer
48
  self.sequencer = GraphSequencer()
49
 
50
+ # Classification head (initialized dynamically)
51
  self.classifier = None
52
 
53
+ # Cache for efficiency
54
+ self._cache = {}
55
+
56
  def _init_input_proj(self, input_dim, device):
57
  """Initialize input projection dynamically"""
58
  if self.input_proj is None:
59
+ self.input_proj = nn.Sequential(
60
+ nn.Linear(input_dim, self.d_model),
61
+ nn.LayerNorm(self.d_model),
62
+ nn.ReLU(),
63
+ nn.Dropout(self.dropout * 0.5)
64
+ ).to(device)
65
 
66
  def _init_classifier(self, num_classes, device):
67
  """Initialize classifier dynamically"""
68
  if self.classifier is None:
69
+ self.classifier = nn.Sequential(
70
+ nn.Linear(self.d_model, self.d_model // 2),
71
+ nn.LayerNorm(self.d_model // 2),
72
+ nn.ReLU(),
73
+ nn.Dropout(self.dropout),
74
+ nn.Linear(self.d_model // 2, num_classes)
75
+ ).to(device)
76
 
77
  def forward(self, x, edge_index, batch=None):
78
  """
79
+ Forward pass with training optimizations
 
 
 
 
 
80
  """
81
  num_nodes = x.size(0)
82
  input_dim = x.size(1)
 
101
  return h
102
 
103
  def _process_single_graph(self, h, edge_index):
104
+ """Process a single graph with caching"""
105
  num_nodes = h.size(0)
106
  device = h.device
107
 
108
  # Ensure edge_index is on correct device
109
  edge_index = edge_index.to(device)
110
 
111
+ # Cache key for ordering
112
+ cache_key = f"{self.ordering_strategy}_{num_nodes}_{edge_index.shape[1]}"
113
+
114
+ # Get ordering (with caching during training)
115
+ if cache_key not in self._cache or not self.training:
116
+ if self.ordering_strategy == "spectral":
117
+ order = self.sequencer.spectral_ordering(edge_index, num_nodes)
118
+ elif self.ordering_strategy == "degree":
119
+ order = self.sequencer.degree_ordering(edge_index, num_nodes)
120
+ elif self.ordering_strategy == "community":
121
+ order = self.sequencer.community_ordering(edge_index, num_nodes)
122
+ else: # default to BFS
123
+ order = self.sequencer.bfs_ordering(edge_index, num_nodes)
124
+
125
+ if self.training:
126
+ self._cache[cache_key] = order
127
+ else:
128
+ order = self._cache[cache_key]
129
 
130
  # Ensure order is on correct device
131
  order = order.to(device)
 
142
  h_ordered = h[order] + pos_embed[order] # Add positional encoding
143
  h_ordered = h_ordered.unsqueeze(0) # (1, num_nodes, d_model)
144
 
145
+ # Process through Mamba layers with residual connections
146
+ for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)):
147
+ # Pre-norm residual connection with gradient scaling
148
+ residual = h_ordered
149
+ h_ordered = ln(h_ordered)
150
+ h_ordered = mamba(h_ordered)
151
+ h_ordered = residual + self.dropout_layer(h_ordered)
152
+
153
+ # Layer-wise learning rate scaling
154
+ if self.training:
155
+ h_ordered = h_ordered * (1.0 - 0.1 * i / self.n_layers)
156
 
157
  # Restore original order
158
  h_out = h_ordered.squeeze(0) # (num_nodes, d_model)
 
164
  return h_final
165
 
166
  def _process_batch(self, h, edge_index, batch):
167
+ """Process batched graphs efficiently"""
168
  device = h.device
169
  batch = batch.to(device)
170
  edge_index = edge_index.to(device)
 
204
  return h_out
205
 
206
  def get_graph_embedding(self, h, batch=None):
207
+ """Get graph-level representation with multiple pooling"""
208
  if batch is None:
209
+ # Single graph - multiple pooling strategies
210
+ mean_pool = h.mean(dim=0, keepdim=True)
211
+ max_pool = h.max(dim=0)[0].unsqueeze(0)
212
+
213
+ # Attention pooling
214
+ attn_weights = torch.softmax(h.sum(dim=1), dim=0)
215
+ attn_pool = (h * attn_weights.unsqueeze(1)).sum(dim=0, keepdim=True)
216
+
217
+ return torch.cat([mean_pool, max_pool, attn_pool], dim=1)
218
  else:
219
+ # Batched graphs
220
  device = h.device
221
  batch = batch.to(device)
222
  batch_size = batch.max().item() + 1
 
225
  for b in range(batch_size):
226
  mask = batch == b
227
  if mask.any():
228
+ batch_h = h[mask]
229
+
230
+ # Multiple pooling for this graph
231
+ mean_pool = batch_h.mean(dim=0)
232
+ max_pool = batch_h.max(dim=0)[0]
233
+
234
+ attn_weights = torch.softmax(batch_h.sum(dim=1), dim=0)
235
+ attn_pool = (batch_h * attn_weights.unsqueeze(1)).sum(dim=0)
236
+
237
+ graph_emb = torch.cat([mean_pool, max_pool, attn_pool])
238
  graph_embeddings.append(graph_emb)
239
  else:
240
+ # Empty graph
241
+ graph_embeddings.append(torch.zeros(h.size(1) * 3, device=device))
242
 
243
+ return torch.stack(graph_embeddings)
244
+
245
+ def clear_cache(self):
246
+ """Clear ordering cache"""
247
+ self._cache.clear()