kfoughali commited on
Commit
0fecf94
Β·
verified Β·
1 Parent(s): f4a2be4

Update core/graph_mamba.py

Browse files
Files changed (1) hide show
  1. core/graph_mamba.py +238 -362
core/graph_mamba.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- Ultra-Regularized GraphMamba - Overfitting Problem Solved
4
- Designed specifically for small training sets like Cora (140 samples)
5
  """
6
 
7
  import torch
@@ -12,9 +12,7 @@ from torch_geometric.datasets import Planetoid
12
  from torch_geometric.transforms import NormalizeFeatures
13
  from torch_geometric.utils import to_undirected, add_self_loops
14
  import torch.optim as optim
15
- from torch.optim.lr_scheduler import ReduceLROnPlateau
16
  import time
17
- import numpy as np
18
 
19
  def get_device():
20
  if torch.cuda.is_available():
@@ -26,347 +24,210 @@ def get_device():
26
  print("πŸ’» Using CPU")
27
  return device
28
 
29
- class TinyMambaBlock(nn.Module):
30
- """Ultra-small Mamba block for small datasets"""
31
- def __init__(self, d_model, d_state=4):
32
  super().__init__()
33
- self.d_model = d_model
34
- self.d_state = d_state
35
- self.d_inner = d_model # No expansion to reduce parameters
36
 
37
- # Minimal projections
38
- self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
39
- self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
 
 
 
 
40
 
41
- # Tiny SSM
42
- self.dt_proj = nn.Linear(self.d_inner, self.d_inner, bias=False)
43
- self.B_proj = nn.Linear(self.d_inner, d_state, bias=False)
44
- self.C_proj = nn.Linear(self.d_inner, d_state, bias=False)
45
 
46
- # Minimal A matrix
47
- A = torch.arange(1, d_state + 1, dtype=torch.float32)
48
- self.A_log = nn.Parameter(torch.log(A.unsqueeze(0).repeat(self.d_inner, 1)))
49
- self.D = nn.Parameter(torch.ones(self.d_inner))
 
 
50
 
51
- # Heavy regularization
52
- self.dropout = nn.Dropout(0.7) # Very aggressive dropout
 
 
 
53
 
54
- def forward(self, x):
55
- B, L, D = x.shape
56
 
57
- # Dual path with heavy dropout
58
- xz = self.dropout(self.in_proj(x))
59
- x_path, z_path = xz.chunk(2, dim=-1)
60
 
61
- # Simple activation
62
- x_path = F.silu(x_path)
63
 
64
- # Ultra-simple SSM (just a weighted sum)
65
- dt = torch.sigmoid(self.dt_proj(x_path))
66
- B_param = self.B_proj(x_path)
67
- C_param = self.C_proj(x_path)
68
 
69
- # Simplified state update
70
- y = x_path * dt + B_param @ C_param.transpose(-1, -2)
71
 
72
- # Gate and output
73
- y = y * F.silu(z_path)
74
- return self.dropout(self.out_proj(y))
75
 
76
- class UltraRegularizedGraphMamba(nn.Module):
77
- """Ultra-regularized version for small datasets"""
78
- def __init__(self, config):
79
  super().__init__()
80
- self.config = config
81
- d_model = config['model']['d_model']
82
- n_layers = config['model']['n_layers']
83
- input_dim = config.get('input_dim', 1433)
84
-
85
- # Aggressive dimensionality reduction
86
- self.input_proj = nn.Sequential(
87
- nn.Linear(input_dim, d_model * 4),
88
  nn.ReLU(),
89
- nn.Dropout(0.8), # Very aggressive
90
- nn.Linear(d_model * 4, d_model),
91
- nn.LayerNorm(d_model)
92
  )
93
 
94
- # Core layers with heavy regularization
95
- self.gcn_layers = nn.ModuleList([
96
- GCNConv(d_model, d_model) for _ in range(n_layers)
97
- ])
98
-
99
- self.mamba_blocks = nn.ModuleList([
100
- TinyMambaBlock(d_model) for _ in range(n_layers)
101
- ])
102
-
103
- self.layer_norms = nn.ModuleList([
104
- nn.LayerNorm(d_model) for _ in range(n_layers)
105
- ])
106
-
107
- # Massive dropout for regularization
108
- self.dropouts = nn.ModuleList([
109
- nn.Dropout(0.8) for _ in range(n_layers) # 80% dropout
110
- ])
111
-
112
- # Lightweight output
113
- self.output_proj = nn.Sequential(
114
- nn.Dropout(0.7),
115
- nn.Linear(d_model, d_model // 2),
116
  nn.ReLU(),
117
- nn.Dropout(0.7),
118
- nn.Linear(d_model // 2, d_model)
119
  )
120
 
121
- self.classifier = None
 
 
 
 
122
 
123
- def forward(self, x, edge_index, batch=None):
124
- # Input with heavy regularization
125
- h = self.input_proj(x)
126
 
127
- # Process through layers
128
- for i in range(len(self.gcn_layers)):
129
- gcn = self.gcn_layers[i]
130
- mamba = self.mamba_blocks[i]
131
- norm = self.layer_norms[i]
132
- dropout = self.dropouts[i]
133
-
134
- # Skip connection from input
135
- residual = h
136
-
137
- # GCN path with dropout
138
- h_gcn = dropout(F.relu(gcn(h, edge_index)))
139
-
140
- # Mamba path with dropout
141
- h_mamba = mamba(h.unsqueeze(0)).squeeze(0)
142
-
143
- # Minimal combination to reduce parameters
144
- h_combined = h_gcn * 0.7 + h_mamba * 0.3
145
-
146
- # Strong residual connection
147
- h = norm(residual + h_combined * 0.3) # Small update
148
-
149
- return self.output_proj(h)
150
-
151
- def init_classifier(self, num_classes):
152
- """Ultra-lightweight classifier"""
153
- self.classifier = nn.Sequential(
154
- nn.Dropout(0.8), # Even more dropout in classifier
155
- nn.Linear(self.config['model']['d_model'], num_classes)
156
- )
157
- return self.classifier
158
 
159
- class MinimalGraphMamba(nn.Module):
160
- """Absolute minimal version"""
161
- def __init__(self, config):
162
  super().__init__()
163
- self.config = config
164
- d_model = config['model']['d_model']
165
- input_dim = config.get('input_dim', 1433)
166
 
167
- # Ultra-simple architecture
168
- self.encoder = nn.Sequential(
169
- nn.Linear(input_dim, d_model * 2),
170
- nn.ReLU(),
171
- nn.Dropout(0.8),
172
- nn.Linear(d_model * 2, d_model),
173
- nn.LayerNorm(d_model)
174
  )
175
 
176
- # Just one GCN layer
177
- self.gcn = GCNConv(d_model, d_model)
178
-
179
- # Simple enhancement
180
- self.enhance = nn.Sequential(
181
- nn.Dropout(0.7),
182
- nn.Linear(d_model, d_model),
183
- nn.ReLU(),
184
- nn.Dropout(0.7),
185
- nn.Linear(d_model, d_model)
186
  )
 
187
 
188
- self.norm = nn.LayerNorm(d_model)
189
- self.classifier = None
190
 
191
- def forward(self, x, edge_index, batch=None):
192
- h = self.encoder(x)
193
- h_gcn = F.relu(self.gcn(h, edge_index))
194
- h_enhanced = self.enhance(h_gcn)
195
- return self.norm(h + h_enhanced * 0.2) # Small residual
196
-
197
- def init_classifier(self, num_classes):
198
- self.classifier = nn.Sequential(
199
- nn.Dropout(0.8),
200
- nn.Linear(self.config['model']['d_model'], num_classes)
201
- )
202
- return self.classifier
203
-
204
- def create_ultra_regularized_config():
205
- """Configuration for tiny models"""
206
- return {
207
- 'model': {
208
- 'd_model': 16, # Extremely small
209
- 'd_state': 4,
210
- 'n_layers': 1, # Just one layer
211
- 'dropout': 0.8
212
- },
213
- 'training': {
214
- 'learning_rate': 0.001, # Much smaller LR
215
- 'weight_decay': 0.1, # Massive weight decay
216
- 'epochs': 500, # More epochs with smaller steps
217
- 'patience': 50, # More patience
218
- 'label_smoothing': 0.3 # Label smoothing for regularization
219
- },
220
- 'input_dim': 1433
221
- }
222
-
223
- def create_minimal_config():
224
- """Even smaller configuration"""
225
- return {
226
- 'model': {
227
- 'd_model': 8, # Tiny
228
- 'd_state': 2,
229
- 'n_layers': 1,
230
- 'dropout': 0.9 # Extreme dropout
231
- },
232
- 'training': {
233
- 'learning_rate': 0.0005,
234
- 'weight_decay': 0.2,
235
- 'epochs': 1000,
236
- 'patience': 100,
237
- 'label_smoothing': 0.4
238
- },
239
- 'input_dim': 1433
240
- }
241
 
242
- class SmartTrainer:
243
- """Trainer with extreme regularization"""
244
- def __init__(self, model, config, device):
245
- self.model = model.to(device)
246
- self.config = config
247
- self.device = device
248
-
249
- # Very conservative optimizer
250
- self.optimizer = optim.Adam( # Adam instead of AdamW
251
- model.parameters(),
252
- lr=config['training']['learning_rate'],
253
- weight_decay=config['training']['weight_decay']
254
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- # Aggressive scheduler
257
- self.scheduler = ReduceLROnPlateau(
258
- self.optimizer, mode='min', factor=0.3, patience=20, min_lr=1e-6
259
- )
260
 
261
- # Label smoothing for regularization
262
- label_smoothing = config['training'].get('label_smoothing', 0.0)
263
- self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
264
-
265
- # Early stopping
266
- self.patience = config['training']['patience']
267
- self.best_val_loss = float('inf')
268
- self.patience_counter = 0
269
-
270
- def train(self, data):
271
- print(f"πŸ‹οΈ Ultra-Regularized Training")
272
- print(f" Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
273
- print(f" Per sample: {sum(p.numel() for p in self.model.parameters())/data.train_mask.sum().item():.1f}")
274
- print(f" Learning rate: {self.config['training']['learning_rate']}")
275
- print(f" Weight decay: {self.config['training']['weight_decay']}")
276
-
277
- # Initialize classifier
278
- num_classes = data.y.max().item() + 1
279
- self.model.init_classifier(num_classes)
280
- self.model.classifier = self.model.classifier.to(self.device)
281
-
282
- history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
283
-
284
- for epoch in range(self.config['training']['epochs']):
285
- # Training step
286
- self.model.train()
287
- self.optimizer.zero_grad()
288
-
289
- out = self.model(data.x, data.edge_index)
290
- logits = self.model.classifier(out)
291
- train_loss = self.criterion(logits[data.train_mask], data.y[data.train_mask])
292
-
293
- train_loss.backward()
294
- # Gradient clipping for stability
295
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
296
- self.optimizer.step()
297
-
298
- # Evaluation
299
- self.model.eval()
300
  with torch.no_grad():
301
- out = self.model(data.x, data.edge_index)
302
- logits = self.model.classifier(out)
303
-
304
- val_loss = self.criterion(logits[data.val_mask], data.y[data.val_mask])
305
 
306
- train_pred = logits[data.train_mask].argmax(dim=1)
307
  train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
308
 
309
- val_pred = logits[data.val_mask].argmax(dim=1)
310
  val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
311
-
312
- # Update history
313
- history['train_loss'].append(train_loss.item())
314
- history['val_loss'].append(val_loss.item())
315
- history['train_acc'].append(train_acc)
316
- history['val_acc'].append(val_acc)
317
-
318
- # Scheduler step
319
- self.scheduler.step(val_loss)
320
-
321
- # Early stopping check
322
- if val_loss < self.best_val_loss:
323
- self.best_val_loss = val_loss
324
- self.patience_counter = 0
325
- else:
326
- self.patience_counter += 1
327
 
328
- if self.patience_counter >= self.patience:
329
- print(f" Early stopping at epoch {epoch+1}")
330
- break
331
-
332
- # Progress
333
- if (epoch + 1) % 50 == 0:
334
  gap = train_acc - val_acc
335
- lr = self.optimizer.param_groups[0]['lr']
336
- print(f" Epoch {epoch+1:3d}: Loss {train_loss.item():.4f} -> {val_loss.item():.4f} | "
337
- f"Acc {train_acc:.4f} -> {val_acc:.4f} | Gap {gap:.4f} | LR {lr:.2e}")
338
-
339
- return history
 
 
 
 
 
 
 
 
340
 
341
- def test(self, data):
342
- self.model.eval()
 
 
343
 
344
- with torch.no_grad():
345
- out = self.model(data.x, data.edge_index)
346
- logits = self.model.classifier(out)
347
-
348
- test_pred = logits[data.test_mask].argmax(dim=1)
349
- test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
350
-
351
- val_pred = logits[data.val_mask].argmax(dim=1)
352
- val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
353
-
354
- train_pred = logits[data.train_mask].argmax(dim=1)
355
- train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
356
-
357
- gap = train_acc - val_acc
358
-
359
- return {
360
- 'test_acc': test_acc,
361
- 'val_acc': val_acc,
362
- 'train_acc': train_acc,
363
- 'gap': gap
364
- }
365
 
366
- def run_ultra_regularized_test():
367
- """Run ultra-regularized test"""
368
- print("🧠 ULTRA-REGULARIZED MAMBA GRAPH NEURAL NETWORK")
369
- print("πŸ›‘οΈ Overfitting Problem Solved")
370
  print("=" * 60)
371
 
372
  device = get_device()
@@ -378,90 +239,105 @@ def run_ultra_regularized_test():
378
  data.edge_index = to_undirected(data.edge_index)
379
  data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.x.size(0))
380
 
381
- print(f"βœ… Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges")
382
- print(f" Train: {data.train_mask.sum()} samples (the challenge!)")
383
 
384
- # Test different model sizes
385
- models_to_test = {
386
- 'Ultra-Regularized (16D)': (UltraRegularizedGraphMamba, create_ultra_regularized_config()),
387
- 'Minimal (8D)': (MinimalGraphMamba, create_minimal_config()),
 
388
  }
389
 
390
  results = {}
391
 
392
- for name, (model_class, config) in models_to_test.items():
393
  print(f"\nπŸ—οΈ Testing {name}...")
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  try:
396
- model = model_class(config)
397
- total_params = sum(p.numel() for p in model.parameters())
398
- params_per_sample = total_params / data.train_mask.sum().item()
399
 
400
- print(f" Parameters: {total_params:,} ({params_per_sample:.1f} per sample)")
 
 
 
401
 
402
- if params_per_sample > 200:
403
- print(f" ⚠️ Still might overfit, but much better!")
 
 
 
 
404
  else:
405
- print(f" βœ… Good parameter ratio!")
406
-
407
- # Test forward pass
408
- model.eval()
409
- with torch.no_grad():
410
- h = model(data.x, data.edge_index)
411
- print(f" Forward pass: {data.x.shape} -> {h.shape} βœ…")
412
-
413
- # Train
414
- trainer = SmartTrainer(model, config, device)
415
- history = trainer.train(data)
416
-
417
- # Test
418
- test_results = trainer.test(data)
419
-
420
- results[name] = {
421
- 'params': total_params,
422
- 'params_per_sample': params_per_sample,
423
- 'test_results': test_results,
424
- 'history': history
425
- }
426
-
427
- print(f"βœ… {name} Results:")
428
- print(f" 🎯 Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)")
429
- print(f" πŸ“Š Validation: {test_results['val_acc']:.4f}")
430
- print(f" πŸ›‘οΈ Overfitting Gap: {test_results['gap']:.4f}")
431
-
432
- if test_results['gap'] < 0.2:
433
- print(f" πŸŽ‰ Overfitting under control!")
434
- elif test_results['gap'] < 0.3:
435
- print(f" πŸ‘ Much better overfitting control!")
436
- else:
437
- print(f" ⚠️ Still some overfitting")
438
 
439
  except Exception as e:
440
- print(f"❌ {name} failed: {str(e)}")
441
 
442
- # Summary
443
  print(f"\n{'='*60}")
444
- print("πŸ† ULTRA-REGULARIZED RESULTS")
445
  print(f"{'='*60}")
446
 
 
 
 
447
  for name, result in results.items():
448
- if 'test_results' in result:
449
- tr = result['test_results']
450
- print(f"πŸ“Š {name}:")
451
- print(f" Parameters: {result['params']:,} ({result['params_per_sample']:.1f}/sample)")
452
- print(f" Test Acc: {tr['test_acc']:.4f} | Gap: {tr['gap']:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
- print(f"\nπŸ’‘ Key Insight: With only 140 training samples, we need < 50 parameters per sample!")
455
- print(f"πŸ“ˆ The ultra-regularized models should show much better generalization.")
456
 
457
  return results
458
 
459
  if __name__ == "__main__":
460
- results = run_ultra_regularized_test()
461
 
462
- print(f"\n🌐 Process staying alive...")
463
  try:
464
  while True:
465
  time.sleep(60)
466
  except KeyboardInterrupt:
467
- print("\nπŸ‘‹ Goodbye!")
 
1
  #!/usr/bin/env python3
2
  """
