kfoughali commited on
Commit
f3d5bea
·
verified ·
1 Parent(s): 159f602

Create core/graph_sequencer.py

Browse files
Files changed (1) hide show
  1. core/graph_sequencer.py +193 -0
core/graph_sequencer.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import networkx as nx
4
+ from scipy.sparse.linalg import eigsh
5
+ from sklearn.cluster import SpectralClustering
6
+ from torch_geometric.utils import to_networkx, get_laplacian
7
+ import torch_geometric.utils as pyg_utils
8
+
9
+ class GraphSequencer:
10
+ """
11
+ Production-ready graph ordering strategies
12
+ All methods use real graph data - no hardcoded values
13
+ """
14
+
15
+ @staticmethod
16
+ def bfs_ordering(edge_index, num_nodes, start_node=None):
17
+ """Breadth-first search ordering"""
18
+ # Convert to NetworkX for BFS
19
+ G = nx.Graph()
20
+ G.add_nodes_from(range(num_nodes))
21
+ edge_list = edge_index.t().cpu().numpy()
22
+ G.add_edges_from(edge_list)
23
+
24
+ # Start from highest degree node if not specified
25
+ if start_node is None:
26
+ degrees = dict(G.degree())
27
+ start_node = max(degrees, key=degrees.get)
28
+
29
+ # BFS traversal
30
+ visited = set()
31
+ order = []
32
+ queue = [start_node]
33
+
34
+ while queue:
35
+ node = queue.pop(0)
36
+ if node in visited:
37
+ continue
38
+
39
+ visited.add(node)
40
+ order.append(node)
41
+
42
+ # Add neighbors by degree (deterministic)
43
+ neighbors = list(G.neighbors(node))
44
+ neighbors.sort(key=lambda n: G.degree(n), reverse=True)
45
+
46
+ for neighbor in neighbors:
47
+ if neighbor not in visited:
48
+ queue.append(neighbor)
49
+
50
+ # Add any disconnected nodes
51
+ for node in range(num_nodes):
52
+ if node not in visited:
53
+ order.append(node)
54
+
55
+ return torch.tensor(order, dtype=torch.long)
56
+
57
+ @staticmethod
58
+ def spectral_ordering(edge_index, num_nodes):
59
+ """Spectral ordering using graph Laplacian eigenvector"""
60
+ try:
61
+ # Compute normalized Laplacian
62
+ edge_index_np = edge_index.cpu().numpy()
63
+
64
+ # Create adjacency matrix
65
+ A = np.zeros((num_nodes, num_nodes))
66
+ A[edge_index_np[0], edge_index_np[1]] = 1
67
+ A[edge_index_np[1], edge_index_np[0]] = 1 # Undirected
68
+
69
+ # Degree matrix
70
+ D = np.diag(np.sum(A, axis=1))
71
+
72
+ # Normalized Laplacian: L = D^(-1/2) * (D - A) * D^(-1/2)
73
+ D_sqrt_inv = np.diag(1.0 / np.sqrt(np.maximum(np.diag(D), 1e-12)))
74
+ L = D_sqrt_inv @ (D - A) @ D_sqrt_inv
75
+
76
+ # Compute second smallest eigenvector (Fiedler vector)
77
+ eigenvals, eigenvecs = eigsh(L, k=min(10, num_nodes-1), which='SM')
78
+ fiedler_vector = eigenvecs[:, 1] # Second smallest
79
+
80
+ # Order by Fiedler vector values
81
+ order = np.argsort(fiedler_vector)
82
+
83
+ return torch.tensor(order, dtype=torch.long)
84
+
85
+ except Exception as e:
86
+ print(f"Spectral ordering failed: {e}, falling back to degree ordering")
87
+ return GraphSequencer.degree_ordering(edge_index, num_nodes)
88
+
89
+ @staticmethod
90
+ def degree_ordering(edge_index, num_nodes):
91
+ """Order nodes by degree (high to low)"""
92
+ # Count degrees
93
+ degrees = torch.zeros(num_nodes, dtype=torch.long)
94
+ degrees.index_add_(0, edge_index[0], torch.ones(edge_index.shape[1], dtype=torch.long))
95
+ degrees.index_add_(0, edge_index[1], torch.ones(edge_index.shape[1], dtype=torch.long))
96
+
97
+ # Sort by degree (descending), then by node index for determinism
98
+ _, order = torch.sort(-degrees * num_nodes - torch.arange(num_nodes))
99
+
100
+ return order
101
+
102
+ @staticmethod
103
+ def community_ordering(edge_index, num_nodes, n_clusters=None):
104
+ """Community-aware ordering using spectral clustering"""
105
+ try:
106
+ if n_clusters is None:
107
+ n_clusters = max(2, min(10, num_nodes // 100))
108
+
109
+ # Convert to adjacency matrix
110
+ edge_index_np = edge_index.cpu().numpy()
111
+ A = np.zeros((num_nodes, num_nodes))
112
+ A[edge_index_np[0], edge_index_np[1]] = 1
113
+ A[edge_index_np[1], edge_index_np[0]] = 1
114
+
115
+ # Spectral clustering
116
+ clustering = SpectralClustering(
117
+ n_clusters=n_clusters,
118
+ affinity='precomputed',
119
+ random_state=42
120
+ )
121
+
122
+ labels = clustering.fit_predict(A)
123
+
124
+ # Order by cluster, then by degree within cluster
125
+ degrees = np.sum(A, axis=1)
126
+
127
+ order = []
128
+ for cluster in range(n_clusters):
129
+ cluster_nodes = np.where(labels == cluster)[0]
130
+ cluster_degrees = degrees[cluster_nodes]
131
+ cluster_order = cluster_nodes[np.argsort(-cluster_degrees)]
132
+ order.extend(cluster_order)
133
+
134
+ return torch.tensor(order, dtype=torch.long)
135
+
136
+ except Exception as e:
137
+ print(f"Community ordering failed: {e}, falling back to BFS ordering")
138
+ return GraphSequencer.bfs_ordering(edge_index, num_nodes)
139
+
140
+ @staticmethod
141
+ def multi_view_ordering(edge_index, num_nodes):
142
+ """Generate multiple orderings for different perspectives"""
143
+ orderings = {}
144
+
145
+ # Primary orderings
146
+ orderings['bfs'] = GraphSequencer.bfs_ordering(edge_index, num_nodes)
147
+ orderings['degree'] = GraphSequencer.degree_ordering(edge_index, num_nodes)
148
+ orderings['spectral'] = GraphSequencer.spectral_ordering(edge_index, num_nodes)
149
+ orderings['community'] = GraphSequencer.community_ordering(edge_index, num_nodes)
150
+
151
+ return orderings
152
+
153
+ class PositionalEncoder:
154
+ """Graph-aware positional encoding"""
155
+
156
+ @staticmethod
157
+ def encode_positions(x, edge_index, order, max_dist=10):
158
+ """
159
+ Create positional encodings that preserve graph structure
160
+ """
161
+ num_nodes = x.size(0)
162
+ device = x.device
163
+
164
+ # Sequential positions
165
+ seq_pos = torch.zeros(num_nodes, device=device)
166
+ seq_pos[order] = torch.arange(num_nodes, device=device, dtype=torch.float)
167
+
168
+ # Graph distances (local neighborhood)
169
+ G = nx.Graph()
170
+ G.add_edges_from(edge_index.t().cpu().numpy())
171
+
172
+ # Compute shortest path distances
173
+ distances = torch.full((num_nodes, max_dist), float('inf'), device=device)
174
+
175
+ for i, node in enumerate(order):
176
+ # Get distances to previous nodes in sequence
177
+ start_idx = max(0, i - max_dist)
178
+ for j in range(start_idx, i):
179
+ prev_node = order[j].item()
180
+ try:
181
+ dist = nx.shortest_path_length(G, source=node.item(), target=prev_node)
182
+ distances[node, j - start_idx] = min(dist, max_dist - 1)
183
+ except nx.NetworkXNoPath:
184
+ distances[node, j - start_idx] = max_dist - 1
185
+
186
+ # Replace infinities with max distance
187
+ distances[distances == float('inf')] = max_dist - 1
188
+
189
+ # Normalize
190
+ seq_pos = seq_pos / num_nodes
191
+ distances = distances / max_dist
192
+
193
+ return seq_pos.unsqueeze(1), distances