improve save callbacks (#1592)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -43,6 +43,7 @@ from axolotl.utils.callbacks import (
|
|
| 43 |
LossWatchDogCallback,
|
| 44 |
SaveAxolotlConfigtoWandBCallback,
|
| 45 |
SaveBetterTransformerModelCallback,
|
|
|
|
| 46 |
bench_eval_callback_factory,
|
| 47 |
causal_lm_bench_eval_callback_factory,
|
| 48 |
log_prediction_callback_factory,
|
|
@@ -888,6 +889,14 @@ class TrainerBuilderBase(abc.ABC):
|
|
| 888 |
callbacks.append(
|
| 889 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
| 890 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 891 |
|
| 892 |
return callbacks
|
| 893 |
|
|
@@ -933,18 +942,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 933 |
):
|
| 934 |
callbacks.append(SaveBetterTransformerModelCallback())
|
| 935 |
|
| 936 |
-
if self.cfg.use_mlflow and is_mlflow_available():
|
| 937 |
-
from axolotl.utils.callbacks.mlflow_ import (
|
| 938 |
-
SaveAxolotlConfigtoMlflowCallback,
|
| 939 |
-
)
|
| 940 |
-
|
| 941 |
-
callbacks.append(
|
| 942 |
-
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
| 943 |
-
)
|
| 944 |
-
|
| 945 |
if self.cfg.loss_watchdog_threshold is not None:
|
| 946 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
| 947 |
|
|
|
|
|
|
|
| 948 |
return callbacks
|
| 949 |
|
| 950 |
def get_post_trainer_create_callbacks(self, trainer):
|
|
@@ -1427,6 +1429,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
| 1427 |
|
| 1428 |
def get_callbacks(self):
|
| 1429 |
callbacks = super().get_callbacks()
|
|
|
|
|
|
|
| 1430 |
return callbacks
|
| 1431 |
|
| 1432 |
def get_post_trainer_create_callbacks(self, trainer):
|
|
|
|
| 43 |
LossWatchDogCallback,
|
| 44 |
SaveAxolotlConfigtoWandBCallback,
|
| 45 |
SaveBetterTransformerModelCallback,
|
| 46 |
+
SaveModelOnTrainEndCallback,
|
| 47 |
bench_eval_callback_factory,
|
| 48 |
causal_lm_bench_eval_callback_factory,
|
| 49 |
log_prediction_callback_factory,
|
|
|
|
| 889 |
callbacks.append(
|
| 890 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
| 891 |
)
|
| 892 |
+
if self.cfg.use_mlflow and is_mlflow_available():
|
| 893 |
+
from axolotl.utils.callbacks.mlflow_ import (
|
| 894 |
+
SaveAxolotlConfigtoMlflowCallback,
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
callbacks.append(
|
| 898 |
+
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
| 899 |
+
)
|
| 900 |
|
| 901 |
return callbacks
|
| 902 |
|
|
|
|
| 942 |
):
|
| 943 |
callbacks.append(SaveBetterTransformerModelCallback())
|
| 944 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 945 |
if self.cfg.loss_watchdog_threshold is not None:
|
| 946 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
| 947 |
|
| 948 |
+
callbacks.append(SaveModelOnTrainEndCallback())
|
| 949 |
+
|
| 950 |
return callbacks
|
| 951 |
|
| 952 |
def get_post_trainer_create_callbacks(self, trainer):
|
|
|
|
| 1429 |
|
| 1430 |
def get_callbacks(self):
|
| 1431 |
callbacks = super().get_callbacks()
|
| 1432 |
+
callbacks.append(SaveModelOnTrainEndCallback())
|
| 1433 |
+
|
| 1434 |
return callbacks
|
| 1435 |
|
| 1436 |
def get_post_trainer_create_callbacks(self, trainer):
|
src/axolotl/utils/callbacks/__init__.py
CHANGED
|
@@ -773,3 +773,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
| 773 |
except (FileNotFoundError, ConnectionError) as err:
|
| 774 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
| 775 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
except (FileNotFoundError, ConnectionError) as err:
|
| 774 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
| 775 |
return control
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
class SaveModelOnTrainEndCallback(TrainerCallback):
|
| 779 |
+
"""Callback to save model on train end"""
|
| 780 |
+
|
| 781 |
+
def on_train_end( # pylint: disable=unused-argument
|
| 782 |
+
self, args, state, control, **kwargs
|
| 783 |
+
):
|
| 784 |
+
control.should_save = True
|
| 785 |
+
return control
|