kfoughali commited on
Commit
8e24e05
·
verified ·
1 Parent(s): 3c6b427

Update core/graph_sequencer.py

Browse files
Files changed (1) hide show
  1. core/graph_sequencer.py +166 -83
core/graph_sequencer.py CHANGED
@@ -3,28 +3,39 @@ import numpy as np
3
  import networkx as nx
4
  from scipy.sparse.linalg import eigsh
5
  from sklearn.cluster import SpectralClustering
6
- from torch_geometric.utils import to_networkx, get_laplacian
7
- import torch_geometric.utils as pyg_utils
8
 
9
  class GraphSequencer:
10
  """
11
  Production-ready graph ordering strategies
12
- All methods use real graph data - no hardcoded values
13
  """
14
 
15
  @staticmethod
16
  def bfs_ordering(edge_index, num_nodes, start_node=None):
17
- """Breadth-first search ordering"""
18
- # Convert to NetworkX for BFS
19
- G = nx.Graph()
20
- G.add_nodes_from(range(num_nodes))
 
 
 
 
21
  edge_list = edge_index.t().cpu().numpy()
22
- G.add_edges_from(edge_list)
 
 
 
 
 
 
 
23
 
24
  # Start from highest degree node if not specified
25
  if start_node is None:
26
- degrees = dict(G.degree())
27
- start_node = max(degrees, key=degrees.get)
28
 
29
  # BFS traversal
30
  visited = set()
@@ -33,15 +44,15 @@ class GraphSequencer:
33
 
34
  while queue:
35
  node = queue.pop(0)
36
- if node in visited:
37
  continue
38
 
39
  visited.add(node)
40
  order.append(node)
41
 
42
  # Add neighbors by degree (deterministic)
43
- neighbors = list(G.neighbors(node))
44
- neighbors.sort(key=lambda n: G.degree(n), reverse=True)
45
 
46
  for neighbor in neighbors:
47
  if neighbor not in visited:
@@ -52,35 +63,64 @@ class GraphSequencer:
52
  if node not in visited:
53
  order.append(node)
54
 
55
- return torch.tensor(order, dtype=torch.long)
56
 
57
  @staticmethod
58
  def spectral_ordering(edge_index, num_nodes):
59
- """Spectral ordering using graph Laplacian eigenvector"""
 
 
 
 
 
60
  try:
61
- # Compute normalized Laplacian
62
- edge_index_np = edge_index.cpu().numpy()
63
 
64
  # Create adjacency matrix
65
  A = np.zeros((num_nodes, num_nodes))
66
- A[edge_index_np[0], edge_index_np[1]] = 1
67
- A[edge_index_np[1], edge_index_np[0]] = 1 # Undirected
 
 
 
68
 
69
  # Degree matrix
70
- D = np.diag(np.sum(A, axis=1))
 
 
 
 
 
 
 
 
 
71
 
72
  # Normalized Laplacian: L = D^(-1/2) * (D - A) * D^(-1/2)
73
- D_sqrt_inv = np.diag(1.0 / np.sqrt(np.maximum(np.diag(D), 1e-12)))
 
74
  L = D_sqrt_inv @ (D - A) @ D_sqrt_inv
75
 
76
- # Compute second smallest eigenvector (Fiedler vector)
77
- eigenvals, eigenvecs = eigsh(L, k=min(10, num_nodes-1), which='SM')
78
- fiedler_vector = eigenvecs[:, 1] # Second smallest
79
-
80
- # Order by Fiedler vector values
81
- order = np.argsort(fiedler_vector)
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- return torch.tensor(order, dtype=torch.long)
84
 
85
  except Exception as e:
86
  print(f"Spectral ordering failed: {e}, falling back to degree ordering")
@@ -88,35 +128,61 @@ class GraphSequencer:
88
 
89
  @staticmethod
90
  def degree_ordering(edge_index, num_nodes):
