kfoughali commited on
Commit
b74043a
·
verified ·
1 Parent(s): e4e8b6b

Update core/graph_mamba.py

Browse files
Files changed (1) hide show
  1. core/graph_mamba.py +299 -219
core/graph_mamba.py CHANGED
@@ -8,206 +8,173 @@ import logging
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
- class CognitiveMomentumEngine(nn.Module):
12
- """Core cognitive momentum system from the document"""
13
- def __init__(self, d_model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  super().__init__()
15
  self.d_model = d_model
 
 
16
 
17
- # Momentum tracking
18
- self.register_buffer('momentum_vectors', torch.zeros(d_model))
19
- self.register_buffer('cognitive_mass', torch.ones(d_model))
20
- self.register_buffer('kinetic_energy', torch.zeros(d_model))
21
- self.register_buffer('potential_energy', torch.zeros(d_model))
22
-
23
- # Field interactions
24
- self.attraction_projection = nn.Linear(d_model, d_model)
25
- self.repulsion_projection = nn.Linear(d_model, d_model)
26
 
27
- # Crystallization threshold
28
- self.crystallization_threshold = 0.1
29
- self.memory_decay = 0.99
 
30
 
31
- def update_momentum(self, concept_features, force, dt=0.1):
32
- """Apply cognitive momentum physics"""
33
- # F = ma => a = F/m
34
- acceleration = force / (self.cognitive_mass + 1e-8)
 
35
 
36
- # Update velocity: v = v₀ + at
37
- current_velocity = self.momentum_vectors / (self.cognitive_mass + 1e-8)
38
- new_velocity = current_velocity + acceleration * dt
39
 
40
- # Update momentum: p = mv
41
- self.momentum_vectors = self.cognitive_mass * new_velocity
42
 
43
- # Update energy
44
- self.kinetic_energy = 0.5 * self.cognitive_mass * (new_velocity ** 2)
 
45
 
46
- return self.momentum_vectors
47
-
48
- def crystallize_knowledge(self):
49
- """Compress low-momentum concepts"""
50
- low_momentum_mask = torch.abs(self.momentum_vectors) < self.crystallization_threshold
51
 
52
- # Compress crystallized knowledge
53
- crystallized_pattern = self.momentum_vectors[low_momentum_mask].mean()
54
 
55
- # Reset crystallized components
56
- self.momentum_vectors[low_momentum_mask] = crystallized_pattern * 0.1
 
57
 
58
- return crystallized_pattern
59
 
60
- def forward(self, x):
61
- """Apply momentum to features"""
62
- if x.dim() == 2:
63
- x = x.unsqueeze(0)
64
- batch_size, seq_len, d_model = x.shape
65
 
66
- # Compute forces from feature interactions
67
- attraction_force = self.attraction_projection(x)
68
- repulsion_force = self.repulsion_projection(x)
 
69
 
70
- # Net force
71
- net_force = attraction_force - repulsion_force * 0.1
 
 
72
 
73
- # Simple momentum application
74
- momentum_enhanced = x + net_force * 0.1
 
75
 
76
- # Crystallize periodically
77
- if torch.rand(1) < 0.1:
78
- self.crystallize_knowledge()
 
 
79
 
80
- return momentum_enhanced
81
 
82
- class AstrocyteLayer(nn.Module):
83
- """Multi-timescale processing with momentum"""
84
- def __init__(self, d_model, astrocyte_ratio=2.0):
85
  super().__init__()
86
  self.d_model = d_model
87
- self.d_astrocyte = int(d_model * astrocyte_ratio)
88
-
89
- # Fast neuronal processing
90
- self.neuron_fast = nn.Linear(d_model, d_model)
91
- self.neuron_dropout = nn.Dropout(0.1)
92
-
93
- # Slow astrocyte processing
94
- self.astrocyte_slow = nn.Linear(d_model, self.d_astrocyte)
95
- self.astrocyte_integration = nn.Linear(self.d_astrocyte, d_model)
96
- self.astrocyte_dropout = nn.Dropout(0.1)
97
 
98
- # Cognitive momentum
99
- self.momentum_engine = CognitiveMomentumEngine(d_model)
 
100
 
101
- # Multi-timescale gates
102
- self.fast_gate = nn.Linear(d_model, d_model)
103
- self.slow_gate = nn.Linear(self.d_astrocyte, d_model)
104
-
105
- # Memory for slow dynamics
106
- self.register_buffer('astrocyte_memory', torch.zeros(1, self.d_astrocyte))
107
- self.memory_decay = 0.9
108
 
109
  def forward(self, x):
110
- batch_size = x.size(0) if x.dim() == 3 else 1
111
  if x.dim() == 2:
112
- x = x.unsqueeze(0)
113
-
114
- if self.astrocyte_memory.size(0) != batch_size:
115
- self.astrocyte_memory = torch.zeros(batch_size, self.d_astrocyte, device=x.device)
116
-
117
- # Apply cognitive momentum
118
- x_momentum = self.momentum_engine(x)
119
-
120
- # Fast neuronal response
121
- fast_out = self.neuron_dropout(torch.tanh(self.neuron_fast(x_momentum)))
122
-
123
- # Slow astrocyte integration
124
- astrocyte_input = self.astrocyte_slow(x_momentum)
125
- self.astrocyte_memory = self.memory_decay * self.astrocyte_memory + (1 - self.memory_decay) * astrocyte_input.mean(dim=1)
126
- slow_out = self.astrocyte_dropout(torch.tanh(self.astrocyte_integration(self.astrocyte_memory))).unsqueeze(1).expand(-1, x.size(1), -1)
127
-
128
- # Multi-timescale gating
129
- fast_gate = torch.sigmoid(self.fast_gate(x_momentum))
130
- slow_gate = torch.sigmoid(self.slow_gate(self.astrocyte_memory)).unsqueeze(1).expand(-1, x.size(1), -1)
131
-
132
- # Combine with momentum
133
- output = fast_gate * fast_out + slow_gate * slow_out
134
-
135
- return output.squeeze(0) if output.size(0) == 1 else output
136
 
137
- class PhysicsInformedMamba(nn.Module):
138
- """Mamba with physics constraints and momentum"""
139
- def __init__(self, d_model, d_state=8):
140
  super().__init__()
141
  self.d_model = d_model
142
- self.d_inner = d_model * 2
143
- self.d_state = d_state
144
 
145
- self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
146
- self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, 4, groups=self.d_inner, padding=3)
147
- self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
148
- self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
149
 
