|
import torch |
|
import torch.nn as nn |
|
from .mamba_block import MambaBlock |
|
from .graph_sequencer import GraphSequencer, PositionalEncoder |
|
|
|
class GraphMamba(nn.Module): |
|
""" |
|
Production Graph-Mamba model with training optimizations |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.config = config |
|
self.d_model = config['model']['d_model'] |
|
self.n_layers = config['model']['n_layers'] |
|
self.dropout = config['model']['dropout'] |
|
self.ordering_strategy = config['ordering']['strategy'] |
|
|
|
|
|
self.input_proj = None |
|
|
|
|
|
self.pos_encoder = PositionalEncoder() |
|
self.pos_embed = nn.Linear(11, self.d_model) |
|
|
|
|
|
self.mamba_layers = nn.ModuleList([ |
|
MambaBlock( |
|
d_model=self.d_model, |
|
d_state=config['model']['d_state'], |
|
d_conv=config['model']['d_conv'], |
|
expand=config['model']['expand'] |
|
) |
|
for _ in range(self.n_layers) |
|
]) |
|
|
|
|
|
self.layer_norms = nn.ModuleList([ |
|
nn.LayerNorm(self.d_model) |
|
for _ in range(self.n_layers) |
|
]) |
|
|
|
|
|
self.dropout_layer = nn.Dropout(self.dropout) |
|
|
|
|
|
self.sequencer = GraphSequencer() |
|
|
|
|
|
self.classifier = None |
|
|
|
|
|
self._cache = {} |
|
|
|
def _init_input_proj(self, input_dim, device): |
|
"""Initialize input projection dynamically""" |
|
if self.input_proj is None: |
|
self.input_proj = nn.Sequential( |
|
nn.Linear(input_dim, self.d_model), |
|
nn.LayerNorm(self.d_model), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout * 0.5) |
|
).to(device) |
|
|
|
def _init_classifier(self, num_classes, device): |
|
"""Initialize classifier dynamically""" |
|
if self.classifier is None: |
|
self.classifier = nn.Sequential( |
|
nn.Linear(self.d_model, self.d_model // 2), |
|
nn.LayerNorm(self.d_model // 2), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout), |
|
nn.Linear(self.d_model // 2, num_classes) |
|
).to(device) |
|
|
|
def forward(self, x, edge_index, batch=None): |
|
""" |
|
Forward pass with training optimizations |
|
""" |
|
num_nodes = x.size(0) |
|
input_dim = x.size(1) |
|
device = x.device |
|
|
|
|
|
self.to(device) |
|
|
|
|
|
self._init_input_proj(input_dim, device) |
|
|
|
|
|
h = self.input_proj(x) |
|
|
|
if batch is None: |
|
|
|
h = self._process_single_graph(h, edge_index) |
|
else: |
|
|
|
h = self._process_batch(h, edge_index, batch) |
|
|
|
return h |
|
|
|
def _process_single_graph(self, h, edge_index): |
|
"""Process a single graph with caching""" |
|
num_nodes = h.size(0) |
|
device = h.device |
|
|
|
|
|
edge_index = edge_index.to(device) |
|
|
|
|
|
cache_key = f"{self.ordering_strategy}_{num_nodes}_{edge_index.shape[1]}" |
|
|
|
|
|
if cache_key not in self._cache or not self.training: |
|
if self.ordering_strategy == "spectral": |
|
order = self.sequencer.spectral_ordering(edge_index, num_nodes) |
|
elif self.ordering_strategy == "degree": |
|
order = self.sequencer.degree_ordering(edge_index, num_nodes) |
|
elif self.ordering_strategy == "community": |
|
order = self.sequencer.community_ordering(edge_index, num_nodes) |
|
else: |
|
order = self.sequencer.bfs_ordering(edge_index, num_nodes) |
|
|
|
if self.training: |
|
self._cache[cache_key] = order |
|
else: |
|
order = self._cache[cache_key] |
|
|
|
|
|
order = order.to(device) |
|
|
|
|
|
seq_pos, distances = self.pos_encoder.encode_positions(h, edge_index, order) |
|
seq_pos = seq_pos.to(device) |
|
distances = distances.to(device) |
|
|
|
pos_features = torch.cat([seq_pos, distances], dim=1) |
|
pos_embed = self.pos_embed(pos_features) |
|
|
|
|
|
h_ordered = h[order] + pos_embed[order] |
|
h_ordered = h_ordered.unsqueeze(0) |
|
|
|
|
|
for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)): |
|
|
|
residual = h_ordered |
|
h_ordered = ln(h_ordered) |
|
h_ordered = mamba(h_ordered) |
|
h_ordered = residual + self.dropout_layer(h_ordered) |
|
|
|
|
|
if self.training: |
|
h_ordered = h_ordered * (1.0 - 0.1 * i / self.n_layers) |
|
|
|
|
|
h_out = h_ordered.squeeze(0) |
|
|
|
|
|
inverse_order = torch.argsort(order) |
|
h_final = h_out[inverse_order] |
|
|
|
return h_final |
|
|
|
def _process_batch(self, h, edge_index, batch): |
|
"""Process batched graphs efficiently""" |
|
device = h.device |
|
batch = batch.to(device) |
|
edge_index = edge_index.to(device) |
|
|
|
batch_size = batch.max().item() + 1 |
|
outputs = [] |
|
|
|
for b in range(batch_size): |
|
|
|
mask = batch == b |
|
batch_h = h[mask] |
|
|
|
|
|
edge_mask = mask[edge_index[0]] & mask[edge_index[1]] |
|
batch_edges = edge_index[:, edge_mask] |
|
|
|
if batch_edges.shape[1] > 0: |
|
|
|
node_indices = torch.where(mask)[0] |
|
node_map = torch.zeros(h.size(0), dtype=torch.long, device=device) |
|
node_map[node_indices] = torch.arange(batch_h.size(0), device=device) |
|
batch_edges_local = node_map[batch_edges] |
|
else: |
|
|
|
batch_edges_local = torch.empty((2, 0), dtype=torch.long, device=device) |
|
|
|
|
|
batch_output = self._process_single_graph(batch_h, batch_edges_local) |
|
outputs.append(batch_output) |
|
|
|
|
|
h_out = torch.zeros_like(h) |
|
for b, output in enumerate(outputs): |
|
mask = batch == b |
|
h_out[mask] = output |
|
|
|
return h_out |
|
|
|
def get_graph_embedding(self, h, batch=None): |
|
"""Get graph-level representation with multiple pooling""" |
|
if batch is None: |
|
|
|
mean_pool = h.mean(dim=0, keepdim=True) |
|
max_pool = h.max(dim=0)[0].unsqueeze(0) |
|
|
|
|
|
attn_weights = torch.softmax(h.sum(dim=1), dim=0) |
|
attn_pool = (h * attn_weights.unsqueeze(1)).sum(dim=0, keepdim=True) |
|
|
|
return torch.cat([mean_pool, max_pool, attn_pool], dim=1) |
|
else: |
|
|
|
device = h.device |
|
batch = batch.to(device) |
|
batch_size = batch.max().item() + 1 |
|
|
|
graph_embeddings = [] |
|
for b in range(batch_size): |
|
mask = batch == b |
|
if mask.any(): |
|
batch_h = h[mask] |
|
|
|
|
|
mean_pool = batch_h.mean(dim=0) |
|
max_pool = batch_h.max(dim=0)[0] |
|
|
|
attn_weights = torch.softmax(batch_h.sum(dim=1), dim=0) |
|
attn_pool = (batch_h * attn_weights.unsqueeze(1)).sum(dim=0) |
|
|
|
graph_emb = torch.cat([mean_pool, max_pool, attn_pool]) |
|
graph_embeddings.append(graph_emb) |
|
else: |
|
|
|
graph_embeddings.append(torch.zeros(h.size(1) * 3, device=device)) |
|
|
|
return torch.stack(graph_embeddings) |
|
|
|
def clear_cache(self): |
|
"""Clear ordering cache""" |
|
self._cache.clear() |