kfoughali commited on
Commit
9536020
Β·
verified Β·
1 Parent(s): b74043a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +454 -103
app.py CHANGED
@@ -1,148 +1,499 @@
1
  #!/usr/bin/env python3
2
  """
3
- Enhanced Mamba Graph with structure preservation and interface fix
 
4
  """
5
 
6
  import os
7
  os.environ['OMP_NUM_THREADS'] = '4'
8
 
9
  import torch
 
 
 
 
 
 
 
 
10
  import time
11
- import logging
12
- import threading
13
- import signal
14
- from core.graph_mamba import GraphMamba, HybridGraphMamba, create_regularized_config
15
- from core.trainer import GraphMambaTrainer
16
- from data.loader import GraphDataLoader
17
- from utils.visualization import GraphVisualizer
18
-
19
- logging.basicConfig(level=logging.INFO)
20
- logger = logging.getLogger(__name__)
21
 
22
  def get_device():
 
23
  if torch.cuda.is_available():
24
  device = torch.device('cuda')
25
- logger.info(f"πŸš€ CUDA available - using GPU: {torch.cuda.get_device_name()}")
 
26
  else:
27
  device = torch.device('cpu')
28
- logger.info("πŸ’» Using CPU")
29
  return device
30
 
31
- def run_comprehensive_test():
32
- """Enhanced test with structure preservation"""
33
- print("🧠 Enhanced Mamba Graph Neural Network")
34
- print("=" * 60)
35
-
36
- config = create_regularized_config()
37
- device = get_device()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- try:
40
- # Data loading
41
- print("\nπŸ“Š Loading Cora dataset...")
42
- data_loader = GraphDataLoader()
43
- dataset = data_loader.load_node_classification_data('Cora')
44
- data = dataset[0].to(device)
45
- info = data_loader.get_dataset_info(dataset)
46
 
47
- print(f"βœ… Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges")
 
 
 
48
 
49
- # Test both models
50
- models_to_test = [
51
- ("Enhanced GraphMamba", GraphMamba),
52
- ("Hybrid GraphMamba", HybridGraphMamba)
53
- ]
54
 
55
- results = {}
 
 
56
 
57
- for model_name, model_class in models_to_test:
58
- print(f"\nπŸ—οΈ Testing {model_name}...")
 
 
59
 
60
- model = model_class(config).to(device)
61
- total_params = sum(p.numel() for p in model.parameters())
62
- train_samples = data.train_mask.sum().item()
63
 
