add optimization for group-by-len (#563)
Browse files- src/axolotl/utils/trainer.py +10 -0
src/axolotl/utils/trainer.py
CHANGED
|
@@ -358,7 +358,14 @@ class ReLoRATrainer(AxolotlTrainer):
|
|
| 358 |
|
| 359 |
|
| 360 |
def add_position_ids(sample):
|
|
|
|
| 361 |
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
return sample
|
| 363 |
|
| 364 |
|
|
@@ -382,6 +389,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|
| 382 |
if eval_dataset:
|
| 383 |
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
| 384 |
|
|
|
|
|
|
|
|
|
|
| 385 |
if cfg.sample_packing:
|
| 386 |
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
| 387 |
if eval_dataset:
|
|
|
|
| 358 |
|
| 359 |
|
| 360 |
def add_position_ids(sample):
|
| 361 |
+
sample_len = len(sample["input_ids"])
|
| 362 |
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
| 363 |
+
sample["length"] = sample_len
|
| 364 |
+
return sample
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def add_length(sample):
|
| 368 |
+
sample["length"] = len(sample["input_ids"])
|
| 369 |
return sample
|
| 370 |
|
| 371 |
|
|
|
|
| 389 |
if eval_dataset:
|
| 390 |
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
| 391 |
|
| 392 |
+
if cfg.group_by_length:
|
| 393 |
+
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
|
| 394 |
+
|
| 395 |
if cfg.sample_packing:
|
| 396 |
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
| 397 |
if eval_dataset:
|