kfoughali commited on
Commit
c681cda
·
verified ·
1 Parent(s): f3d5bea

Create core/graph_mamba.py

Browse files
Files changed (1) hide show
  1. core/graph_mamba.py +162 -0
core/graph_mamba.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .mamba_block import MambaBlock
4
+ from .graph_sequencer import GraphSequencer, PositionalEncoder
5
+
6
+ class GraphMamba(nn.Module):
7
+ """
8
+ Production Graph-Mamba model
9
+ Dynamically handles any graph size and structure
10
+ """
11
+
12
+ def __init__(self, config):
13
+ super().__init__()
14
+
15
+ self.config = config
16
+ self.d_model = config['model']['d_model']
17
+ self.n_layers = config['model']['n_layers']
18
+ self.dropout = config['model']['dropout']
19
+ self.ordering_strategy = config['ordering']['strategy']
20
+
21
+ # Input projection (dynamic input dimension)
22
+ self.input_proj = None # Will be initialized on first forward
23
+
24
+ # Positional encoding
25
+ self.pos_encoder = PositionalEncoder()
26
+ self.pos_embed = nn.Linear(11, self.d_model) # 1 + 10 distances
27
+
28
+ # Mamba layers
29
+ self.mamba_layers = nn.ModuleList([
30
+ MambaBlock(
31
+ d_model=self.d_model,
32
+ d_state=config['model']['d_state'],
33
+ d_conv=config['model']['d_conv'],
34
+ expand=config['model']['expand']
35
+ )
36
+ for _ in range(self.n_layers)
37
+ ])
38
+
39
+ # Layer norms
40
+ self.layer_norms = nn.ModuleList([
41
+ nn.LayerNorm(self.d_model)
42
+ for _ in range(self.n_layers)
43
+ ])
44
+
45
+ # Dropout
46
+ self.dropout_layer = nn.Dropout(self.dropout)
47
+
48
+ # Graph sequencer
49
+ self.sequencer = GraphSequencer()
50
+
51
+ def _init_input_proj(self, input_dim):
52
+ """Initialize input projection dynamically"""
53
+ if self.input_proj is None:
54
+ self.input_proj = nn.Linear(input_dim, self.d_model)
55
+
56
+ def forward(self, x, edge_index, batch=None):
57
+ """
58
+ Forward pass with dynamic graph handling
59
+
60
+ Args:
61
+ x: Node features (num_nodes, input_dim)
62
+ edge_index: Edge connectivity (2, num_edges)
63
+ batch: Batch assignment (num_nodes,) - optional
64
+ """
65
+ num_nodes = x.size(0)
66
+ input_dim = x.size(1)
67
+
68
+ # Initialize input projection if needed
69
+ self._init_input_proj(input_dim)
70
+
71
+ # Project input features
72
+ h = self.input_proj(x) # (num_nodes, d_model)
73
+
74
+ if batch is None:
75
+ # Single graph processing
76
+ h = self._process_single_graph(h, edge_index)
77
+ else:
78
+ # Batch processing
79
+ h = self._process_batch(h, edge_index, batch)
80
+
81
+ return h
82
+
83
+ def _process_single_graph(self, h, edge_index):
84
+ """Process a single graph"""
85
+ num_nodes = h.size(0)
86
+
87
+ # Get ordering
88
+ if self.ordering_strategy == "multi_view":
89
+ # Use BFS as primary for now (can be extended)
90
+ order = self.sequencer.bfs_ordering(edge_index, num_nodes)
91
+ elif self.ordering_strategy == "spectral":
92
+ order = self.sequencer.spectral_ordering(edge_index, num_nodes)
93
+ elif self.ordering_strategy == "degree":
94
+ order = self.sequencer.degree_ordering(edge_index, num_nodes)
95
+ else: # default to BFS
96
+ order = self.sequencer.bfs_ordering(edge_index, num_nodes)
97
+
98
+ # Add positional encoding
99
+ seq_pos, distances = self.pos_encoder.encode_positions(h, edge_index, order)
100
+ pos_features = torch.cat([seq_pos, distances], dim=1) # (num_nodes, 11)
101
+ pos_embed = self.pos_embed(pos_features)
102
+
103
+ # Reorder nodes for sequential processing
104
+ h_ordered = h[order] + pos_embed[order] # Add positional encoding
105
+ h_ordered = h_ordered.unsqueeze(0) # (1, num_nodes, d_model)
106
+
107
+ # Process through Mamba layers
108
+ for mamba, ln in zip(self.mamba_layers, self.layer_norms):
109
+ # Pre-norm residual connection
110
+ h_ordered = h_ordered + self.dropout_layer(mamba(ln(h_ordered)))
111
+
112
+ # Restore original order
113
+ h_out = h_ordered.squeeze(0) # (num_nodes, d_model)
114
+
115
+ # Create inverse mapping
116
+ inverse_order = torch.argsort(order)
117
+ h_final = h_out[inverse_order]
118
+
119
+ return h_final
120
+
121
+ def _process_batch(self, h, edge_index, batch):
122
+ """Process batched graphs"""
123
+ batch_size = batch.max().item() + 1
124
+ outputs = []
125
+
126
+ for b in range(batch_size):
127
+ # Extract subgraph
128
+ mask = batch == b
129
+ batch_h = h[mask]
130
+
131
+ # Get edges for this graph
132
+ edge_mask = mask[edge_index[0]] & mask[edge_index[1]]
133
+ batch_edges = edge_index[:, edge_mask]
134
+
135
+ # Reindex edges to local indices
136
+ node_indices = torch.where(mask)[0]
137
+ node_map = torch.zeros(h.size(0), dtype=torch.long, device=h.device)
138
+ node_map[node_indices] = torch.arange(batch_h.size(0), device=h.device)
139
+ batch_edges_local = node_map[batch_edges]
140
+
141
+ # Process subgraph
142
+ batch_output = self._process_single_graph(batch_h, batch_edges_local)
143
+ outputs.append(batch_output)
144
+
145
+ # Reconstruct full batch
146
+ h_out = torch.zeros_like(h)
147
+ start_idx = 0
148
+ for b, output in enumerate(outputs):
149
+ mask = batch == b
150
+ h_out[mask] = output
151
+
152
+ return h_out
153
+
154
+ def get_graph_embedding(self, h, batch=None):
155
+ """Get graph-level representation"""
156
+ if batch is None:
157
+ # Single graph - mean pooling
158
+ return h.mean(dim=0, keepdim=True)
159
+ else:
160
+ # Batched graphs
161
+ from torch_geometric.nn import global_mean_pool
162
+ return global_mean_pool(h, batch)