150
- # Physics constraints
151
- A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1)
152
- self.A_log = nn.Parameter(torch.log(A))
153
- self.D = nn.Parameter(torch.ones(self.d_inner))
154
- self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
155
-
156
- # Energy conservation
157
- self.energy_projection = nn.Linear(d_model, d_model)
158
 
159
- def forward(self, x):
160
- if x.dim() == 2:
161
- x = x.unsqueeze(0)
162
 
163
- batch, length, _ = x.shape
 
 
164
 
165
- # Energy conservation
166
- total_energy = x.norm(dim=-1, keepdim=True)
 
167
 
168
- xz = self.in_proj(x)
169
- x_inner, z = xz.chunk(2, dim=-1)
170
 
171
- # Convolution
172
- x_inner = x_inner.transpose(1, 2)
173
- x_inner = self.conv1d(x_inner)[:, :, :length]
174
- x_inner = x_inner.transpose(1, 2)
175
- x_inner = F.silu(x_inner)
176
-
177
- # State space with physics
178
- y = self.selective_scan(x_inner)
179
- y = y * F.silu(z)
180
 
181
- # Apply energy conservation
182
- output = self.out_proj(y)
183
- output_energy = output.norm(dim=-1, keepdim=True)
184
- energy_scale = total_energy / (output_energy + 1e-8)
185
- output = output * energy_scale
186
-
187
- return output
188
-
189
- def selective_scan(self, x):
190
- batch, length, d_inner = x.shape
191
 
