File size: 7,955 Bytes
c681cda b09f924 a7a0326 b09f924 c681cda b09f924 a7a0326 b09f924 a7a0326 b09f924 a7a0326 b09f924 a7a0326 b09f924 a7a0326 c681cda b09f924 a7a0326 b09f924 a7a0326 b09f924 a7a0326 b09f924 a7a0326 b09f924 a7a0326 b09f924 a7a0326 b09f924 a7a0326 c681cda a7a0326 b09f924 a7a0326 c681cda a7a0326 c681cda a7a0326 c681cda a7a0326 c681cda a7a0326 c681cda a7a0326 c681cda a7a0326 b09f924 c681cda a7a0326 1bdb453 b09f924 a7a0326 1bdb453 a7a0326 1bdb453 a7a0326 c681cda a7a0326 b09f924 c681cda a7a0326 e4d5cc2 a7a0326 c681cda a7a0326 c681cda a7a0326 b09f924 a7a0326 b09f924 a7a0326 c681cda b09f924 a7a0326 b09f924 a7a0326 b09f924 a7a0326 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import degree
import networkx as nx
import logging
logger = logging.getLogger(__name__)
class MambaBlock(nn.Module):
"""Heavily regularized Mamba block"""
def __init__(self, d_model, d_state=4, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_inner = int(expand * d_model)
self.d_state = d_state
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, groups=self.d_inner, padding=d_conv-1)
self.act = nn.SiLU()
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
# Heavy regularization
self.dropout = nn.Dropout(0.3)
def forward(self, x):
batch, length, d_model = x.shape
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
x = x.transpose(1, 2)
x = self.conv1d(x)[:, :, :length]
x = x.transpose(1, 2)
x = self.act(x)
x = self.dropout(x)
y = self.selective_scan(x)
y = y * self.act(z)
return self.dropout(self.out_proj(y))
def selective_scan(self, x):
batch, length, d_inner = x.shape
deltaBC = self.x_proj(x)
delta, B, C = torch.split(deltaBC, [1, self.d_state, self.d_state], dim=-1)
delta = F.softplus(self.dt_proj(delta))
deltaA = torch.exp(delta.unsqueeze(-1) * (-torch.exp(self.A_log)))
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
states = torch.zeros(batch, d_inner, self.d_state, device=x.device)
outputs = []
for i in range(length):
states = deltaA[:, i] * states + deltaB[:, i] * x[:, i, :, None]
y = (states @ C[:, i, :, None]).squeeze(-1) + self.D * x[:, i]
outputs.append(y)
return torch.stack(outputs, dim=1)
class GraphDataAugmentation:
"""Data augmentation to combat overfitting"""
@staticmethod
def augment_features(x, noise_level=0.1, dropout_prob=0.2):
if x.size(0) == 0:
return x
# Feature noise
noise = torch.randn_like(x) * noise_level
x_aug = x + noise
# Feature dropout
mask = torch.rand(x.shape[0], x.shape[1], device=x.device) > dropout_prob
x_aug = x_aug * mask.float()
return x_aug
@staticmethod
def augment_edges(edge_index, drop_prob=0.1):
if edge_index.size(1) == 0:
return edge_index
# Edge dropout
edge_mask = torch.rand(edge_index.size(1), device=edge_index.device) > drop_prob
return edge_index[:, edge_mask]
class LightStructuralEncoding(nn.Module):
"""Lightweight structural encoding"""
def __init__(self, d_model, max_degree=50):
super().__init__()
self.degree_encoding = nn.Embedding(max_degree, d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.5)
def forward(self, x, edge_index):
num_nodes = x.size(0)
# Only degree encoding (simpler)
degrees = degree(edge_index[0], num_nodes).long().clamp(max=49)
degree_emb = self.degree_encoding(degrees)
# Combine with heavy dropout
combined = self.layer_norm(x + degree_emb)
return self.dropout(combined)
class GraphMamba(nn.Module):
"""Heavily regularized GraphMamba to prevent overfitting"""
def __init__(self, config):
super().__init__()
self.config = config
d_model = config['model']['d_model'] # Should be 64
n_layers = config['model']['n_layers'] # Should be 2
input_dim = config.get('input_dim', 1433)
# Minimal architecture
self.input_proj = nn.Linear(input_dim, d_model)
self.input_dropout = nn.Dropout(0.5)
# Light structural encoding
self.structural_encoding = LightStructuralEncoding(d_model)
# Minimal Mamba layers
self.mamba_layers = nn.ModuleList([
MambaBlock(d_model, d_state=4) for _ in range(n_layers)
])
# Layer norms with dropout
self.layer_norms = nn.ModuleList([
nn.LayerNorm(d_model) for _ in range(n_layers)
])
self.hidden_dropout = nn.Dropout(0.5)
self.output_dropout = nn.Dropout(0.3)
# Simple output
self.output_proj = nn.Linear(d_model, d_model)
# Data augmentation
self.augmentation = GraphDataAugmentation()
# Classifier will be added later
self.classifier = None
def forward(self, x, edge_index, batch=None):
# Apply data augmentation during training
if self.training:
x = self.augmentation.augment_features(x)
edge_index = self.augmentation.augment_edges(edge_index)
# Input projection with dropout
h = self.input_dropout(self.input_proj(x))
# Add minimal structural information
h = self.structural_encoding(h, edge_index)
# Simple BFS ordering only
order = torch.arange(h.size(0), device=h.device)
h_ordered = h[order].unsqueeze(0)
# Process through minimal Mamba layers
for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)):
residual = h_ordered
h_ordered = ln(h_ordered)
h_ordered = residual + mamba(h_ordered)
h_ordered = self.hidden_dropout(h_ordered)
# Restore order and final processing
h_restored = h_ordered.squeeze(0)
h_out = self.output_dropout(self.output_proj(h_restored))
return h_out
def _init_classifier(self, num_classes, device):
"""Initialize heavily regularized classifier"""
if self.classifier is None:
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(self.config['model']['d_model'], num_classes)
).to(device)
def get_performance_stats(self):
"""Get model statistics"""
total_params = sum(p.numel() for p in self.parameters())
return {
'total_params': total_params,
'device': next(self.parameters()).device,
'dtype': next(self.parameters()).dtype,
'model_size': f"{total_params/1000:.1f}K parameters"
}
def create_regularized_config():
"""Create config optimized for small training sets"""
return {
'model': {
'd_model': 64, # Reduced from 128
'd_state': 4, # Reduced from 8
'd_conv': 4,
'expand': 2,
'n_layers': 2, # Reduced from 3
'dropout': 0.5 # Increased from 0.1
},
'data': {
'batch_size': 1, # Full batch for small datasets
'test_split': 0.2
},
'training': {
'learning_rate': 0.0005, # Reduced from 0.001
'weight_decay': 0.01, # High regularization
'epochs': 200,
'patience': 10, # More patient early stopping
'warmup_epochs': 10,
'min_lr': 1e-6
},
'ordering': {
'strategy': 'bfs', # Simple strategy only
'preserve_locality': True
},
'input_dim': 1433
} |