3
+ 🚨 EMERGENCY OVERFITTING FIX 🚨
4
+ Tiny GraphMamba designed specifically for 140 training samples
5
  """
6
 
7
  import torch
 
12
  from torch_geometric.transforms import NormalizeFeatures
13
  from torch_geometric.utils import to_undirected, add_self_loops
14
  import torch.optim as optim
 
15
  import time
 
16
 
17
  def get_device():
18
  if torch.cuda.is_available():
 
24
  print("πŸ’» Using CPU")
25
  return device
26
 
27
+ class EmergencyTinyMamba(nn.Module):
28
+ """Emergency ultra-tiny model for 140 samples"""
29
+ def __init__(self, input_dim=1433, hidden_dim=8, num_classes=7):
30
  super().__init__()
 
 
 
31
 
32
+ # TINY feature extraction
33
+ self.feature_reduce = nn.Sequential(
34
+ nn.Linear(input_dim, 32),
35
+ nn.ReLU(),
36
+ nn.Dropout(0.9), # Extreme dropout
37
+ nn.Linear(32, hidden_dim)
38
+ )
39
 
40
+ # Single GCN layer
41
+ self.gcn = GCNConv(hidden_dim, hidden_dim)
 
 
42
 
43
+ # Tiny "Mamba-inspired" temporal processing
44
+ self.temporal = nn.Sequential(
45
+ nn.Linear(hidden_dim, hidden_dim),
46
+ nn.Tanh(), # Bounded activation
47
+ nn.Dropout(0.9)
48
+ )
49
 
50
+ # Direct classifier
51
+ self.classifier = nn.Sequential(
52
+ nn.Dropout(0.95), # Extreme dropout before classification
53
+ nn.Linear(hidden_dim, num_classes)
54
+ )
55
 
56
+ print(f"🦾 Emergency Model - Parameters: {sum(p.numel() for p in self.parameters()):,}")
 
57
 
58
+ def forward(self, x, edge_index):
59
+ # Feature reduction
60
+ h = self.feature_reduce(x)
61
 
62
+ # Graph convolution
63
+ h_gcn = F.relu(self.gcn(h, edge_index))
64
 
65
+ # Temporal processing (Mamba-inspired)
66
+ h_temporal = self.temporal(h_gcn)
 
 
67
 
68
+ # Small residual connection
69
+ h = h + h_temporal * 0.1 # Very small update
70
 
71
+ # Classification
72
+ return self.classifier(h)
 
73
 
74
+ class MicroMamba(nn.Module):
75
+ """Even smaller model"""
76
+ def __init__(self, input_dim=1433, hidden_dim=4, num_classes=7):
77
  super().__init__()
78
+
79
+ # Ultra-compressed feature extraction
80
+ self.features = nn.Sequential(
81
+ nn.Linear(input_dim, 16),
 
 
 
 
82
  nn.ReLU(),
83
+ nn.Dropout(0.95),
84
+ nn.Linear(16, hidden_dim)
 
85
  )
86
 
87
+ # Minimal processing
88
+ self.process = nn.Sequential(
89
+ GCNConv(hidden_dim, hidden_dim),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  nn.ReLU(),
91
+ nn.Dropout(0.9)
 
92
  )
93
 
94
+ # Direct classification
95
+ self.classify = nn.Sequential(
96
+ nn.Dropout(0.95),
97
+ nn.Linear(hidden_dim, num_classes)
98
+ )
99
 
100
+ print(f"🀏 Micro Model - Parameters: {sum(p.numel() for p in self.parameters()):,}")
 
 
101
 
102
+ def forward(self, x, edge_index):
103
+ h = self.features(x)
104
+ h = self.process[0](h, edge_index) # GCN
105
+ h = self.process[1](h) # ReLU
106
+ h = self.process[2](h) # Dropout
107
+ return self.classify(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ class NanoMamba(nn.Module):
110
+ """Absolutely minimal model"""
111
+ def __init__(self, input_dim=1433, num_classes=7):
112
  super().__init__()
 
 
 
113
 
114
+ # Direct path - no hidden layers
115
+ self.direct = nn.Sequential(
116
+ nn.Linear(input_dim, num_classes),
117
+ nn.Dropout(0.8)
 
 
 
118
  )
119
 
120
+ # GCN path
121
+ self.gcn_path = nn.Sequential(
122
+ nn.Linear(input_dim, 8),
123
+ nn.Dropout(0.9)
 
 
 
 
 
 
124
  )
125
+ self.gcn = GCNConv(8, num_classes)
126
 
127
+ print(f"βš›οΈ Nano Model - Parameters: {sum(p.numel() for p in self.parameters()):,}")
 
128
 
129
+ def forward(self, x, edge_index):
130
+ # Direct classification
131
+ direct_out = self.direct(x)
132
+
133
+ # GCN path
134
+ h = self.gcn_path(x)
135
+ gcn_out = self.gcn(h, edge_index)
136
+
137
+ # Minimal combination
138
+ return direct_out * 0.7 + gcn_out * 0.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ def emergency_train(model, data, device, epochs=2000):
141
+ """Emergency training with extreme regularization"""
142
+ model = model.to(device)
143
+ data = data.to(device)
144
+
145
+ # Very conservative optimizer
146
+ optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.5)
147
+
148
+ # Label smoothing cross entropy
149
+ criterion = nn.CrossEntropyLoss(label_smoothing=0.5)
150
+
151
+ print(f"🚨 Emergency Training Protocol")
152
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
153
+ print(f" Per sample: {sum(p.numel() for p in model.parameters())/140:.1f}")
154
+ print(f" Epochs: {epochs}")
155
+ print(f" Learning rate: 0.001")
156
+ print(f" Weight decay: 0.5")
157
+ print(f" Label smoothing: 0.5")
158
+
159
+ best_val_acc = 0
160
+ patience = 0
161
+
162
+ for epoch in range(epochs):
163
+ # Training
164
+ model.train()
165
+ optimizer.zero_grad()
166
 
167
+ out = model(data.x, data.edge_index)
168
+ loss = criterion(out[data.train_mask], data.y[data.train_mask])
 
 
169
 
170
+ loss.backward()
171
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # Tiny gradients
172
+ optimizer.step()
173
+
174
+ # Evaluation
175
+ if (epoch + 1) % 100 == 0:
176
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  with torch.no_grad():
178
+ out = model(data.x, data.edge_index)
 
 
 
179
 
180
+ train_pred = out[data.train_mask].argmax(dim=1)
181
  train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
182
 
183
+ val_pred = out[data.val_mask].argmax(dim=1)
184
  val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ test_pred = out[data.test_mask].argmax(dim=1)
187
+ test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
188
+
 
 
 
189
  gap = train_acc - val_acc
190
+
191
+ print(f" Epoch {epoch+1:4d}: Train {train_acc:.3f} | Val {val_acc:.3f} | "
192
+ f"Test {test_acc:.3f} | Gap {gap:.3f}")
193
+
194
+ if val_acc > best_val_acc:
195
+ best_val_acc = val_acc
196
+ patience = 0
197
+ else:
198
+ patience += 100
199
+
200
+ if patience >= 500: # Stop if no improvement
201
+ print(f" Early stopping at epoch {epoch+1}")
202
+ break
203
 
204
+ # Final evaluation
205
+ model.eval()
206
+ with torch.no_grad():
207
+ out = model(data.x, data.edge_index)
208
 
209
+ train_pred = out[data.train_mask].argmax(dim=1)
210
+ train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
211
+
212
+ val_pred = out[data.val_mask].argmax(dim=1)
213
+ val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
214
+
215
+ test_pred = out[data.test_mask].argmax(dim=1)
216
+ test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
217
+
218
+ gap = train_acc - val_acc
219
+
220
+ return {
221
+ 'train_acc': train_acc,
222
+ 'val_acc': val_acc,
223
+ 'test_acc': test_acc,
224
+ 'gap': gap
225
+ }
 
 
 
 
226
 
227
+ def run_emergency_fix():
228
+ """Emergency overfitting fix"""
229
+ print("🚨🚨🚨 EMERGENCY OVERFITTING FIX 🚨🚨🚨")
230
+ print("🩹 Ultra-Tiny Models for 140 Training Samples")
231
  print("=" * 60)
232
 
233
  device = get_device()
 
239
  data.edge_index = to_undirected(data.edge_index)
240
  data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.x.size(0))
241
 
242
+ print(f"βœ… Dataset: {data.num_nodes} nodes, Train: {data.train_mask.sum()} samples")
243
+ print(f"🎯 Target: <50 parameters per sample = <7,000 total parameters")
244
 
245
+ # Test emergency models
246
+ models = {
247
+ 'Emergency Tiny (8D)': EmergencyTinyMamba(hidden_dim=8),
248
+ 'Micro (4D)': MicroMamba(hidden_dim=4),
249
+ 'Nano (Direct)': NanoMamba()
250
  }
251
 
252
  results = {}
253
 
254
+ for name, model in models.items():
255
  print(f"\nπŸ—οΈ Testing {name}...")
256
 
257
+ total_params = sum(p.numel() for p in model.parameters())
258
+ params_per_sample = total_params / 140
259
+
260
+ print(f" Parameters: {total_params:,} ({params_per_sample:.1f} per sample)")
261
+
262
+ if params_per_sample < 50:
263
+ print(f" βœ… EXCELLENT parameter ratio!")
264
+ elif params_per_sample < 100:
265
+ print(f" πŸ‘ Good parameter ratio!")
266
+ else:
267
+ print(f" ⚠️ Still might overfit")
268
+
269
+ # Test forward pass
270
+ with torch.no_grad():
271
+ out = model(data.x, data.edge_index)
272
+ print(f" Forward: {data.x.shape} -> {out.shape} βœ…")
273
+
274
  try:
275
+ # Emergency training
276
+ result = emergency_train(model, data, device)
277
+ results[name] = result
278
 
279
+ print(f" 🎯 Final Results:")
280
+ print(f" Test Accuracy: {result['test_acc']:.3f} ({result['test_acc']*100:.1f}%)")
281
+ print(f" Train Accuracy: {result['train_acc']:.3f}")
282
+ print(f" Overfitting Gap: {result['gap']:.3f}")
283
 
284
+ if result['gap'] < 0.1:
285
+ print(f" πŸŽ‰ OVERFITTING SOLVED!")
286
+ elif result['gap'] < 0.2:
287
+ print(f" πŸ‘ Much better generalization!")
288
+ elif result['gap'] < 0.3:
289
+ print(f" πŸ“ˆ Improved generalization")
290
  else:
291
+ print(f" ⚠️ Still overfitting")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  except Exception as e:
294
+ print(f" ❌ Training failed: {e}")
295
 
296
+ # Emergency summary
297
  print(f"\n{'='*60}")
298
+ print("🚨 EMERGENCY RESULTS SUMMARY")
299
  print(f"{'='*60}")
300
 
301
+ best_gap = float('inf')
302
+ best_model = None
303
+
304
  for name, result in results.items():
305
+ print(f"πŸ“Š {name}:")
306
+ print(f" Test: {result['test_acc']:.3f} | Gap: {result['gap']:.3f}")
307
+
308
+ if result['gap'] < best_gap:
309
+ best_gap = result['gap']
310
+ best_model = name
311
+
312
+ if best_model:
313
+ print(f"\nπŸ† Best Generalization: {best_model} (Gap: {best_gap:.3f})")
314
+
315
+ if best_gap < 0.1:
316
+ print(f"πŸŽ‰ MISSION ACCOMPLISHED! Overfitting crisis resolved!")
317
+ elif best_gap < 0.2:
318
+ print(f"πŸ‘ Significant improvement in generalization!")
319
+ else:
320
+ print(f"πŸ“ˆ Progress made, but still work to do...")
321
+
322
+ # Comparison with your current model
323
+ print(f"\nπŸ“ˆ Comparison:")
324
+ print(f" Your model: 194K params, Gap ~0.5")
325
+ if best_model and best_gap < 0.3:
326
+ improvement = 0.5 - best_gap
327
+ print(f" Best tiny model: Gap {best_gap:.3f} (Improvement: {improvement:.3f})")
328
+ print(f" 🎯 {improvement/0.5*100:.0f}% reduction in overfitting!")
329
 
330
+ print(f"\nπŸ’‘ Key Lesson: With only 140 samples, bigger β‰  better!")
331
+ print(f"🧠 Tiny models can achieve competitive performance with much better generalization.")
332
 
333
  return results
334
 
335
  if __name__ == "__main__":
336
+ results = run_emergency_fix()
337
 
338
+ print(f"\n🌐 Emergency fix complete. Process staying alive...")
339
  try:
340
  while True:
341
  time.sleep(60)
342
  except KeyboardInterrupt:
343
+ print("\nπŸ‘‹ Emergency protocol terminated.")