zaydzuhri's picture
Add files using upload-large-folder tool
53083dd verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Any, Callable
import torch
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.tools.logging import logger
def _load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
return load_dataset(dataset_path, name="en", split="train", streaming=True)
def _process_c4_text(sample: dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]
@dataclass
class DatasetConfig:
path: str
loader: Callable
text_processor: Callable
# Add your dataset here here - more information at docs/datasets.md
DATASETS = {
"c4": DatasetConfig(
path="allenai/c4",
loader=_load_c4_dataset,
text_processor=_process_c4_text,
),
"c4_test": DatasetConfig(
path="tests/assets/c4_test",
loader=lambda path: load_dataset(path, split="train"),
text_processor=_process_c4_text,
),
}
def _validate_dataset(
dataset_name: str, dataset_path: str | None = None
) -> tuple[str, Callable, Callable]:
"""Validate dataset name and path."""
if dataset_name not in DATASETS:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(DATASETS.keys())}"
)
config = DATASETS[dataset_name]
path = dataset_path or config.path
logger.info(f"Preparing {dataset_name} dataset from {path}")
return path, config.loader, config.text_processor
class HuggingFaceDataset(IterableDataset, Stateful):
def __init__(
self,
dataset_name: str,
dataset_path: str | None,
tokenizer: Tokenizer,
seq_len: int = 2048,
dp_rank: int = 0,
dp_world_size: int = 1,
infinite: bool = False,
) -> None:
# Force lowercase for consistent comparison
dataset_name = dataset_name.lower()
path, dataset_loader, text_processor = _validate_dataset(
dataset_name, dataset_path
)
ds = dataset_loader(path)
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite
self._text_processor = text_processor
# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: list[int] = []
def _get_data_iter(self):
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])
it = iter(self._data)
for _ in range(self._sample_idx):
next(it)
return it
def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
while True:
for sample in self._get_data_iter():
# Use the dataset-specific text processor
sample_text = self._text_processor(sample)
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1
while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield {"input": input}, label
if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]
def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
def build_hf_dataloader(
dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
job_config: JobConfig,
infinite: bool = True,
) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.batch_size
seq_len = job_config.training.seq_len
hf_ds = HuggingFaceDataset(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
)
return ParallelAwareDataloader(
dataset=hf_ds,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)