192
- deltaBC = self.x_proj(x)
193
- delta, B, C = torch.split(deltaBC, [1, self.d_state, self.d_state], dim=-1)
194
- delta = F.softplus(self.dt_proj(delta))
195
 
196
- deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log)))
197
- deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
198
-
199
- states = torch.zeros(batch, d_inner, self.d_state, device=x.device)
200
- outputs = []
201
-
202
- for i in range(length):
203
- states = deltaA[:, i] * states + deltaB[:, i] * x[:, i, :, None]
204
- y = (states @ C[:, i, :, None]).squeeze(-1) + self.D * x[:, i]
205
- outputs.append(y)
206
-
207
- return torch.stack(outputs, dim=1)
208
 
209
- class CognitiveMambaGraphMamba(nn.Module):
210
- """Revolutionary cognitive momentum architecture"""
211
  def __init__(self, config):
212
  super().__init__()
213
 
@@ -219,129 +186,243 @@ class CognitiveMambaGraphMamba(nn.Module):
219
  # Input processing
220
  self.input_proj = nn.Linear(input_dim, d_model)
221
  self.input_norm = nn.LayerNorm(d_model)
 
222
 
223
- # GCN backbone for graph structure
 
 
 
224
  self.gcn_layers = nn.ModuleList([
225
  GCNConv(d_model, d_model) for _ in range(n_layers)
226
  ])
227
 
228
- # Revolutionary components
229
  self.astrocyte_layers = nn.ModuleList([
230
  AstrocyteLayer(d_model) for _ in range(n_layers)
231
  ])
232
 
233
- self.physics_mamba = PhysicsInformedMamba(d_model)
 
 
234
 
235
- # Global cognitive momentum
236
- self.global_momentum = CognitiveMomentumEngine(d_model)
237
 
238
- # Layer norms
239
- self.norms = nn.ModuleList([
240
  nn.LayerNorm(d_model) for _ in range(n_layers)
241
  ])
242
 
243
- # Multi-path fusion
244
- self.fusion_weights = nn.Parameter(torch.tensor([0.4, 0.3, 0.3])) # GCN, Astrocyte, Mamba
 
 
 
 
 
 
 
 
 
245
 
246
- self.dropout = nn.Dropout(0.1)
247
  self.classifier = None
248
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  def forward(self, x, edge_index, batch=None):
 
 
 
 
 
250
  # Input processing
251
- h = self.input_norm(self.input_proj(x))
252
 
253
- # Multi-path processing with momentum
 
 
 
254
  for i in range(len(self.gcn_layers)):
255
  gcn = self.gcn_layers[i]
256
- astrocyte = self.astrocyte_layers[i]
257
- norm = self.norms[i]
258
- # Path 1: GCN (graph structure)
259
- h_gcn = F.relu(gcn(h, edge_index))
260
- h_gcn = self.dropout(h_gcn)
261
 
262
- # Path 2: Astrocyte (multi-timescale with momentum)
263
- h_astrocyte = astrocyte(h.unsqueeze(0)).squeeze(0)
264
 
265
- # Path 3: Physics-informed Mamba (sequential with physics)
266
- h_mamba = self.physics_mamba(h.unsqueeze(0)).squeeze(0)
267
 
268
- # Apply global cognitive momentum
269
- h_combined = torch.stack([h_gcn, h_astrocyte, h_mamba], dim=0) # (3, nodes, features)
270
- h_combined = h_combined.permute(1, 0, 2) # (nodes, 3, features)
271
- h_momentum = self.global_momentum(h_combined.unsqueeze(0)).squeeze(0) # (nodes, 3, features)
272
- h_momentum = h_momentum.mean(dim=1) # (nodes, features)
273
 
274
- # Weighted fusion
275
- weights = F.softmax(self.fusion_weights, dim=0)
276
- h_fused = weights[0] * h_gcn + weights[1] * h_astrocyte + weights[2] * h_mamba + h_momentum * 0.1
 
277
 
278
- # Residual + norm
279
- h = norm(h + h_fused)
 
 
 
280
 
281
  return h
