Added "epoch" evaluation_strategy (#388)
Browse files- src/axolotl/utils/trainer.py +10 -1
src/axolotl/utils/trainer.py
CHANGED
|
@@ -451,6 +451,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 451 |
"sample_packing_efficiency"
|
| 452 |
] = cfg.sample_packing_eff_est
|
| 453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
| 455 |
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
| 456 |
max_seq_length=cfg.sequence_len,
|
|
@@ -462,7 +471,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 462 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
| 463 |
num_train_epochs=cfg.num_epochs,
|
| 464 |
learning_rate=cfg.learning_rate,
|
| 465 |
-
evaluation_strategy=
|
| 466 |
save_strategy="steps" if cfg.save_steps else "epoch",
|
| 467 |
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
| 468 |
save_steps=cfg.save_steps,
|
|
|
|
| 451 |
"sample_packing_efficiency"
|
| 452 |
] = cfg.sample_packing_eff_est
|
| 453 |
|
| 454 |
+
if cfg.val_set_size == 0:
|
| 455 |
+
evaluation_strategy = "no"
|
| 456 |
+
elif cfg.eval_steps < 1:
|
| 457 |
+
# eval every epoch
|
| 458 |
+
evaluation_strategy = "epoch"
|
| 459 |
+
else:
|
| 460 |
+
# eval every eval_steps steps
|
| 461 |
+
evaluation_strategy = "steps"
|
| 462 |
+
|
| 463 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
| 464 |
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
| 465 |
max_seq_length=cfg.sequence_len,
|
|
|
|
| 471 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
| 472 |
num_train_epochs=cfg.num_epochs,
|
| 473 |
learning_rate=cfg.learning_rate,
|
| 474 |
+
evaluation_strategy=evaluation_strategy,
|
| 475 |
save_strategy="steps" if cfg.save_steps else "epoch",
|
| 476 |
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
| 477 |
save_steps=cfg.save_steps,
|