serpent / data /loader.py
kfoughali's picture
Update data/loader.py
beb8b0c verified
raw
history blame
5.63 kB
import torch
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeFeatures
import yaml
import os
class GraphDataLoader:
"""
Production data loading with real datasets only
Device-safe implementation
"""
def __init__(self, config_path='config.yaml'):
if os.path.exists(config_path):
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
else:
# Default config if file doesn't exist
self.config = {
'data': {
'batch_size': 32,
'test_split': 0.2
}
}
self.batch_size = self.config['data']['batch_size']
self.test_split = self.config['data']['test_split']
def load_node_classification_data(self, dataset_name='Cora'):
"""Load real node classification datasets"""
try:
if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
dataset = Planetoid(
root=f'./data/{dataset_name}',
name=dataset_name,
transform=NormalizeFeatures()
)
else:
# Fallback to Cora
dataset = Planetoid(
root='./data/Cora',
name='Cora',
transform=NormalizeFeatures()
)
except Exception as e:
print(f"Error loading {dataset_name}: {e}")
# Fallback to Cora
dataset = Planetoid(
root='./data/Cora',
name='Cora',
transform=NormalizeFeatures()
)
return dataset
def load_graph_classification_data(self, dataset_name='MUTAG'):
"""Load real graph classification datasets"""
valid_datasets = ['MUTAG', 'ENZYMES', 'PROTEINS', 'COLLAB', 'IMDB-BINARY']
try:
if dataset_name not in valid_datasets:
dataset_name = 'MUTAG' # Default fallback
dataset = TUDataset(
root=f'./data/{dataset_name}',
name=dataset_name,
transform=NormalizeFeatures()
)
except Exception as e:
print(f"Error loading {dataset_name}: {e}")
# Create a minimal synthetic dataset as fallback
from torch_geometric.data import Data
dataset = [Data(
x=torch.randn(10, 5),
edge_index=torch.randint(0, 10, (2, 20)),
y=torch.randint(0, 2, (1,))
)]
return dataset
def create_dataloaders(self, dataset, task_type='node_classification'):
"""Create train/val/test splits"""
if task_type == 'node_classification':
# Use predefined splits for node classification
data = dataset[0]
return data, None, None # Single graph with masks
elif task_type == 'graph_classification':
# Random split for graph classification
num_graphs = len(dataset)
indices = torch.randperm(num_graphs)
train_size = int(0.8 * num_graphs)
val_size = int(0.1 * num_graphs)
train_dataset = [dataset[i] for i in indices[:train_size]]
val_dataset = [dataset[i] for i in indices[train_size:train_size+val_size]]
test_dataset = [dataset[i] for i in indices[train_size+val_size:]]
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
return train_loader, val_loader, test_loader
def get_dataset_info(self, dataset):
"""Get dynamic dataset information"""
try:
if hasattr(dataset, 'num_features'):
num_features = dataset.num_features
else:
num_features = dataset[0].x.size(1)
if hasattr(dataset, 'num_classes'):
num_classes = dataset.num_classes
else:
if hasattr(dataset[0], 'y'):
if len(dataset) > 1:
all_labels = torch.cat([data.y.flatten() for data in dataset])
num_classes = len(torch.unique(all_labels))
else:
num_classes = len(torch.unique(dataset[0].y))
else:
num_classes = 2 # Default binary
num_graphs = len(dataset)
avg_nodes = sum([data.num_nodes for data in dataset]) / len(dataset)
avg_edges = sum([data.num_edges for data in dataset]) / len(dataset)
except Exception as e:
print(f"Error getting dataset info: {e}")
# Return defaults
num_features = 1433 # Cora default
num_classes = 7 # Cora default
num_graphs = 1
avg_nodes = 2708
avg_edges = 10556
return {
'num_features': num_features,
'num_classes': num_classes,
'num_graphs': num_graphs,
'avg_nodes': avg_nodes,
'avg_edges': avg_edges
}