Update data/processor.py
Browse files- data/processor.py +133 -8
data/processor.py
CHANGED
@@ -1,30 +1,155 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from torch_geometric.data import Data
|
|
|
|
|
4 |
|
5 |
class GraphProcessor:
|
6 |
-
"""
|
7 |
|
8 |
@staticmethod
|
9 |
-
def normalize_features(x):
|
10 |
"""Normalize node features"""
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
@staticmethod
|
14 |
def add_self_loops(edge_index, num_nodes):
|
15 |
"""Add self loops to graph"""
|
16 |
-
self_loops = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
|
17 |
edge_index = torch.cat([edge_index, self_loops], dim=1)
|
18 |
return edge_index
|
19 |
|
20 |
@staticmethod
|
21 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
"""Move data to device safely"""
|
23 |
if hasattr(data, 'to'):
|
24 |
return data.to(device)
|
25 |
elif isinstance(data, (list, tuple)):
|
26 |
-
return [GraphProcessor.
|
27 |
elif isinstance(data, dict):
|
28 |
-
return {k: GraphProcessor.
|
29 |
else:
|
30 |
-
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from torch_geometric.data import Data
|
4 |
+
from torch_geometric.transforms import Compose
|
5 |
+
import numpy as np
|
6 |
|
7 |
class GraphProcessor:
|
8 |
+
"""Advanced data preprocessing utilities"""
|
9 |
|
10 |
@staticmethod
|
11 |
+
def normalize_features(x, method='l2'):
|
12 |
"""Normalize node features"""
|
13 |
+
if method == 'l2':
|
14 |
+
return F.normalize(x, p=2, dim=1)
|
15 |
+
elif method == 'minmax':
|
16 |
+
x_min = x.min(dim=0, keepdim=True)[0]
|
17 |
+
x_max = x.max(dim=0, keepdim=True)[0]
|
18 |
+
return (x - x_min) / (x_max - x_min + 1e-8)
|
19 |
+
elif method == 'standard':
|
20 |
+
return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8)
|
21 |
+
else:
|
22 |
+
return x
|
23 |
|
24 |
@staticmethod
|
25 |
def add_self_loops(edge_index, num_nodes):
|
26 |
"""Add self loops to graph"""
|
27 |
+
self_loops = torch.arange(num_nodes, device=edge_index.device).unsqueeze(0).repeat(2, 1)
|
28 |
edge_index = torch.cat([edge_index, self_loops], dim=1)
|
29 |
return edge_index
|
30 |
|
31 |
@staticmethod
|
32 |
+
def remove_self_loops(edge_index):
|
33 |
+
"""Remove self loops from graph"""
|
34 |
+
mask = edge_index[0] != edge_index[1]
|
35 |
+
return edge_index[:, mask]
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def add_positional_features(data, encoding_dim=8):
|
39 |
+
"""Add positional encodings as features"""
|
40 |
+
num_nodes = data.num_nodes
|
41 |
+
|
42 |
+
# Random walk positional encoding
|
43 |
+
if data.edge_index.shape[1] > 0:
|
44 |
+
adj = torch.zeros(num_nodes, num_nodes)
|
45 |
+
adj[data.edge_index[0], data.edge_index[1]] = 1
|
46 |
+
adj = adj + adj.t() # Make symmetric
|
47 |
+
|
48 |
+
# Degree normalization
|
49 |
+
degree = adj.sum(dim=1)
|
50 |
+
degree[degree == 0] = 1 # Avoid division by zero
|
51 |
+
D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree))
|
52 |
+
|
53 |
+
# Normalized adjacency
|
54 |
+
A_norm = D_inv_sqrt @ adj @ D_inv_sqrt
|
55 |
+
|
56 |
+
# Random walk features
|
57 |
+
rw_features = []
|
58 |
+
A_power = torch.eye(num_nodes)
|
59 |
+
|
60 |
+
for k in range(encoding_dim):
|
61 |
+
A_power = A_power @ A_norm
|
62 |
+
rw_features.append(A_power.diag().unsqueeze(1))
|
63 |
+
|
64 |
+
pos_encoding = torch.cat(rw_features, dim=1)
|
65 |
+
else:
|
66 |
+
# No edges - use node indices
|
67 |
+
pos_encoding = torch.zeros(num_nodes, encoding_dim)
|
68 |
+
for i in range(min(encoding_dim, num_nodes)):
|
69 |
+
pos_encoding[i, i] = 1.0
|
70 |
+
|
71 |
+
# Concatenate with existing features
|
72 |
+
if data.x is not None:
|
73 |
+
data.x = torch.cat([data.x, pos_encoding], dim=1)
|
74 |
+
else:
|
75 |
+
data.x = pos_encoding
|
76 |
+
|
77 |
+
return data
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def augment_graph(data, aug_type='edge_drop', aug_ratio=0.1):
|
81 |
+
"""Graph augmentation for training"""
|
82 |
+
if aug_type == 'edge_drop':
|
83 |
+
# Randomly drop edges
|
84 |
+
num_edges = data.edge_index.shape[1]
|
85 |
+
mask = torch.rand(num_edges) > aug_ratio
|
86 |
+
data.edge_index = data.edge_index[:, mask]
|
87 |
+
|
88 |
+
elif aug_type == 'node_drop':
|
89 |
+
# Randomly drop nodes
|
90 |
+
num_nodes = data.num_nodes
|
91 |
+
keep_mask = torch.rand(num_nodes) > aug_ratio
|
92 |
+
keep_nodes = torch.where(keep_mask)[0]
|
93 |
+
|
94 |
+
# Update edge index
|
95 |
+
node_map = torch.full((num_nodes,), -1, dtype=torch.long)
|
96 |
+
node_map[keep_nodes] = torch.arange(len(keep_nodes))
|
97 |
+
|
98 |
+
# Filter edges
|
99 |
+
edge_mask = keep_mask[data.edge_index[0]] & keep_mask[data.edge_index[1]]
|
100 |
+
filtered_edges = data.edge_index[:, edge_mask]
|
101 |
+
data.edge_index = node_map[filtered_edges]
|
102 |
+
|
103 |
+
# Update features
|
104 |
+
data.x = data.x[keep_nodes]
|
105 |
+
if hasattr(data, 'y') and data.y.size(0) == num_nodes:
|
106 |
+
data.y = data.y[keep_nodes]
|
107 |
+
|
108 |
+
elif aug_type == 'feature_noise':
|
109 |
+
# Add Gaussian noise to features
|
110 |
+
if data.x is not None:
|
111 |
+
noise = torch.randn_like(data.x) * aug_ratio
|
112 |
+
data.x = data.x + noise
|
113 |
+
|
114 |
+
elif aug_type == 'feature_mask':
|
115 |
+
# Randomly mask features
|
116 |
+
if data.x is not None:
|
117 |
+
mask = torch.rand_like(data.x) > aug_ratio
|
118 |
+
data.x = data.x * mask
|
119 |
+
|
120 |
+
return data
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def to_device_safe(data, device):
|
124 |
"""Move data to device safely"""
|
125 |
if hasattr(data, 'to'):
|
126 |
return data.to(device)
|
127 |
elif isinstance(data, (list, tuple)):
|
128 |
+
return [GraphProcessor.to_device_safe(item, device) for item in data]
|
129 |
elif isinstance(data, dict):
|
130 |
+
return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()}
|
131 |
else:
|
132 |
+
return data
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def validate_data(data):
|
136 |
+
"""Validate graph data integrity"""
|
137 |
+
errors = []
|
138 |
+
|
139 |
+
# Check basic structure
|
140 |
+
if not hasattr(data, 'edge_index'):
|
141 |
+
errors.append("Missing edge_index")
|
142 |
+
elif data.edge_index.shape[0] != 2:
|
143 |
+
errors.append("edge_index must have shape (2, num_edges)")
|
144 |
+
|
145 |
+
if hasattr(data, 'x') and data.x is not None:
|
146 |
+
if hasattr(data, 'num_nodes') and data.x.shape[0] != data.num_nodes:
|
147 |
+
errors.append("Feature matrix size mismatch")
|
148 |
+
|
149 |
+
# Check edge indices
|
150 |
+
if hasattr(data, 'edge_index') and data.edge_index.shape[1] > 0:
|
151 |
+
max_idx = data.edge_index.max().item()
|
152 |
+
if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes:
|
153 |
+
errors.append("Edge indices exceed number of nodes")
|
154 |
+
|
155 |
+
return errors
|