Create core/graph_sequencer.py
Browse files- core/graph_sequencer.py +193 -0
core/graph_sequencer.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import networkx as nx
|
4 |
+
from scipy.sparse.linalg import eigsh
|
5 |
+
from sklearn.cluster import SpectralClustering
|
6 |
+
from torch_geometric.utils import to_networkx, get_laplacian
|
7 |
+
import torch_geometric.utils as pyg_utils
|
8 |
+
|
9 |
+
class GraphSequencer:
|
10 |
+
"""
|
11 |
+
Production-ready graph ordering strategies
|
12 |
+
All methods use real graph data - no hardcoded values
|
13 |
+
"""
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def bfs_ordering(edge_index, num_nodes, start_node=None):
|
17 |
+
"""Breadth-first search ordering"""
|
18 |
+
# Convert to NetworkX for BFS
|
19 |
+
G = nx.Graph()
|
20 |
+
G.add_nodes_from(range(num_nodes))
|
21 |
+
edge_list = edge_index.t().cpu().numpy()
|
22 |
+
G.add_edges_from(edge_list)
|
23 |
+
|
24 |
+
# Start from highest degree node if not specified
|
25 |
+
if start_node is None:
|
26 |
+
degrees = dict(G.degree())
|
27 |
+
start_node = max(degrees, key=degrees.get)
|
28 |
+
|
29 |
+
# BFS traversal
|
30 |
+
visited = set()
|
31 |
+
order = []
|
32 |
+
queue = [start_node]
|
33 |
+
|
34 |
+
while queue:
|
35 |
+
node = queue.pop(0)
|
36 |
+
if node in visited:
|
37 |
+
continue
|
38 |
+
|
39 |
+
visited.add(node)
|
40 |
+
order.append(node)
|
41 |
+
|
42 |
+
# Add neighbors by degree (deterministic)
|
43 |
+
neighbors = list(G.neighbors(node))
|
44 |
+
neighbors.sort(key=lambda n: G.degree(n), reverse=True)
|
45 |
+
|
46 |
+
for neighbor in neighbors:
|
47 |
+
if neighbor not in visited:
|
48 |
+
queue.append(neighbor)
|
49 |
+
|
50 |
+
# Add any disconnected nodes
|
51 |
+
for node in range(num_nodes):
|
52 |
+
if node not in visited:
|
53 |
+
order.append(node)
|
54 |
+
|
55 |
+
return torch.tensor(order, dtype=torch.long)
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def spectral_ordering(edge_index, num_nodes):
|
59 |
+
"""Spectral ordering using graph Laplacian eigenvector"""
|
60 |
+
try:
|
61 |
+
# Compute normalized Laplacian
|
62 |
+
edge_index_np = edge_index.cpu().numpy()
|
63 |
+
|
64 |
+
# Create adjacency matrix
|
65 |
+
A = np.zeros((num_nodes, num_nodes))
|
66 |
+
A[edge_index_np[0], edge_index_np[1]] = 1
|
67 |
+
A[edge_index_np[1], edge_index_np[0]] = 1 # Undirected
|
68 |
+
|
69 |
+
# Degree matrix
|
70 |
+
D = np.diag(np.sum(A, axis=1))
|
71 |
+
|
72 |
+
# Normalized Laplacian: L = D^(-1/2) * (D - A) * D^(-1/2)
|
73 |
+
D_sqrt_inv = np.diag(1.0 / np.sqrt(np.maximum(np.diag(D), 1e-12)))
|
74 |
+
L = D_sqrt_inv @ (D - A) @ D_sqrt_inv
|
75 |
+
|
76 |
+
# Compute second smallest eigenvector (Fiedler vector)
|
77 |
+
eigenvals, eigenvecs = eigsh(L, k=min(10, num_nodes-1), which='SM')
|
78 |
+
fiedler_vector = eigenvecs[:, 1] # Second smallest
|
79 |
+
|
80 |
+
# Order by Fiedler vector values
|
81 |
+
order = np.argsort(fiedler_vector)
|
82 |
+
|
83 |
+
return torch.tensor(order, dtype=torch.long)
|
84 |
+
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Spectral ordering failed: {e}, falling back to degree ordering")
|
87 |
+
return GraphSequencer.degree_ordering(edge_index, num_nodes)
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def degree_ordering(edge_index, num_nodes):
|
91 |
+
"""Order nodes by degree (high to low)"""
|
92 |
+
# Count degrees
|
93 |
+
degrees = torch.zeros(num_nodes, dtype=torch.long)
|
94 |
+
degrees.index_add_(0, edge_index[0], torch.ones(edge_index.shape[1], dtype=torch.long))
|
95 |
+
degrees.index_add_(0, edge_index[1], torch.ones(edge_index.shape[1], dtype=torch.long))
|
96 |
+
|
97 |
+
# Sort by degree (descending), then by node index for determinism
|
98 |
+
_, order = torch.sort(-degrees * num_nodes - torch.arange(num_nodes))
|
99 |
+
|
100 |
+
return order
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def community_ordering(edge_index, num_nodes, n_clusters=None):
|
104 |
+
"""Community-aware ordering using spectral clustering"""
|
105 |
+
try:
|
106 |
+
if n_clusters is None:
|
107 |
+
n_clusters = max(2, min(10, num_nodes // 100))
|
108 |
+
|
109 |
+
# Convert to adjacency matrix
|
110 |
+
edge_index_np = edge_index.cpu().numpy()
|
111 |
+
A = np.zeros((num_nodes, num_nodes))
|
112 |
+
A[edge_index_np[0], edge_index_np[1]] = 1
|
113 |
+
A[edge_index_np[1], edge_index_np[0]] = 1
|
114 |
+
|
115 |
+
# Spectral clustering
|
116 |
+
clustering = SpectralClustering(
|
117 |
+
n_clusters=n_clusters,
|
118 |
+
affinity='precomputed',
|
119 |
+
random_state=42
|
120 |
+
)
|
121 |
+
|
122 |
+
labels = clustering.fit_predict(A)
|
123 |
+
|
124 |
+
# Order by cluster, then by degree within cluster
|
125 |
+
degrees = np.sum(A, axis=1)
|
126 |
+
|
127 |
+
order = []
|
128 |
+
for cluster in range(n_clusters):
|
129 |
+
cluster_nodes = np.where(labels == cluster)[0]
|
130 |
+
cluster_degrees = degrees[cluster_nodes]
|
131 |
+
cluster_order = cluster_nodes[np.argsort(-cluster_degrees)]
|
132 |
+
order.extend(cluster_order)
|
133 |
+
|
134 |
+
return torch.tensor(order, dtype=torch.long)
|
135 |
+
|
136 |
+
except Exception as e:
|
137 |
+
print(f"Community ordering failed: {e}, falling back to BFS ordering")
|
138 |
+
return GraphSequencer.bfs_ordering(edge_index, num_nodes)
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def multi_view_ordering(edge_index, num_nodes):
|
142 |
+
"""Generate multiple orderings for different perspectives"""
|
143 |
+
orderings = {}
|
144 |
+
|
145 |
+
# Primary orderings
|
146 |
+
orderings['bfs'] = GraphSequencer.bfs_ordering(edge_index, num_nodes)
|
147 |
+
orderings['degree'] = GraphSequencer.degree_ordering(edge_index, num_nodes)
|
148 |
+
orderings['spectral'] = GraphSequencer.spectral_ordering(edge_index, num_nodes)
|
149 |
+
orderings['community'] = GraphSequencer.community_ordering(edge_index, num_nodes)
|
150 |
+
|
151 |
+
return orderings
|
152 |
+
|
153 |
+
class PositionalEncoder:
|
154 |
+
"""Graph-aware positional encoding"""
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def encode_positions(x, edge_index, order, max_dist=10):
|
158 |
+
"""
|
159 |
+
Create positional encodings that preserve graph structure
|
160 |
+
"""
|
161 |
+
num_nodes = x.size(0)
|
162 |
+
device = x.device
|
163 |
+
|
164 |
+
# Sequential positions
|
165 |
+
seq_pos = torch.zeros(num_nodes, device=device)
|
166 |
+
seq_pos[order] = torch.arange(num_nodes, device=device, dtype=torch.float)
|
167 |
+
|
168 |
+
# Graph distances (local neighborhood)
|
169 |
+
G = nx.Graph()
|
170 |
+
G.add_edges_from(edge_index.t().cpu().numpy())
|
171 |
+
|
172 |
+
# Compute shortest path distances
|
173 |
+
distances = torch.full((num_nodes, max_dist), float('inf'), device=device)
|
174 |
+
|
175 |
+
for i, node in enumerate(order):
|
176 |
+
# Get distances to previous nodes in sequence
|
177 |
+
start_idx = max(0, i - max_dist)
|
178 |
+
for j in range(start_idx, i):
|
179 |
+
prev_node = order[j].item()
|
180 |
+
try:
|
181 |
+
dist = nx.shortest_path_length(G, source=node.item(), target=prev_node)
|
182 |
+
distances[node, j - start_idx] = min(dist, max_dist - 1)
|
183 |
+
except nx.NetworkXNoPath:
|
184 |
+
distances[node, j - start_idx] = max_dist - 1
|
185 |
+
|
186 |
+
# Replace infinities with max distance
|
187 |
+
distances[distances == float('inf')] = max_dist - 1
|
188 |
+
|
189 |
+
# Normalize
|
190 |
+
seq_pos = seq_pos / num_nodes
|
191 |
+
distances = distances / max_dist
|
192 |
+
|
193 |
+
return seq_pos.unsqueeze(1), distances
|