91
- """Order nodes by degree (high to low)"""
92
- # Count degrees
93
- degrees = torch.zeros(num_nodes, dtype=torch.long)
94
- degrees.index_add_(0, edge_index[0], torch.ones(edge_index.shape[1], dtype=torch.long))
95
- degrees.index_add_(0, edge_index[1], torch.ones(edge_index.shape[1], dtype=torch.long))
 
 
 
 
 
 
 
 
 
96
 
97
  # Sort by degree (descending), then by node index for determinism
98
- _, order = torch.sort(-degrees * num_nodes - torch.arange(num_nodes))
 
99
 
100
  return order
101
 
102
  @staticmethod
103
  def community_ordering(edge_index, num_nodes, n_clusters=None):
104
- """Community-aware ordering using spectral clustering"""
 
 
 
 
 
105
  try:
106
  if n_clusters is None:
107
- n_clusters = max(2, min(10, num_nodes // 100))
108
 
109
- # Convert to adjacency matrix
110
- edge_index_np = edge_index.cpu().numpy()
 
 
111
  A = np.zeros((num_nodes, num_nodes))
112
- A[edge_index_np[0], edge_index_np[1]] = 1
113
- A[edge_index_np[1], edge_index_np[0]] = 1
 
 
 
 
 
 
 
 
114
 
115
  # Spectral clustering
116
  clustering = SpectralClustering(
117
  n_clusters=n_clusters,
118
  affinity='precomputed',
119
- random_state=42
 
120
  )
121
 
122
  labels = clustering.fit_predict(A)
@@ -127,36 +193,30 @@ class GraphSequencer:
127
  order = []
128
  for cluster in range(n_clusters):
129
  cluster_nodes = np.where(labels == cluster)[0]
130
- cluster_degrees = degrees[cluster_nodes]
131
- cluster_order = cluster_nodes[np.argsort(-cluster_degrees)]
132
- order.extend(cluster_order)
 
133
 
134
- return torch.tensor(order, dtype=torch.long)
 
 
 
 
 
135
 
136
  except Exception as e:
137
  print(f"Community ordering failed: {e}, falling back to BFS ordering")
138
  return GraphSequencer.bfs_ordering(edge_index, num_nodes)
139
-
140
- @staticmethod
141
- def multi_view_ordering(edge_index, num_nodes):
142
- """Generate multiple orderings for different perspectives"""
143
- orderings = {}
144
-
145
- # Primary orderings
146
- orderings['bfs'] = GraphSequencer.bfs_ordering(edge_index, num_nodes)
147
- orderings['degree'] = GraphSequencer.degree_ordering(edge_index, num_nodes)
148
- orderings['spectral'] = GraphSequencer.spectral_ordering(edge_index, num_nodes)
149
- orderings['community'] = GraphSequencer.community_ordering(edge_index, num_nodes)
150
-
151
- return orderings
152
 
153
  class PositionalEncoder:
154
- """Graph-aware positional encoding"""
155
 
156
  @staticmethod
157
  def encode_positions(x, edge_index, order, max_dist=10):
158
  """
159
  Create positional encodings that preserve graph structure
 
160
  """
161
  num_nodes = x.size(0)
162
  device = x.device
@@ -164,30 +224,53 @@ class PositionalEncoder:
164
  # Sequential positions
165
  seq_pos = torch.zeros(num_nodes, device=device)
166
  seq_pos[order] = torch.arange(num_nodes, device=device, dtype=torch.float)
 
 
 
 
167
 
168
- # Graph distances (local neighborhood)
169
- G = nx.Graph()
170
- G.add_edges_from(edge_index.t().cpu().numpy())
171
-
172
- # Compute shortest path distances
173
- distances = torch.full((num_nodes, max_dist), float('inf'), device=device)
174
-
175
- for i, node in enumerate(order):
176
- # Get distances to previous nodes in sequence
177
- start_idx = max(0, i - max_dist)
178
- for j in range(start_idx, i):
179
- prev_node = order[j].item()
180
- try:
181
- dist = nx.shortest_path_length(G, source=node.item(), target=prev_node)
182
- distances[node, j - start_idx] = min(dist, max_dist - 1)
183
- except nx.NetworkXNoPath:
184
- distances[node, j - start_idx] = max_dist - 1
185
-
186
- # Replace infinities with max distance
187
- distances[distances == float('inf')] = max_dist - 1
188
-
189
- # Normalize
190
- seq_pos = seq_pos / num_nodes
191
- distances = distances / max_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  return seq_pos.unsqueeze(1), distances
 
3
  import networkx as nx
4
  from scipy.sparse.linalg import eigsh
5
  from sklearn.cluster import SpectralClustering
6
+ import warnings
7
+ warnings.filterwarnings('ignore')
8
 
9
  class GraphSequencer:
10
  """
11
  Production-ready graph ordering strategies
12
+ Device-safe implementation with performance optimizations
13
  """
14
 
15
  @staticmethod
16
  def bfs_ordering(edge_index, num_nodes, start_node=None):
17
+ """Breadth-first search ordering - optimized version"""
18
+ device = edge_index.device
19
+
20
+ if num_nodes <= 1:
21
+ return torch.arange(num_nodes, device=device)
22
+
23
+ # Convert to adjacency list efficiently
24
+ adj_list = [[] for _ in range(num_nodes)]
25
  edge_list = edge_index.t().cpu().numpy()
26
+
27
+ for src, dst in edge_list:
28
+ if src < num_nodes and dst < num_nodes:
29
+ adj_list[src].append(dst)
30
+ adj_list[dst].append(src)
31
+
32
+ # Remove duplicates and sort for determinism
33
+ adj_list = [sorted(list(set(neighbors))) for neighbors in adj_list]
34
 
35
  # Start from highest degree node if not specified
36
  if start_node is None:
37
+ degrees = [len(neighbors) for neighbors in adj_list]
38
+ start_node = np.argmax(degrees) if degrees else 0
39
 
40
  # BFS traversal
41
  visited = set()
 
44
 
45
  while queue:
46
  node = queue.pop(0)
47
+ if node in visited or node >= num_nodes:
48
  continue
49
 
50
  visited.add(node)
51
  order.append(node)
52
 
53
  # Add neighbors by degree (deterministic)
54
+ neighbors = adj_list[node]
55
+ neighbors.sort(key=lambda n: (len(adj_list[n]), n), reverse=True)
56
 
57
  for neighbor in neighbors:
58
  if neighbor not in visited:
 
63
  if node not in visited:
64
  order.append(node)
65
 
66
+ return torch.tensor(order, dtype=torch.long, device=device)
67
 
68
  @staticmethod
69
  def spectral_ordering(edge_index, num_nodes):
70
+ """Spectral ordering using graph Laplacian eigenvector - robust version"""
71
+ device = edge_index.device
72
+
73
+ if num_nodes <= 2:
74
+ return torch.arange(num_nodes, device=device)
75
+
76
  try:
77
+ # Move to CPU for scipy operations
78
+ edge_index_cpu = edge_index.cpu().numpy()
79
 
80
  # Create adjacency matrix
81
  A = np.zeros((num_nodes, num_nodes))
82
+ valid_edges = (edge_index_cpu[0] < num_nodes) & (edge_index_cpu[1] < num_nodes)
83
+ valid_edge_index = edge_index_cpu[:, valid_edges]
84
+
85
+ A[valid_edge_index[0], valid_edge_index[1]] = 1
86
+ A[valid_edge_index[1], valid_edge_index[0]] = 1 # Undirected
87
 
88
  # Degree matrix
89
+ degrees = np.sum(A, axis=1)
90
+
91
+ # Handle disconnected components
92
+ if np.any(degrees == 0):
93
+ # Add self-loops to isolated nodes
94
+ isolated = degrees == 0
95
+ A[isolated, isolated] = 1
96
+ degrees = np.sum(A, axis=1)
97
+
98
+ D = np.diag(degrees)
99
 
100
  # Normalized Laplacian: L = D^(-1/2) * (D - A) * D^(-1/2)
101
+ degrees_sqrt_inv = np.where(degrees > 0, 1.0 / np.sqrt(degrees), 0)
102
+ D_sqrt_inv = np.diag(degrees_sqrt_inv)
103
  L = D_sqrt_inv @ (D - A) @ D_sqrt_inv
104
 
105
+ # Compute eigenvectors
106
+ k = min(10, num_nodes - 1)
107
+ try:
108
+ eigenvals, eigenvecs = eigsh(L, k=k, which='SM', sigma=0.0)
109
+
110
+ # Use second smallest eigenvector (Fiedler vector)
111
+ if eigenvecs.shape[1] > 1:
112
+ fiedler_vector = eigenvecs[:, 1]
113
+ else:
114
+ fiedler_vector = eigenvecs[:, 0]
115
+
116
+ # Order by Fiedler vector values
117
+ order = np.argsort(fiedler_vector)
118
+
119
+ except Exception:
120
+ # Fallback to degree ordering
121
+ order = np.argsort(-degrees)
122
 
123
+ return torch.tensor(order, dtype=torch.long, device=device)
124
 
125
  except Exception as e:
126
  print(f"Spectral ordering failed: {e}, falling back to degree ordering")
 
128
 
129
  @staticmethod
130
  def degree_ordering(edge_index, num_nodes):
131
+ """Order nodes by degree (high to low) - optimized version"""
132
+ device = edge_index.device
133
+
134
+ # Count degrees efficiently
135
+ degrees = torch.zeros(num_nodes, dtype=torch.long, device=device)
136
+
137
+ if edge_index.shape[1] > 0:
138
+ # Ensure valid indices
139
+ valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
140
+ valid_edges = edge_index[:, valid_mask]
141
+
142
+ if valid_edges.shape[1] > 0:
143
+ degrees.index_add_(0, valid_edges[0], torch.ones(valid_edges.shape[1], dtype=torch.long, device=device))
144
+ degrees.index_add_(0, valid_edges[1], torch.ones(valid_edges.shape[1], dtype=torch.long, device=device))
145
 
146
  # Sort by degree (descending), then by node index for determinism
147
+ node_indices = torch.arange(num_nodes, device=device)
148
+ _, order = torch.sort(-degrees * num_nodes - node_indices)
149
 
150
  return order
151
 
152
  @staticmethod
153
  def community_ordering(edge_index, num_nodes, n_clusters=None):
154
+ """Community-aware ordering - robust version"""
155
+ device = edge_index.device
156
+
157
+ if num_nodes <= 3:
158
+ return GraphSequencer.degree_ordering(edge_index, num_nodes)
159
+
160
  try:
161
  if n_clusters is None:
162
+ n_clusters = max(2, min(10, int(np.sqrt(num_nodes))))
163
 
164
+ n_clusters = min(n_clusters, num_nodes)
165
+
166
+ # Convert to adjacency matrix on CPU
167
+ edge_index_cpu = edge_index.cpu().numpy()
168
  A = np.zeros((num_nodes, num_nodes))
169
+
170
+ valid_edges = (edge_index_cpu[0] < num_nodes) & (edge_index_cpu[1] < num_nodes)
171
+ valid_edge_index = edge_index_cpu[:, valid_edges]
172
+
173
+ if valid_edge_index.shape[1] > 0:
174
+ A[valid_edge_index[0], valid_edge_index[1]] = 1
175
+ A[valid_edge_index[1], valid_edge_index[0]] = 1
176
+
177
+ # Add small diagonal for stability
178
+ A += np.eye(num_nodes) * 0.01
179
 
180
  # Spectral clustering
181
  clustering = SpectralClustering(
182
  n_clusters=n_clusters,
183
  affinity='precomputed',
184
+ random_state=42,
185
+ assign_labels='discretize'
186
  )
187
 
188
  labels = clustering.fit_predict(A)
 
193
  order = []
194
  for cluster in range(n_clusters):
195
  cluster_nodes = np.where(labels == cluster)[0]
196
+ if len(cluster_nodes) > 0:
197
+ cluster_degrees = degrees[cluster_nodes]
198
+ cluster_order = cluster_nodes[np.argsort(-cluster_degrees)]
199
+ order.extend(cluster_order)
200
 
201
+ # Add any missed nodes
202
+ for i in range(num_nodes):
203
+ if i not in order:
204
+ order.append(i)
205
+
206
+ return torch.tensor(order, dtype=torch.long, device=device)
207
 
208
  except Exception as e:
209
  print(f"Community ordering failed: {e}, falling back to BFS ordering")
210
  return GraphSequencer.bfs_ordering(edge_index, num_nodes)
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  class PositionalEncoder:
213
+ """Graph-aware positional encoding - optimized version"""
214
 
215
  @staticmethod
216
  def encode_positions(x, edge_index, order, max_dist=10):
217
  """
218
  Create positional encodings that preserve graph structure
219
+ Optimized for training stability
220
  """
221
  num_nodes = x.size(0)
222
  device = x.device
 
224
  # Sequential positions
225
  seq_pos = torch.zeros(num_nodes, device=device)
226
  seq_pos[order] = torch.arange(num_nodes, device=device, dtype=torch.float)
227
+ seq_pos = seq_pos / max(num_nodes, 1)
228
+
229
+ # Enhanced distance encoding
230
+ distances = torch.zeros((num_nodes, max_dist), device=device)
231
 
232
+ if edge_index.shape[1] > 0:
233
+ # Create adjacency matrix efficiently
234
+ adj = torch.zeros(num_nodes, num_nodes, device=device, dtype=torch.bool)
235
+
236
+ # Filter valid edges
237
+ valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
238
+ if valid_mask.any():
239
+ valid_edges = edge_index[:, valid_mask]
240
+ adj[valid_edges[0], valid_edges[1]] = True
241
+ adj[valid_edges[1], valid_edges[0]] = True # Undirected
242
+
243
+ # Compute 2-hop neighbors for richer encoding
244
+ adj2 = torch.matmul(adj.float(), adj.float()) > 0
245
+
246
+ # Fill distance features
247
+ for i, node in enumerate(order):
248
+ node_idx = node.item() if isinstance(node, torch.Tensor) else node
249
+
250
+ if node_idx < num_nodes:
251
+ # Get 1-hop and 2-hop neighbors
252
+ neighbors_1hop = torch.where(adj[node_idx])[0]
253
+ neighbors_2hop = torch.where(adj2[node_idx] & ~adj[node_idx])[0]
254
+
255
+ # Fill distance features based on order position
256
+ start_idx = max(0, i - max_dist)
257
+ for j in range(start_idx, i):
258
+ if j - start_idx < max_dist:
259
+ prev_node = order[j]
260
+ prev_idx = prev_node.item() if isinstance(prev_node, torch.Tensor) else prev_node
261
+
262
+ if prev_idx < num_nodes:
263
+ # Multi-scale distance encoding
264
+ if prev_idx in neighbors_1hop:
265
+ distances[node_idx, j - start_idx] = 0.9 # Direct neighbor
266
+ elif prev_idx in neighbors_2hop:
267
+ distances[node_idx, j - start_idx] = 0.6 # 2-hop neighbor
268
+ else:
269
+ distances[node_idx, j - start_idx] = 0.3 # Distant
270
+ else:
271
+ # No edges - use position-based encoding
272
+ for i in range(num_nodes):
273
+ for j in range(max_dist):
274
+ distances[i, j] = (max_dist - j) / max_dist
275
 
276
  return seq_pos.unsqueeze(1), distances