kfoughali commited on
Commit
4992374
·
verified ·
1 Parent(s): 021bc4e

Update data/processor.py

Browse files
Files changed (1) hide show
  1. data/processor.py +133 -8
data/processor.py CHANGED
@@ -1,30 +1,155 @@
1
  import torch
2
  import torch.nn.functional as F
3
  from torch_geometric.data import Data
 
 
4
 
5
  class GraphProcessor:
6
- """Data preprocessing utilities"""
7
 
8
  @staticmethod
9
- def normalize_features(x):
10
  """Normalize node features"""
11
- return F.normalize(x, p=2, dim=1)
 
 
 
 
 
 
 
 
 
12
 
13
  @staticmethod
14
  def add_self_loops(edge_index, num_nodes):
15
  """Add self loops to graph"""
16
- self_loops = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
17
  edge_index = torch.cat([edge_index, self_loops], dim=1)
18
  return edge_index
19
 
20
  @staticmethod
21
- def to_device(data, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """Move data to device safely"""
23
  if hasattr(data, 'to'):
24
  return data.to(device)
25
  elif isinstance(data, (list, tuple)):
26
- return [GraphProcessor.to_device(item, device) for item in data]
27
  elif isinstance(data, dict):
28
- return {k: GraphProcessor.to_device(v, device) for k, v in data.items()}
29
  else:
30
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  from torch_geometric.data import Data
4
+ from torch_geometric.transforms import Compose
5
+ import numpy as np
6
 
7
  class GraphProcessor:
8
+ """Advanced data preprocessing utilities"""
9
 
10
  @staticmethod
11
+ def normalize_features(x, method='l2'):
12
  """Normalize node features"""
13
+ if method == 'l2':
14
+ return F.normalize(x, p=2, dim=1)
15
+ elif method == 'minmax':
16
+ x_min = x.min(dim=0, keepdim=True)[0]
17
+ x_max = x.max(dim=0, keepdim=True)[0]
18
+ return (x - x_min) / (x_max - x_min + 1e-8)
19
+ elif method == 'standard':
20
+ return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8)
21
+ else:
22
+ return x
23
 
24
  @staticmethod
25
  def add_self_loops(edge_index, num_nodes):
26
  """Add self loops to graph"""
27
+ self_loops = torch.arange(num_nodes, device=edge_index.device).unsqueeze(0).repeat(2, 1)
28
  edge_index = torch.cat([edge_index, self_loops], dim=1)
29
  return edge_index
30
 
31
  @staticmethod
32
+ def remove_self_loops(edge_index):
33
+ """Remove self loops from graph"""
34
+ mask = edge_index[0] != edge_index[1]
35
+ return edge_index[:, mask]
36
+
37
+ @staticmethod
38
+ def add_positional_features(data, encoding_dim=8):
39
+ """Add positional encodings as features"""
40
+ num_nodes = data.num_nodes
41
+
42
+ # Random walk positional encoding
43
+ if data.edge_index.shape[1] > 0:
44
+ adj = torch.zeros(num_nodes, num_nodes)
45
+ adj[data.edge_index[0], data.edge_index[1]] = 1
46
+ adj = adj + adj.t() # Make symmetric
47
+
48
+ # Degree normalization
49
+ degree = adj.sum(dim=1)
50
+ degree[degree == 0] = 1 # Avoid division by zero
51
+ D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree))
52
+
53
+ # Normalized adjacency
54
+ A_norm = D_inv_sqrt @ adj @ D_inv_sqrt
55
+
56
+ # Random walk features
57
+ rw_features = []
58
+ A_power = torch.eye(num_nodes)
59
+
60
+ for k in range(encoding_dim):
61
+ A_power = A_power @ A_norm
62
+ rw_features.append(A_power.diag().unsqueeze(1))
63
+
64
+ pos_encoding = torch.cat(rw_features, dim=1)
65
+ else:
66
+ # No edges - use node indices
67
+ pos_encoding = torch.zeros(num_nodes, encoding_dim)
68
+ for i in range(min(encoding_dim, num_nodes)):
69
+ pos_encoding[i, i] = 1.0
70
+
71
+ # Concatenate with existing features
72
+ if data.x is not None:
73
+ data.x = torch.cat([data.x, pos_encoding], dim=1)
74
+ else:
75
+ data.x = pos_encoding
76
+
77
+ return data
78
+
79
+ @staticmethod
80
+ def augment_graph(data, aug_type='edge_drop', aug_ratio=0.1):
81
+ """Graph augmentation for training"""
82
+ if aug_type == 'edge_drop':
83
+ # Randomly drop edges
84
+ num_edges = data.edge_index.shape[1]
85
+ mask = torch.rand(num_edges) > aug_ratio
86
+ data.edge_index = data.edge_index[:, mask]
87
+
88
+ elif aug_type == 'node_drop':
89
+ # Randomly drop nodes
90
+ num_nodes = data.num_nodes
91
+ keep_mask = torch.rand(num_nodes) > aug_ratio
92
+ keep_nodes = torch.where(keep_mask)[0]
93
+
94
+ # Update edge index
95
+ node_map = torch.full((num_nodes,), -1, dtype=torch.long)
96
+ node_map[keep_nodes] = torch.arange(len(keep_nodes))
97
+
98
+ # Filter edges
99
+ edge_mask = keep_mask[data.edge_index[0]] & keep_mask[data.edge_index[1]]
100
+ filtered_edges = data.edge_index[:, edge_mask]
101
+ data.edge_index = node_map[filtered_edges]
102
+
103
+ # Update features
104
+ data.x = data.x[keep_nodes]
105
+ if hasattr(data, 'y') and data.y.size(0) == num_nodes:
106
+ data.y = data.y[keep_nodes]
107
+
108
+ elif aug_type == 'feature_noise':
109
+ # Add Gaussian noise to features
110
+ if data.x is not None:
111
+ noise = torch.randn_like(data.x) * aug_ratio
112
+ data.x = data.x + noise
113
+
114
+ elif aug_type == 'feature_mask':
115
+ # Randomly mask features
116
+ if data.x is not None:
117
+ mask = torch.rand_like(data.x) > aug_ratio
118
+ data.x = data.x * mask
119
+
120
+ return data
121
+
122
+ @staticmethod
123
+ def to_device_safe(data, device):
124
  """Move data to device safely"""
125
  if hasattr(data, 'to'):
126
  return data.to(device)
127
  elif isinstance(data, (list, tuple)):
128
+ return [GraphProcessor.to_device_safe(item, device) for item in data]
129
  elif isinstance(data, dict):
130
+ return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()}
131
  else:
132
+ return data
133
+
134
+ @staticmethod
135
+ def validate_data(data):
136
+ """Validate graph data integrity"""
137
+ errors = []
138
+
139
+ # Check basic structure
140
+ if not hasattr(data, 'edge_index'):
141
+ errors.append("Missing edge_index")
142
+ elif data.edge_index.shape[0] != 2:
143
+ errors.append("edge_index must have shape (2, num_edges)")
144
+
145
+ if hasattr(data, 'x') and data.x is not None:
146
+ if hasattr(data, 'num_nodes') and data.x.shape[0] != data.num_nodes:
147
+ errors.append("Feature matrix size mismatch")
148
+
149
+ # Check edge indices
150
+ if hasattr(data, 'edge_index') and data.edge_index.shape[1] > 0:
151
+ max_idx = data.edge_index.max().item()
152
+ if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes:
153
+ errors.append("Edge indices exceed number of nodes")
154
+
155
+ return errors