282
 
283
  def _init_classifier(self, num_classes, device):
284
  if self.classifier is None:
285
  self.classifier = nn.Sequential(
286
- nn.Dropout(0.1),
287
  nn.Linear(self.config['model']['d_model'], num_classes)
288
  ).to(device)
 
289
 
290
  def get_performance_stats(self):
291
  total_params = sum(p.numel() for p in self.parameters())
 
 
292
  return {
293
  'total_params': total_params,
 
294
  'device': next(self.parameters()).device,
295
  'dtype': next(self.parameters()).dtype,
296
  'model_size': f"{total_params/1000:.1f}K parameters"
297
  }
298
 
299
- class LegacyGraphMamba(nn.Module):
300
- """Fallback simple version"""
301
  def __init__(self, config):
302
  super().__init__()
303
- self.cognitive_mamba = CognitiveMambaGraphMamba(config)
304
  self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  self.classifier = None
306
 
307
  def forward(self, x, edge_index, batch=None):
308
- return self.cognitive_mamba(x, edge_index, batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  def _init_classifier(self, num_classes, device):
311
- self.classifier = nn.Sequential(
312
- nn.Dropout(0.1),
313
- nn.Linear(self.config['model']['d_model'], num_classes)
314
- ).to(device)
315
- self.cognitive_mamba.classifier = self.classifier
316
  return self.classifier
317
 
318
  def get_performance_stats(self):
319
- return self.cognitive_mamba.get_performance_stats()
 
 
 
 
 
320
 
321
  def create_astrocyte_config():
322
- """Revolutionary cognitive momentum configuration"""
323
  return {
324
  'model': {
325
- 'd_model': 128,
326
  'd_state': 8,
327
  'd_conv': 4,
328
  'expand': 2,
329
- 'n_layers': 4,
330
- 'dropout': 0.1
331
  },
332
  'data': {
333
  'batch_size': 1,
334
  'test_split': 0.2
335
  },
336
  'training': {
337
- 'learning_rate': 0.003,
338
- 'weight_decay': 0.001,
339
- 'epochs': 500,
340
- 'patience': 100,
341
- 'warmup_epochs': 25,
342
- 'min_lr': 1e-7,
343
  'label_smoothing': 0.0,
344
- 'max_gap': 0.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  },
346
  'ordering': {
347
  'strategy': 'none',
@@ -350,9 +431,8 @@ def create_astrocyte_config():
350
  'input_dim': 1433
351
  }
352
 
353
- # Use simple working version for now
354
- AstrocyteGraphMamba = LegacyGraphMamba
355
- GraphMamba = LegacyGraphMamba
356
- HybridGraphMamba = LegacyGraphMamba
357
- QuantumEnhancedGraphMamba = LegacyGraphMamba
358
- create_regularized_config = create_astrocyte_config
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
+ class GraphDataAugmentation:
12
+ """Enhanced data augmentation for overfitting prevention"""
13
+ @staticmethod
14
+ def augment_features(x, noise_level=0.1, dropout_prob=0.05):
15
+ if not torch.is_tensor(x) or x.size(0) == 0:
16
+ return x
17
+ # Feature noise
18
+ noise = torch.randn_like(x) * noise_level
19
+ x_aug = x + noise
20
+ # Feature masking
21
+ mask = torch.rand(x.shape, device=x.device) > dropout_prob
22
+ return x_aug * mask.float()
23
+
24
+ @staticmethod
25
+ def augment_edges(edge_index, drop_prob=0.1):
26
+ if not torch.is_tensor(edge_index) or edge_index.size(1) == 0:
27
+ return edge_index
28
+ edge_mask = torch.rand(edge_index.size(1), device=edge_index.device) > drop_prob
29
+ return edge_index[:, edge_mask]
30
+
31
+ class SimpleMambaBlock(nn.Module):
32
+ """Simplified Mamba block that actually works"""
33
+ def __init__(self, d_model, d_state=16):
34
  super().__init__()
35
  self.d_model = d_model
36
+ self.d_state = d_state
37
+ self.d_inner = d_model * 2
38
 
39
+ # Core projections
40
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
41
+ self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, 3, groups=self.d_inner, padding=1)
42
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
 
 
 
 
 
43
 
44
+ # State space parameters
45
+ self.dt_proj = nn.Linear(self.d_inner, self.d_inner, bias=True)
46
+ self.B_proj = nn.Linear(self.d_inner, d_state, bias=False)
47
+ self.C_proj = nn.Linear(self.d_inner, d_state, bias=False)
48
 
49
+ # Initialize A matrix
50
+ A = torch.arange(1, d_state + 1, dtype=torch.float32)
51
+ A = A.unsqueeze(0).repeat(self.d_inner, 1)
52
+ self.A_log = nn.Parameter(torch.log(A))
53
+ self.D = nn.Parameter(torch.ones(self.d_inner))
54
 
55
+ self.dropout = nn.Dropout(0.1)
 
 
56
 
57
+ def forward(self, x):
58
+ batch_size, seq_len, d_model = x.shape
59
 
60
+ # Dual path
61
+ xz = self.in_proj(x) # (B, L, 2*d_inner)
62
+ x_inner, z = xz.chunk(2, dim=-1) # Each: (B, L, d_inner)
63
 
64
+ # Convolution
65
+ x_conv = x_inner.transpose(1, 2) # (B, d_inner, L)
66
+ x_conv = self.conv1d(x_conv) # (B, d_inner, L)
67
+ x_conv = x_conv.transpose(1, 2) # (B, L, d_inner)
68
+ x_conv = F.silu(x_conv)
69
 
70
+ # State space
71
+ y = self.selective_scan(x_conv)
72
 
73
+ # Gate and output
74
+ y = y * F.silu(z)
75
+ output = self.out_proj(y)
76
 
77
+ return self.dropout(output)
78
 
79
+ def selective_scan(self, x):
80
+ """Simplified selective scan"""
81
+ batch_size, seq_len, d_inner = x.shape
 
 
82
 
83
+ # Get parameters
84
+ dt = F.softplus(self.dt_proj(x)) # (B, L, d_inner)
85
+ B = self.B_proj(x) # (B, L, d_state)
86
+ C = self.C_proj(x) # (B, L, d_state)
87
 
88
+ # Discretize A
89
+ A = -torch.exp(self.A_log) # (d_inner, d_state)
90
+ deltaA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B, L, d_inner, d_state)
91
+ deltaB = dt.unsqueeze(-1) * B.unsqueeze(2) # (B, L, d_inner, d_state)
92
 
