make sure to save on the last step (#1615)
Browse files
src/axolotl/utils/callbacks/__init__.py
CHANGED
|
@@ -778,6 +778,17 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
| 778 |
class SaveModelOnTrainEndCallback(TrainerCallback):
|
| 779 |
"""Callback to save model on train end"""
|
| 780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
def on_train_end( # pylint: disable=unused-argument
|
| 782 |
self, args, state, control, **kwargs
|
| 783 |
):
|
|
|
|
| 778 |
class SaveModelOnTrainEndCallback(TrainerCallback):
|
| 779 |
"""Callback to save model on train end"""
|
| 780 |
|
| 781 |
+
def on_step_end( # pylint: disable=unused-argument
|
| 782 |
+
self,
|
| 783 |
+
args: TrainingArguments,
|
| 784 |
+
state: TrainerState,
|
| 785 |
+
control: TrainerControl,
|
| 786 |
+
**kwargs,
|
| 787 |
+
):
|
| 788 |
+
# Save
|
| 789 |
+
if state.global_step >= state.max_steps:
|
| 790 |
+
control.should_save = True
|
| 791 |
+
|
| 792 |
def on_train_end( # pylint: disable=unused-argument
|
| 793 |
self, args, state, control, **kwargs
|
| 794 |
):
|