Efficiently get the length of the tokenized docs (#1063)
Browse files* Efficiently get the length of the tokenized docs
* chore: lint
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -37,7 +37,7 @@ from axolotl.utils.collators import (
|
|
| 37 |
DataCollatorForSeq2Seq,
|
| 38 |
MambaDataCollator,
|
| 39 |
)
|
| 40 |
-
from axolotl.utils.samplers import MultipackBatchSampler
|
| 41 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
| 42 |
|
| 43 |
try:
|
|
@@ -170,12 +170,7 @@ class AxolotlTrainer(Trainer):
|
|
| 170 |
self.args.train_batch_size,
|
| 171 |
drop_last=True,
|
| 172 |
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
| 173 |
-
lengths=(
|
| 174 |
-
self.train_dataset.data.column("position_ids")
|
| 175 |
-
.to_pandas()
|
| 176 |
-
.apply(lambda x: x[-1] + 1)
|
| 177 |
-
.values
|
| 178 |
-
),
|
| 179 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
| 180 |
)
|
| 181 |
return super()._get_train_sampler()
|
|
@@ -189,12 +184,7 @@ class AxolotlTrainer(Trainer):
|
|
| 189 |
self.args.per_device_eval_batch_size,
|
| 190 |
drop_last=True,
|
| 191 |
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
| 192 |
-
lengths=(
|
| 193 |
-
eval_dataset.data.column("position_ids")
|
| 194 |
-
.to_pandas()
|
| 195 |
-
.apply(lambda x: x[-1] + 1)
|
| 196 |
-
.values
|
| 197 |
-
),
|
| 198 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
| 199 |
)
|
| 200 |
return super()._get_eval_sampler(eval_dataset)
|
|
|
|
| 37 |
DataCollatorForSeq2Seq,
|
| 38 |
MambaDataCollator,
|
| 39 |
)
|
| 40 |
+
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 41 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
| 42 |
|
| 43 |
try:
|
|
|
|
| 170 |
self.args.train_batch_size,
|
| 171 |
drop_last=True,
|
| 172 |
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
| 173 |
+
lengths=get_dataset_lengths(self.train_dataset),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
| 175 |
)
|
| 176 |
return super()._get_train_sampler()
|
|
|
|
| 184 |
self.args.per_device_eval_batch_size,
|
| 185 |
drop_last=True,
|
| 186 |
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
| 187 |
+
lengths=get_dataset_lengths(eval_dataset),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
| 189 |
)
|
| 190 |
return super()._get_eval_sampler(eval_dataset)
|
src/axolotl/utils/data.py
CHANGED
|
@@ -44,7 +44,7 @@ from axolotl.prompters import (
|
|
| 44 |
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
| 45 |
from axolotl.utils.dict import DictDefault
|
| 46 |
from axolotl.utils.distributed import is_main_process, zero_first
|
| 47 |
-
from axolotl.utils.samplers
|
| 48 |
from axolotl.utils.trainer import (
|
| 49 |
calculate_total_num_steps,
|
| 50 |
process_datasets_for_packing,
|
|
@@ -889,12 +889,7 @@ def encode_packed_pretraining(
|
|
| 889 |
batch_size=batch_size,
|
| 890 |
drop_last=True,
|
| 891 |
batch_max_len=batch_size * max_seq_length,
|
| 892 |
-
lengths=(
|
| 893 |
-
train_dataset.data.column("position_ids")
|
| 894 |
-
.to_pandas()
|
| 895 |
-
.apply(lambda x: x[-1] + 1)
|
| 896 |
-
.values
|
| 897 |
-
),
|
| 898 |
)
|
| 899 |
|
| 900 |
chunked_data = defaultdict(list)
|
|
|
|
| 44 |
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
| 45 |
from axolotl.utils.dict import DictDefault
|
| 46 |
from axolotl.utils.distributed import is_main_process, zero_first
|
| 47 |
+
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 48 |
from axolotl.utils.trainer import (
|
| 49 |
calculate_total_num_steps,
|
| 50 |
process_datasets_for_packing,
|
|
|
|
| 889 |
batch_size=batch_size,
|
| 890 |
drop_last=True,
|
| 891 |
batch_max_len=batch_size * max_seq_length,
|
| 892 |
+
lengths=get_dataset_lengths(train_dataset),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
)
|
| 894 |
|
| 895 |
chunked_data = defaultdict(list)
|
src/axolotl/utils/samplers/__init__.py
CHANGED
|
@@ -2,3 +2,4 @@
|
|
| 2 |
axolotl samplers module
|
| 3 |
"""
|
| 4 |
from .multipack import MultipackBatchSampler # noqa: F401
|
|
|
|
|
|
| 2 |
axolotl samplers module
|
| 3 |
"""
|
| 4 |
from .multipack import MultipackBatchSampler # noqa: F401
|
| 5 |
+
from .utils import get_dataset_lengths # noqa: F401
|
src/axolotl/utils/samplers/utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
helper util to calculate dataset lengths
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_dataset_lengths(dataset):
|
| 8 |
+
if "length" in dataset.data.column_names:
|
| 9 |
+
lengths = np.array(dataset.data.column("length"))
|
| 10 |
+
else:
|
| 11 |
+
lengths = (
|
| 12 |
+
dataset.data.column("position_ids")
|
| 13 |
+
.to_pandas()
|
| 14 |
+
.apply(lambda x: x[-1] + 1)
|
| 15 |
+
.values
|
| 16 |
+
)
|
| 17 |
+
return lengths
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -14,7 +14,7 @@ from torch.utils.data import DataLoader, RandomSampler
|
|
| 14 |
|
| 15 |
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
| 16 |
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
| 17 |
-
from axolotl.utils.samplers import MultipackBatchSampler
|
| 18 |
|
| 19 |
LOG = get_logger("axolotl")
|
| 20 |
|
|
@@ -212,12 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|
| 212 |
drop_last=True,
|
| 213 |
batch_max_len=cfg.micro_batch_size
|
| 214 |
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
| 215 |
-
lengths=(
|
| 216 |
-
train_dataset.data.column("position_ids")
|
| 217 |
-
.to_pandas()
|
| 218 |
-
.apply(lambda x: x[-1] + 1)
|
| 219 |
-
.values
|
| 220 |
-
),
|
| 221 |
)
|
| 222 |
|
| 223 |
data_loader = DataLoader(
|
|
|
|
| 14 |
|
| 15 |
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
| 16 |
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
| 17 |
+
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 18 |
|
| 19 |
LOG = get_logger("axolotl")
|
| 20 |
|
|
|
|
| 212 |
drop_last=True,
|
| 213 |
batch_max_len=cfg.micro_batch_size
|
| 214 |
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
| 215 |
+
lengths=get_dataset_lengths(train_dataset),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
)
|
| 217 |
|
| 218 |
data_loader = DataLoader(
|