Preprocess dataset size fix (#1131)
Browse files* overwrite cache on preprocess step
* don't cache the TokenizedPromptDataset at all
* load_from_cache_file no longer needed
- src/axolotl/cli/preprocess.py +1 -0
- src/axolotl/datasets.py +5 -1
- src/axolotl/utils/data.py +30 -10
- src/axolotl/utils/trainer.py +17 -5
src/axolotl/cli/preprocess.py
CHANGED
|
@@ -25,6 +25,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
| 25 |
# pylint: disable=duplicate-code
|
| 26 |
print_axolotl_text_art()
|
| 27 |
parsed_cfg = load_cfg(config, **kwargs)
|
|
|
|
| 28 |
check_accelerate_default_config()
|
| 29 |
check_user_token()
|
| 30 |
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
|
|
|
| 25 |
# pylint: disable=duplicate-code
|
| 26 |
print_axolotl_text_art()
|
| 27 |
parsed_cfg = load_cfg(config, **kwargs)
|
| 28 |
+
parsed_cfg.is_preprocess = True
|
| 29 |
check_accelerate_default_config()
|
| 30 |
check_user_token()
|
| 31 |
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
src/axolotl/datasets.py
CHANGED
|
@@ -35,7 +35,10 @@ class TokenizedPromptDataset(Dataset):
|
|
| 35 |
):
|
| 36 |
self.prompt_tokenizer = prompt_tokenizer
|
| 37 |
self.process_count = process_count
|
| 38 |
-
super().__init__(
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def process(self, dataset):
|
| 41 |
features = dataset.features.keys()
|
|
@@ -52,6 +55,7 @@ class TokenizedPromptDataset(Dataset):
|
|
| 52 |
self.prompt_tokenizer.tokenize_prompt,
|
| 53 |
num_proc=num_proc,
|
| 54 |
remove_columns=features,
|
|
|
|
| 55 |
**map_kwargs,
|
| 56 |
)
|
| 57 |
|
|
|
|
| 35 |
):
|
| 36 |
self.prompt_tokenizer = prompt_tokenizer
|
| 37 |
self.process_count = process_count
|
| 38 |
+
super().__init__(
|
| 39 |
+
self.process(dataset).data,
|
| 40 |
+
**kwargs,
|
| 41 |
+
)
|
| 42 |
|
| 43 |
def process(self, dataset):
|
| 44 |
features = dataset.features.keys()
|
|
|
|
| 55 |
self.prompt_tokenizer.tokenize_prompt,
|
| 56 |
num_proc=num_proc,
|
| 57 |
remove_columns=features,
|
| 58 |
+
keep_in_memory=True,
|
| 59 |
**map_kwargs,
|
| 60 |
)
|
| 61 |
|
src/axolotl/utils/data.py
CHANGED
|
@@ -594,12 +594,16 @@ def get_dataset_wrapper(
|
|
| 594 |
)
|
| 595 |
dataset_prompter = UnsupportedPrompter()
|
| 596 |
dataset_wrapper = TokenizedPromptDataset(
|
| 597 |
-
ds_strategy,
|
|
|
|
|
|
|
| 598 |
)
|
| 599 |
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
| 600 |
dataset_prompter = UnsupportedPrompter()
|
| 601 |
dataset_wrapper = TokenizedPromptDataset(
|
| 602 |
-
ds_strategy,
|
|
|
|
|
|
|
| 603 |
)
|
| 604 |
elif d_base_type == "alpaca":
|
| 605 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
@@ -610,7 +614,9 @@ def get_dataset_wrapper(
|
|
| 610 |
cfg.sequence_len,
|
| 611 |
)
|
| 612 |
ds_wrapper = TokenizedPromptDataset(
|
| 613 |
-
ds_strategy,
|
|
|
|
|
|
|
| 614 |
)
|
| 615 |
dataset_wrapper = ds_wrapper
|
| 616 |
elif d_base_type == "explainchoice":
|
|
@@ -622,7 +628,9 @@ def get_dataset_wrapper(
|
|
| 622 |
cfg.sequence_len,
|
| 623 |
)
|
| 624 |
ds_wrapper = TokenizedPromptDataset(
|
| 625 |
-
ds_strategy,
|
|
|
|
|
|
|
| 626 |
)
|
| 627 |
dataset_wrapper = ds_wrapper
|
| 628 |
elif d_base_type == "concisechoice":
|
|
@@ -634,7 +642,9 @@ def get_dataset_wrapper(
|
|
| 634 |
cfg.sequence_len,
|
| 635 |
)
|
| 636 |
ds_wrapper = TokenizedPromptDataset(
|
| 637 |
-
ds_strategy,
|
|
|
|
|
|
|
| 638 |
)
|
| 639 |
dataset_wrapper = ds_wrapper
|
| 640 |
elif d_base_type == "summarizetldr":
|
|
@@ -646,7 +656,9 @@ def get_dataset_wrapper(
|
|
| 646 |
cfg.sequence_len,
|
| 647 |
)
|
| 648 |
ds_wrapper = TokenizedPromptDataset(
|
| 649 |
-
ds_strategy,
|
|
|
|
|
|
|
| 650 |
)
|
| 651 |
dataset_wrapper = ds_wrapper
|
| 652 |
elif d_base_type == "jeopardy":
|
|
@@ -658,7 +670,9 @@ def get_dataset_wrapper(
|
|
| 658 |
cfg.sequence_len,
|
| 659 |
)
|
| 660 |
ds_wrapper = TokenizedPromptDataset(
|
| 661 |
-
ds_strategy,
|
|
|
|
|
|
|
| 662 |
)
|
| 663 |
dataset_wrapper = ds_wrapper
|
| 664 |
elif d_base_type == "oasst":
|
|
@@ -670,7 +684,9 @@ def get_dataset_wrapper(
|
|
| 670 |
cfg.sequence_len,
|
| 671 |
)
|
| 672 |
ds_wrapper = TokenizedPromptDataset(
|
| 673 |
-
ds_strategy,
|
|
|
|
|
|
|
| 674 |
)
|
| 675 |
dataset_wrapper = ds_wrapper
|
| 676 |
elif d_base_type == "gpteacher":
|
|
@@ -682,7 +698,9 @@ def get_dataset_wrapper(
|
|
| 682 |
cfg.sequence_len,
|
| 683 |
)
|
| 684 |
ds_wrapper = TokenizedPromptDataset(
|
| 685 |
-
ds_strategy,
|
|
|
|
|
|
|
| 686 |
)
|
| 687 |
dataset_wrapper = ds_wrapper
|
| 688 |
elif d_base_type == "reflection":
|
|
@@ -694,7 +712,9 @@ def get_dataset_wrapper(
|
|
| 694 |
cfg.sequence_len,
|
| 695 |
)
|
| 696 |
ds_wrapper = TokenizedPromptDataset(
|
| 697 |
-
ds_strategy,
|
|
|
|
|
|
|
| 698 |
)
|
| 699 |
dataset_wrapper = ds_wrapper
|
| 700 |
else:
|
|
|
|
| 594 |
)
|
| 595 |
dataset_prompter = UnsupportedPrompter()
|
| 596 |
dataset_wrapper = TokenizedPromptDataset(
|
| 597 |
+
ds_strategy,
|
| 598 |
+
dataset,
|
| 599 |
+
process_count=cfg.dataset_processes,
|
| 600 |
)
|
| 601 |
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
| 602 |
dataset_prompter = UnsupportedPrompter()
|
| 603 |
dataset_wrapper = TokenizedPromptDataset(
|
| 604 |
+
ds_strategy,
|
| 605 |
+
dataset,
|
| 606 |
+
process_count=cfg.dataset_processes,
|
| 607 |
)
|
| 608 |
elif d_base_type == "alpaca":
|
| 609 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
|
|
| 614 |
cfg.sequence_len,
|
| 615 |
)
|
| 616 |
ds_wrapper = TokenizedPromptDataset(
|
| 617 |
+
ds_strategy,
|
| 618 |
+
dataset,
|
| 619 |
+
process_count=cfg.dataset_processes,
|
| 620 |
)
|
| 621 |
dataset_wrapper = ds_wrapper
|
| 622 |
elif d_base_type == "explainchoice":
|
|
|
|
| 628 |
cfg.sequence_len,
|
| 629 |
)
|
| 630 |
ds_wrapper = TokenizedPromptDataset(
|
| 631 |
+
ds_strategy,
|
| 632 |
+
dataset,
|
| 633 |
+
process_count=cfg.dataset_processes,
|
| 634 |
)
|
| 635 |
dataset_wrapper = ds_wrapper
|
| 636 |
elif d_base_type == "concisechoice":
|
|
|
|
| 642 |
cfg.sequence_len,
|
| 643 |
)
|
| 644 |
ds_wrapper = TokenizedPromptDataset(
|
| 645 |
+
ds_strategy,
|
| 646 |
+
dataset,
|
| 647 |
+
process_count=cfg.dataset_processes,
|
| 648 |
)
|
| 649 |
dataset_wrapper = ds_wrapper
|
| 650 |
elif d_base_type == "summarizetldr":
|
|
|
|
| 656 |
cfg.sequence_len,
|
| 657 |
)
|
| 658 |
ds_wrapper = TokenizedPromptDataset(
|
| 659 |
+
ds_strategy,
|
| 660 |
+
dataset,
|
| 661 |
+
process_count=cfg.dataset_processes,
|
| 662 |
)
|
| 663 |
dataset_wrapper = ds_wrapper
|
| 664 |
elif d_base_type == "jeopardy":
|
|
|
|
| 670 |
cfg.sequence_len,
|
| 671 |
)
|
| 672 |
ds_wrapper = TokenizedPromptDataset(
|
| 673 |
+
ds_strategy,
|
| 674 |
+
dataset,
|
| 675 |
+
process_count=cfg.dataset_processes,
|
| 676 |
)
|
| 677 |
dataset_wrapper = ds_wrapper
|
| 678 |
elif d_base_type == "oasst":
|
|
|
|
| 684 |
cfg.sequence_len,
|
| 685 |
)
|
| 686 |
ds_wrapper = TokenizedPromptDataset(
|
| 687 |
+
ds_strategy,
|
| 688 |
+
dataset,
|
| 689 |
+
process_count=cfg.dataset_processes,
|
| 690 |
)
|
| 691 |
dataset_wrapper = ds_wrapper
|
| 692 |
elif d_base_type == "gpteacher":
|
|
|
|
| 698 |
cfg.sequence_len,
|
| 699 |
)
|
| 700 |
ds_wrapper = TokenizedPromptDataset(
|
| 701 |
+
ds_strategy,
|
| 702 |
+
dataset,
|
| 703 |
+
process_count=cfg.dataset_processes,
|
| 704 |
)
|
| 705 |
dataset_wrapper = ds_wrapper
|
| 706 |
elif d_base_type == "reflection":
|
|
|
|
| 712 |
cfg.sequence_len,
|
| 713 |
)
|
| 714 |
ds_wrapper = TokenizedPromptDataset(
|
| 715 |
+
ds_strategy,
|
| 716 |
+
dataset,
|
| 717 |
+
process_count=cfg.dataset_processes,
|
| 718 |
)
|
| 719 |
dataset_wrapper = ds_wrapper
|
| 720 |
else:
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -111,27 +111,39 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
| 111 |
with zero_first(is_main_process()):
|
| 112 |
if cfg.group_by_length:
|
| 113 |
train_dataset = train_dataset.map(
|
| 114 |
-
add_length,
|
|
|
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
if cfg.sample_packing:
|
| 118 |
train_dataset = train_dataset.map(
|
| 119 |
-
add_position_ids,
|
|
|
|
|
|
|
| 120 |
)
|
| 121 |
if cfg.eval_sample_packing is not False:
|
| 122 |
if eval_dataset:
|
| 123 |
eval_dataset = eval_dataset.map(
|
| 124 |
-
add_position_ids,
|
|
|
|
|
|
|
| 125 |
)
|
| 126 |
|
| 127 |
if cfg.group_by_length or cfg.sample_packing:
|
| 128 |
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
| 129 |
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
| 130 |
|
| 131 |
-
train_dataset = train_dataset.filter(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if eval_dataset:
|
| 133 |
eval_dataset = eval_dataset.filter(
|
| 134 |
-
drop_long,
|
|
|
|
|
|
|
| 135 |
)
|
| 136 |
|
| 137 |
# Phi doesn't want the attention_mask feature when training
|
|
|
|
| 111 |
with zero_first(is_main_process()):
|
| 112 |
if cfg.group_by_length:
|
| 113 |
train_dataset = train_dataset.map(
|
| 114 |
+
add_length,
|
| 115 |
+
num_proc=cfg.dataset_processes,
|
| 116 |
+
load_from_cache_file=not cfg.is_preprocess,
|
| 117 |
)
|
| 118 |
|
| 119 |
if cfg.sample_packing:
|
| 120 |
train_dataset = train_dataset.map(
|
| 121 |
+
add_position_ids,
|
| 122 |
+
num_proc=cfg.dataset_processes,
|
| 123 |
+
load_from_cache_file=not cfg.is_preprocess,
|
| 124 |
)
|
| 125 |
if cfg.eval_sample_packing is not False:
|
| 126 |
if eval_dataset:
|
| 127 |
eval_dataset = eval_dataset.map(
|
| 128 |
+
add_position_ids,
|
| 129 |
+
num_proc=cfg.dataset_processes,
|
| 130 |
+
load_from_cache_file=not cfg.is_preprocess,
|
| 131 |
)
|
| 132 |
|
| 133 |
if cfg.group_by_length or cfg.sample_packing:
|
| 134 |
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
| 135 |
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
| 136 |
|
| 137 |
+
train_dataset = train_dataset.filter(
|
| 138 |
+
drop_long,
|
| 139 |
+
num_proc=cfg.dataset_processes,
|
| 140 |
+
load_from_cache_file=not cfg.is_preprocess,
|
| 141 |
+
)
|
| 142 |
if eval_dataset:
|
| 143 |
eval_dataset = eval_dataset.filter(
|
| 144 |
+
drop_long,
|
| 145 |
+
num_proc=cfg.dataset_processes,
|
| 146 |
+
load_from_cache_file=not cfg.is_preprocess,
|
| 147 |
)
|
| 148 |
|
| 149 |
# Phi doesn't want the attention_mask feature when training
|