kfoughali commited on
Commit
021bc4e
·
verified ·
1 Parent(s): 395c718

Update data/loader.py

Browse files
Files changed (1) hide show
  1. data/loader.py +143 -47
data/loader.py CHANGED
@@ -1,14 +1,13 @@
1
  import torch
2
- from torch_geometric.datasets import Planetoid, TUDataset
3
  from torch_geometric.loader import DataLoader
4
- from torch_geometric.transforms import NormalizeFeatures
5
  import yaml
6
  import os
7
 
8
  class GraphDataLoader:
9
  """
10
- Production data loading with real datasets only
11
- Device-safe implementation
12
  """
13
 
14
  def __init__(self, config_path='config.yaml'):
@@ -16,7 +15,7 @@ class GraphDataLoader:
16
  with open(config_path, 'r') as f:
17
  self.config = yaml.safe_load(f)
18
  else:
19
- # Default config if file doesn't exist
20
  self.config = {
21
  'data': {
22
  'batch_size': 32,
@@ -27,70 +26,139 @@ class GraphDataLoader:
27
  self.batch_size = self.config['data']['batch_size']
28
  self.test_split = self.config['data']['test_split']
29
 
 
 
 
 
 
30
  def load_node_classification_data(self, dataset_name='Cora'):
31
- """Load real node classification datasets"""
32
 
33
  try:
34
  if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
35
  dataset = Planetoid(
36
  root=f'./data/{dataset_name}',
37
  name=dataset_name,
38
- transform=NormalizeFeatures()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
 
40
  else:
41
- # Fallback to Cora
42
  dataset = Planetoid(
43
  root='./data/Cora',
44
  name='Cora',
45
- transform=NormalizeFeatures()
46
  )
 
47
  except Exception as e:
48
  print(f"Error loading {dataset_name}: {e}")
49
  # Fallback to Cora
50
  dataset = Planetoid(
51
  root='./data/Cora',
52
  name='Cora',
53
- transform=NormalizeFeatures()
54
  )
55
-
 
 
 
 
56
  return dataset
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def load_graph_classification_data(self, dataset_name='MUTAG'):
59
- """Load real graph classification datasets"""
60
 
61
- valid_datasets = ['MUTAG', 'ENZYMES', 'PROTEINS', 'COLLAB', 'IMDB-BINARY']
62
 
63
  try:
64
  if dataset_name not in valid_datasets:
65
- dataset_name = 'MUTAG' # Default fallback
66
 
67
  dataset = TUDataset(
68
  root=f'./data/{dataset_name}',
69
  name=dataset_name,
70
- transform=NormalizeFeatures()
71
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  except Exception as e:
73
  print(f"Error loading {dataset_name}: {e}")
74
- # Create a minimal synthetic dataset as fallback
75
  from torch_geometric.data import Data
76
- dataset = [Data(
77
- x=torch.randn(10, 5),
78
- edge_index=torch.randint(0, 10, (2, 20)),
79
- y=torch.randint(0, 2, (1,))
80
- )]
 
 
81
 
82
  return dataset
83
 
84
  def create_dataloaders(self, dataset, task_type='node_classification'):
85
- """Create train/val/test splits"""
86
 
87
  if task_type == 'node_classification':
88
- # Use predefined splits for node classification
89
  data = dataset[0]
90
- return data, None, None # Single graph with masks
91
 
92
  elif task_type == 'graph_classification':
93
- # Random split for graph classification
94
  num_graphs = len(dataset)
95
  indices = torch.randperm(num_graphs)
96
 
@@ -108,42 +176,70 @@ class GraphDataLoader:
108
  return train_loader, val_loader, test_loader
109
 
110
  def get_dataset_info(self, dataset):
111
- """Get dynamic dataset information"""
112
  try:
113
  if hasattr(dataset, 'num_features'):
114
  num_features = dataset.num_features
115
  else:
116
- num_features = dataset[0].x.size(1)
117
 
118
  if hasattr(dataset, 'num_classes'):
119
  num_classes = dataset.num_classes
120
  else:
121
- if hasattr(dataset[0], 'y'):
122
  if len(dataset) > 1:
123
- all_labels = torch.cat([data.y.flatten() for data in dataset])
124
- num_classes = len(torch.unique(all_labels))
 
 
 
125
  else:
126
  num_classes = len(torch.unique(dataset[0].y))
127
  else:
128
- num_classes = 2 # Default binary
129
 
130
  num_graphs = len(dataset)
131
- avg_nodes = sum([data.num_nodes for data in dataset]) / len(dataset)
132
- avg_edges = sum([data.num_edges for data in dataset]) / len(dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  except Exception as e:
135
  print(f"Error getting dataset info: {e}")
136
- # Return defaults
137
- num_features = 1433 # Cora default
138
- num_classes = 7 # Cora default
139
- num_graphs = 1
140
- avg_nodes = 2708
141
- avg_edges = 10556
142
-
143
- return {
144
- 'num_features': num_features,
145
- 'num_classes': num_classes,
146
- 'num_graphs': num_graphs,
147
- 'avg_nodes': avg_nodes,
148
- 'avg_edges': avg_edges
149
- }
 
 
 
1
  import torch
2
+ from torch_geometric.datasets import Planetoid, TUDataset, Amazon, Coauthor
3
  from torch_geometric.loader import DataLoader
4
+ from torch_geometric.transforms import NormalizeFeatures, Compose
5
  import yaml
6
  import os
7
 
8
  class GraphDataLoader:
9
  """
10
+ Production data loading with comprehensive dataset support
 
11
  """
12
 
13
  def __init__(self, config_path='config.yaml'):
 
15
  with open(config_path, 'r') as f:
16
  self.config = yaml.safe_load(f)
17
  else:
18
+ # Default config
19
  self.config = {
20
  'data': {
21
  'batch_size': 32,
 
26
  self.batch_size = self.config['data']['batch_size']
27
  self.test_split = self.config['data']['test_split']
28
 
29
+ # Standard transform
30
+ self.transform = Compose([
31
+ NormalizeFeatures()
32
+ ])
33
+
34
  def load_node_classification_data(self, dataset_name='Cora'):
35
+ """Load node classification datasets with proper splits"""
36
 
37
  try:
38
  if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
39
  dataset = Planetoid(
40
  root=f'./data/{dataset_name}',
41
  name=dataset_name,
42
+ transform=self.transform
43
+ )
44
+
45
+ elif dataset_name in ['Computers', 'Photo']:
46
+ dataset = Amazon(
47
+ root=f'./data/Amazon{dataset_name}',
48
+ name=dataset_name,
49
+ transform=self.transform
50
+ )
51
+
52
+ elif dataset_name in ['CS', 'Physics']:
53
+ dataset = Coauthor(
54
+ root=f'./data/Coauthor{dataset_name}',
55
+ name=dataset_name,
56
+ transform=self.transform
57
  )
58
+
59
  else:
60
+ print(f"Unknown dataset {dataset_name}, falling back to Cora")
61
  dataset = Planetoid(
62
  root='./data/Cora',
63
  name='Cora',
64
+ transform=self.transform
65
  )
66
+
67
  except Exception as e:
68
  print(f"Error loading {dataset_name}: {e}")
69
  # Fallback to Cora
70
  dataset = Planetoid(
71
  root='./data/Cora',
72
  name='Cora',
73
+ transform=self.transform
74
  )
75
+
76
+ # Ensure proper masks exist
77
+ data = dataset[0]
78
+ self._ensure_masks(data)
79
+
80
  return dataset
81
 
82
+ def _ensure_masks(self, data):
83
+ """Ensure train/val/test masks exist"""
84
+ num_nodes = data.num_nodes
85
+
86
+ if not hasattr(data, 'train_mask') or data.train_mask is None:
87
+ # Create random splits
88
+ indices = torch.randperm(num_nodes)
89
+
90
+ train_size = int(0.6 * num_nodes)
91
+ val_size = int(0.2 * num_nodes)
92
+
93
+ train_mask = torch.zeros(num_nodes, dtype=torch.bool)
94
+ val_mask = torch.zeros(num_nodes, dtype=torch.bool)
95
+ test_mask = torch.zeros(num_nodes, dtype=torch.bool)
96
+
97
+ train_mask[indices[:train_size]] = True
98
+ val_mask[indices[train_size:train_size + val_size]] = True
99
+ test_mask[indices[train_size + val_size:]] = True
100
+
101
+ data.train_mask = train_mask
102
+ data.val_mask = val_mask
103
+ data.test_mask = test_mask
104
+
105
  def load_graph_classification_data(self, dataset_name='MUTAG'):
106
+ """Load graph classification datasets"""
107
 
108
+ valid_datasets = ['MUTAG', 'ENZYMES', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'DD']
109
 
110
  try:
111
  if dataset_name not in valid_datasets:
112
+ dataset_name = 'MUTAG'
113
 
114
  dataset = TUDataset(
115
  root=f'./data/{dataset_name}',
116
  name=dataset_name,
117
+ transform=self.transform
118
  )
119
+
120
+ # Handle missing features
121
+ if dataset[0].x is None:
122
+ # Use degree as features
123
+ max_degree = 0
124
+ for data in dataset:
125
+ if data.edge_index.shape[1] > 0:
126
+ degree = torch.zeros(data.num_nodes)
127
+ degree.index_add_(0, data.edge_index[0], torch.ones(data.edge_index.shape[1]))
128
+ max_degree = max(max_degree, degree.max().item())
129
+
130
+ for data in dataset:
131
+ if data.edge_index.shape[1] > 0:
132
+ degree = torch.zeros(data.num_nodes)
133
+ degree.index_add_(0, data.edge_index[0], torch.ones(data.edge_index.shape[1]))
134
+ data.x = degree.unsqueeze(1) / max(max_degree, 1)
135
+ else:
136
+ data.x = torch.zeros(data.num_nodes, 1)
137
+
138
  except Exception as e:
139
  print(f"Error loading {dataset_name}: {e}")
140
+ # Create minimal synthetic dataset
141
  from torch_geometric.data import Data
142
+ dataset = [
143
+ Data(
144
+ x=torch.randn(10, 5),
145
+ edge_index=torch.randint(0, 10, (2, 20)),
146
+ y=torch.randint(0, 2, (1,))
147
+ ) for _ in range(100)
148
+ ]
149
 
150
  return dataset
151
 
152
  def create_dataloaders(self, dataset, task_type='node_classification'):
153
+ """Create train/val/test splits with dataloaders"""
154
 
155
  if task_type == 'node_classification':
156
+ # Single graph with masks
157
  data = dataset[0]
158
+ return data, None, None
159
 
160
  elif task_type == 'graph_classification':
161
+ # Split dataset
162
  num_graphs = len(dataset)
163
  indices = torch.randperm(num_graphs)
164
 
 
176
  return train_loader, val_loader, test_loader
177
 
178
  def get_dataset_info(self, dataset):
179
+ """Get comprehensive dataset information"""
180
  try:
181
  if hasattr(dataset, 'num_features'):
182
  num_features = dataset.num_features
183
  else:
184
+ num_features = dataset[0].x.size(1) if dataset[0].x is not None else 1
185
 
186
  if hasattr(dataset, 'num_classes'):
187
  num_classes = dataset.num_classes
188
  else:
189
+ if hasattr(dataset[0], 'y') and dataset[0].y is not None:
190
  if len(dataset) > 1:
191
+ all_labels = []
192
+ for data in dataset:
193
+ if data.y is not None:
194
+ all_labels.extend(data.y.flatten().tolist())
195
+ num_classes = len(set(all_labels)) if all_labels else 2
196
  else:
197
  num_classes = len(torch.unique(dataset[0].y))
198
  else:
199
+ num_classes = 2
200
 
201
  num_graphs = len(dataset)
202
+
203
+ # Calculate statistics
204
+ total_nodes = sum([data.num_nodes for data in dataset])
205
+ total_edges = sum([data.num_edges for data in dataset])
206
+
207
+ avg_nodes = total_nodes / num_graphs
208
+ avg_edges = total_edges / num_graphs
209
+
210
+ # Additional statistics
211
+ node_counts = [data.num_nodes for data in dataset]
212
+ edge_counts = [data.num_edges for data in dataset]
213
+
214
+ stats = {
215
+ 'num_features': num_features,
216
+ 'num_classes': num_classes,
217
+ 'num_graphs': num_graphs,
218
+ 'avg_nodes': avg_nodes,
219
+ 'avg_edges': avg_edges,
220
+ 'min_nodes': min(node_counts),
221
+ 'max_nodes': max(node_counts),
222
+ 'min_edges': min(edge_counts),
223
+ 'max_edges': max(edge_counts),
224
+ 'total_nodes': total_nodes,
225
+ 'total_edges': total_edges
226
+ }
227
 
228
  except Exception as e:
229
  print(f"Error getting dataset info: {e}")
230
+ # Return safe defaults
231
+ stats = {
232
+ 'num_features': 1433,
233
+ 'num_classes': 7,
234
+ 'num_graphs': 1,
235
+ 'avg_nodes': 2708.0,
236
+ 'avg_edges': 10556.0,
237
+ 'min_nodes': 2708,
238
+ 'max_nodes': 2708,
239
+ 'min_edges': 10556,
240
+ 'max_edges': 10556,
241
+ 'total_nodes': 2708,
242
+ 'total_edges': 10556
243
+ }
244
+
245
+ return stats