kfoughali commited on
Commit
abceea1
·
verified ·
1 Parent(s): c681cda

Create data/loader.py

Browse files
Files changed (1) hide show
  1. data/loader.py +104 -0
data/loader.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.datasets import Planetoid, TUDataset, Reddit, Flickr
3
+ from torch_geometric.loader import DataLoader
4
+ from torch_geometric.transforms import NormalizeFeatures
5
+ import yaml
6
+
7
+ class GraphDataLoader:
8
+ """
9
+ Production data loading with real datasets only
10
+ No synthetic or hardcoded data
11
+ """
12
+
13
+ def __init__(self, config_path='config.yaml'):
14
+ with open(config_path, 'r') as f:
15
+ self.config = yaml.safe_load(f)
16
+
17
+ self.batch_size = self.config['data']['batch_size']
18
+ self.test_split = self.config['data']['test_split']
19
+
20
+ def load_node_classification_data(self, dataset_name='Cora'):
21
+ """Load real node classification datasets"""
22
+
23
+ if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
24
+ dataset = Planetoid(
25
+ root=f'./data/{dataset_name}',
26
+ name=dataset_name,
27
+ transform=NormalizeFeatures()
28
+ )
29
+ elif dataset_name == 'Reddit':
30
+ dataset = Reddit(
31
+ root='./data/Reddit',
32
+ transform=NormalizeFeatures()
33
+ )
34
+ elif dataset_name == 'Flickr':
35
+ dataset = Flickr(
36
+ root='./data/Flickr',
37
+ transform=NormalizeFeatures()
38
+ )
39
+ else:
40
+ raise ValueError(f"Unknown dataset: {dataset_name}")
41
+
42
+ return dataset
43
+
44
+ def load_graph_classification_data(self, dataset_name='MUTAG'):
45
+ """Load real graph classification datasets"""
46
+
47
+ valid_datasets = ['MUTAG', 'ENZYMES', 'PROTEINS', 'COLLAB', 'IMDB-BINARY']
48
+
49
+ if dataset_name not in valid_datasets:
50
+ raise ValueError(f"Dataset must be one of {valid_datasets}")
51
+
52
+ dataset = TUDataset(
53
+ root=f'./data/{dataset_name}',
54
+ name=dataset_name,
55
+ transform=NormalizeFeatures()
56
+ )
57
+
58
+ return dataset
59
+
60
+ def create_dataloaders(self, dataset, task_type='node_classification'):
61
+ """Create train/val/test splits"""
62
+
63
+ if task_type == 'node_classification':
64
+ # Use predefined splits for node classification
65
+ data = dataset[0]
66
+ return data, None, None # Single graph with masks
67
+
68
+ elif task_type == 'graph_classification':
69
+ # Random split for graph classification
70
+ num_graphs = len(dataset)
71
+ indices = torch.randperm(num_graphs)
72
+
73
+ train_size = int(0.8 * num_graphs)
74
+ val_size = int(0.1 * num_graphs)
75
+
76
+ train_dataset = dataset[indices[:train_size]]
77
+ val_dataset = dataset[indices[train_size:train_size+val_size]]
78
+ test_dataset = dataset[indices[train_size+val_size:]]
79
+
80
+ train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
81
+ val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
82
+ test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
83
+
84
+ return train_loader, val_loader, test_loader
85
+
86
+ def get_dataset_info(self, dataset):
87
+ """Get dynamic dataset information"""
88
+ if hasattr(dataset, 'num_features'):
89
+ num_features = dataset.num_features
90
+ else:
91
+ num_features = dataset[0].x.size(1)
92
+
93
+ if hasattr(dataset, 'num_classes'):
94
+ num_classes = dataset.num_classes
95
+ else:
96
+ num_classes = len(torch.unique(dataset[0].y))
97
+
98
+ return {
99
+ 'num_features': num_features,
100
+ 'num_classes': num_classes,
101
+ 'num_graphs': len(dataset),
102
+ 'avg_nodes': sum([data.num_nodes for data in dataset]) / len(dataset),
103
+ 'avg_edges': sum([data.num_edges for data in dataset]) / len(dataset)
104
+ }