serpent / core /graph_mamba.py
kfoughali's picture
Update core/graph_mamba.py
e4d5cc2 verified
raw
history blame
9.13 kB
import torch
import torch.nn as nn
from .mamba_block import MambaBlock
from .graph_sequencer import GraphSequencer, PositionalEncoder
class GraphMamba(nn.Module):
"""
Production Graph-Mamba model with training optimizations
"""
def __init__(self, config):
super().__init__()
self.config = config
self.d_model = config['model']['d_model']
self.n_layers = config['model']['n_layers']
self.dropout = config['model']['dropout']
self.ordering_strategy = config['ordering']['strategy']
# Input projection (dynamic input dimension)
self.input_proj = None
# Positional encoding
self.pos_encoder = PositionalEncoder()
self.pos_embed = nn.Linear(11, self.d_model)
# Mamba layers with residual connections
self.mamba_layers = nn.ModuleList([
MambaBlock(
d_model=self.d_model,
d_state=config['model']['d_state'],
d_conv=config['model']['d_conv'],
expand=config['model']['expand']
)
for _ in range(self.n_layers)
])
# Layer norms
self.layer_norms = nn.ModuleList([
nn.LayerNorm(self.d_model)
for _ in range(self.n_layers)
])
# Dropout
self.dropout_layer = nn.Dropout(self.dropout)
# Graph sequencer
self.sequencer = GraphSequencer()
# Classification head (initialized dynamically)
self.classifier = None
# Cache for efficiency
self._cache = {}
def _init_input_proj(self, input_dim, device):
"""Initialize input projection dynamically"""
if self.input_proj is None:
self.input_proj = nn.Sequential(
nn.Linear(input_dim, self.d_model),
nn.LayerNorm(self.d_model),
nn.ReLU(),
nn.Dropout(self.dropout * 0.5)
).to(device)
def _init_classifier(self, num_classes, device):
"""Initialize classifier dynamically"""
if self.classifier is None:
self.classifier = nn.Sequential(
nn.Linear(self.d_model, self.d_model // 2),
nn.LayerNorm(self.d_model // 2),
nn.ReLU(),
nn.Dropout(self.dropout),
nn.Linear(self.d_model // 2, num_classes)
).to(device)
def forward(self, x, edge_index, batch=None):
"""
Forward pass with training optimizations
"""
num_nodes = x.size(0)
input_dim = x.size(1)
device = x.device
# Move all components to correct device
self.to(device)
# Initialize input projection if needed
self._init_input_proj(input_dim, device)
# Project input features
h = self.input_proj(x) # (num_nodes, d_model)
if batch is None:
# Single graph processing
h = self._process_single_graph(h, edge_index)
else:
# Batch processing
h = self._process_batch(h, edge_index, batch)
return h
def _process_single_graph(self, h, edge_index):
"""Process a single graph with caching"""
num_nodes = h.size(0)
device = h.device
# Ensure edge_index is on correct device
edge_index = edge_index.to(device)
# Cache key for ordering
cache_key = f"{self.ordering_strategy}_{num_nodes}_{edge_index.shape[1]}"
# Get ordering (with caching during training)
if cache_key not in self._cache or not self.training:
if self.ordering_strategy == "spectral":
order = self.sequencer.spectral_ordering(edge_index, num_nodes)
elif self.ordering_strategy == "degree":
order = self.sequencer.degree_ordering(edge_index, num_nodes)
elif self.ordering_strategy == "community":
order = self.sequencer.community_ordering(edge_index, num_nodes)
else: # default to BFS
order = self.sequencer.bfs_ordering(edge_index, num_nodes)
if self.training:
self._cache[cache_key] = order
else:
order = self._cache[cache_key]
# Ensure order is on correct device
order = order.to(device)
# Add positional encoding
seq_pos, distances = self.pos_encoder.encode_positions(h, edge_index, order)
seq_pos = seq_pos.to(device)
distances = distances.to(device)
pos_features = torch.cat([seq_pos, distances], dim=1) # (num_nodes, 11)
pos_embed = self.pos_embed(pos_features)
# Reorder nodes for sequential processing
h_ordered = h[order] + pos_embed[order] # Add positional encoding
h_ordered = h_ordered.unsqueeze(0) # (1, num_nodes, d_model)
# Process through Mamba layers with residual connections
for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)):
# Pre-norm residual connection with gradient scaling
residual = h_ordered
h_ordered = ln(h_ordered)
h_ordered = mamba(h_ordered)
h_ordered = residual + self.dropout_layer(h_ordered)
# Layer-wise learning rate scaling
if self.training:
h_ordered = h_ordered * (1.0 - 0.1 * i / self.n_layers)
# Restore original order
h_out = h_ordered.squeeze(0) # (num_nodes, d_model)
# Create inverse mapping
inverse_order = torch.argsort(order)
h_final = h_out[inverse_order]
return h_final
def _process_batch(self, h, edge_index, batch):
"""Process batched graphs efficiently"""
device = h.device
batch = batch.to(device)
edge_index = edge_index.to(device)
batch_size = batch.max().item() + 1
outputs = []
for b in range(batch_size):
# Extract subgraph
mask = batch == b
batch_h = h[mask]
# Get edges for this graph
edge_mask = mask[edge_index[0]] & mask[edge_index[1]]
batch_edges = edge_index[:, edge_mask]
if batch_edges.shape[1] > 0:
# Reindex edges to local indices
node_indices = torch.where(mask)[0]
node_map = torch.zeros(h.size(0), dtype=torch.long, device=device)
node_map[node_indices] = torch.arange(batch_h.size(0), device=device)
batch_edges_local = node_map[batch_edges]
else:
# Empty graph
batch_edges_local = torch.empty((2, 0), dtype=torch.long, device=device)
# Process subgraph
batch_output = self._process_single_graph(batch_h, batch_edges_local)
outputs.append(batch_output)
# Reconstruct full batch
h_out = torch.zeros_like(h)
for b, output in enumerate(outputs):
mask = batch == b
h_out[mask] = output
return h_out
def get_graph_embedding(self, h, batch=None):
"""Get graph-level representation with multiple pooling"""
if batch is None:
# Single graph - multiple pooling strategies
mean_pool = h.mean(dim=0, keepdim=True)
max_pool = h.max(dim=0)[0].unsqueeze(0)
# Attention pooling
attn_weights = torch.softmax(h.sum(dim=1), dim=0)
attn_pool = (h * attn_weights.unsqueeze(1)).sum(dim=0, keepdim=True)
return torch.cat([mean_pool, max_pool, attn_pool], dim=1)
else:
# Batched graphs
device = h.device
batch = batch.to(device)
batch_size = batch.max().item() + 1
graph_embeddings = []
for b in range(batch_size):
mask = batch == b
if mask.any():
batch_h = h[mask]
# Multiple pooling for this graph
mean_pool = batch_h.mean(dim=0)
max_pool = batch_h.max(dim=0)[0]
attn_weights = torch.softmax(batch_h.sum(dim=1), dim=0)
attn_pool = (batch_h * attn_weights.unsqueeze(1)).sum(dim=0)
graph_emb = torch.cat([mean_pool, max_pool, attn_pool])
graph_embeddings.append(graph_emb)
else:
# Empty graph
graph_embeddings.append(torch.zeros(h.size(1) * 3, device=device))
return torch.stack(graph_embeddings)
def clear_cache(self):
"""Clear ordering cache"""
self._cache.clear()