support for explicit test_dataset definition for evals (#786)
Browse files- src/axolotl/utils/config.py +5 -0
- src/axolotl/utils/data.py +39 -29
src/axolotl/utils/config.py
CHANGED
|
@@ -519,6 +519,11 @@ def validate_config(cfg):
|
|
| 519 |
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
| 520 |
)
|
| 521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
# TODO
|
| 523 |
# MPT 7b
|
| 524 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 519 |
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
| 520 |
)
|
| 521 |
|
| 522 |
+
if cfg.test_datasets and cfg.val_set_size:
|
| 523 |
+
raise ValueError(
|
| 524 |
+
"non-zero val_set_size should not be used with test_datasets configuration"
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
# TODO
|
| 528 |
# MPT 7b
|
| 529 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/data.py
CHANGED
|
@@ -4,7 +4,7 @@ import hashlib
|
|
| 4 |
import logging
|
| 5 |
from collections import defaultdict
|
| 6 |
from pathlib import Path
|
| 7 |
-
from typing import Dict, List, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
| 10 |
from datasets import (
|
|
@@ -65,9 +65,17 @@ def prepare_dataset(cfg, tokenizer):
|
|
| 65 |
prompters = []
|
| 66 |
if not cfg.pretraining_dataset:
|
| 67 |
with zero_first(is_main_process()):
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
else:
|
| 72 |
path = cfg.pretraining_dataset
|
| 73 |
name = None
|
|
@@ -108,8 +116,12 @@ def prepare_dataset(cfg, tokenizer):
|
|
| 108 |
|
| 109 |
|
| 110 |
def load_tokenized_prepared_datasets(
|
| 111 |
-
tokenizer,
|
|
|
|
|
|
|
|
|
|
| 112 |
) -> Tuple[DatasetDict, List[Prompter]]:
|
|
|
|
| 113 |
tokenizer_name = tokenizer.__class__.__name__
|
| 114 |
ds_hash = str(
|
| 115 |
md5(
|
|
@@ -126,7 +138,7 @@ def load_tokenized_prepared_datasets(
|
|
| 126 |
sorted(
|
| 127 |
[
|
| 128 |
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
| 129 |
-
for d in
|
| 130 |
]
|
| 131 |
)
|
| 132 |
)
|
|
@@ -149,7 +161,7 @@ def load_tokenized_prepared_datasets(
|
|
| 149 |
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
| 150 |
token=use_auth_token,
|
| 151 |
)
|
| 152 |
-
dataset = dataset[
|
| 153 |
except Exception: # pylint: disable=broad-except # nosec
|
| 154 |
pass
|
| 155 |
|
|
@@ -188,8 +200,8 @@ def load_tokenized_prepared_datasets(
|
|
| 188 |
yield dataset
|
| 189 |
|
| 190 |
# pylint: disable=invalid-name
|
| 191 |
-
for config_dataset in for_d_in_datasets(
|
| 192 |
-
ds: Union[Dataset, DatasetDict] = None
|
| 193 |
ds_from_hub = False
|
| 194 |
try:
|
| 195 |
load_dataset(
|
|
@@ -342,16 +354,6 @@ def load_tokenized_prepared_datasets(
|
|
| 342 |
)
|
| 343 |
if not ds:
|
| 344 |
raise ValueError("unhandled dataset load")
|
| 345 |
-
# support for using a subset of the data
|
| 346 |
-
if config_dataset.shards:
|
| 347 |
-
if "train" in ds:
|
| 348 |
-
ds = ds.shuffle(seed=seed)["train"].shard(
|
| 349 |
-
num_shards=config_dataset.shards, index=0
|
| 350 |
-
)
|
| 351 |
-
else:
|
| 352 |
-
ds = ds.shuffle(seed=seed).shard(
|
| 353 |
-
num_shards=config_dataset.shards, index=0
|
| 354 |
-
)
|
| 355 |
|
| 356 |
d_base_type = d_prompt_style = None
|
| 357 |
d_type = config_dataset.type
|
|
@@ -359,17 +361,21 @@ def load_tokenized_prepared_datasets(
|
|
| 359 |
d_type_split = d_type.split(":")
|
| 360 |
d_base_type = d_type_split[0]
|
| 361 |
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
and config_dataset.train_on_split in ds
|
| 368 |
-
):
|
| 369 |
-
ds = ds[config_dataset.train_on_split]
|
| 370 |
elif isinstance(ds, DatasetDict):
|
| 371 |
raise ValueError(
|
| 372 |
-
f"no
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
)
|
| 374 |
|
| 375 |
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
|
@@ -428,6 +434,7 @@ def load_prepare_datasets(
|
|
| 428 |
tokenizer: PreTrainedTokenizerBase,
|
| 429 |
cfg,
|
| 430 |
default_dataset_prepared_path,
|
|
|
|
| 431 |
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
| 432 |
dataset, prompters = load_tokenized_prepared_datasets(
|
| 433 |
tokenizer, cfg, default_dataset_prepared_path
|
|
@@ -442,7 +449,7 @@ def load_prepare_datasets(
|
|
| 442 |
index=cfg.dataset_shard_idx,
|
| 443 |
)
|
| 444 |
|
| 445 |
-
if cfg.val_set_size:
|
| 446 |
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
| 447 |
to_hash_train = (
|
| 448 |
dataset._fingerprint # pylint: disable=protected-access
|
|
@@ -475,6 +482,9 @@ def load_prepare_datasets(
|
|
| 475 |
|
| 476 |
train_dataset = dataset["train"]
|
| 477 |
eval_dataset = dataset["test"]
|
|
|
|
|
|
|
|
|
|
| 478 |
else:
|
| 479 |
train_dataset = dataset
|
| 480 |
eval_dataset = None
|
|
|
|
| 4 |
import logging
|
| 5 |
from collections import defaultdict
|
| 6 |
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
| 10 |
from datasets import (
|
|
|
|
| 65 |
prompters = []
|
| 66 |
if not cfg.pretraining_dataset:
|
| 67 |
with zero_first(is_main_process()):
|
| 68 |
+
if cfg.test_datasets:
|
| 69 |
+
train_dataset, _, prompters = load_prepare_datasets(
|
| 70 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
| 71 |
+
)
|
| 72 |
+
_, eval_dataset, _ = load_prepare_datasets(
|
| 73 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
| 77 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
| 78 |
+
)
|
| 79 |
else:
|
| 80 |
path = cfg.pretraining_dataset
|
| 81 |
name = None
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
def load_tokenized_prepared_datasets(
|
| 119 |
+
tokenizer,
|
| 120 |
+
cfg,
|
| 121 |
+
default_dataset_prepared_path,
|
| 122 |
+
split="train",
|
| 123 |
) -> Tuple[DatasetDict, List[Prompter]]:
|
| 124 |
+
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
| 125 |
tokenizer_name = tokenizer.__class__.__name__
|
| 126 |
ds_hash = str(
|
| 127 |
md5(
|
|
|
|
| 138 |
sorted(
|
| 139 |
[
|
| 140 |
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
| 141 |
+
for d in cfg_datasets
|
| 142 |
]
|
| 143 |
)
|
| 144 |
)
|
|
|
|
| 161 |
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
| 162 |
token=use_auth_token,
|
| 163 |
)
|
| 164 |
+
dataset = dataset[split]
|
| 165 |
except Exception: # pylint: disable=broad-except # nosec
|
| 166 |
pass
|
| 167 |
|
|
|
|
| 200 |
yield dataset
|
| 201 |
|
| 202 |
# pylint: disable=invalid-name
|
| 203 |
+
for config_dataset in for_d_in_datasets(cfg_datasets):
|
| 204 |
+
ds: Optional[Union[Dataset, DatasetDict]] = None
|
| 205 |
ds_from_hub = False
|
| 206 |
try:
|
| 207 |
load_dataset(
|
|
|
|
| 354 |
)
|
| 355 |
if not ds:
|
| 356 |
raise ValueError("unhandled dataset load")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
d_base_type = d_prompt_style = None
|
| 359 |
d_type = config_dataset.type
|
|
|
|
| 361 |
d_type_split = d_type.split(":")
|
| 362 |
d_base_type = d_type_split[0]
|
| 363 |
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
| 364 |
+
|
| 365 |
+
if config_dataset.split and config_dataset.split in ds:
|
| 366 |
+
ds = ds[config_dataset.split]
|
| 367 |
+
elif split in ds:
|
| 368 |
+
ds = ds[split]
|
|
|
|
|
|
|
|
|
|
| 369 |
elif isinstance(ds, DatasetDict):
|
| 370 |
raise ValueError(
|
| 371 |
+
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# support for using a subset of the data
|
| 375 |
+
if config_dataset.shards:
|
| 376 |
+
shards_idx = config_dataset.get("shards_idx", 0)
|
| 377 |
+
ds = ds.shuffle(seed=seed).shard(
|
| 378 |
+
num_shards=config_dataset.shards, index=shards_idx
|
| 379 |
)
|
| 380 |
|
| 381 |
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
|
|
|
| 434 |
tokenizer: PreTrainedTokenizerBase,
|
| 435 |
cfg,
|
| 436 |
default_dataset_prepared_path,
|
| 437 |
+
split="train",
|
| 438 |
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
| 439 |
dataset, prompters = load_tokenized_prepared_datasets(
|
| 440 |
tokenizer, cfg, default_dataset_prepared_path
|
|
|
|
| 449 |
index=cfg.dataset_shard_idx,
|
| 450 |
)
|
| 451 |
|
| 452 |
+
if split == "train" and cfg.val_set_size:
|
| 453 |
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
| 454 |
to_hash_train = (
|
| 455 |
dataset._fingerprint # pylint: disable=protected-access
|
|
|
|
| 482 |
|
| 483 |
train_dataset = dataset["train"]
|
| 484 |
eval_dataset = dataset["test"]
|
| 485 |
+
elif split == "test":
|
| 486 |
+
train_dataset = None
|
| 487 |
+
eval_dataset = dataset
|
| 488 |
else:
|
| 489 |
train_dataset = dataset
|
| 490 |
eval_dataset = None
|