kfoughali commited on
Commit
aab84ef
·
verified ·
1 Parent(s): beb8b0c

Create processor.py

Browse files
Files changed (1) hide show
  1. data/processor.py +30 -0
data/processor.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch_geometric.data import Data
4
+
5
+ class GraphProcessor:
6
+ """Data preprocessing utilities"""
7
+
8
+ @staticmethod
9
+ def normalize_features(x):
10
+ """Normalize node features"""
11
+ return F.normalize(x, p=2, dim=1)
12
+
13
+ @staticmethod
14
+ def add_self_loops(edge_index, num_nodes):
15
+ """Add self loops to graph"""
16
+ self_loops = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
17
+ edge_index = torch.cat([edge_index, self_loops], dim=1)
18
+ return edge_index
19
+
20
+ @staticmethod
21
+ def to_device(data, device):
22
+ """Move data to device safely"""
23
+ if hasattr(data, 'to'):
24
+ return data.to(device)
25
+ elif isinstance(data, (list, tuple)):
26
+ return [GraphProcessor.to_device(item, device) for item in data]
27
+ elif isinstance(data, dict):
28
+ return {k: GraphProcessor.to_device(v, device) for k, v in data.items()}
29
+ else:
30
+ return data