93
+ # Initialize state
94
+ h = torch.zeros(batch_size, d_inner, self.d_state, device=x.device)
95
+ outputs = []
96
 
97
+ # Sequential processing
98
+ for i in range(seq_len):
99
+ h = deltaA[:, i] * h + deltaB[:, i] * x[:, i].unsqueeze(-1)
100
+ y = torch.sum(h * C[:, i].unsqueeze(1), dim=-1) + self.D * x[:, i]
101
+ outputs.append(y)
102
 
103
+ return torch.stack(outputs, dim=1)
104
 
105
+ class CognitiveMomentumEngine(nn.Module):
106
+ """Simplified cognitive momentum"""
107
+ def __init__(self, d_model):
108
  super().__init__()
109
  self.d_model = d_model
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # Momentum projections
112
+ self.momentum_proj = nn.Linear(d_model, d_model)
113
+ self.force_proj = nn.Linear(d_model, d_model)
114
 
115
+ # Memory
116
+ self.register_buffer('momentum_state', torch.zeros(d_model))
117
+ self.decay = 0.95
 
 
 
 
118
 
119
  def forward(self, x):
 
120
  if x.dim() == 2:
121
+ batch_size, d_model = x.shape
122
+ # Global momentum update
123
+ force = self.force_proj(x.mean(dim=0))
124
+ self.momentum_state = self.decay * self.momentum_state + (1 - self.decay) * force
125
+
126
+ # Apply momentum
127
+ momentum_effect = self.momentum_proj(self.momentum_state).unsqueeze(0).expand(batch_size, -1)
128
+ return x + momentum_effect * 0.1
129
+ else:
130
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ class AstrocyteLayer(nn.Module):
133
+ """Simplified astrocyte processing"""
134
+ def __init__(self, d_model):
135
  super().__init__()
