don't train if eval split is too small (#873)
Browse files* allow zero len dataset
* better handling and warning of small eval splits
* raise error if eval split is too small
* don't mess with calculating total num steps in distributed context
* fix eval_sample_packing training args logic
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -658,7 +658,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 658 |
self.cfg.sample_packing if self.cfg.sample_packing else False
|
| 659 |
)
|
| 660 |
training_arguments_kwargs["eval_sample_packing"] = (
|
| 661 |
-
self.cfg.sample_packing
|
|
|
|
|
|
|
| 662 |
)
|
| 663 |
training_arguments_kwargs[
|
| 664 |
"sample_packing_seq_len_multiplier"
|
|
|
|
| 658 |
self.cfg.sample_packing if self.cfg.sample_packing else False
|
| 659 |
)
|
| 660 |
training_arguments_kwargs["eval_sample_packing"] = (
|
| 661 |
+
self.cfg.sample_packing
|
| 662 |
+
if self.cfg.eval_sample_packing is not False
|
| 663 |
+
else False
|
| 664 |
)
|
| 665 |
training_arguments_kwargs[
|
| 666 |
"sample_packing_seq_len_multiplier"
|
src/axolotl/utils/data.py
CHANGED
|
@@ -79,6 +79,14 @@ def prepare_dataset(cfg, tokenizer):
|
|
| 79 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
| 80 |
cfg, train_dataset, eval_dataset, tokenizer
|
| 81 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
if cfg.max_steps:
|
| 83 |
total_num_steps = min(
|
| 84 |
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
|
|
|
| 79 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
| 80 |
cfg, train_dataset, eval_dataset, tokenizer
|
| 81 |
)
|
| 82 |
+
|
| 83 |
+
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
| 84 |
+
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
| 85 |
+
if total_eval_steps == 0:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
if cfg.max_steps:
|
| 91 |
total_num_steps = min(
|
| 92 |
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
src/axolotl/utils/samplers/multipack.py
CHANGED
|
@@ -182,7 +182,7 @@ class MultipackBatchSampler(BatchSampler):
|
|
| 182 |
|
| 183 |
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
| 184 |
return max(
|
| 185 |
-
|
| 186 |
(
|
| 187 |
world_size
|
| 188 |
* math.floor(
|
|
|
|
| 182 |
|
| 183 |
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
| 184 |
return max(
|
| 185 |
+
0,
|
| 186 |
(
|
| 187 |
world_size
|
| 188 |
* math.floor(
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -141,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
| 141 |
return train_dataset, eval_dataset
|
| 142 |
|
| 143 |
|
| 144 |
-
def calculate_total_num_steps(cfg, train_dataset):
|
| 145 |
if not cfg.total_num_tokens:
|
| 146 |
total_num_tokens = np.sum(
|
| 147 |
train_dataset.data.column("input_ids")
|
|
@@ -150,7 +150,8 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|
| 150 |
.values
|
| 151 |
)
|
| 152 |
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
| 153 |
-
|
|
|
|
| 154 |
|
| 155 |
if not cfg.total_supervised_tokens:
|
| 156 |
total_supervised_tokens = (
|
|
@@ -163,7 +164,8 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|
| 163 |
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
| 164 |
main_process_only=True,
|
| 165 |
)
|
| 166 |
-
|
|
|
|
| 167 |
|
| 168 |
if cfg.sample_packing:
|
| 169 |
# we have to drop anything longer then sequence len otherwise
|
|
@@ -232,7 +234,8 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|
| 232 |
sample_packing_eff_est = (
|
| 233 |
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
| 234 |
)
|
| 235 |
-
|
|
|
|
| 236 |
LOG.debug(
|
| 237 |
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
| 238 |
main_process_only=True,
|
|
|
|
| 141 |
return train_dataset, eval_dataset
|
| 142 |
|
| 143 |
|
| 144 |
+
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
| 145 |
if not cfg.total_num_tokens:
|
| 146 |
total_num_tokens = np.sum(
|
| 147 |
train_dataset.data.column("input_ids")
|
|
|
|
| 150 |
.values
|
| 151 |
)
|
| 152 |
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
| 153 |
+
if update:
|
| 154 |
+
cfg.total_num_tokens = total_num_tokens
|
| 155 |
|
| 156 |
if not cfg.total_supervised_tokens:
|
| 157 |
total_supervised_tokens = (
|
|
|
|
| 164 |
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
| 165 |
main_process_only=True,
|
| 166 |
)
|
| 167 |
+
if update:
|
| 168 |
+
cfg.total_supervised_tokens = total_supervised_tokens
|
| 169 |
|
| 170 |
if cfg.sample_packing:
|
| 171 |
# we have to drop anything longer then sequence len otherwise
|
|
|
|
| 234 |
sample_packing_eff_est = (
|
| 235 |
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
| 236 |
)
|
| 237 |
+
if update:
|
| 238 |
+
cfg.sample_packing_eff_est = sample_packing_eff_est
|
| 239 |
LOG.debug(
|
| 240 |
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
| 241 |
main_process_only=True,
|