kfoughali commited on
Commit
991b7c0
·
verified ·
1 Parent(s): 929a171

Update core/graph_mamba.py

Browse files
Files changed (1) hide show
  1. core/graph_mamba.py +358 -0
core/graph_mamba.py CHANGED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch_geometric.utils import degree, to_dense_adj
5
+ from torch_geometric.nn import GCNConv
6
+ import math
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class CognitiveMomentumEngine(nn.Module):
12
+ """Core cognitive momentum system from the document"""
13
+ def __init__(self, d_model):
14
+ super().__init__()
15
+ self.d_model = d_model
16
+
17
+ # Momentum tracking
18
+ self.register_buffer('momentum_vectors', torch.zeros(d_model))
19
+ self.register_buffer('cognitive_mass', torch.ones(d_model))
20
+ self.register_buffer('kinetic_energy', torch.zeros(d_model))
21
+ self.register_buffer('potential_energy', torch.zeros(d_model))
22
+
23
+ # Field interactions
24
+ self.attraction_projection = nn.Linear(d_model, d_model)
25
+ self.repulsion_projection = nn.Linear(d_model, d_model)
26
+
27
+ # Crystallization threshold
28
+ self.crystallization_threshold = 0.1
29
+ self.memory_decay = 0.99
30
+
31
+ def update_momentum(self, concept_features, force, dt=0.1):
32
+ """Apply cognitive momentum physics"""
33
+ # F = ma => a = F/m
34
+ acceleration = force / (self.cognitive_mass + 1e-8)
35
+
36
+ # Update velocity: v = v₀ + at
37
+ current_velocity = self.momentum_vectors / (self.cognitive_mass + 1e-8)
38
+ new_velocity = current_velocity + acceleration * dt
39
+
40
+ # Update momentum: p = mv
41
+ self.momentum_vectors = self.cognitive_mass * new_velocity
42
+
43
+ # Update energy
44
+ self.kinetic_energy = 0.5 * self.cognitive_mass * (new_velocity ** 2)
45
+
46
+ return self.momentum_vectors
47
+
48
+ def crystallize_knowledge(self):
49
+ """Compress low-momentum concepts"""
50
+ low_momentum_mask = torch.abs(self.momentum_vectors) < self.crystallization_threshold
51
+
52
+ # Compress crystallized knowledge
53
+ crystallized_pattern = self.momentum_vectors[low_momentum_mask].mean()
54
+
55
+ # Reset crystallized components
56
+ self.momentum_vectors[low_momentum_mask] = crystallized_pattern * 0.1
57
+
58
+ return crystallized_pattern
59
+
60
+ def forward(self, x):
61
+ """Apply momentum to features"""
62
+ if x.dim() == 2:
63
+ x = x.unsqueeze(0)
64
+ batch_size, seq_len, d_model = x.shape
65
+
66
+ # Compute forces from feature interactions
67
+ attraction_force = self.attraction_projection(x)
68
+ repulsion_force = self.repulsion_projection(x)
69
+
70
+ # Net force
71
+ net_force = attraction_force - repulsion_force * 0.1
72
+
73
+ # Simple momentum application
74
+ momentum_enhanced = x + net_force * 0.1
75
+
76
+ # Crystallize periodically
77
+ if torch.rand(1) < 0.1:
78
+ self.crystallize_knowledge()
79
+
80
+ return momentum_enhanced
81
+
82
+ class AstrocyteLayer(nn.Module):
83
+ """Multi-timescale processing with momentum"""
84
+ def __init__(self, d_model, astrocyte_ratio=2.0):
85
+ super().__init__()
86
+ self.d_model = d_model
87
+ self.d_astrocyte = int(d_model * astrocyte_ratio)
88
+
89
+ # Fast neuronal processing
90
+ self.neuron_fast = nn.Linear(d_model, d_model)
91
+ self.neuron_dropout = nn.Dropout(0.1)
92
+
93
+ # Slow astrocyte processing
94
+ self.astrocyte_slow = nn.Linear(d_model, self.d_astrocyte)
95
+ self.astrocyte_integration = nn.Linear(self.d_astrocyte, d_model)
96
+ self.astrocyte_dropout = nn.Dropout(0.1)
97
+
98
+ # Cognitive momentum
99
+ self.momentum_engine = CognitiveMomentumEngine(d_model)
100
+
101
+ # Multi-timescale gates
102
+ self.fast_gate = nn.Linear(d_model, d_model)
103
+ self.slow_gate = nn.Linear(self.d_astrocyte, d_model)
104
+
105
+ # Memory for slow dynamics
106
+ self.register_buffer('astrocyte_memory', torch.zeros(1, self.d_astrocyte))
107
+ self.memory_decay = 0.9
108
+
109
+ def forward(self, x):
110
+ batch_size = x.size(0) if x.dim() == 3 else 1
111
+ if x.dim() == 2:
112
+ x = x.unsqueeze(0)
113
+
114
+ if self.astrocyte_memory.size(0) != batch_size:
115
+ self.astrocyte_memory = torch.zeros(batch_size, self.d_astrocyte, device=x.device)
116
+
117
+ # Apply cognitive momentum
118
+ x_momentum = self.momentum_engine(x)
119
+
120
+ # Fast neuronal response
121
+ fast_out = self.neuron_dropout(torch.tanh(self.neuron_fast(x_momentum)))
122
+
123
+ # Slow astrocyte integration
124
+ astrocyte_input = self.astrocyte_slow(x_momentum)
125
+ self.astrocyte_memory = self.memory_decay * self.astrocyte_memory + (1 - self.memory_decay) * astrocyte_input.mean(dim=1)
126
+ slow_out = self.astrocyte_dropout(torch.tanh(self.astrocyte_integration(self.astrocyte_memory))).unsqueeze(1).expand(-1, x.size(1), -1)
127
+
128
+ # Multi-timescale gating
129
+ fast_gate = torch.sigmoid(self.fast_gate(x_momentum))
130
+ slow_gate = torch.sigmoid(self.slow_gate(self.astrocyte_memory)).unsqueeze(1).expand(-1, x.size(1), -1)
131
+
132
+ # Combine with momentum
133
+ output = fast_gate * fast_out + slow_gate * slow_out
134
+
135
+ return output.squeeze(0) if output.size(0) == 1 else output
136
+
137
+ class PhysicsInformedMamba(nn.Module):
138
+ """Mamba with physics constraints and momentum"""
139
+ def __init__(self, d_model, d_state=8):
140
+ super().__init__()
141
+ self.d_model = d_model
142
+ self.d_inner = d_model * 2
143
+ self.d_state = d_state
144
+
145
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
146
+ self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, 4, groups=self.d_inner, padding=3)
147
+ self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
148
+ self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
149
+
150
+ # Physics constraints
151
+ A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1)
152
+ self.A_log = nn.Parameter(torch.log(A))
153
+ self.D = nn.Parameter(torch.ones(self.d_inner))
154
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
155
+
156
+ # Energy conservation
157
+ self.energy_projection = nn.Linear(d_model, d_model)
158
+
159
+ def forward(self, x):
160
+ if x.dim() == 2:
161
+ x = x.unsqueeze(0)
162
+
163
+ batch, length, _ = x.shape
164
+
165
+ # Energy conservation
166
+ total_energy = x.norm(dim=-1, keepdim=True)
167
+
168
+ xz = self.in_proj(x)
169
+ x_inner, z = xz.chunk(2, dim=-1)
170
+
171
+ # Convolution
172
+ x_inner = x_inner.transpose(1, 2)
173
+ x_inner = self.conv1d(x_inner)[:, :, :length]
174
+ x_inner = x_inner.transpose(1, 2)
175
+ x_inner = F.silu(x_inner)
176
+
177
+ # State space with physics
178
+ y = self.selective_scan(x_inner)
179
+ y = y * F.silu(z)
180
+
181
+ # Apply energy conservation
182
+ output = self.out_proj(y)
183
+ output_energy = output.norm(dim=-1, keepdim=True)
184
+ energy_scale = total_energy / (output_energy + 1e-8)
185
+ output = output * energy_scale
186
+
187
+ return output
188
+
189
+ def selective_scan(self, x):
190
+ batch, length, d_inner = x.shape
191
+
192
+ deltaBC = self.x_proj(x)
193
+ delta, B, C = torch.split(deltaBC, [1, self.d_state, self.d_state], dim=-1)
194
+ delta = F.softplus(self.dt_proj(delta))
195
+
196
+ deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log)))
197
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
198
+
199
+ states = torch.zeros(batch, d_inner, self.d_state, device=x.device)
200
+ outputs = []
201
+
202
+ for i in range(length):
203
+ states = deltaA[:, i] * states + deltaB[:, i] * x[:, i, :, None]
204
+ y = (states @ C[:, i, :, None]).squeeze(-1) + self.D * x[:, i]
205
+ outputs.append(y)
206
+
207
+ return torch.stack(outputs, dim=1)
208
+
209
+ class CognitiveMambaGraphMamba(nn.Module):
210
+ """Revolutionary cognitive momentum architecture"""
211
+ def __init__(self, config):
212
+ super().__init__()
213
+
214
+ self.config = config
215
+ d_model = config['model']['d_model']
216
+ n_layers = config['model']['n_layers']
217
+ input_dim = config.get('input_dim', 1433)
218
+
219
+ # Input processing
220
+ self.input_proj = nn.Linear(input_dim, d_model)
221
+ self.input_norm = nn.LayerNorm(d_model)
222
+
223
+ # GCN backbone for graph structure
224
+ self.gcn_layers = nn.ModuleList([
225
+ GCNConv(d_model, d_model) for _ in range(n_layers)
226
+ ])
227
+
228
+ # Revolutionary components
229
+ self.astrocyte_layers = nn.ModuleList([
230
+ AstrocyteLayer(d_model) for _ in range(n_layers)
231
+ ])
232
+
233
+ self.physics_mamba = PhysicsInformedMamba(d_model)
234
+
235
+ # Global cognitive momentum
236
+ self.global_momentum = CognitiveMomentumEngine(d_model)
237
+
238
+ # Layer norms
239
+ self.norms = nn.ModuleList([
240
+ nn.LayerNorm(d_model) for _ in range(n_layers)
241
+ ])
242
+
243
+ # Multi-path fusion
244
+ self.fusion_weights = nn.Parameter(torch.tensor([0.4, 0.3, 0.3])) # GCN, Astrocyte, Mamba
245
+
246
+ self.dropout = nn.Dropout(0.1)
247
+ self.classifier = None
248
+
249
+ def forward(self, x, edge_index, batch=None):
250
+ # Input processing
251
+ h = self.input_norm(self.input_proj(x))
252
+
253
+ # Multi-path processing with momentum
254
+ for i in range(len(self.gcn_layers)):
255
+ gcn = self.gcn_layers[i]
256
+ astrocyte = self.astrocyte_layers[i]
257
+ norm = self.norms[i]
258
+ # Path 1: GCN (graph structure)
259
+ h_gcn = F.relu(gcn(h, edge_index))
260
+ h_gcn = self.dropout(h_gcn)
261
+
262
+ # Path 2: Astrocyte (multi-timescale with momentum)
263
+ h_astrocyte = astrocyte(h.unsqueeze(0)).squeeze(0)
264
+
265
+ # Path 3: Physics-informed Mamba (sequential with physics)
266
+ h_mamba = self.physics_mamba(h.unsqueeze(0)).squeeze(0)
267
+
268
+ # Apply global cognitive momentum
269
+ h_combined = torch.stack([h_gcn, h_astrocyte, h_mamba], dim=0) # (3, nodes, features)
270
+ h_combined = h_combined.permute(1, 0, 2) # (nodes, 3, features)
271
+ h_momentum = self.global_momentum(h_combined.unsqueeze(0)).squeeze(0) # (nodes, 3, features)
272
+ h_momentum = h_momentum.mean(dim=1) # (nodes, features)
273
+
274
+ # Weighted fusion
275
+ weights = F.softmax(self.fusion_weights, dim=0)
276
+ h_fused = weights[0] * h_gcn + weights[1] * h_astrocyte + weights[2] * h_mamba + h_momentum * 0.1
277
+
278
+ # Residual + norm
279
+ h = norm(h + h_fused)
280
+
281
+ return h
282
+
283
+ def _init_classifier(self, num_classes, device):
284
+ if self.classifier is None:
285
+ self.classifier = nn.Sequential(
286
+ nn.Dropout(0.1),
287
+ nn.Linear(self.config['model']['d_model'], num_classes)
288
+ ).to(device)
289
+
290
+ def get_performance_stats(self):
291
+ total_params = sum(p.numel() for p in self.parameters())
292
+ return {
293
+ 'total_params': total_params,
294
+ 'device': next(self.parameters()).device,
295
+ 'dtype': next(self.parameters()).dtype,
296
+ 'model_size': f"{total_params/1000:.1f}K parameters"
297
+ }
298
+
299
+ class LegacyGraphMamba(nn.Module):
300
+ """Fallback simple version"""
301
+ def __init__(self, config):
302
+ super().__init__()
303
+ self.cognitive_mamba = CognitiveMambaGraphMamba(config)
304
+ self.config = config
305
+ self.classifier = None
306
+
307
+ def forward(self, x, edge_index, batch=None):
308
+ return self.cognitive_mamba(x, edge_index, batch)
309
+
310
+ def _init_classifier(self, num_classes, device):
311
+ self.classifier = nn.Sequential(
312
+ nn.Dropout(0.1),
313
+ nn.Linear(self.config['model']['d_model'], num_classes)
314
+ ).to(device)
315
+ self.cognitive_mamba.classifier = self.classifier
316
+ return self.classifier
317
+
318
+ def get_performance_stats(self):
319
+ return self.cognitive_mamba.get_performance_stats()
320
+
321
+ def create_astrocyte_config():
322
+ """Revolutionary cognitive momentum configuration"""
323
+ return {
324
+ 'model': {
325
+ 'd_model': 128,
326
+ 'd_state': 8,
327
+ 'd_conv': 4,
328
+ 'expand': 2,
329
+ 'n_layers': 4,
330
+ 'dropout': 0.1
331
+ },
332
+ 'data': {
333
+ 'batch_size': 1,
334
+ 'test_split': 0.2
335
+ },
336
+ 'training': {
337
+ 'learning_rate': 0.003,
338
+ 'weight_decay': 0.001,
339
+ 'epochs': 500,
340
+ 'patience': 100,
341
+ 'warmup_epochs': 25,
342
+ 'min_lr': 1e-7,
343
+ 'label_smoothing': 0.0,
344
+ 'max_gap': 0.3
345
+ },
346
+ 'ordering': {
347
+ 'strategy': 'none',
348
+ 'preserve_locality': True
349
+ },
350
+ 'input_dim': 1433
351
+ }
352
+
353
+ # Aliases
354
+ AstrocyteGraphMamba = CognitiveMambaGraphMamba
355
+ GraphMamba = CognitiveMambaGraphMamba
356
+ HybridGraphMamba = LegacyGraphMamba
357
+ QuantumEnhancedGraphMamba = LegacyGraphMamba
358
+ create_regularized_config = create_astrocyte_config