File size: 994 Bytes
aab84ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
class GraphProcessor:
"""Data preprocessing utilities"""
@staticmethod
def normalize_features(x):
"""Normalize node features"""
return F.normalize(x, p=2, dim=1)
@staticmethod
def add_self_loops(edge_index, num_nodes):
"""Add self loops to graph"""
self_loops = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
edge_index = torch.cat([edge_index, self_loops], dim=1)
return edge_index
@staticmethod
def to_device(data, device):
"""Move data to device safely"""
if hasattr(data, 'to'):
return data.to(device)
elif isinstance(data, (list, tuple)):
return [GraphProcessor.to_device(item, device) for item in data]
elif isinstance(data, dict):
return {k: GraphProcessor.to_device(v, device) for k, v in data.items()}
else:
return data |