64
- print(f" Parameters: {total_params:,} ({total_params/train_samples:.1f} per sample)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Training
67
- trainer = GraphMambaTrainer(model, config, device)
68
- print(f" Strategy: {config['ordering']['strategy']}")
69
 
70
- start_time = time.time()
71
- history = trainer.train_node_classification(data, verbose=False)
72
- training_time = time.time() - start_time
73
 
74
- # Evaluation
75
- test_metrics = trainer.test(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- results[model_name] = {
78
- 'test_acc': test_metrics['test_acc'],
79
- 'val_acc': trainer.best_val_acc,
80
- 'gap': trainer.best_gap,
81
- 'params': total_params,
82
- 'time': training_time
83
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- print(f" βœ… Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)")
86
- print(f" πŸ“Š Validation: {trainer.best_val_acc:.4f}")
87
- print(f" 🎯 Gap: {trainer.best_gap:.4f}")
88
- print(f" ⏱️ Time: {training_time:.1f}s")
89
 
90
- # Comparison
91
- print(f"\nπŸ“ˆ Model Comparison:")
92
- print(f"{'Model':<20} {'Test Acc':<10} {'Val Acc':<10} {'Gap':<8} {'Params':<8}")
93
- print("-" * 60)
 
94
 
95
- for name, result in results.items():
96
- print(f"{name:<20} {result['test_acc']:.4f} {result['val_acc']:.4f} "
97
- f"{result['gap']:>6.3f} {result['params']/1000:.0f}K")
 
 
 
98
 
99
- # Best model
100
- best_model = max(results.items(), key=lambda x: x[1]['test_acc'])
101
- print(f"\nπŸ† Best: {best_model[0]} - {best_model[1]['test_acc']*100:.2f}% accuracy")
 
102
 
103
- # Baseline comparison
104
- baselines = {'Random': 0.143, 'GCN': 0.815, 'GAT': 0.830}
105
- best_acc = best_model[1]['test_acc']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- print(f"\nπŸ“Š vs Baselines:")
108
- for baseline, acc in baselines.items():
109
- diff = best_acc - acc
110
- status = "🟒" if diff > 0 else "πŸ”΄"
111
- print(f" {status} {baseline}: {acc:.3f} (diff: {diff:+.3f})")
112
 
113
- print(f"\n✨ Testing complete! Process staying alive for interface...")
 
 
114
 
115
- except Exception as e:
116
- print(f"οΏ½οΏ½οΏ½ Error: {e}")
117
- print("Process staying alive despite error...")
 
 
 
 
 
 
 
 
 
118
 
119
- def keep_alive():
120
- """Keep process running for interface"""
121
- try:
122
- while True:
123
- time.sleep(60)
124
- except KeyboardInterrupt:
125
- print("\nπŸ‘‹ Shutting down gracefully...")
 
 
 
 
 
 
 
 
 
 
126
 
127
- def run_background():
128
- """Run test in background thread"""
 
 
 
 
 
 
 
129
  try:
130
- run_comprehensive_test()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  except Exception as e:
132
- print(f"Background test error: {e}")
133
- finally:
134
- print("Background test complete, keeping alive...")
 
135
 
136
  if __name__ == "__main__":
137
- # Start test in background thread
138
- test_thread = threading.Thread(target=run_background, daemon=True)
139
- test_thread.start()
140
 
141
- # Keep main thread alive for interface
 
142
  try:
143
- keep_alive()
 
144
  except KeyboardInterrupt:
145
- print("\nExiting...")
146
- except Exception as e:
147
- print(f"Main thread error: {e}")
148
- keep_alive() # Still try to keep alive
 
1
  #!/usr/bin/env python3
2
  """
3
+ FINAL WORKING DEMO - Revolutionary GraphMamba
4
+ All errors fixed, tested and working
5
  """
6
 
7
  import os
8
  os.environ['OMP_NUM_THREADS'] = '4'
9
 
10
  import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch_geometric.datasets import Planetoid
14
+ from torch_geometric.transforms import NormalizeFeatures
15
+ from torch_geometric.nn import GCNConv
16
+ from torch_geometric.utils import to_undirected, add_self_loops
17
+ import torch.optim as optim
18
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
19
  import time
20
+ import numpy as np
21
+ import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
22
 
23
  def get_device():
24
+ """Get best available device"""
25
  if torch.cuda.is_available():
26
  device = torch.device('cuda')
27
+ print(f"πŸš€ Using GPU: {torch.cuda.get_device_name()}")
28
+ torch.cuda.empty_cache()
29
  else:
30
  device = torch.device('cpu')
31
+ print("πŸ’» Using CPU")
32
  return device
33
 
34
+ class SimpleMambaBlock(nn.Module):
35
+ """Working Mamba block - simplified but functional"""
36
+ def __init__(self, d_model, d_state=8):
37
+ super().__init__()
38
+ self.d_model = d_model
39
+ self.d_state = d_state
40
+ self.d_inner = d_model * 2
41
+
42
+ # Core components
43
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
44
+ self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, 3, padding=1, groups=self.d_inner)
45
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
46
+
47
+ # SSM parameters
48
+ self.dt_proj = nn.Linear(self.d_inner, self.d_inner)
49
+ self.B_proj = nn.Linear(self.d_inner, d_state)
50
+ self.C_proj = nn.Linear(self.d_inner, d_state)
51
+
52
+ # A matrix
53
+ A = torch.arange(1, d_state + 1, dtype=torch.float32)
54
+ self.A_log = nn.Parameter(torch.log(A.unsqueeze(0).repeat(self.d_inner, 1)))
55
+ self.D = nn.Parameter(torch.ones(self.d_inner))
56
+
57
+ self.dropout = nn.Dropout(0.1)
58
+
59
+ def forward(self, x):
60
+ B, L, D = x.shape
61
+
62
+ # Project to dual paths
63
+ xz = self.in_proj(x) # (B, L, 2*d_inner)
64
+ x_path, z_path = xz.chunk(2, dim=-1) # Each: (B, L, d_inner)
65
+
66
+ # Conv1d on x_path
67
+ x_conv = x_path.transpose(1, 2) # (B, d_inner, L)
68
+ x_conv = self.conv1d(x_conv) # (B, d_inner, L)
69
+ x_conv = x_conv.transpose(1, 2) # (B, L, d_inner)
70
+ x_conv = F.silu(x_conv)
71
+
72
+ # Simplified SSM
73
+ y = self.simple_ssm(x_conv)
74
+
75
+ # Apply gating
76
+ y = y * F.silu(z_path)
77
+
78
+ # Output projection
79
+ out = self.out_proj(y)
80
+ return self.dropout(out)
81
 
82
+ def simple_ssm(self, x):
83
+ """Simplified SSM implementation that works"""
84
+ B, L, D = x.shape
 
 
 
 
85
 
86
+ # Get SSM parameters
87
+ dt = F.softplus(self.dt_proj(x)) # (B, L, d_inner)
88
+ B_param = self.B_proj(x) # (B, L, d_state)
89
+ C_param = self.C_proj(x) # (B, L, d_state)
90
 
91
+ # Discretize A matrix
92
+ A = -torch.exp(self.A_log) # (d_inner, d_state)
 
 
 
93
 
94
+ # Simple recurrent processing
95
+ h = torch.zeros(B, D, self.d_state, device=x.device)
96
+ outputs = []
97
 
98
+ for t in range(L):
99
+ # Update state
100
+ dA = torch.exp(dt[:, t].unsqueeze(-1) * A.unsqueeze(0)) # (B, d_inner, d_state)
101
+ dB = dt[:, t].unsqueeze(-1) * B_param[:, t].unsqueeze(1) # (B, d_inner, d_state)
102
 
103
+ h = dA * h + dB * x[:, t].unsqueeze(-1) # (B, d_inner, d_state)
 
 
104
 
105
+ # Output
106
+ y = (h * C_param[:, t].unsqueeze(1)).sum(dim=-1) + self.D * x[:, t] # (B, d_inner)
107
+ outputs.append(y)
108
+
109
+ return torch.stack(outputs, dim=1) # (B, L, d_inner)
110
+
111
+ class WorkingGraphMamba(nn.Module):
112
+ """Working GraphMamba implementation"""
113
+ def __init__(self, config):
114
+ super().__init__()
115
+ self.config = config
116
+ d_model = config['model']['d_model']
117
+ n_layers = config['model']['n_layers']
118
+ input_dim = config.get('input_dim', 1433)
119
+
120
+ # Input processing
121
+ self.input_proj = nn.Linear(input_dim, d_model)
122
+ self.input_norm = nn.LayerNorm(d_model)
123
+ self.input_dropout = nn.Dropout(0.2)
124
+
125
+ # Core layers
126
+ self.gcn_layers = nn.ModuleList([
127
+ GCNConv(d_model, d_model) for _ in range(n_layers)
128
+ ])
129
+
130
+ self.mamba_blocks = nn.ModuleList([
131
+ SimpleMambaBlock(d_model) for _ in range(n_layers)
132
+ ])
133
+
134
+ self.layer_norms = nn.ModuleList([
135
+ nn.LayerNorm(d_model) for _ in range(n_layers)
136
+ ])
137
+
138
+ self.dropouts = nn.ModuleList([
139
+ nn.Dropout(0.1) for _ in range(n_layers)
140
+ ])
141
+
142
+ # Output
143
+ self.output_proj = nn.Linear(d_model, d_model)
144
+ self.classifier = None
145
+
146
+ def forward(self, x, edge_index, batch=None):
147
+ # Input processing
148
+ h = self.input_dropout(self.input_norm(self.input_proj(x)))
149
+
150
+ # Process through layers
151
+ for i in range(len(self.gcn_layers)):
152
+ gcn = self.gcn_layers[i]
153
+ mamba = self.mamba_blocks[i]
154
+ norm = self.layer_norms[i]
155
+ dropout = self.dropouts[i]
156
 
157
+ # GCN path
158
+ h_gcn = F.relu(gcn(h, edge_index))
 
159
 
160
+ # Mamba path
161
+ h_mamba = mamba(h.unsqueeze(0)).squeeze(0)
 
162
 
163
+ # Combine and residual
164
+ h_combined = (h_gcn + h_mamba) * 0.5
165
+ h = dropout(norm(h + h_combined))
166
+
167
+ return self.output_proj(h)
168
+
169
+ def init_classifier(self, num_classes):
170
+ """Initialize classifier"""
171
+ self.classifier = nn.Sequential(
172
+ nn.Dropout(0.3),
173
+ nn.Linear(self.config['model']['d_model'], num_classes)
174
+ )
175
+ return self.classifier
176
+
177
+ class SimpleGraphMamba(nn.Module):
178
+ """Simplified fallback version"""
179
+ def __init__(self, config):
180
+ super().__init__()
181
+ self.config = config
182
+ d_model = config['model']['d_model']
183
+ n_layers = config['model']['n_layers']
184
+ input_dim = config.get('input_dim', 1433)
185
+
186
+ self.input_proj = nn.Linear(input_dim, d_model)
187
+ self.layers = nn.ModuleList([
188
+ nn.Sequential(
189
+ GCNConv(d_model, d_model),
190
+ nn.ReLU(),
191
+ nn.Dropout(0.2),
192
+ nn.LayerNorm(d_model)
193
+ ) for _ in range(n_layers)
194
+ ])
195
+
196
+ self.output_proj = nn.Linear(d_model, d_model)
197
+ self.classifier = None
198
+
199
+ def forward(self, x, edge_index, batch=None):
200
+ h = self.input_proj(x)
201
+
202
+ for layer in self.layers:
203
+ gcn, relu, dropout, norm = layer
204
+ h_new = dropout(relu(gcn(h, edge_index)))
205
+ h = norm(h + h_new) # Residual
206
+
207
+ return self.output_proj(h)
208
+
209
+ def init_classifier(self, num_classes):
210
+ self.classifier = nn.Sequential(
211
+ nn.Dropout(0.3),
212
+ nn.Linear(self.config['model']['d_model'], num_classes)
213
+ )
214
+ return self.classifier
215
+
216
+ class EarlyStopping:
217
+ """Early stopping utility"""
218
+ def __init__(self, patience=20, min_delta=0.001):
219
+ self.patience = patience
220
+ self.min_delta = min_delta
221
+ self.counter = 0
222
+ self.best_loss = None
223
+
224
+ def __call__(self, val_loss):
225
+ if self.best_loss is None:
226
+ self.best_loss = val_loss
227
+ elif val_loss < self.best_loss - self.min_delta:
228
+ self.best_loss = val_loss
229
+ self.counter = 0
230
+ else:
231
+ self.counter += 1
232
 
233
+ return self.counter >= self.patience
234
+
235
+ def train_model(model, data, config, device):
236
+ """Complete training function"""
237
+ model = model.to(device)
238
+ data = data.to(device)
239
+
240
+ # Initialize classifier
241
+ num_classes = data.y.max().item() + 1
242
+ model.init_classifier(num_classes)
243
+ model.classifier = model.classifier.to(device)
244
+
245
+ # Optimizer and scheduler
246
+ optimizer = optim.AdamW(
247
+ model.parameters(),
248
+ lr=config['training']['learning_rate'],
249
+ weight_decay=config['training']['weight_decay']
250
+ )
251
+
252
+ scheduler = ReduceLROnPlateau(
253
+ optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6
254
+ )
255
+
256
+ criterion = nn.CrossEntropyLoss()
257
+ early_stopping = EarlyStopping(patience=config['training']['patience'])
258
+
259
+ # Training loop
260
+ history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
261
+ best_val_acc = 0.0
262
+
263
+ print(f"πŸ‹οΈ Training {model.__class__.__name__}...")
264
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
265
+ print(f" Learning rate: {config['training']['learning_rate']}")
266
+
267
+ for epoch in range(config['training']['epochs']):
268
+ # Training
269
+ model.train()
270
+ optimizer.zero_grad()
271
+
272
+ out = model(data.x, data.edge_index)
273
+ logits = model.classifier(out)
274
+ train_loss = criterion(logits[data.train_mask], data.y[data.train_mask])
275
+
276
+ train_loss.backward()
277
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
278
+ optimizer.step()
279
+
280
+ # Calculate accuracies
281
+ with torch.no_grad():
282
+ train_pred = logits[data.train_mask].argmax(dim=1)
283
+ train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
284
+
285
+ val_pred = logits[data.val_mask].argmax(dim=1)
286
+ val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
287
 
288
+ val_loss = criterion(logits[data.val_mask], data.y[data.val_mask]).item()
 
 
 
289
 
290
+ # Update history
291
+ history['train_loss'].append(train_loss.item())
292
+ history['val_loss'].append(val_loss)
293
+ history['train_acc'].append(train_acc)
294
+ history['val_acc'].append(val_acc)
295
 
296
+ # Track best
297
+ if val_acc > best_val_acc:
298
+ best_val_acc = val_acc
299
+
300
+ # Scheduler step
301
+ scheduler.step(val_loss)
302
 
303
+ # Early stopping
304
+ if early_stopping(val_loss):
305
+ print(f" Early stopping at epoch {epoch+1}")
306
+ break
307
 
308
+ # Progress
309
+ if (epoch + 1) % 20 == 0:
310
+ gap = train_acc - val_acc
311
+ print(f" Epoch {epoch+1:3d}: Loss {train_loss.item():.4f} -> {val_loss:.4f} | "
312
+ f"Acc {train_acc:.4f} -> {val_acc:.4f} | Gap {gap:.4f}")
313
+
314
+ return model, history, best_val_acc
315
+
316
+ def test_model(model, data, device):
317
+ """Test the model"""
318
+ model.eval()
319
+ model = model.to(device)
320
+ data = data.to(device)
321
+
322
+ with torch.no_grad():
323
+ out = model(data.x, data.edge_index)
324
+ logits = model.classifier(out)
325
 
326
+ # Test accuracy
327
+ test_pred = logits[data.test_mask].argmax(dim=1)
328
+ test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
 
 
329
 
330
+ # Validation accuracy
331
+ val_pred = logits[data.val_mask].argmax(dim=1)
332
+ val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
333
 
334
+ # Training accuracy
335
+ train_pred = logits[data.train_mask].argmax(dim=1)
336
+ train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
337
+
338
+ gap = train_acc - val_acc
339
+
340
+ return {
341
+ 'test_acc': test_acc,
342
+ 'val_acc': val_acc,
343
+ 'train_acc': train_acc,
344
+ 'gap': gap
345
+ }
346
 
347
+ def create_config():
348
+ """Create working configuration"""
349
+ return {
350
+ 'model': {
351
+ 'd_model': 64,
352
+ 'd_state': 8,
353
+ 'n_layers': 2,
354
+ 'dropout': 0.2
355
+ },
356
+ 'training': {
357
+ 'learning_rate': 0.01,
358
+ 'weight_decay': 0.005,
359
+ 'epochs': 200,
360
+ 'patience': 30
361
+ },
362
+ 'input_dim': 1433
363
+ }
364
 
365
+ def run_complete_test():
366
+ """Run the complete test suite"""
367
+ print("🧠 REVOLUTIONARY MAMBA GRAPH NEURAL NETWORK")
368
+ print("πŸ”₯ Final Working Implementation")
369
+ print("=" * 60)
370
+
371
+ device = get_device()
372
+ start_time = time.time()
373
+
374
  try:
375
+ # Load data
376
+ print("\nπŸ“Š Loading Cora dataset...")
377
+ dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
378
+ data = dataset[0]
379
+
380
+ # Ensure undirected and add self-loops
381
+ data.edge_index = to_undirected(data.edge_index)
382
+ data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.x.size(0))
383
+
384
+ print(f"βœ… Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges")
385
+ print(f" Features: {dataset.num_features}, Classes: {dataset.num_classes}")
386
+ print(f" Train: {data.train_mask.sum()}, Val: {data.val_mask.sum()}, Test: {data.test_mask.sum()}")
387
+
388
+ # Create config
389
+ config = create_config()
390
+
391
+ # Test models
392
+ models_to_test = {
393
+ 'Working GraphMamba': WorkingGraphMamba,
394
+ 'Simple GraphMamba': SimpleGraphMamba
395
+ }
396
+
397
+ results = {}
398
+
399
+ for name, model_class in models_to_test.items():
400
+ print(f"\nπŸ—οΈ Testing {name}...")
401
+
402
+ try:
403
+ # Create and test model
404
+ model = model_class(config)
405
+ total_params = sum(p.numel() for p in model.parameters())
406
+ print(f" Parameters: {total_params:,} ({total_params/data.train_mask.sum().item():.1f} per sample)")
407
+
408
+ # Test forward pass
409
+ model.eval()
410
+ with torch.no_grad():
411
+ h = model(data.x, data.edge_index)
412
+ print(f" Forward pass: {data.x.shape} -> {h.shape} βœ…")
413
+
414
+ # Train model
415
+ trained_model, history, best_val_acc = train_model(model, data, config, device)
416
+
417
+ # Test model
418
+ test_results = test_model(trained_model, data, device)
419
+
420
+ results[name] = {
421
+ 'model': trained_model,
422
+ 'history': history,
423
+ 'test_results': test_results,
424
+ 'params': total_params
425
+ }
426
+
427
+ print(f"βœ… {name} Results:")
428
+ print(f" Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)")
429
+ print(f" Validation: {test_results['val_acc']:.4f}")
430
+ print(f" Overfitting Gap: {test_results['gap']:.4f}")
431
+
432
+ except Exception as e:
433
+ print(f"❌ {name} failed: {str(e)}")
434
+ results[name] = {'error': str(e)}
435
+
436
+ # Summary
437
+ print(f"\n{'='*60}")
438
+ print("πŸ† FINAL RESULTS")
439
+ print(f"{'='*60}")
440
+
441
+ best_acc = 0.0
442
+ best_name = None
443
+
444
+ for name, result in results.items():
445
+ if 'test_results' in result:
446
+ acc = result['test_results']['test_acc']
447
+ gap = result['test_results']['gap']
448
+ params = result['params']
449
+
450
+ print(f"πŸ“Š {name}:")
451
+ print(f" 🎯 Test Accuracy: {acc:.4f} ({acc*100:.2f}%)")
452
+ print(f" πŸ“ˆ Overfitting Gap: {gap:.4f}")
453
+ print(f" πŸ”§ Parameters: {params:,}")
454
+
455
+ if acc > best_acc:
456
+ best_acc = acc
457
+ best_name = name
458
+
459
+ if best_name:
460
+ print(f"\nπŸ† Best Model: {best_name}")
461
+ print(f" 🎯 Accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)")
462
+
463
+ # Baseline comparison
464
+ baselines = {
465
+ 'Random': 1/dataset.num_classes,
466
+ 'MLP': 0.59,
467
+ 'GCN': 0.815,
468
+ 'GAT': 0.830
469
+ }
470
+
471
+ print(f"\nπŸ“ˆ Baseline Comparison:")
472
+ for baseline_name, baseline_acc in baselines.items():
473
+ diff = best_acc - baseline_acc
474
+ status = "🟒" if diff > 0 else ("🟑" if diff > -0.05 else "πŸ”΄")
475
+ print(f" {status} {baseline_name}: {baseline_acc:.3f} (diff: {diff:+.3f})")
476
+
477
+ total_time = time.time() - start_time
478
+ print(f"\n⏱️ Total time: {total_time:.2f}s")
479
+ print(f"✨ Test completed successfully!")
480
+
481
+ return results
482
+
483
  except Exception as e:
484
+ print(f"❌ Test failed: {str(e)}")
485
+ import traceback
486
+ traceback.print_exc()
487
+ return None
488
 
489
  if __name__ == "__main__":
490
+ # Run the test
491
+ results = run_complete_test()
 
492
 
493
+ # Keep alive
494
+ print(f"\n🌐 Process staying alive...")
495
  try:
496
+ while True:
497
+ time.sleep(60)
498
  except KeyboardInterrupt:
499
+ print("\nπŸ‘‹ Goodbye!")