136
  self.d_model = d_model
137
+ self.d_astrocyte = d_model
 
138
 
139
+ # Fast pathway
140
+ self.fast_proj = nn.Linear(d_model, d_model)
141
+ self.fast_dropout = nn.Dropout(0.1)
 
142
 
143
+ # Slow pathway
144
+ self.slow_proj = nn.Linear(d_model, self.d_astrocyte)
145
+ self.slow_integrate = nn.Linear(self.d_astrocyte, d_model)
146
+ self.slow_dropout = nn.Dropout(0.1)
 
 
 
 
147
 
148
+ # Gating
149
+ self.gate = nn.Linear(d_model * 2, d_model)
 
150
 
151
+ # Memory
152
+ self.register_buffer('slow_memory', torch.zeros(self.d_astrocyte))
153
+ self.memory_decay = 0.9
154
 
155
+ def forward(self, x):
156
+ if x.dim() == 3:
157
+ x = x.squeeze(0)
158
 
159
+ batch_size = x.size(0)
 
160
 
161
+ # Fast processing
162
+ fast_out = self.fast_dropout(F.relu(self.fast_proj(x)))
 
 
 
 
 
 
 
163
 
164
+ # Slow processing with memory
165
+ slow_input = self.slow_proj(x.mean(dim=0))
166
+ self.slow_memory = self.memory_decay * self.slow_memory + (1 - self.memory_decay) * slow_input
167
+ slow_out = self.slow_dropout(F.relu(self.slow_integrate(self.slow_memory)))
168
+ slow_out = slow_out.unsqueeze(0).expand(batch_size, -1)
 
 
 
 
 
169
 
170
+ # Combine
171
+ combined = torch.cat([fast_out, slow_out], dim=-1)
172
+ gated = torch.sigmoid(self.gate(combined))
173
 
174
+ return fast_out * gated + slow_out * (1 - gated)
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ class RevolutionaryGraphMamba(nn.Module):
177
+ """Complete revolutionary implementation"""
178
  def __init__(self, config):
179
  super().__init__()
180
 
 
186
  # Input processing
187
  self.input_proj = nn.Linear(input_dim, d_model)
188
  self.input_norm = nn.LayerNorm(d_model)
189
+ self.input_dropout = nn.Dropout(0.2)
190
 
191
+ # Data augmentation
192
+ self.augmentation = GraphDataAugmentation()
193
+
194
+ # Core components
195
  self.gcn_layers = nn.ModuleList([
196
  GCNConv(d_model, d_model) for _ in range(n_layers)
197
  ])
198
 
 
199
  self.astrocyte_layers = nn.ModuleList([
200
  AstrocyteLayer(d_model) for _ in range(n_layers)
201
  ])
202
 
203
+ self.mamba_blocks = nn.ModuleList([
204
+ SimpleMambaBlock(d_model) for _ in range(n_layers)
205
+ ])
206
 
207
+ # Cognitive momentum
208
+ self.momentum_engine = CognitiveMomentumEngine(d_model)
209
 
210
+ # Layer processing
211
+ self.layer_norms = nn.ModuleList([
212
  nn.LayerNorm(d_model) for _ in range(n_layers)
213
  ])
214
 
215
+ self.layer_dropouts = nn.ModuleList([
216
+ nn.Dropout(0.1) for _ in range(n_layers)
217
+ ])
218
+
219
+ # Fusion
220
+ self.fusion_weights = nn.Parameter(torch.tensor([0.4, 0.3, 0.3]))
221
+ self.fusion_proj = nn.Linear(d_model * 3, d_model)
222
+
223
+ # Output
224
+ self.output_proj = nn.Linear(d_model, d_model)
225
+ self.output_dropout = nn.Dropout(0.2)
226
 
 
227
  self.classifier = None
