|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Callable |
|
|
|
import torch |
|
|
|
from datasets import Dataset, load_dataset |
|
from datasets.distributed import split_dataset_by_node |
|
from torch.distributed.checkpoint.stateful import Stateful |
|
from torch.utils.data import IterableDataset |
|
|
|
from torchtitan.components.dataloader import ParallelAwareDataloader |
|
from torchtitan.components.tokenizer import Tokenizer |
|
from torchtitan.config_manager import JobConfig |
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
def _load_c4_dataset(dataset_path: str): |
|
"""Load C4 dataset with default configuration.""" |
|
return load_dataset(dataset_path, name="en", split="train", streaming=True) |
|
|
|
|
|
def _process_c4_text(sample: dict[str, Any]) -> str: |
|
"""Process C4 dataset sample text.""" |
|
return sample["text"] |
|
|
|
|
|
@dataclass |
|
class DatasetConfig: |
|
path: str |
|
loader: Callable |
|
text_processor: Callable |
|
|
|
|
|
|
|
DATASETS = { |
|
"c4": DatasetConfig( |
|
path="allenai/c4", |
|
loader=_load_c4_dataset, |
|
text_processor=_process_c4_text, |
|
), |
|
"c4_test": DatasetConfig( |
|
path="tests/assets/c4_test", |
|
loader=lambda path: load_dataset(path, split="train"), |
|
text_processor=_process_c4_text, |
|
), |
|
} |
|
|
|
|
|
def _validate_dataset( |
|
dataset_name: str, dataset_path: str | None = None |
|
) -> tuple[str, Callable, Callable]: |
|
"""Validate dataset name and path.""" |
|
if dataset_name not in DATASETS: |
|
raise ValueError( |
|
f"Dataset {dataset_name} is not supported. " |
|
f"Supported datasets are: {list(DATASETS.keys())}" |
|
) |
|
|
|
config = DATASETS[dataset_name] |
|
path = dataset_path or config.path |
|
logger.info(f"Preparing {dataset_name} dataset from {path}") |
|
return path, config.loader, config.text_processor |
|
|
|
|
|
class HuggingFaceDataset(IterableDataset, Stateful): |
|
def __init__( |
|
self, |
|
dataset_name: str, |
|
dataset_path: str | None, |
|
tokenizer: Tokenizer, |
|
seq_len: int = 2048, |
|
dp_rank: int = 0, |
|
dp_world_size: int = 1, |
|
infinite: bool = False, |
|
) -> None: |
|
|
|
dataset_name = dataset_name.lower() |
|
|
|
path, dataset_loader, text_processor = _validate_dataset( |
|
dataset_name, dataset_path |
|
) |
|
ds = dataset_loader(path) |
|
|
|
self.dataset_name = dataset_name |
|
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) |
|
self._tokenizer = tokenizer |
|
self.seq_len = seq_len |
|
self.infinite = infinite |
|
self._text_processor = text_processor |
|
|
|
|
|
self._sample_idx = 0 |
|
self._all_tokens: list[int] = [] |
|
|
|
def _get_data_iter(self): |
|
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): |
|
return iter([]) |
|
|
|
it = iter(self._data) |
|
for _ in range(self._sample_idx): |
|
next(it) |
|
return it |
|
|
|
def __iter__(self): |
|
max_buffer_token_len = 1 + self.seq_len |
|
|
|
while True: |
|
for sample in self._get_data_iter(): |
|
|
|
sample_text = self._text_processor(sample) |
|
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) |
|
self._all_tokens.extend(sample_tokens) |
|
self._sample_idx += 1 |
|
|
|
while len(self._all_tokens) >= max_buffer_token_len: |
|
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) |
|
|
|
self._all_tokens = self._all_tokens[max_buffer_token_len:] |
|
input = x[:-1] |
|
label = x[1:] |
|
yield {"input": input}, label |
|
|
|
if not self.infinite: |
|
logger.warning(f"Dataset {self.dataset_name} has run out of data") |
|
break |
|
else: |
|
|
|
self._sample_idx = 0 |
|
logger.warning(f"Dataset {self.dataset_name} is being re-looped") |
|
|
|
def load_state_dict(self, state_dict): |
|
self._sample_idx = state_dict["sample_idx"] |
|
self._all_tokens = state_dict["token_buffer"] |
|
|
|
def state_dict(self): |
|
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} |
|
|
|
|
|
def build_hf_dataloader( |
|
dp_world_size: int, |
|
dp_rank: int, |
|
tokenizer: Tokenizer, |
|
job_config: JobConfig, |
|
infinite: bool = True, |
|
) -> ParallelAwareDataloader: |
|
"""Build a data loader for HuggingFace datasets.""" |
|
dataset_name = job_config.training.dataset |
|
dataset_path = job_config.training.dataset_path |
|
batch_size = job_config.training.batch_size |
|
seq_len = job_config.training.seq_len |
|
|
|
hf_ds = HuggingFaceDataset( |
|
dataset_name=dataset_name, |
|
dataset_path=dataset_path, |
|
tokenizer=tokenizer, |
|
seq_len=seq_len, |
|
dp_rank=dp_rank, |
|
dp_world_size=dp_world_size, |
|
infinite=infinite, |
|
) |
|
|
|
return ParallelAwareDataloader( |
|
dataset=hf_ds, |
|
dp_rank=dp_rank, |
|
dp_world_size=dp_world_size, |
|
batch_size=batch_size, |
|
) |
|
|