import torch import torch.nn as nn from .mamba_block import MambaBlock from .graph_sequencer import GraphSequencer, PositionalEncoder class GraphMamba(nn.Module): """ Production Graph-Mamba model with training optimizations """ def __init__(self, config): super().__init__() self.config = config self.d_model = config['model']['d_model'] self.n_layers = config['model']['n_layers'] self.dropout = config['model']['dropout'] self.ordering_strategy = config['ordering']['strategy'] # Input projection (dynamic input dimension) self.input_proj = None # Positional encoding self.pos_encoder = PositionalEncoder() self.pos_embed = nn.Linear(11, self.d_model) # Mamba layers with residual connections self.mamba_layers = nn.ModuleList([ MambaBlock( d_model=self.d_model, d_state=config['model']['d_state'], d_conv=config['model']['d_conv'], expand=config['model']['expand'] ) for _ in range(self.n_layers) ]) # Layer norms self.layer_norms = nn.ModuleList([ nn.LayerNorm(self.d_model) for _ in range(self.n_layers) ]) # Dropout self.dropout_layer = nn.Dropout(self.dropout) # Graph sequencer self.sequencer = GraphSequencer() # Classification head (initialized dynamically) self.classifier = None # Cache for efficiency self._cache = {} def _init_input_proj(self, input_dim, device): """Initialize input projection dynamically""" if self.input_proj is None: self.input_proj = nn.Sequential( nn.Linear(input_dim, self.d_model), nn.LayerNorm(self.d_model), nn.ReLU(), nn.Dropout(self.dropout * 0.5) ).to(device) def _init_classifier(self, num_classes, device): """Initialize classifier dynamically""" if self.classifier is None: self.classifier = nn.Sequential( nn.Linear(self.d_model, self.d_model // 2), nn.LayerNorm(self.d_model // 2), nn.ReLU(), nn.Dropout(self.dropout), nn.Linear(self.d_model // 2, num_classes) ).to(device) def forward(self, x, edge_index, batch=None): """ Forward pass with training optimizations """ num_nodes = x.size(0) input_dim = x.size(1) device = x.device # Move all components to correct device self.to(device) # Initialize input projection if needed self._init_input_proj(input_dim, device) # Project input features h = self.input_proj(x) # (num_nodes, d_model) if batch is None: # Single graph processing h = self._process_single_graph(h, edge_index) else: # Batch processing h = self._process_batch(h, edge_index, batch) return h def _process_single_graph(self, h, edge_index): """Process a single graph with caching""" num_nodes = h.size(0) device = h.device # Ensure edge_index is on correct device edge_index = edge_index.to(device) # Cache key for ordering cache_key = f"{self.ordering_strategy}_{num_nodes}_{edge_index.shape[1]}" # Get ordering (with caching during training) if cache_key not in self._cache or not self.training: if self.ordering_strategy == "spectral": order = self.sequencer.spectral_ordering(edge_index, num_nodes) elif self.ordering_strategy == "degree": order = self.sequencer.degree_ordering(edge_index, num_nodes) elif self.ordering_strategy == "community": order = self.sequencer.community_ordering(edge_index, num_nodes) else: # default to BFS order = self.sequencer.bfs_ordering(edge_index, num_nodes) if self.training: self._cache[cache_key] = order else: order = self._cache[cache_key] # Ensure order is on correct device order = order.to(device) # Add positional encoding seq_pos, distances = self.pos_encoder.encode_positions(h, edge_index, order) seq_pos = seq_pos.to(device) distances = distances.to(device) pos_features = torch.cat([seq_pos, distances], dim=1) # (num_nodes, 11) pos_embed = self.pos_embed(pos_features) # Reorder nodes for sequential processing h_ordered = h[order] + pos_embed[order] # Add positional encoding h_ordered = h_ordered.unsqueeze(0) # (1, num_nodes, d_model) # Process through Mamba layers with residual connections for i, (mamba, ln) in enumerate(zip(self.mamba_layers, self.layer_norms)): # Pre-norm residual connection with gradient scaling residual = h_ordered h_ordered = ln(h_ordered) h_ordered = mamba(h_ordered) h_ordered = residual + self.dropout_layer(h_ordered) # Layer-wise learning rate scaling if self.training: h_ordered = h_ordered * (1.0 - 0.1 * i / self.n_layers) # Restore original order h_out = h_ordered.squeeze(0) # (num_nodes, d_model) # Create inverse mapping inverse_order = torch.argsort(order) h_final = h_out[inverse_order] return h_final def _process_batch(self, h, edge_index, batch): """Process batched graphs efficiently""" device = h.device batch = batch.to(device) edge_index = edge_index.to(device) batch_size = batch.max().item() + 1 outputs = [] for b in range(batch_size): # Extract subgraph mask = batch == b batch_h = h[mask] # Get edges for this graph edge_mask = mask[edge_index[0]] & mask[edge_index[1]] batch_edges = edge_index[:, edge_mask] if batch_edges.shape[1] > 0: # Reindex edges to local indices node_indices = torch.where(mask)[0] node_map = torch.zeros(h.size(0), dtype=torch.long, device=device) node_map[node_indices] = torch.arange(batch_h.size(0), device=device) batch_edges_local = node_map[batch_edges] else: # Empty graph batch_edges_local = torch.empty((2, 0), dtype=torch.long, device=device) # Process subgraph batch_output = self._process_single_graph(batch_h, batch_edges_local) outputs.append(batch_output) # Reconstruct full batch h_out = torch.zeros_like(h) for b, output in enumerate(outputs): mask = batch == b h_out[mask] = output return h_out def get_graph_embedding(self, h, batch=None): """Get graph-level representation with multiple pooling""" if batch is None: # Single graph - multiple pooling strategies mean_pool = h.mean(dim=0, keepdim=True) max_pool = h.max(dim=0)[0].unsqueeze(0) # Attention pooling attn_weights = torch.softmax(h.sum(dim=1), dim=0) attn_pool = (h * attn_weights.unsqueeze(1)).sum(dim=0, keepdim=True) return torch.cat([mean_pool, max_pool, attn_pool], dim=1) else: # Batched graphs device = h.device batch = batch.to(device) batch_size = batch.max().item() + 1 graph_embeddings = [] for b in range(batch_size): mask = batch == b if mask.any(): batch_h = h[mask] # Multiple pooling for this graph mean_pool = batch_h.mean(dim=0) max_pool = batch_h.max(dim=0)[0] attn_weights = torch.softmax(batch_h.sum(dim=1), dim=0) attn_pool = (batch_h * attn_weights.unsqueeze(1)).sum(dim=0) graph_emb = torch.cat([mean_pool, max_pool, attn_pool]) graph_embeddings.append(graph_emb) else: # Empty graph graph_embeddings.append(torch.zeros(h.size(1) * 3, device=device)) return torch.stack(graph_embeddings) def clear_cache(self): """Clear ordering cache""" self._cache.clear()