228
 
229
+ # Initialize weights
230
+ self.apply(self._init_weights)
231
+
232
+ def _init_weights(self, module):
233
+ if isinstance(module, nn.Linear):
234
+ torch.nn.init.xavier_uniform_(module.weight)
235
+ if module.bias is not None:
236
+ torch.nn.init.zeros_(module.bias)
237
+ elif isinstance(module, nn.LayerNorm):
238
+ torch.nn.init.ones_(module.weight)
239
+ torch.nn.init.zeros_(module.bias)
240
+
241
  def forward(self, x, edge_index, batch=None):
242
+ # Apply data augmentation during training
243
+ if self.training:
244
+ x = self.augmentation.augment_features(x)
245
+ edge_index = self.augmentation.augment_edges(edge_index)
246
+
247
  # Input processing
248
+ h = self.input_dropout(self.input_norm(self.input_proj(x)))
249
 
250
+ # Apply cognitive momentum
251
+ h = self.momentum_engine(h)
252
+
253
+ # Multi-path processing
254
  for i in range(len(self.gcn_layers)):
255
  gcn = self.gcn_layers[i]
256
+ astrocyte = self.astrocyte_layers[i]
257
+ mamba = self.mamba_blocks[i]
258
+ norm = self.layer_norms[i]
259
+ dropout = self.layer_dropouts[i]
 
260
 
261
+ # Path 1: GCN (structural)
262
+ h_gcn = F.relu(gcn(h, edge_index))
263
 
264
+ # Path 2: Astrocyte (temporal)
265
+ h_astrocyte = astrocyte(h)
266
 
267
+ # Path 3: Mamba (sequential)
268
+ h_mamba = mamba(h.unsqueeze(0)).squeeze(0)
 
 
 
269
 
270
+ # Fusion
271
+ h_paths = torch.stack([h_gcn, h_astrocyte, h_mamba], dim=-1) # (nodes, d_model, 3)
272
+ weights = F.softmax(self.fusion_weights, dim=0) # (3,)
273
+ h_fused = torch.sum(h_paths * weights, dim=-1) # (nodes, d_model)
274
 
275
+ # Residual connection
276
+ h = dropout(norm(h + h_fused))
277
+
278
+ # Output processing
279
+ h = self.output_dropout(self.output_proj(h))
280
 
281
  return h
282
 
283
  def _init_classifier(self, num_classes, device):
284
  if self.classifier is None:
285
  self.classifier = nn.Sequential(
286
+ nn.Dropout(0.3),
287
  nn.Linear(self.config['model']['d_model'], num_classes)
288
  ).to(device)
289
+ return self.classifier
290
 
291
  def get_performance_stats(self):
292
  total_params = sum(p.numel() for p in self.parameters())
293
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
294
+
295
  return {
296
  'total_params': total_params,
297
+ 'trainable_params': trainable_params,
298
  'device': next(self.parameters()).device,
299
  'dtype': next(self.parameters()).dtype,
300
  'model_size': f"{total_params/1000:.1f}K parameters"
301
  }
302
 
303
+ class SimpleGraphMamba(nn.Module):
304
+ """Simplified but working version"""
305
  def __init__(self, config):
306
  super().__init__()
 
307
  self.config = config
308
+ d_model = config['model']['d_model']
309
+ n_layers = config['model']['n_layers']
310
+ input_dim = config.get('input_dim', 1433)
311
+
312
+ # Simple architecture
313
+ self.input_proj = nn.Linear(input_dim, d_model)
314
+ self.input_norm = nn.LayerNorm(d_model)
315
+
316
+ # GCN backbone
317
+ self.gcn_layers = nn.ModuleList([
318
+ GCNConv(d_model, d_model) for _ in range(n_layers)
319
+ ])
320
+
321
+ # Enhanced features
322
+ self.enhancements = nn.ModuleList([
323
+ nn.Sequential(
324
+ nn.Linear(d_model, d_model * 2),
325
+ nn.ReLU(),
326
+ nn.Dropout(0.1),
327
+ nn.Linear(d_model * 2, d_model)
328
+ ) for _ in range(n_layers)
329
+ ])
330
+
331
+ self.layer_norms = nn.ModuleList([
332
+ nn.LayerNorm(d_model) for _ in range(n_layers)
333
+ ])
334
+
335
+ self.dropout = nn.Dropout(0.2)
336
  self.classifier = None
