serpent / core /graph_mamba.py
kfoughali's picture
Update core/graph_mamba.py
a7a0326 verified
raw
history blame
7.96 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import degree
import networkx as nx
import logging
logger = logging.getLogger(__name__)
class MambaBlock(nn.Module):
"""Heavily regularized Mamba block"""
def __init__(self, d_model, d_state=4, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_inner = int(expand * d_model)
self.d_state = d_state
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1)
self.act = nn.SiLU()
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
# Heavy regularization
self.dropout = nn.Dropout(0.3)
def forward(self, x):
batch, length, d_model = x.shape
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
x = x.transpose(1, 2)
x = self.conv1d(x)[:, :, :length]
x = x.transpose(1, 2)
x = self.act(x)
x = self.dropout(x)
y = self.selective_scan(x)
y = y * self.act(z)
return self.dropout(self.out_proj(y))
def selective_scan(self, x):
batch, length, d_inner = x.shape
deltaBC = self.x_proj(x)
delta, B, C = torch.split(deltaBC, [1, self.d_state, self.d_state], dim=-1)
delta = F.softplus(self.dt_proj(delta))
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log)))
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
states = torch.zeros(batch, d_inner, self.d_state, device=x.device)
outputs = []
for i in range(length):
states = deltaA[:, i] * states + deltaB[:, i] * x[:, i, :, None]
y = (states @ C[:, i, :, None]).squeeze(-1) + self.D * x[:, i]
outputs.append(y)
return torch.stack(outputs, dim=1)
class GraphDataAugmentation:
"""Data augmentation to combat overfitting"""
@staticmethod
def augment_features(x, noise_level=0.1, dropout_prob=0.2):
if x.size(0) == 0:
return x
# Feature noise
noise = torch.randn_like(x) * noise_level
x_aug = x + noise
# Feature dropout
mask = torch.rand(x.shape[0], x.shape[1], device=x.device) > dropout_prob
x_aug = x_aug * mask.float()
return x_aug
@staticmethod
def augment_edges(edge_index, drop_prob=0.1):
if edge_index.size(1) == 0:
return edge_index
# Edge dropout
edge_mask = torch.rand(edge_index.size(1), device=edge_index.device) > drop_prob
return edge_index[:, edge_mask]
class LightStructuralEncoding(nn.Module):
"""Lightweight structural encoding"""
def __init__(self, d_model, max_degree=50):
super().__init__()
self.degree_encoding = nn.Embedding(max_degree, d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.5)
def forward(self, x, edge_index):
num_nodes = x.size(0)
# Only degree encoding (simpler)
degrees = degree(edge_index[0], num_nodes).long().clamp(max=49)
degree_emb = self.degree_encoding(degrees)
# Combine with heavy dropout
combined = self.layer_norm(x + degree_emb)
return self.dropout(combined)
class GraphMamba(nn.Module):
"""Heavily regularized GraphMamba to prevent overfitting"""
def __init__(self, config):
super().__init__()
self.config = config
d_model = config['model']['d_model'] # Should be 64
n_layers = config['model']['n_layers'] # Should be 2
input_dim = config.get('input_dim', 1433)
# Minimal architecture
self.input_proj = nn.Linear(input_dim, d_model)
self.input_dropout = nn.Dropout(0.5)
# Light structural encoding
self.structural_encoding = LightStructuralEncoding(d_model)
# Minimal Mamba layers
self.mamba_layers = nn.ModuleList([
MambaBlock(d_model, d_state=4) for _ in range(n_layers)
])
# Layer norms with dropout
self.layer_norms = nn.ModuleList([
nn.LayerNorm(d_model) for _ in range(n_layers)
])
self.hidden_dropout = nn.Dropout(0.5)
self.output_dropout = nn.Dropout(0.3)
# Simple output
self.output_proj = nn.Linear(d_model, d_model)
# Data augmentation
self.augmentation = GraphDataAugmentation()
# Classifier will be added later
self.classifier = None
def forward(self, x, edge_index, batch=None):
# Apply data augmentation during training
if self.training:
x = self.augmentation.augment_features(x)
edge_index = self.augmentation.augment_edges(edge_index)
# Input projection with dropout
h = self.input_dropout(self.input_proj(x))
# Add minimal structural information
h = self.structural_encoding(h, edge_index)
# Simple BFS ordering only
order = torch.arange(h.size(0), device=h.device)
h_ordered = h[order].unsqueeze(0)
# Process through minimal Mamba layers
for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)):
residual = h_ordered
h_ordered = ln(h_ordered)
h_ordered = residual + mamba(h_ordered)
h_ordered = self.hidden_dropout(h_ordered)
# Restore order and final processing
h_restored = h_ordered.squeeze(0)
h_out = self.output_dropout(self.output_proj(h_restored))
return h_out
def _init_classifier(self, num_classes, device):
"""Initialize heavily regularized classifier"""
if self.classifier is None:
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(self.config['model']['d_model'], num_classes)
).to(device)
def get_performance_stats(self):
"""Get model statistics"""
total_params = sum(p.numel() for p in self.parameters())
return {
'total_params': total_params,
'device': next(self.parameters()).device,
'dtype': next(self.parameters()).dtype,
'model_size': f"{total_params/1000:.1f}K parameters"
}
def create_regularized_config():
"""Create config optimized for small training sets"""
return {
'model': {
'd_model': 64, # Reduced from 128
'd_state': 4, # Reduced from 8
'd_conv': 4,
'expand': 2,
'n_layers': 2, # Reduced from 3
'dropout': 0.5 # Increased from 0.1
},
'data': {
'batch_size': 1, # Full batch for small datasets
'test_split': 0.2
},
'training': {
'learning_rate': 0.0005, # Reduced from 0.001
'weight_decay': 0.01, # High regularization
'epochs': 200,
'patience': 10, # More patient early stopping
'warmup_epochs': 10,
'min_lr': 1e-6
},
'ordering': {
'strategy': 'bfs', # Simple strategy only
'preserve_locality': True
},
'input_dim': 1433
}