File size: 5,507 Bytes
53083dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Any, Callable
import torch
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.tools.logging import logger
def _load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
return load_dataset(dataset_path, name="en", split="train", streaming=True)
def _process_c4_text(sample: dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]
@dataclass
class DatasetConfig:
path: str
loader: Callable
text_processor: Callable
# Add your dataset here here - more information at docs/datasets.md
DATASETS = {
"c4": DatasetConfig(
path="allenai/c4",
loader=_load_c4_dataset,
text_processor=_process_c4_text,
),
"c4_test": DatasetConfig(
path="tests/assets/c4_test",
loader=lambda path: load_dataset(path, split="train"),
text_processor=_process_c4_text,
),
}
def _validate_dataset(
dataset_name: str, dataset_path: str | None = None
) -> tuple[str, Callable, Callable]:
"""Validate dataset name and path."""
if dataset_name not in DATASETS:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(DATASETS.keys())}"
)
config = DATASETS[dataset_name]
path = dataset_path or config.path
logger.info(f"Preparing {dataset_name} dataset from {path}")
return path, config.loader, config.text_processor
class HuggingFaceDataset(IterableDataset, Stateful):
def __init__(
self,
dataset_name: str,
dataset_path: str | None,
tokenizer: Tokenizer,
seq_len: int = 2048,
dp_rank: int = 0,
dp_world_size: int = 1,
infinite: bool = False,
) -> None:
# Force lowercase for consistent comparison
dataset_name = dataset_name.lower()
path, dataset_loader, text_processor = _validate_dataset(
dataset_name, dataset_path
)
ds = dataset_loader(path)
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite
self._text_processor = text_processor
# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: list[int] = []
def _get_data_iter(self):
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])
it = iter(self._data)
for _ in range(self._sample_idx):
next(it)
return it
def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
while True:
for sample in self._get_data_iter():
# Use the dataset-specific text processor
sample_text = self._text_processor(sample)
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1
while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield {"input": input}, label
if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]
def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
def build_hf_dataloader(
dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
job_config: JobConfig,
infinite: bool = True,
) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.batch_size
seq_len = job_config.training.seq_len
hf_ds = HuggingFaceDataset(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
)
return ParallelAwareDataloader(
dataset=hf_ds,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
|