|
import pytorch_lightning as pl |
|
from torch.utils.data import DataLoader |
|
import torch |
|
from typing import Dict |
|
|
|
|
|
class DataModule(pl.LightningDataModule): |
|
""" |
|
Lightning DataModule for handling training and validation datasets. |
|
|
|
Args: |
|
training_set (torch.utils.data.Dataset): Training dataset. |
|
validation_set (torch.utils.data.Dataset): Validation dataset. |
|
|
|
Attributes: |
|
training_set (torch.utils.data.Dataset): Training dataset. |
|
validation_set (torch.utils.data.Dataset): Validation dataset. |
|
train_ds (torch.utils.data.Dataset): Alias for the training dataset during setup. |
|
val_ds (torch.utils.data.Dataset): Alias for the validation dataset during setup. |
|
|
|
Methods: |
|
setup(self, stage: Optional[str] = None): |
|
Setup method to load and preprocess datasets. |
|
|
|
train_dataloader(self) -> DataLoader: |
|
Return a DataLoader for the training dataset. |
|
|
|
val_dataloader(self) -> DataLoader: |
|
Return a DataLoader for the validation dataset. |
|
""" |
|
def __init__(self, training_set, validation_set): |
|
super().__init__() |
|
self.training_set = training_set |
|
self.validation_set = validation_set |
|
|
|
def setup(self, stage: str): |
|
self.train_ds = self.training_set |
|
self.val_ds = self.validation_set |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_ds, batch_size=1, shuffle=True) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.val_ds, batch_size=1, shuffle=False) |