File size: 7,344 Bytes
f3d5bea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import torch
import numpy as np
import networkx as nx
from scipy.sparse.linalg import eigsh
from sklearn.cluster import SpectralClustering
from torch_geometric.utils import to_networkx, get_laplacian
import torch_geometric.utils as pyg_utils
class GraphSequencer:
"""
Production-ready graph ordering strategies
All methods use real graph data - no hardcoded values
"""
@staticmethod
def bfs_ordering(edge_index, num_nodes, start_node=None):
"""Breadth-first search ordering"""
# Convert to NetworkX for BFS
G = nx.Graph()
G.add_nodes_from(range(num_nodes))
edge_list = edge_index.t().cpu().numpy()
G.add_edges_from(edge_list)
# Start from highest degree node if not specified
if start_node is None:
degrees = dict(G.degree())
start_node = max(degrees, key=degrees.get)
# BFS traversal
visited = set()
order = []
queue = [start_node]
while queue:
node = queue.pop(0)
if node in visited:
continue
visited.add(node)
order.append(node)
# Add neighbors by degree (deterministic)
neighbors = list(G.neighbors(node))
neighbors.sort(key=lambda n: G.degree(n), reverse=True)
for neighbor in neighbors:
if neighbor not in visited:
queue.append(neighbor)
# Add any disconnected nodes
for node in range(num_nodes):
if node not in visited:
order.append(node)
return torch.tensor(order, dtype=torch.long)
@staticmethod
def spectral_ordering(edge_index, num_nodes):
"""Spectral ordering using graph Laplacian eigenvector"""
try:
# Compute normalized Laplacian
edge_index_np = edge_index.cpu().numpy()
# Create adjacency matrix
A = np.zeros((num_nodes, num_nodes))
A[edge_index_np[0], edge_index_np[1]] = 1
A[edge_index_np[1], edge_index_np[0]] = 1 # Undirected
# Degree matrix
D = np.diag(np.sum(A, axis=1))
# Normalized Laplacian: L = D^(-1/2) * (D - A) * D^(-1/2)
D_sqrt_inv = np.diag(1.0 / np.sqrt(np.maximum(np.diag(D), 1e-12)))
L = D_sqrt_inv @ (D - A) @ D_sqrt_inv
# Compute second smallest eigenvector (Fiedler vector)
eigenvals, eigenvecs = eigsh(L, k=min(10, num_nodes-1), which='SM')
fiedler_vector = eigenvecs[:, 1] # Second smallest
# Order by Fiedler vector values
order = np.argsort(fiedler_vector)
return torch.tensor(order, dtype=torch.long)
except Exception as e:
print(f"Spectral ordering failed: {e}, falling back to degree ordering")
return GraphSequencer.degree_ordering(edge_index, num_nodes)
@staticmethod
def degree_ordering(edge_index, num_nodes):
"""Order nodes by degree (high to low)"""
# Count degrees
degrees = torch.zeros(num_nodes, dtype=torch.long)
degrees.index_add_(0, edge_index[0], torch.ones(edge_index.shape[1], dtype=torch.long))
degrees.index_add_(0, edge_index[1], torch.ones(edge_index.shape[1], dtype=torch.long))
# Sort by degree (descending), then by node index for determinism
_, order = torch.sort(-degrees * num_nodes - torch.arange(num_nodes))
return order
@staticmethod
def community_ordering(edge_index, num_nodes, n_clusters=None):
"""Community-aware ordering using spectral clustering"""
try:
if n_clusters is None:
n_clusters = max(2, min(10, num_nodes // 100))
# Convert to adjacency matrix
edge_index_np = edge_index.cpu().numpy()
A = np.zeros((num_nodes, num_nodes))
A[edge_index_np[0], edge_index_np[1]] = 1
A[edge_index_np[1], edge_index_np[0]] = 1
# Spectral clustering
clustering = SpectralClustering(
n_clusters=n_clusters,
affinity='precomputed',
random_state=42
)
labels = clustering.fit_predict(A)
# Order by cluster, then by degree within cluster
degrees = np.sum(A, axis=1)
order = []
for cluster in range(n_clusters):
cluster_nodes = np.where(labels == cluster)[0]
cluster_degrees = degrees[cluster_nodes]
cluster_order = cluster_nodes[np.argsort(-cluster_degrees)]
order.extend(cluster_order)
return torch.tensor(order, dtype=torch.long)
except Exception as e:
print(f"Community ordering failed: {e}, falling back to BFS ordering")
return GraphSequencer.bfs_ordering(edge_index, num_nodes)
@staticmethod
def multi_view_ordering(edge_index, num_nodes):
"""Generate multiple orderings for different perspectives"""
orderings = {}
# Primary orderings
orderings['bfs'] = GraphSequencer.bfs_ordering(edge_index, num_nodes)
orderings['degree'] = GraphSequencer.degree_ordering(edge_index, num_nodes)
orderings['spectral'] = GraphSequencer.spectral_ordering(edge_index, num_nodes)
orderings['community'] = GraphSequencer.community_ordering(edge_index, num_nodes)
return orderings
class PositionalEncoder:
"""Graph-aware positional encoding"""
@staticmethod
def encode_positions(x, edge_index, order, max_dist=10):
"""
Create positional encodings that preserve graph structure
"""
num_nodes = x.size(0)
device = x.device
# Sequential positions
seq_pos = torch.zeros(num_nodes, device=device)
seq_pos[order] = torch.arange(num_nodes, device=device, dtype=torch.float)
# Graph distances (local neighborhood)
G = nx.Graph()
G.add_edges_from(edge_index.t().cpu().numpy())
# Compute shortest path distances
distances = torch.full((num_nodes, max_dist), float('inf'), device=device)
for i, node in enumerate(order):
# Get distances to previous nodes in sequence
start_idx = max(0, i - max_dist)
for j in range(start_idx, i):
prev_node = order[j].item()
try:
dist = nx.shortest_path_length(G, source=node.item(), target=prev_node)
distances[node, j - start_idx] = min(dist, max_dist - 1)
except nx.NetworkXNoPath:
distances[node, j - start_idx] = max_dist - 1
# Replace infinities with max distance
distances[distances == float('inf')] = max_dist - 1
# Normalize
seq_pos = seq_pos / num_nodes
distances = distances / max_dist
return seq_pos.unsqueeze(1), distances |