kfoughali commited on
Commit
929a171
·
verified ·
1 Parent(s): 29178ec

Update core/graph_mamba.py

Browse files
Files changed (1) hide show
  1. core/graph_mamba.py +0 -358
core/graph_mamba.py CHANGED
@@ -1,358 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch_geometric.utils import degree, to_dense_adj
5
- from torch_geometric.nn import GCNConv
6
- import math
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- class CognitiveMomentumEngine(nn.Module):
12
- """Core cognitive momentum system from the document"""
13
- def __init__(self, d_model):
14
- super().__init__()
15
- self.d_model = d_model
16
-
17
- # Momentum tracking
18
- self.register_buffer('momentum_vectors', torch.zeros(d_model))
19
- self.register_buffer('cognitive_mass', torch.ones(d_model))
20
- self.register_buffer('kinetic_energy', torch.zeros(d_model))
21
- self.register_buffer('potential_energy', torch.zeros(d_model))
22
-
23
- # Field interactions
24
- self.attraction_projection = nn.Linear(d_model, d_model)
25
- self.repulsion_projection = nn.Linear(d_model, d_model)
26
-
27
- # Crystallization threshold
28
- self.crystallization_threshold = 0.1
29
- self.memory_decay = 0.99
30
-
31
- def update_momentum(self, concept_features, force, dt=0.1):
32
- """Apply cognitive momentum physics"""
33
- # F = ma => a = F/m
34
- acceleration = force / (self.cognitive_mass + 1e-8)
35
-
36
- # Update velocity: v = v₀ + at
37
- current_velocity = self.momentum_vectors / (self.cognitive_mass + 1e-8)
38
- new_velocity = current_velocity + acceleration * dt
39
-
40
- # Update momentum: p = mv
41
- self.momentum_vectors = self.cognitive_mass * new_velocity
42
-
43
- # Update energy
44
- self.kinetic_energy = 0.5 * self.cognitive_mass * (new_velocity ** 2)
45
-
46
- return self.momentum_vectors
47
-
48
- def crystallize_knowledge(self):
49
- """Compress low-momentum concepts"""
50
- low_momentum_mask = torch.abs(self.momentum_vectors) < self.crystallization_threshold
51
-
52
- # Compress crystallized knowledge
53
- crystallized_pattern = self.momentum_vectors[low_momentum_mask].mean()
54
-
55
- # Reset crystallized components
56
- self.momentum_vectors[low_momentum_mask] = crystallized_pattern * 0.1
57
-
58
- return crystallized_pattern
59
-
60
- def forward(self, x):
61
- """Apply momentum to features"""
62
- if x.dim() == 2:
63
- x = x.unsqueeze(0)
64
- batch_size, seq_len, d_model = x.shape
65
-
66
- # Compute forces from feature interactions
67
- attraction_force = self.attraction_projection(x)
68
- repulsion_force = self.repulsion_projection(x)
69
-
70
- # Net force
71
- net_force = attraction_force - repulsion_force * 0.1
72
-
73
- # Simple momentum application
74
- momentum_enhanced = x + net_force * 0.1
75
-
76
- # Crystallize periodically
77
- if torch.rand(1) < 0.1:
78
- self.crystallize_knowledge()
79
-
80
- return momentum_enhanced
81
-
82
- class AstrocyteLayer(nn.Module):
83
- """Multi-timescale processing with momentum"""
84
- def __init__(self, d_model, astrocyte_ratio=2.0):
85
- super().__init__()
86
- self.d_model = d_model
87
- self.d_astrocyte = int(d_model * astrocyte_ratio)
88
-
89
- # Fast neuronal processing
90
- self.neuron_fast = nn.Linear(d_model, d_model)
91
- self.neuron_dropout = nn.Dropout(0.1)
92
-
93
- # Slow astrocyte processing
94
- self.astrocyte_slow = nn.Linear(d_model, self.d_astrocyte)
95
- self.astrocyte_integration = nn.Linear(self.d_astrocyte, d_model)
96
- self.astrocyte_dropout = nn.Dropout(0.1)
97
-
98
- # Cognitive momentum
99
- self.momentum_engine = CognitiveMomentumEngine(d_model)
100
-
101
- # Multi-timescale gates
102
- self.fast_gate = nn.Linear(d_model, d_model)
103
- self.slow_gate = nn.Linear(self.d_astrocyte, d_model)
104
-
105
- # Memory for slow dynamics
106
- self.register_buffer('astrocyte_memory', torch.zeros(1, self.d_astrocyte))
107
- self.memory_decay = 0.9
108
-
109
- def forward(self, x):
110
- batch_size = x.size(0) if x.dim() == 3 else 1
111
- if x.dim() == 2:
112
- x = x.unsqueeze(0)
113
-
114
- if self.astrocyte_memory.size(0) != batch_size:
115
- self.astrocyte_memory = torch.zeros(batch_size, self.d_astrocyte, device=x.device)
116
-
117
- # Apply cognitive momentum
118
- x_momentum = self.momentum_engine(x)
119
-
120
- # Fast neuronal response
121
- fast_out = self.neuron_dropout(torch.tanh(self.neuron_fast(x_momentum)))
122
-
123
- # Slow astrocyte integration
124
- astrocyte_input = self.astrocyte_slow(x_momentum)
125
- self.astrocyte_memory = self.memory_decay * self.astrocyte_memory + (1 - self.memory_decay) * astrocyte_input.mean(dim=1)
126
- slow_out = self.astrocyte_dropout(torch.tanh(self.astrocyte_integration(self.astrocyte_memory))).unsqueeze(1).expand(-1, x.size(1), -1)
127
-
128
- # Multi-timescale gating
129
- fast_gate = torch.sigmoid(self.fast_gate(x_momentum))
130
- slow_gate = torch.sigmoid(self.slow_gate(self.astrocyte_memory)).unsqueeze(1).expand(-1, x.size(1), -1)
131
-
132
- # Combine with momentum
133
- output = fast_gate * fast_out + slow_gate * slow_out
134
-
135
- return output.squeeze(0) if output.size(0) == 1 else output
136
-
137
- class PhysicsInformedMamba(nn.Module):
138
- """Mamba with physics constraints and momentum"""
139
- def __init__(self, d_model, d_state=8):
140
- super().__init__()
141
- self.d_model = d_model
142
- self.d_inner = d_model * 2
143
- self.d_state = d_state
144
-
145
- self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
146
- self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, 4, groups=self.d_inner, padding=3)
147
- self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
148
- self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
149
-
150
- # Physics constraints
151
- A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1)
152
- self.A_log = nn.Parameter(torch.log(A))
153
- self.D = nn.Parameter(torch.ones(self.d_inner))
154
- self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
155
-
156
- # Energy conservation
157
- self.energy_projection = nn.Linear(d_model, d_model)
158
-
159
- def forward(self, x):
160
- if x.dim() == 2:
161
- x = x.unsqueeze(0)
162
-
163
- batch, length, _ = x.shape
164
-
165
- # Energy conservation
166
- total_energy = x.norm(dim=-1, keepdim=True)
167
-
168
- xz = self.in_proj(x)
169
- x_inner, z = xz.chunk(2, dim=-1)
170
-
171
- # Convolution
172
- x_inner = x_inner.transpose(1, 2)
173
- x_inner = self.conv1d(x_inner)[:, :, :length]
174
- x_inner = x_inner.transpose(1, 2)
175
- x_inner = F.silu(x_inner)
176
-
177
- # State space with physics
178
- y = self.selective_scan(x_inner)
179
- y = y * F.silu(z)
180
-
181
- # Apply energy conservation
182
- output = self.out_proj(y)
183
- output_energy = output.norm(dim=-1, keepdim=True)
184
- energy_scale = total_energy / (output_energy + 1e-8)
185
- output = output * energy_scale
186
-
187
- return output
188
-
189
- def selective_scan(self, x):
190
- batch, length, d_inner = x.shape
191
-
192
- deltaBC = self.x_proj(x)
193
- delta, B, C = torch.split(deltaBC, [1, self.d_state, self.d_state], dim=-1)
194
- delta = F.softplus(self.dt_proj(delta))
195
-
196
- deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log)))
197
- deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
198
-
199
- states = torch.zeros(batch, d_inner, self.d_state, device=x.device)
200
- outputs = []
201
-
202
- for i in range(length):
203
- states = deltaA[:, i] * states + deltaB[:, i] * x[:, i, :, None]
204
- y = (states @ C[:, i, :, None]).squeeze(-1) + self.D * x[:, i]
205
- outputs.append(y)
206
-
207
- return torch.stack(outputs, dim=1)
208
-
209
- class CognitiveMambaGraphMamba(nn.Module):
210
- """Revolutionary cognitive momentum architecture"""
211
- def __init__(self, config):
212
- super().__init__()
213
-
214
- self.config = config
215
- d_model = config['model']['d_model']
216
- n_layers = config['model']['n_layers']
217
- input_dim = config.get('input_dim', 1433)
218
-
219
- # Input processing
220
- self.input_proj = nn.Linear(input_dim, d_model)
221
- self.input_norm = nn.LayerNorm(d_model)
222
-
223
- # GCN backbone for graph structure
224
- self.gcn_layers = nn.ModuleList([
225
- GCNConv(d_model, d_model) for _ in range(n_layers)
226
- ])
227
-
228
- # Revolutionary components
229
- self.astrocyte_layers = nn.ModuleList([
230
- AstrocyteLayer(d_model) for _ in range(n_layers)
231
- ])
232
-
233
- self.physics_mamba = PhysicsInformedMamba(d_model)
234
-
235
- # Global cognitive momentum
236
- self.global_momentum = CognitiveMomentumEngine(d_model)
237
-
238
- # Layer norms
239
- self.norms = nn.ModuleList([
240
- nn.LayerNorm(d_model) for _ in range(n_layers)
241
- ])
242
-
243
- # Multi-path fusion
244
- self.fusion_weights = nn.Parameter(torch.tensor([0.4, 0.3, 0.3])) # GCN, Astrocyte, Mamba
245
-
246
- self.dropout = nn.Dropout(0.1)
247
- self.classifier = None
248
-
249
- def forward(self, x, edge_index, batch=None):
250
- # Input processing
251
- h = self.input_norm(self.input_proj(x))
252
-
253
- # Multi-path processing with momentum
254
- for i in range(len(self.gcn_layers)):
255
- gcn = self.gcn_layers[i]
256
- astrocyte = self.astrocyte_layers[i]
257
- norm = self.norms[i]
258
- # Path 1: GCN (graph structure)
259
- h_gcn = F.relu(gcn(h, edge_index))
260
- h_gcn = self.dropout(h_gcn)
261
-
262
- # Path 2: Astrocyte (multi-timescale with momentum)
263
- h_astrocyte = astrocyte(h.unsqueeze(0)).squeeze(0)
264
-
265
- # Path 3: Physics-informed Mamba (sequential with physics)
266
- h_mamba = self.physics_mamba(h.unsqueeze(0)).squeeze(0)
267
-
268
- # Apply global cognitive momentum
269
- h_combined = torch.stack([h_gcn, h_astrocyte, h_mamba], dim=0) # (3, nodes, features)
270
- h_combined = h_combined.permute(1, 0, 2) # (nodes, 3, features)
271
- h_momentum = self.global_momentum(h_combined.unsqueeze(0)).squeeze(0) # (nodes, 3, features)
272
- h_momentum = h_momentum.mean(dim=1) # (nodes, features)
273
-
274
- # Weighted fusion
275
- weights = F.softmax(self.fusion_weights, dim=0)
276
- h_fused = weights[0] * h_gcn + weights[1] * h_astrocyte + weights[2] * h_mamba + h_momentum * 0.1
277
-
278
- # Residual + norm
279
- h = norm(h + h_fused)
280
-
281
- return h
282
-
283
- def _init_classifier(self, num_classes, device):
284
- if self.classifier is None:
285
- self.classifier = nn.Sequential(
286
- nn.Dropout(0.1),
287
- nn.Linear(self.config['model']['d_model'], num_classes)
288
- ).to(device)
289
-
290
- def get_performance_stats(self):
291
- total_params = sum(p.numel() for p in self.parameters())
292
- return {
293
- 'total_params': total_params,
294
- 'device': next(self.parameters()).device,
295
- 'dtype': next(self.parameters()).dtype,
296
- 'model_size': f"{total_params/1000:.1f}K parameters"
297
- }
298
-
299
- class LegacyGraphMamba(nn.Module):
300
- """Fallback simple version"""
301
- def __init__(self, config):
302
- super().__init__()
303
- self.cognitive_mamba = CognitiveMambaGraphMamba(config)
304
- self.config = config
305
- self.classifier = None
306
-
307
- def forward(self, x, edge_index, batch=None):
308
- return self.cognitive_mamba(x, edge_index, batch)
309
-
310
- def _init_classifier(self, num_classes, device):
311
- self.classifier = nn.Sequential(
312
- nn.Dropout(0.1),
313
- nn.Linear(self.config['model']['d_model'], num_classes)
314
- ).to(device)
315
- self.cognitive_mamba.classifier = self.classifier
316
- return self.classifier
317
-
318
- def get_performance_stats(self):
319
- return self.cognitive_mamba.get_performance_stats()
320
-
321
- def create_astrocyte_config():
322
- """Revolutionary cognitive momentum configuration"""
323
- return {
324
- 'model': {
325
- 'd_model': 128,
326
- 'd_state': 8,
327
- 'd_conv': 4,
328
- 'expand': 2,
329
- 'n_layers': 4,
330
- 'dropout': 0.1
331
- },
332
- 'data': {
333
- 'batch_size': 1,
334
- 'test_split': 0.2
335
- },
336
- 'training': {
337
- 'learning_rate': 0.003,
338
- 'weight_decay': 0.001,
339
- 'epochs': 500,
340
- 'patience': 100,
341
- 'warmup_epochs': 25,
342
- 'min_lr': 1e-7,
343
- 'label_smoothing': 0.0,
344
- 'max_gap': 0.3
345
- },
346
- 'ordering': {
347
- 'strategy': 'none',
348
- 'preserve_locality': True
349
- },
350
- 'input_dim': 1433
351
- }
352
-
353
- # Aliases
354
- AstrocyteGraphMamba = CognitiveMambaGraphMamba
355
- GraphMamba = CognitiveMambaGraphMamba
356
- HybridGraphMamba = LegacyGraphMamba
357
- QuantumEnhancedGraphMamba = LegacyGraphMamba
358
- create_regularized_config = create_astrocyte_config