JohanWork
commited on
Add mlflow callback for pushing config to mlflow artifacts (#1125)
Browse files* Update callbacks.py
adding callback for mlflow
* Update trainer_builder.py
* clean up
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -28,6 +28,7 @@ from axolotl.utils.callbacks import (
|
|
| 28 |
EvalFirstStepCallback,
|
| 29 |
GPUStatsCallback,
|
| 30 |
LossWatchDogCallback,
|
|
|
|
| 31 |
SaveAxolotlConfigtoWandBCallback,
|
| 32 |
SaveBetterTransformerModelCallback,
|
| 33 |
bench_eval_callback_factory,
|
|
@@ -543,6 +544,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 543 |
callbacks.append(
|
| 544 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
| 545 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
if self.cfg.loss_watchdog_threshold is not None:
|
| 548 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
|
|
|
| 28 |
EvalFirstStepCallback,
|
| 29 |
GPUStatsCallback,
|
| 30 |
LossWatchDogCallback,
|
| 31 |
+
SaveAxolotlConfigtoMlflowCallback,
|
| 32 |
SaveAxolotlConfigtoWandBCallback,
|
| 33 |
SaveBetterTransformerModelCallback,
|
| 34 |
bench_eval_callback_factory,
|
|
|
|
| 544 |
callbacks.append(
|
| 545 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
| 546 |
)
|
| 547 |
+
if self.cfg.use_mlflow:
|
| 548 |
+
callbacks.append(
|
| 549 |
+
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
| 550 |
+
)
|
| 551 |
|
| 552 |
if self.cfg.loss_watchdog_threshold is not None:
|
| 553 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile
|
|
| 9 |
from typing import TYPE_CHECKING, Dict, List
|
| 10 |
|
| 11 |
import evaluate
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
import pandas as pd
|
| 14 |
import torch
|
|
@@ -575,3 +576,31 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
| 575 |
except (FileNotFoundError, ConnectionError) as err:
|
| 576 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
| 577 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from typing import TYPE_CHECKING, Dict, List
|
| 10 |
|
| 11 |
import evaluate
|
| 12 |
+
import mlflow
|
| 13 |
import numpy as np
|
| 14 |
import pandas as pd
|
| 15 |
import torch
|
|
|
|
| 576 |
except (FileNotFoundError, ConnectionError) as err:
|
| 577 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
| 578 |
return control
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
| 582 |
+
"""Callback to save axolotl config to mlflow"""
|
| 583 |
+
|
| 584 |
+
def __init__(self, axolotl_config_path):
|
| 585 |
+
self.axolotl_config_path = axolotl_config_path
|
| 586 |
+
|
| 587 |
+
def on_train_begin(
|
| 588 |
+
self,
|
| 589 |
+
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
| 590 |
+
state: TrainerState, # pylint: disable=unused-argument
|
| 591 |
+
control: TrainerControl,
|
| 592 |
+
**kwargs, # pylint: disable=unused-argument
|
| 593 |
+
):
|
| 594 |
+
if is_main_process():
|
| 595 |
+
try:
|
| 596 |
+
with NamedTemporaryFile(
|
| 597 |
+
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
| 598 |
+
) as temp_file:
|
| 599 |
+
copyfile(self.axolotl_config_path, temp_file.name)
|
| 600 |
+
mlflow.log_artifact(temp_file.name, artifact_path="")
|
| 601 |
+
LOG.info(
|
| 602 |
+
"The Axolotl config has been saved to the MLflow artifacts."
|
| 603 |
+
)
|
| 604 |
+
except (FileNotFoundError, ConnectionError) as err:
|
| 605 |
+
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
| 606 |
+
return control
|