import torch from torch_geometric.datasets import Planetoid, TUDataset from torch_geometric.loader import DataLoader from torch_geometric.transforms import NormalizeFeatures import yaml import os class GraphDataLoader: """ Production data loading with real datasets only Device-safe implementation """ def __init__(self, config_path='config.yaml'): if os.path.exists(config_path): with open(config_path, 'r') as f: self.config = yaml.safe_load(f) else: # Default config if file doesn't exist self.config = { 'data': { 'batch_size': 32, 'test_split': 0.2 } } self.batch_size = self.config['data']['batch_size'] self.test_split = self.config['data']['test_split'] def load_node_classification_data(self, dataset_name='Cora'): """Load real node classification datasets""" try: if dataset_name in ['Cora', 'CiteSeer', 'PubMed']: dataset = Planetoid( root=f'./data/{dataset_name}', name=dataset_name, transform=NormalizeFeatures() ) else: # Fallback to Cora dataset = Planetoid( root='./data/Cora', name='Cora', transform=NormalizeFeatures() ) except Exception as e: print(f"Error loading {dataset_name}: {e}") # Fallback to Cora dataset = Planetoid( root='./data/Cora', name='Cora', transform=NormalizeFeatures() ) return dataset def load_graph_classification_data(self, dataset_name='MUTAG'): """Load real graph classification datasets""" valid_datasets = ['MUTAG', 'ENZYMES', 'PROTEINS', 'COLLAB', 'IMDB-BINARY'] try: if dataset_name not in valid_datasets: dataset_name = 'MUTAG' # Default fallback dataset = TUDataset( root=f'./data/{dataset_name}', name=dataset_name, transform=NormalizeFeatures() ) except Exception as e: print(f"Error loading {dataset_name}: {e}") # Create a minimal synthetic dataset as fallback from torch_geometric.data import Data dataset = [Data( x=torch.randn(10, 5), edge_index=torch.randint(0, 10, (2, 20)), y=torch.randint(0, 2, (1,)) )] return dataset def create_dataloaders(self, dataset, task_type='node_classification'): """Create train/val/test splits""" if task_type == 'node_classification': # Use predefined splits for node classification data = dataset[0] return data, None, None # Single graph with masks elif task_type == 'graph_classification': # Random split for graph classification num_graphs = len(dataset) indices = torch.randperm(num_graphs) train_size = int(0.8 * num_graphs) val_size = int(0.1 * num_graphs) train_dataset = [dataset[i] for i in indices[:train_size]] val_dataset = [dataset[i] for i in indices[train_size:train_size+val_size]] test_dataset = [dataset[i] for i in indices[train_size+val_size:]] train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False) return train_loader, val_loader, test_loader def get_dataset_info(self, dataset): """Get dynamic dataset information""" try: if hasattr(dataset, 'num_features'): num_features = dataset.num_features else: num_features = dataset[0].x.size(1) if hasattr(dataset, 'num_classes'): num_classes = dataset.num_classes else: if hasattr(dataset[0], 'y'): if len(dataset) > 1: all_labels = torch.cat([data.y.flatten() for data in dataset]) num_classes = len(torch.unique(all_labels)) else: num_classes = len(torch.unique(dataset[0].y)) else: num_classes = 2 # Default binary num_graphs = len(dataset) avg_nodes = sum([data.num_nodes for data in dataset]) / len(dataset) avg_edges = sum([data.num_edges for data in dataset]) / len(dataset) except Exception as e: print(f"Error getting dataset info: {e}") # Return defaults num_features = 1433 # Cora default num_classes = 7 # Cora default num_graphs = 1 avg_nodes = 2708 avg_edges = 10556 return { 'num_features': num_features, 'num_classes': num_classes, 'num_graphs': num_graphs, 'avg_nodes': avg_nodes, 'avg_edges': avg_edges }