337
 
338
  def forward(self, x, edge_index, batch=None):
339
+ h = self.input_norm(self.input_proj(x))
340
+
341
+ for i, (gcn, enhance, norm) in enumerate(zip(self.gcn_layers, self.enhancements, self.layer_norms)):
342
+ # GCN processing
343
+ h_gcn = F.relu(gcn(h, edge_index))
344
+
345
+ # Enhancement
346
+ h_enhanced = enhance(h_gcn)
347
+
348
+ # Residual + norm
349
+ h = norm(h + h_enhanced)
350
+ h = self.dropout(h)
351
+
352
+ return h
353
 
354
  def _init_classifier(self, num_classes, device):
355
+ if self.classifier is None:
356
+ self.classifier = nn.Sequential(
357
+ nn.Dropout(0.3),
358
+ nn.Linear(self.config['model']['d_model'], num_classes)
359
+ ).to(device)
360
  return self.classifier
361
 
362
  def get_performance_stats(self):
363
+ total_params = sum(p.numel() for p in self.parameters())
364
+ return {
365
+ 'total_params': total_params,
366
+ 'device': next(self.parameters()).device,
367
+ 'model_size': f"{total_params/1000:.1f}K parameters"
368
+ }
369
 
370
  def create_astrocyte_config():
371
+ """Optimized configuration"""
372
  return {
373
  'model': {
374
+ 'd_model': 64, # Reduced to prevent overfitting
375
  'd_state': 8,
376
  'd_conv': 4,
377
  'expand': 2,
378
+ 'n_layers': 2, # Reduced layers
379
+ 'dropout': 0.2
380
  },
381
  'data': {
382
  'batch_size': 1,
383
  'test_split': 0.2
384
  },
385
  'training': {
386
+ 'learning_rate': 0.01,
387
+ 'weight_decay': 0.005,
388
+ 'epochs': 200,
389
+ 'patience': 30,
390
+ 'warmup_epochs': 10,
391
+ 'min_lr': 1e-5,
392
  'label_smoothing': 0.0,
393
+ 'max_gap': 0.15
394
+ },
395
+ 'ordering': {
396
+ 'strategy': 'none',
397
+ 'preserve_locality': True
398
+ },
399
+ 'input_dim': 1433
400
+ }
401
+
402
+ def create_regularized_config():
403
+ """Heavily regularized config for small datasets"""
404
+ return {
405
+ 'model': {
406
+ 'd_model': 32, # Very small
407
+ 'd_state': 4,
408
+ 'd_conv': 4,
409
+ 'expand': 2,
410
+ 'n_layers': 2,
411
+ 'dropout': 0.3
412
+ },
413
+ 'data': {
414
+ 'batch_size': 1,
415
+ 'test_split': 0.2
416
+ },
417
+ 'training': {
418
+ 'learning_rate': 0.005,
419
+ 'weight_decay': 0.01,
420
+ 'epochs': 150,
421
+ 'patience': 20,
422
+ 'warmup_epochs': 5,
423
+ 'min_lr': 1e-6,
424
+ 'label_smoothing': 0.1,
425
+ 'max_gap': 0.1
426
  },
427
  'ordering': {
428
  'strategy': 'none',
 
431
  'input_dim': 1433
432
  }
433
 
434
+ # Model aliases
435
+ GraphMamba = RevolutionaryGraphMamba
436
+ AstrocyteGraphMamba = RevolutionaryGraphMamba
437
+ HybridGraphMamba = SimpleGraphMamba # Fallback to simple version
438
+ QuantumEnhancedGraphMamba = SimpleGraphMamba