# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass from typing import Any, Callable import torch from datasets import Dataset, load_dataset from datasets.distributed import split_dataset_by_node from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset from torchtitan.components.dataloader import ParallelAwareDataloader from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig from torchtitan.tools.logging import logger def _load_c4_dataset(dataset_path: str): """Load C4 dataset with default configuration.""" return load_dataset(dataset_path, name="en", split="train", streaming=True) def _process_c4_text(sample: dict[str, Any]) -> str: """Process C4 dataset sample text.""" return sample["text"] @dataclass class DatasetConfig: path: str loader: Callable text_processor: Callable # Add your dataset here here - more information at docs/datasets.md DATASETS = { "c4": DatasetConfig( path="allenai/c4", loader=_load_c4_dataset, text_processor=_process_c4_text, ), "c4_test": DatasetConfig( path="tests/assets/c4_test", loader=lambda path: load_dataset(path, split="train"), text_processor=_process_c4_text, ), } def _validate_dataset( dataset_name: str, dataset_path: str | None = None ) -> tuple[str, Callable, Callable]: """Validate dataset name and path.""" if dataset_name not in DATASETS: raise ValueError( f"Dataset {dataset_name} is not supported. " f"Supported datasets are: {list(DATASETS.keys())}" ) config = DATASETS[dataset_name] path = dataset_path or config.path logger.info(f"Preparing {dataset_name} dataset from {path}") return path, config.loader, config.text_processor class HuggingFaceDataset(IterableDataset, Stateful): def __init__( self, dataset_name: str, dataset_path: str | None, tokenizer: Tokenizer, seq_len: int = 2048, dp_rank: int = 0, dp_world_size: int = 1, infinite: bool = False, ) -> None: # Force lowercase for consistent comparison dataset_name = dataset_name.lower() path, dataset_loader, text_processor = _validate_dataset( dataset_name, dataset_path ) ds = dataset_loader(path) self.dataset_name = dataset_name self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) self._tokenizer = tokenizer self.seq_len = seq_len self.infinite = infinite self._text_processor = text_processor # Variables for checkpointing self._sample_idx = 0 self._all_tokens: list[int] = [] def _get_data_iter(self): if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): return iter([]) it = iter(self._data) for _ in range(self._sample_idx): next(it) return it def __iter__(self): max_buffer_token_len = 1 + self.seq_len while True: for sample in self._get_data_iter(): # Use the dataset-specific text processor sample_text = self._text_processor(sample) sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) self._all_tokens.extend(sample_tokens) self._sample_idx += 1 while len(self._all_tokens) >= max_buffer_token_len: x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) # update tokens to the remaining tokens self._all_tokens = self._all_tokens[max_buffer_token_len:] input = x[:-1] label = x[1:] yield {"input": input}, label if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data") break else: # Reset offset for the next iteration self._sample_idx = 0 logger.warning(f"Dataset {self.dataset_name} is being re-looped") def load_state_dict(self, state_dict): self._sample_idx = state_dict["sample_idx"] self._all_tokens = state_dict["token_buffer"] def state_dict(self): return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} def build_hf_dataloader( dp_world_size: int, dp_rank: int, tokenizer: Tokenizer, job_config: JobConfig, infinite: bool = True, ) -> ParallelAwareDataloader: """Build a data loader for HuggingFace datasets.""" dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path batch_size = job_config.training.batch_size seq_len = job_config.training.seq_len hf_ds = HuggingFaceDataset( dataset_name=dataset_name, dataset_path=dataset_path, tokenizer=tokenizer, seq_len=seq_len, dp_rank=dp_rank, dp_world_size=dp_world_size, infinite=infinite, ) return ParallelAwareDataloader( dataset=hf_ds, dp_rank=dp_rank, dp_world_size=dp_world_size, batch_size=batch_size, )