File size: 994 Bytes
aab84ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn.functional as F
from torch_geometric.data import Data

class GraphProcessor:
    """Data preprocessing utilities"""
    
    @staticmethod
    def normalize_features(x):
        """Normalize node features"""
        return F.normalize(x, p=2, dim=1)
    
    @staticmethod
    def add_self_loops(edge_index, num_nodes):
        """Add self loops to graph"""
        self_loops = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
        edge_index = torch.cat([edge_index, self_loops], dim=1)
        return edge_index
    
    @staticmethod
    def to_device(data, device):
        """Move data to device safely"""
        if hasattr(data, 'to'):
            return data.to(device)
        elif isinstance(data, (list, tuple)):
            return [GraphProcessor.to_device(item, device) for item in data]
        elif isinstance(data, dict):
            return {k: GraphProcessor.to_device(v, device) for k, v in data.items()}
        else:
            return data