|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pickle |
|
from abc import ABC, abstractmethod |
|
from collections.abc import Callable |
|
from typing import Any |
|
|
|
from torch.distributed.checkpoint.stateful import Stateful |
|
from torch.utils.data import IterableDataset |
|
from torchdata.stateful_dataloader import StatefulDataLoader |
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
class BaseDataLoader(Stateful, ABC): |
|
"""Base class for all dataloaders. |
|
|
|
This is used to enforce that all dataloaders have the methods defined in ``Stateful``, |
|
``state_dict()`` and ``load_state_dict()``. |
|
""" |
|
|
|
@abstractmethod |
|
def __iter__(self): |
|
... |
|
|
|
|
|
class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): |
|
"""Dataloader that is aware of distributed data parallelism. |
|
|
|
This dataloader is used to load data in a distributed data parallel fashion. It also |
|
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary |
|
methods such as ``__iter__``. |
|
|
|
Args: |
|
dataset (IterableDataset): The dataset to iterate over. |
|
dp_rank: Data parallelism rank for this dataloader. |
|
dp_world_size: The world size of the data parallelism. |
|
batch_size: The batch size to use for each iteration. |
|
collate_fn: Optional function to collate samples in a batch. |
|
""" |
|
|
|
dp_rank: int |
|
dp_world_size: int |
|
batch_size: int |
|
|
|
def __init__( |
|
self, |
|
dataset: IterableDataset, |
|
dp_rank: int, |
|
dp_world_size: int, |
|
batch_size: int, |
|
collate_fn: Callable | None = None, |
|
): |
|
self.dp_world_size = dp_world_size |
|
self.dp_rank = dp_rank |
|
self.batch_size = batch_size |
|
super().__init__(dataset, batch_size, collate_fn=collate_fn) |
|
self._rank_id = f"dp_rank_{dp_rank}" |
|
|
|
def state_dict(self) -> dict[str, Any]: |
|
|
|
return { |
|
|
|
|
|
self._rank_id: pickle.dumps(super().state_dict()), |
|
"world_size": self.dp_world_size, |
|
} |
|
|
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
|
|
|
if not state_dict: |
|
return |
|
|
|
if self._rank_id not in state_dict: |
|
logger.warning( |
|
f"DataLoader state is empty for dp rank {self.dp_rank}, " |
|
"expected key {self._rank_id}" |
|
) |
|
return |
|
|
|
assert self.dp_world_size == state_dict["world_size"], ( |
|
"dp_degree is inconsistent before and after checkpoint, " |
|
"dataloader resharding is not supported yet." |
|
) |
|
|
|
|
|
super().load_state_dict(pickle.loads(state_dict[self._rank_id])) |
|
|