Threaded MultipackDistributedDataloader with prefetched samples (#759)
Browse files* Multithreading implementation [WIP]
* Added benchmarking
* 35% increased throughput
* Memory pinning
* Start threads in init
* Correct print of samples
* Sleep if queue is full
* Remove pin_memory (worse)
* Simplify logic to one thread
* Remove benchmark
* Use deque for constant speed
* Formatting
* Formatting
* Formatting
* Formatting
* Rollback to use queue
* Fix multi-epoch training
* Add num epochs arg
* Start thread in __iter__
* Formatting
* Use is_alive correctly
* Simplify loading thread
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -111,7 +111,8 @@ class AxolotlTrainer(Trainer):
|
|
| 111 |
|
| 112 |
args = None # type: AxolotlTrainingArguments
|
| 113 |
|
| 114 |
-
def __init__(self, *args, bench_data_collator=None, **kwargs):
|
|
|
|
| 115 |
self.bench_data_collator = bench_data_collator
|
| 116 |
super().__init__(*args, **kwargs)
|
| 117 |
|
|
@@ -182,6 +183,7 @@ class AxolotlTrainer(Trainer):
|
|
| 182 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
| 183 |
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
| 184 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
|
|
|
| 185 |
)
|
| 186 |
)
|
| 187 |
return super().get_train_dataloader()
|
|
@@ -205,6 +207,7 @@ class AxolotlTrainer(Trainer):
|
|
| 205 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
| 206 |
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
| 207 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
|
|
|
| 208 |
)
|
| 209 |
)
|
| 210 |
return super().get_eval_dataloader(eval_dataset)
|
|
@@ -680,6 +683,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 680 |
**data_collator_kwargs,
|
| 681 |
),
|
| 682 |
callbacks=self.get_callbacks(),
|
|
|
|
| 683 |
**trainer_kwargs,
|
| 684 |
)
|
| 685 |
trainer = self.hook_post_create_trainer(trainer)
|
|
|
|
| 111 |
|
| 112 |
args = None # type: AxolotlTrainingArguments
|
| 113 |
|
| 114 |
+
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
| 115 |
+
self.num_epochs = num_epochs
|
| 116 |
self.bench_data_collator = bench_data_collator
|
| 117 |
super().__init__(*args, **kwargs)
|
| 118 |
|
|
|
|
| 183 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
| 184 |
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
| 185 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
| 186 |
+
num_epochs=self.num_epochs,
|
| 187 |
)
|
| 188 |
)
|
| 189 |
return super().get_train_dataloader()
|
|
|
|
| 207 |
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
| 208 |
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
| 209 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
| 210 |
+
num_epochs=self.num_epochs,
|
| 211 |
)
|
| 212 |
)
|
| 213 |
return super().get_eval_dataloader(eval_dataset)
|
|
|
|
| 683 |
**data_collator_kwargs,
|
| 684 |
),
|
| 685 |
callbacks=self.get_callbacks(),
|
| 686 |
+
num_epochs=self.cfg.num_epochs,
|
| 687 |
**trainer_kwargs,
|
| 688 |
)
|
| 689 |
trainer = self.hook_post_create_trainer(trainer)
|
src/axolotl/utils/dataloader.py
CHANGED
|
@@ -3,6 +3,9 @@ import hashlib
|
|
| 3 |
import itertools
|
| 4 |
import logging
|
| 5 |
import math
|
|
|
|
|
|
|
|
|
|
| 6 |
from typing import Any, Callable, List, Union
|
| 7 |
|
| 8 |
import numba
|
|
@@ -149,6 +152,8 @@ class MultipackDistributedDataloader:
|
|
| 149 |
packing_efficiency_estimate: float = 1.0,
|
| 150 |
sample_packing_seq_len_multiplier: int = 1,
|
| 151 |
device_count: int = 1,
|
|
|
|
|
|
|
| 152 |
):
|
| 153 |
# Dataset
|
| 154 |
self.dataset = dataset
|
|
@@ -167,6 +172,7 @@ class MultipackDistributedDataloader:
|
|
| 167 |
self.seq_max_length = seq_max_length
|
| 168 |
self.batch_max_length = batch_size * seq_max_length
|
| 169 |
self.collate_fn = collate_fn
|
|
|
|
| 170 |
|
| 171 |
self.num_replicas = 1
|
| 172 |
self.rank = 0
|
|
@@ -177,6 +183,44 @@ class MultipackDistributedDataloader:
|
|
| 177 |
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
| 178 |
self.device_count = device_count
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
def generate_batches(self, set_stats=False):
|
| 181 |
LOG.info("generating packed batches")
|
| 182 |
if self.sampler:
|
|
@@ -206,11 +250,7 @@ class MultipackDistributedDataloader:
|
|
| 206 |
|
| 207 |
return batches, totseqs
|
| 208 |
|
| 209 |
-
def
|
| 210 |
-
if hasattr(self.sampler, "set_epoch"):
|
| 211 |
-
new_epoch = self.sampler.epoch + 1
|
| 212 |
-
self.sampler.set_epoch(new_epoch)
|
| 213 |
-
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
| 214 |
all_batches, _ = self.generate_batches(set_stats=True)
|
| 215 |
features = self.dataset.features.keys()
|
| 216 |
len_remaining = self._len_est()
|
|
|
|
| 3 |
import itertools
|
| 4 |
import logging
|
| 5 |
import math
|
| 6 |
+
import time
|
| 7 |
+
from queue import Queue
|
| 8 |
+
from threading import Thread
|
| 9 |
from typing import Any, Callable, List, Union
|
| 10 |
|
| 11 |
import numba
|
|
|
|
| 152 |
packing_efficiency_estimate: float = 1.0,
|
| 153 |
sample_packing_seq_len_multiplier: int = 1,
|
| 154 |
device_count: int = 1,
|
| 155 |
+
prefetch_max: int = 1000,
|
| 156 |
+
num_epochs: int = 1,
|
| 157 |
):
|
| 158 |
# Dataset
|
| 159 |
self.dataset = dataset
|
|
|
|
| 172 |
self.seq_max_length = seq_max_length
|
| 173 |
self.batch_max_length = batch_size * seq_max_length
|
| 174 |
self.collate_fn = collate_fn
|
| 175 |
+
self.num_epochs = num_epochs
|
| 176 |
|
| 177 |
self.num_replicas = 1
|
| 178 |
self.rank = 0
|
|
|
|
| 183 |
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
| 184 |
self.device_count = device_count
|
| 185 |
|
| 186 |
+
# maxsize is maximum number of samples in queue
|
| 187 |
+
self.prefetch_max = prefetch_max
|
| 188 |
+
self.queue: Queue = Queue(maxsize=prefetch_max)
|
| 189 |
+
self.thread = None
|
| 190 |
+
|
| 191 |
+
def _worker(self):
|
| 192 |
+
LOG.info(
|
| 193 |
+
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
|
| 194 |
+
)
|
| 195 |
+
for epoch in range(self.num_epochs):
|
| 196 |
+
for sample in self._internal_batch_generator():
|
| 197 |
+
while True:
|
| 198 |
+
if self.queue.full():
|
| 199 |
+
time.sleep(1)
|
| 200 |
+
else:
|
| 201 |
+
break
|
| 202 |
+
self.queue.put(sample)
|
| 203 |
+
|
| 204 |
+
# stop the queue when epoch is done
|
| 205 |
+
self.queue.put(None)
|
| 206 |
+
|
| 207 |
+
def __iter__(self):
|
| 208 |
+
if hasattr(self.sampler, "set_epoch"):
|
| 209 |
+
new_epoch = self.sampler.epoch + 1
|
| 210 |
+
self.sampler.set_epoch(new_epoch)
|
| 211 |
+
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
| 212 |
+
|
| 213 |
+
if self.thread is None:
|
| 214 |
+
self.thread = Thread(target=self._worker, daemon=True)
|
| 215 |
+
self.thread.start()
|
| 216 |
+
|
| 217 |
+
while True:
|
| 218 |
+
item = self.queue.get()
|
| 219 |
+
|
| 220 |
+
if item is None:
|
| 221 |
+
break
|
| 222 |
+
yield item
|
| 223 |
+
|
| 224 |
def generate_batches(self, set_stats=False):
|
| 225 |
LOG.info("generating packed batches")
|
| 226 |
if self.sampler:
|
|
|
|
| 250 |
|
| 251 |
return batches, totseqs
|
| 252 |
|
| 253 |
+
def _internal_batch_generator(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
all_batches, _ = self.generate_batches(set_stats=True)
|
| 255 |
features = self.dataset.features.keys()
|
| 256 |
len_remaining = self._len_est()
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -216,6 +216,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|
| 216 |
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
| 217 |
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
| 218 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
|
|
|
| 219 |
)
|
| 220 |
data_loader_len = data_loader.len_w_stats()
|
| 221 |
actual_eff = data_loader.efficiency()
|
|
|
|
| 216 |
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
| 217 |
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
| 218 |
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
| 219 |
+
num_epochs=cfg.num_epochs,
|
| 220 |
)
|
| 221 |
data_loader_len = data_loader.len_w_stats()
|
| 222 |
actual_eff = data_loader.efficiency()
|