Add a config not to shuffle merged dataset (#1394) [skip ci]
Browse files* Add a config not to shuffle merged dataset
* Update README.md
* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Co-authored-by: Wing Lian <[email protected]>
* invert the condition name
* update README
* info -> debug
---------
Co-authored-by: Wing Lian <[email protected]>
README.md
CHANGED
|
@@ -678,6 +678,10 @@ datasets:
|
|
| 678 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
| 679 |
field:
|
| 680 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
# A list of one or more datasets to eval the model with.
|
| 682 |
# You can use either test_datasets, or val_set_size, but not both.
|
| 683 |
test_datasets:
|
|
|
|
| 678 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
| 679 |
field:
|
| 680 |
|
| 681 |
+
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
|
| 682 |
+
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
| 683 |
+
shuffle_merged_datasets: true
|
| 684 |
+
|
| 685 |
# A list of one or more datasets to eval the model with.
|
| 686 |
# You can use either test_datasets, or val_set_size, but not both.
|
| 687 |
test_datasets:
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -416,6 +416,7 @@ class AxolotlInputConfig(
|
|
| 416 |
|
| 417 |
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
| 418 |
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
|
|
|
| 419 |
dataset_prepared_path: Optional[str] = None
|
| 420 |
dataset_shard_num: Optional[int] = None
|
| 421 |
dataset_shard_idx: Optional[int] = None
|
|
|
|
| 416 |
|
| 417 |
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
| 418 |
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
| 419 |
+
shuffle_merged_datasets: Optional[bool] = True
|
| 420 |
dataset_prepared_path: Optional[str] = None
|
| 421 |
dataset_shard_num: Optional[int] = None
|
| 422 |
dataset_shard_idx: Optional[int] = None
|
src/axolotl/utils/data.py
CHANGED
|
@@ -415,8 +415,11 @@ def load_tokenized_prepared_datasets(
|
|
| 415 |
dataset = concatenate_datasets(datasets)
|
| 416 |
|
| 417 |
if len(datasets) > 1:
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
| 422 |
|
|
@@ -819,7 +822,11 @@ def wrap_pretraining_dataset(
|
|
| 819 |
else:
|
| 820 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
| 821 |
|
| 822 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 823 |
dataset = dataset.map(
|
| 824 |
encode,
|
| 825 |
batched=True,
|
|
|
|
| 415 |
dataset = concatenate_datasets(datasets)
|
| 416 |
|
| 417 |
if len(datasets) > 1:
|
| 418 |
+
if cfg.shuffle_merged_datasets:
|
| 419 |
+
LOG.debug("shuffle merged datasets")
|
| 420 |
+
dataset = dataset.shuffle(seed=seed)
|
| 421 |
+
else:
|
| 422 |
+
LOG.debug("NOT shuffling merged datasets")
|
| 423 |
|
| 424 |
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
| 425 |
|
|
|
|
| 822 |
else:
|
| 823 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
| 824 |
|
| 825 |
+
if cfg.shuffle_merged_datasets:
|
| 826 |
+
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
| 827 |
+
else:
|
| 828 |
+
LOG.debug("NOT shuffling merged pretraining datasets")
|
| 829 |
+
|
| 830 |
dataset = dataset.map(
|
| 831 |
encode,
|
| 832 |
batched=True,
|