import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.utils import degree import networkx as nx import logging logger = logging.getLogger(__name__) class MambaBlock(nn.Module): """Heavily regularized Mamba block""" def __init__(self, d_model, d_state=4, d_conv=4, expand=2): super().__init__() self.d_model = d_model self.d_inner = int(expand * d_model) self.d_state = d_state self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1) self.act = nn.SiLU() self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) self.dt_proj = nn.Linear(1, self.d_inner, bias=True) A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(self.d_inner)) self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) # Heavy regularization self.dropout = nn.Dropout(0.3) def forward(self, x): batch, length, d_model = x.shape xz = self.in_proj(x) x, z = xz.chunk(2, dim=-1) x = x.transpose(1, 2) x = self.conv1d(x)[:, :, :length] x = x.transpose(1, 2) x = self.act(x) x = self.dropout(x) y = self.selective_scan(x) y = y * self.act(z) return self.dropout(self.out_proj(y)) def selective_scan(self, x): batch, length, d_inner = x.shape deltaBC = self.x_proj(x) delta, B, C = torch.split(deltaBC, [1, self.d_state, self.d_state], dim=-1) delta = F.softplus(self.dt_proj(delta)) deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log))) deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) states = torch.zeros(batch, d_inner, self.d_state, device=x.device) outputs = [] for i in range(length): states = deltaA[:, i] * states + deltaB[:, i] * x[:, i, :, None] y = (states @ C[:, i, :, None]).squeeze(-1) + self.D * x[:, i] outputs.append(y) return torch.stack(outputs, dim=1) class GraphDataAugmentation: """Data augmentation to combat overfitting""" @staticmethod def augment_features(x, noise_level=0.1, dropout_prob=0.2): if x.size(0) == 0: return x # Feature noise noise = torch.randn_like(x) * noise_level x_aug = x + noise # Feature dropout mask = torch.rand(x.shape[0], x.shape[1], device=x.device) > dropout_prob x_aug = x_aug * mask.float() return x_aug @staticmethod def augment_edges(edge_index, drop_prob=0.1): if edge_index.size(1) == 0: return edge_index # Edge dropout edge_mask = torch.rand(edge_index.size(1), device=edge_index.device) > drop_prob return edge_index[:, edge_mask] class LightStructuralEncoding(nn.Module): """Lightweight structural encoding""" def __init__(self, d_model, max_degree=50): super().__init__() self.degree_encoding = nn.Embedding(max_degree, d_model) self.layer_norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index): num_nodes = x.size(0) # Only degree encoding (simpler) degrees = degree(edge_index[0], num_nodes).long().clamp(max=49) degree_emb = self.degree_encoding(degrees) # Combine with heavy dropout combined = self.layer_norm(x + degree_emb) return self.dropout(combined) class GraphMamba(nn.Module): """Heavily regularized GraphMamba to prevent overfitting""" def __init__(self, config): super().__init__() self.config = config d_model = config['model']['d_model'] # Should be 64 n_layers = config['model']['n_layers'] # Should be 2 input_dim = config.get('input_dim', 1433) # Minimal architecture self.input_proj = nn.Linear(input_dim, d_model) self.input_dropout = nn.Dropout(0.5) # Light structural encoding self.structural_encoding = LightStructuralEncoding(d_model) # Minimal Mamba layers self.mamba_layers = nn.ModuleList([ MambaBlock(d_model, d_state=4) for _ in range(n_layers) ]) # Layer norms with dropout self.layer_norms = nn.ModuleList([ nn.LayerNorm(d_model) for _ in range(n_layers) ]) self.hidden_dropout = nn.Dropout(0.5) self.output_dropout = nn.Dropout(0.3) # Simple output self.output_proj = nn.Linear(d_model, d_model) # Data augmentation self.augmentation = GraphDataAugmentation() # Classifier will be added later self.classifier = None def forward(self, x, edge_index, batch=None): # Apply data augmentation during training if self.training: x = self.augmentation.augment_features(x) edge_index = self.augmentation.augment_edges(edge_index) # Input projection with dropout h = self.input_dropout(self.input_proj(x)) # Add minimal structural information h = self.structural_encoding(h, edge_index) # Simple BFS ordering only order = torch.arange(h.size(0), device=h.device) h_ordered = h[order].unsqueeze(0) # Process through minimal Mamba layers for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)): residual = h_ordered h_ordered = ln(h_ordered) h_ordered = residual + mamba(h_ordered) h_ordered = self.hidden_dropout(h_ordered) # Restore order and final processing h_restored = h_ordered.squeeze(0) h_out = self.output_dropout(self.output_proj(h_restored)) return h_out def _init_classifier(self, num_classes, device): """Initialize heavily regularized classifier""" if self.classifier is None: self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(self.config['model']['d_model'], num_classes) ).to(device) def get_performance_stats(self): """Get model statistics""" total_params = sum(p.numel() for p in self.parameters()) return { 'total_params': total_params, 'device': next(self.parameters()).device, 'dtype': next(self.parameters()).dtype, 'model_size': f"{total_params/1000:.1f}K parameters" } def create_regularized_config(): """Create config optimized for small training sets""" return { 'model': { 'd_model': 64, # Reduced from 128 'd_state': 4, # Reduced from 8 'd_conv': 4, 'expand': 2, 'n_layers': 2, # Reduced from 3 'dropout': 0.5 # Increased from 0.1 }, 'data': { 'batch_size': 1, # Full batch for small datasets 'test_split': 0.2 }, 'training': { 'learning_rate': 0.0005, # Reduced from 0.001 'weight_decay': 0.01, # High regularization 'epochs': 200, 'patience': 10, # More patient early stopping 'warmup_epochs': 10, 'min_lr': 1e-6 }, 'ordering': { 'strategy': 'bfs', # Simple strategy only 'preserve_locality': True }, 'input_dim': 1433 }