serpent / app.py
kfoughali's picture
Update app.py
9536020 verified
raw
history blame
16.9 kB
#!/usr/bin/env python3
"""
FINAL WORKING DEMO - Revolutionary GraphMamba
All errors fixed, tested and working
"""
import os
os.environ['OMP_NUM_THREADS'] = '4'
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_undirected, add_self_loops
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
import numpy as np
import matplotlib.pyplot as plt
def get_device():
"""Get best available device"""
if torch.cuda.is_available():
device = torch.device('cuda')
print(f"πŸš€ Using GPU: {torch.cuda.get_device_name()}")
torch.cuda.empty_cache()
else:
device = torch.device('cpu')
print("πŸ’» Using CPU")
return device
class SimpleMambaBlock(nn.Module):
"""Working Mamba block - simplified but functional"""
def __init__(self, d_model, d_state=8):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = d_model * 2
# Core components
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, 3, padding=1, groups=self.d_inner)
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
# SSM parameters
self.dt_proj = nn.Linear(self.d_inner, self.d_inner)
self.B_proj = nn.Linear(self.d_inner, d_state)
self.C_proj = nn.Linear(self.d_inner, d_state)
# A matrix
A = torch.arange(1, d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(torch.log(A.unsqueeze(0).repeat(self.d_inner, 1)))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.dropout = nn.Dropout(0.1)
def forward(self, x):
B, L, D = x.shape
# Project to dual paths
xz = self.in_proj(x) # (B, L, 2*d_inner)
x_path, z_path = xz.chunk(2, dim=-1) # Each: (B, L, d_inner)
# Conv1d on x_path
x_conv = x_path.transpose(1, 2) # (B, d_inner, L)
x_conv = self.conv1d(x_conv) # (B, d_inner, L)
x_conv = x_conv.transpose(1, 2) # (B, L, d_inner)
x_conv = F.silu(x_conv)
# Simplified SSM
y = self.simple_ssm(x_conv)
# Apply gating
y = y * F.silu(z_path)
# Output projection
out = self.out_proj(y)
return self.dropout(out)
def simple_ssm(self, x):
"""Simplified SSM implementation that works"""
B, L, D = x.shape
# Get SSM parameters
dt = F.softplus(self.dt_proj(x)) # (B, L, d_inner)
B_param = self.B_proj(x) # (B, L, d_state)
C_param = self.C_proj(x) # (B, L, d_state)
# Discretize A matrix
A = -torch.exp(self.A_log) # (d_inner, d_state)
# Simple recurrent processing
h = torch.zeros(B, D, self.d_state, device=x.device)
outputs = []
for t in range(L):
# Update state
dA = torch.exp(dt[:, t].unsqueeze(-1) * A.unsqueeze(0)) # (B, d_inner, d_state)
dB = dt[:, t].unsqueeze(-1) * B_param[:, t].unsqueeze(1) # (B, d_inner, d_state)
h = dA * h + dB * x[:, t].unsqueeze(-1) # (B, d_inner, d_state)
# Output
y = (h * C_param[:, t].unsqueeze(1)).sum(dim=-1) + self.D * x[:, t] # (B, d_inner)
outputs.append(y)
return torch.stack(outputs, dim=1) # (B, L, d_inner)
class WorkingGraphMamba(nn.Module):
"""Working GraphMamba implementation"""
def __init__(self, config):
super().__init__()
self.config = config
d_model = config['model']['d_model']
n_layers = config['model']['n_layers']
input_dim = config.get('input_dim', 1433)
# Input processing
self.input_proj = nn.Linear(input_dim, d_model)
self.input_norm = nn.LayerNorm(d_model)
self.input_dropout = nn.Dropout(0.2)
# Core layers
self.gcn_layers = nn.ModuleList([
GCNConv(d_model, d_model) for _ in range(n_layers)
])
self.mamba_blocks = nn.ModuleList([
SimpleMambaBlock(d_model) for _ in range(n_layers)
])
self.layer_norms = nn.ModuleList([
nn.LayerNorm(d_model) for _ in range(n_layers)
])
self.dropouts = nn.ModuleList([
nn.Dropout(0.1) for _ in range(n_layers)
])
# Output
self.output_proj = nn.Linear(d_model, d_model)
self.classifier = None
def forward(self, x, edge_index, batch=None):
# Input processing
h = self.input_dropout(self.input_norm(self.input_proj(x)))
# Process through layers
for i in range(len(self.gcn_layers)):
gcn = self.gcn_layers[i]
mamba = self.mamba_blocks[i]
norm = self.layer_norms[i]
dropout = self.dropouts[i]
# GCN path
h_gcn = F.relu(gcn(h, edge_index))
# Mamba path
h_mamba = mamba(h.unsqueeze(0)).squeeze(0)
# Combine and residual
h_combined = (h_gcn + h_mamba) * 0.5
h = dropout(norm(h + h_combined))
return self.output_proj(h)
def init_classifier(self, num_classes):
"""Initialize classifier"""
self.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(self.config['model']['d_model'], num_classes)
)
return self.classifier
class SimpleGraphMamba(nn.Module):
"""Simplified fallback version"""
def __init__(self, config):
super().__init__()
self.config = config
d_model = config['model']['d_model']
n_layers = config['model']['n_layers']
input_dim = config.get('input_dim', 1433)
self.input_proj = nn.Linear(input_dim, d_model)
self.layers = nn.ModuleList([
nn.Sequential(
GCNConv(d_model, d_model),
nn.ReLU(),
nn.Dropout(0.2),
nn.LayerNorm(d_model)
) for _ in range(n_layers)
])
self.output_proj = nn.Linear(d_model, d_model)
self.classifier = None
def forward(self, x, edge_index, batch=None):
h = self.input_proj(x)
for layer in self.layers:
gcn, relu, dropout, norm = layer
h_new = dropout(relu(gcn(h, edge_index)))
h = norm(h + h_new) # Residual
return self.output_proj(h)
def init_classifier(self, num_classes):
self.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(self.config['model']['d_model'], num_classes)
)
return self.classifier
class EarlyStopping:
"""Early stopping utility"""
def __init__(self, patience=20, min_delta=0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
return self.counter >= self.patience
def train_model(model, data, config, device):
"""Complete training function"""
model = model.to(device)
data = data.to(device)
# Initialize classifier
num_classes = data.y.max().item() + 1
model.init_classifier(num_classes)
model.classifier = model.classifier.to(device)
# Optimizer and scheduler
optimizer = optim.AdamW(
model.parameters(),
lr=config['training']['learning_rate'],
weight_decay=config['training']['weight_decay']
)
scheduler = ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6
)
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=config['training']['patience'])
# Training loop
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_val_acc = 0.0
print(f"πŸ‹οΈ Training {model.__class__.__name__}...")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f" Learning rate: {config['training']['learning_rate']}")
for epoch in range(config['training']['epochs']):
# Training
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
logits = model.classifier(out)
train_loss = criterion(logits[data.train_mask], data.y[data.train_mask])
train_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# Calculate accuracies
with torch.no_grad():
train_pred = logits[data.train_mask].argmax(dim=1)
train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
val_pred = logits[data.val_mask].argmax(dim=1)
val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
val_loss = criterion(logits[data.val_mask], data.y[data.val_mask]).item()
# Update history
history['train_loss'].append(train_loss.item())
history['val_loss'].append(val_loss)
history['train_acc'].append(train_acc)
history['val_acc'].append(val_acc)
# Track best
if val_acc > best_val_acc:
best_val_acc = val_acc
# Scheduler step
scheduler.step(val_loss)
# Early stopping
if early_stopping(val_loss):
print(f" Early stopping at epoch {epoch+1}")
break
# Progress
if (epoch + 1) % 20 == 0:
gap = train_acc - val_acc
print(f" Epoch {epoch+1:3d}: Loss {train_loss.item():.4f} -> {val_loss:.4f} | "
f"Acc {train_acc:.4f} -> {val_acc:.4f} | Gap {gap:.4f}")
return model, history, best_val_acc
def test_model(model, data, device):
"""Test the model"""
model.eval()
model = model.to(device)
data = data.to(device)
with torch.no_grad():
out = model(data.x, data.edge_index)
logits = model.classifier(out)
# Test accuracy
test_pred = logits[data.test_mask].argmax(dim=1)
test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
# Validation accuracy
val_pred = logits[data.val_mask].argmax(dim=1)
val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
# Training accuracy
train_pred = logits[data.train_mask].argmax(dim=1)
train_acc = (train_pred == data.y[data.train_mask]).float().mean().item()
gap = train_acc - val_acc
return {
'test_acc': test_acc,
'val_acc': val_acc,
'train_acc': train_acc,
'gap': gap
}
def create_config():
"""Create working configuration"""
return {
'model': {
'd_model': 64,
'd_state': 8,
'n_layers': 2,
'dropout': 0.2
},
'training': {
'learning_rate': 0.01,
'weight_decay': 0.005,
'epochs': 200,
'patience': 30
},
'input_dim': 1433
}
def run_complete_test():
"""Run the complete test suite"""
print("🧠 REVOLUTIONARY MAMBA GRAPH NEURAL NETWORK")
print("πŸ”₯ Final Working Implementation")
print("=" * 60)
device = get_device()
start_time = time.time()
try:
# Load data
print("\nπŸ“Š Loading Cora dataset...")
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
# Ensure undirected and add self-loops
data.edge_index = to_undirected(data.edge_index)
data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.x.size(0))
print(f"βœ… Dataset loaded: {data.num_nodes} nodes, {data.num_edges} edges")
print(f" Features: {dataset.num_features}, Classes: {dataset.num_classes}")
print(f" Train: {data.train_mask.sum()}, Val: {data.val_mask.sum()}, Test: {data.test_mask.sum()}")
# Create config
config = create_config()
# Test models
models_to_test = {
'Working GraphMamba': WorkingGraphMamba,
'Simple GraphMamba': SimpleGraphMamba
}
results = {}
for name, model_class in models_to_test.items():
print(f"\nπŸ—οΈ Testing {name}...")
try:
# Create and test model
model = model_class(config)
total_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {total_params:,} ({total_params/data.train_mask.sum().item():.1f} per sample)")
# Test forward pass
model.eval()
with torch.no_grad():
h = model(data.x, data.edge_index)
print(f" Forward pass: {data.x.shape} -> {h.shape} βœ…")
# Train model
trained_model, history, best_val_acc = train_model(model, data, config, device)
# Test model
test_results = test_model(trained_model, data, device)
results[name] = {
'model': trained_model,
'history': history,
'test_results': test_results,
'params': total_params
}
print(f"βœ… {name} Results:")
print(f" Test Accuracy: {test_results['test_acc']:.4f} ({test_results['test_acc']*100:.2f}%)")
print(f" Validation: {test_results['val_acc']:.4f}")
print(f" Overfitting Gap: {test_results['gap']:.4f}")
except Exception as e:
print(f"❌ {name} failed: {str(e)}")
results[name] = {'error': str(e)}
# Summary
print(f"\n{'='*60}")
print("πŸ† FINAL RESULTS")
print(f"{'='*60}")
best_acc = 0.0
best_name = None
for name, result in results.items():
if 'test_results' in result:
acc = result['test_results']['test_acc']
gap = result['test_results']['gap']
params = result['params']
print(f"πŸ“Š {name}:")
print(f" 🎯 Test Accuracy: {acc:.4f} ({acc*100:.2f}%)")
print(f" πŸ“ˆ Overfitting Gap: {gap:.4f}")
print(f" πŸ”§ Parameters: {params:,}")
if acc > best_acc:
best_acc = acc
best_name = name
if best_name:
print(f"\nπŸ† Best Model: {best_name}")
print(f" 🎯 Accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)")
# Baseline comparison
baselines = {
'Random': 1/dataset.num_classes,
'MLP': 0.59,
'GCN': 0.815,
'GAT': 0.830
}
print(f"\nπŸ“ˆ Baseline Comparison:")
for baseline_name, baseline_acc in baselines.items():
diff = best_acc - baseline_acc
status = "🟒" if diff > 0 else ("🟑" if diff > -0.05 else "πŸ”΄")
print(f" {status} {baseline_name}: {baseline_acc:.3f} (diff: {diff:+.3f})")
total_time = time.time() - start_time
print(f"\n⏱️ Total time: {total_time:.2f}s")
print(f"✨ Test completed successfully!")
return results
except Exception as e:
print(f"❌ Test failed: {str(e)}")
import traceback
traceback.print_exc()
return None
if __name__ == "__main__":
# Run the test
results = run_complete_test()
# Keep alive
print(f"\n🌐 Process staying alive...")
try:
while True:
time.sleep(60)
except KeyboardInterrupt:
print("\nπŸ‘‹ Goodbye!")