zaydzuhri's picture
Add files using upload-large-folder tool
3c70147 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
import pickle
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader
from torchtitan.tools.logging import logger
class BaseDataLoader(Stateful, ABC):
"""Base class for all dataloaders.
This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
``state_dict()`` and ``load_state_dict()``.
"""
@abstractmethod
def __iter__(self):
...
class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
"""Dataloader that is aware of distributed data parallelism.
This dataloader is used to load data in a distributed data parallel fashion. It also
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
methods such as ``__iter__``.
Args:
dataset (IterableDataset): The dataset to iterate over.
dp_rank: Data parallelism rank for this dataloader.
dp_world_size: The world size of the data parallelism.
batch_size: The batch size to use for each iteration.
collate_fn: Optional function to collate samples in a batch.
"""
dp_rank: int
dp_world_size: int
batch_size: int
def __init__(
self,
dataset: IterableDataset,
dp_rank: int,
dp_world_size: int,
batch_size: int,
collate_fn: Callable | None = None,
):
self.dp_world_size = dp_world_size
self.dp_rank = dp_rank
self.batch_size = batch_size
super().__init__(dataset, batch_size, collate_fn=collate_fn)
self._rank_id = f"dp_rank_{dp_rank}"
def state_dict(self) -> dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions.
return {
# We don't have to use pickle as DCP will serialize the state_dict. However,
# we have to keep this for backward compatibility.
self._rank_id: pickle.dumps(super().state_dict()),
"world_size": self.dp_world_size,
}
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# State being empty is valid.
if not state_dict:
return
if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self.dp_rank}, "
"expected key {self._rank_id}"
)
return
assert self.dp_world_size == state_dict["world_size"], (
"dp_degree is inconsistent before and after checkpoint, "
"dataloader resharding is not supported yet."
)
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
# keep this for backward compatibility.
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))