|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch_geometric.utils import degree |
|
import networkx as nx |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class MambaBlock(nn.Module): |
|
"""Heavily regularized Mamba block""" |
|
def __init__(self, d_model, d_state=4, d_conv=4, expand=2): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.d_inner = int(expand * d_model) |
|
self.d_state = d_state |
|
|
|
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) |
|
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1) |
|
self.act = nn.SiLU() |
|
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) |
|
self.dt_proj = nn.Linear(1, self.d_inner, bias=True) |
|
|
|
A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1) |
|
self.A_log = nn.Parameter(torch.log(A)) |
|
self.D = nn.Parameter(torch.ones(self.d_inner)) |
|
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) |
|
|
|
|
|
self.dropout = nn.Dropout(0.3) |
|
|
|
def forward(self, x): |
|
batch, length, d_model = x.shape |
|
xz = self.in_proj(x) |
|
x, z = xz.chunk(2, dim=-1) |
|
|
|
x = x.transpose(1, 2) |
|
x = self.conv1d(x)[:, :, :length] |
|
x = x.transpose(1, 2) |
|
x = self.act(x) |
|
x = self.dropout(x) |
|
|
|
y = self.selective_scan(x) |
|
y = y * self.act(z) |
|
return self.dropout(self.out_proj(y)) |
|
|
|
def selective_scan(self, x): |
|
batch, length, d_inner = x.shape |
|
deltaBC = self.x_proj(x) |
|
delta, B, C = torch.split(deltaBC, [1, self.d_state, self.d_state], dim=-1) |
|
delta = F.softplus(self.dt_proj(delta)) |
|
|
|
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log))) |
|
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) |
|
|
|
states = torch.zeros(batch, d_inner, self.d_state, device=x.device) |
|
outputs = [] |
|
|
|
for i in range(length): |
|
states = deltaA[:, i] * states + deltaB[:, i] * x[:, i, :, None] |
|
y = (states @ C[:, i, :, None]).squeeze(-1) + self.D * x[:, i] |
|
outputs.append(y) |
|
|
|
return torch.stack(outputs, dim=1) |
|
|
|
|
|
class GraphDataAugmentation: |
|
"""Data augmentation to combat overfitting""" |
|
|
|
@staticmethod |
|
def augment_features(x, noise_level=0.1, dropout_prob=0.2): |
|
if x.size(0) == 0: |
|
return x |
|
|
|
noise = torch.randn_like(x) * noise_level |
|
x_aug = x + noise |
|
|
|
|
|
mask = torch.rand(x.shape[0], x.shape[1], device=x.device) > dropout_prob |
|
x_aug = x_aug * mask.float() |
|
|
|
return x_aug |
|
|
|
@staticmethod |
|
def augment_edges(edge_index, drop_prob=0.1): |
|
if edge_index.size(1) == 0: |
|
return edge_index |
|
|
|
edge_mask = torch.rand(edge_index.size(1), device=edge_index.device) > drop_prob |
|
return edge_index[:, edge_mask] |
|
|
|
|
|
class LightStructuralEncoding(nn.Module): |
|
"""Lightweight structural encoding""" |
|
def __init__(self, d_model, max_degree=50): |
|
super().__init__() |
|
self.degree_encoding = nn.Embedding(max_degree, d_model) |
|
self.layer_norm = nn.LayerNorm(d_model) |
|
self.dropout = nn.Dropout(0.5) |
|
|
|
def forward(self, x, edge_index): |
|
num_nodes = x.size(0) |
|
|
|
|
|
degrees = degree(edge_index[0], num_nodes).long().clamp(max=49) |
|
degree_emb = self.degree_encoding(degrees) |
|
|
|
|
|
combined = self.layer_norm(x + degree_emb) |
|
return self.dropout(combined) |
|
|
|
|
|
class GraphMamba(nn.Module): |
|
"""Heavily regularized GraphMamba to prevent overfitting""" |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.config = config |
|
d_model = config['model']['d_model'] |
|
n_layers = config['model']['n_layers'] |
|
input_dim = config.get('input_dim', 1433) |
|
|
|
|
|
self.input_proj = nn.Linear(input_dim, d_model) |
|
self.input_dropout = nn.Dropout(0.5) |
|
|
|
|
|
self.structural_encoding = LightStructuralEncoding(d_model) |
|
|
|
|
|
self.mamba_layers = nn.ModuleList([ |
|
MambaBlock(d_model, d_state=4) for _ in range(n_layers) |
|
]) |
|
|
|
|
|
self.layer_norms = nn.ModuleList([ |
|
nn.LayerNorm(d_model) for _ in range(n_layers) |
|
]) |
|
|
|
self.hidden_dropout = nn.Dropout(0.5) |
|
self.output_dropout = nn.Dropout(0.3) |
|
|
|
|
|
self.output_proj = nn.Linear(d_model, d_model) |
|
|
|
|
|
self.augmentation = GraphDataAugmentation() |
|
|
|
|
|
self.classifier = None |
|
|
|
def forward(self, x, edge_index, batch=None): |
|
|
|
if self.training: |
|
x = self.augmentation.augment_features(x) |
|
edge_index = self.augmentation.augment_edges(edge_index) |
|
|
|
|
|
h = self.input_dropout(self.input_proj(x)) |
|
|
|
|
|
h = self.structural_encoding(h, edge_index) |
|
|
|
|
|
order = torch.arange(h.size(0), device=h.device) |
|
h_ordered = h[order].unsqueeze(0) |
|
|
|
|
|
for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)): |
|
residual = h_ordered |
|
h_ordered = ln(h_ordered) |
|
h_ordered = residual + mamba(h_ordered) |
|
h_ordered = self.hidden_dropout(h_ordered) |
|
|
|
|
|
h_restored = h_ordered.squeeze(0) |
|
h_out = self.output_dropout(self.output_proj(h_restored)) |
|
|
|
return h_out |
|
|
|
def _init_classifier(self, num_classes, device): |
|
"""Initialize heavily regularized classifier""" |
|
if self.classifier is None: |
|
self.classifier = nn.Sequential( |
|
nn.Dropout(0.5), |
|
nn.Linear(self.config['model']['d_model'], num_classes) |
|
).to(device) |
|
|
|
def get_performance_stats(self): |
|
"""Get model statistics""" |
|
total_params = sum(p.numel() for p in self.parameters()) |
|
return { |
|
'total_params': total_params, |
|
'device': next(self.parameters()).device, |
|
'dtype': next(self.parameters()).dtype, |
|
'model_size': f"{total_params/1000:.1f}K parameters" |
|
} |
|
|
|
|
|
def create_regularized_config(): |
|
"""Create config optimized for small training sets""" |
|
return { |
|
'model': { |
|
'd_model': 64, |
|
'd_state': 4, |
|
'd_conv': 4, |
|
'expand': 2, |
|
'n_layers': 2, |
|
'dropout': 0.5 |
|
}, |
|
'data': { |
|
'batch_size': 1, |
|
'test_split': 0.2 |
|
}, |
|
'training': { |
|
'learning_rate': 0.0005, |
|
'weight_decay': 0.01, |
|
'epochs': 200, |
|
'patience': 10, |
|
'warmup_epochs': 10, |
|
'min_lr': 1e-6 |
|
}, |
|
'ordering': { |
|
'strategy': 'bfs', |
|
'preserve_locality': True |
|
}, |
|
'input_dim': 1433 |
|
} |