import pytorch_lightning as pl from torch.utils.data import DataLoader import torch from typing import Dict class DataModule(pl.LightningDataModule): """ Lightning DataModule for handling training and validation datasets. Args: training_set (torch.utils.data.Dataset): Training dataset. validation_set (torch.utils.data.Dataset): Validation dataset. Attributes: training_set (torch.utils.data.Dataset): Training dataset. validation_set (torch.utils.data.Dataset): Validation dataset. train_ds (torch.utils.data.Dataset): Alias for the training dataset during setup. val_ds (torch.utils.data.Dataset): Alias for the validation dataset during setup. Methods: setup(self, stage: Optional[str] = None): Setup method to load and preprocess datasets. train_dataloader(self) -> DataLoader: Return a DataLoader for the training dataset. val_dataloader(self) -> DataLoader: Return a DataLoader for the validation dataset. """ def __init__(self, training_set, validation_set): super().__init__() self.training_set = training_set self.validation_set = validation_set def setup(self, stage: str): self.train_ds = self.training_set self.val_ds = self.validation_set def train_dataloader(self): return DataLoader(self.train_ds, batch_size=1, shuffle=True) def val_dataloader(self): return DataLoader(self.val_ds, batch_size=1, shuffle=False)