use save_strategy from config if available (#434)
Browse files* use save_strategy from config if available
* update docs for save_strategy
- README.md +1 -0
- src/axolotl/utils/trainer.py +7 -1
README.md
CHANGED
|
@@ -472,6 +472,7 @@ warmup_steps: 100
|
|
| 472 |
learning_rate: 0.00003
|
| 473 |
lr_quadratic_warmup:
|
| 474 |
logging_steps:
|
|
|
|
| 475 |
save_steps: # leave empty to save at each epoch
|
| 476 |
eval_steps:
|
| 477 |
save_total_limit: # checkpoints saved at a time
|
|
|
|
| 472 |
learning_rate: 0.00003
|
| 473 |
lr_quadratic_warmup:
|
| 474 |
logging_steps:
|
| 475 |
+
save_strategy: # set to `no` to skip checkpoint saves
|
| 476 |
save_steps: # leave empty to save at each epoch
|
| 477 |
eval_steps:
|
| 478 |
save_total_limit: # checkpoints saved at a time
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -457,6 +457,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 457 |
# we have an eval set, but no steps defined, use epoch
|
| 458 |
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
| 459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
| 461 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
| 462 |
max_seq_length=cfg.sequence_len,
|
|
@@ -468,7 +475,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 468 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
| 469 |
num_train_epochs=cfg.num_epochs,
|
| 470 |
learning_rate=cfg.learning_rate,
|
| 471 |
-
save_strategy="steps" if cfg.save_steps else "epoch",
|
| 472 |
save_steps=cfg.save_steps,
|
| 473 |
output_dir=cfg.output_dir,
|
| 474 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
|
|
|
| 457 |
# we have an eval set, but no steps defined, use epoch
|
| 458 |
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
| 459 |
|
| 460 |
+
if cfg.save_strategy:
|
| 461 |
+
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
|
| 462 |
+
else:
|
| 463 |
+
training_arguments_kwargs["save_strategy"] = (
|
| 464 |
+
"steps" if cfg.save_steps else "epoch",
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
| 468 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
| 469 |
max_seq_length=cfg.sequence_len,
|
|
|
|
| 475 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
| 476 |
num_train_epochs=cfg.num_epochs,
|
| 477 |
learning_rate=cfg.learning_rate,
|
|
|
|
| 478 |
save_steps=cfg.save_steps,
|
| 479 |
output_dir=cfg.output_dir,
|
| 480 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|