Jan Philipp Harries
commited on
Save Axolotl config as WandB artifact (#716)
Browse files
src/axolotl/cli/__init__.py
CHANGED
|
@@ -194,6 +194,7 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|
| 194 |
# load the config from the yaml file
|
| 195 |
with open(config, encoding="utf-8") as file:
|
| 196 |
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
|
|
|
| 197 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
| 198 |
# then overwrite the value
|
| 199 |
cfg_keys = cfg.keys()
|
|
|
|
| 194 |
# load the config from the yaml file
|
| 195 |
with open(config, encoding="utf-8") as file:
|
| 196 |
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
| 197 |
+
cfg.axolotl_config_path = config
|
| 198 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
| 199 |
# then overwrite the value
|
| 200 |
cfg_keys = cfg.keys()
|
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -514,3 +514,27 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|
| 514 |
return control
|
| 515 |
|
| 516 |
return LogPredictionCallback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
return control
|
| 515 |
|
| 516 |
return LogPredictionCallback
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
| 520 |
+
"""Callback to save axolotl config to wandb"""
|
| 521 |
+
|
| 522 |
+
def __init__(self, axolotl_config_path):
|
| 523 |
+
self.axolotl_config_path = axolotl_config_path
|
| 524 |
+
|
| 525 |
+
def on_train_begin(
|
| 526 |
+
self,
|
| 527 |
+
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
| 528 |
+
state: TrainerState, # pylint: disable=unused-argument
|
| 529 |
+
control: TrainerControl,
|
| 530 |
+
**kwargs, # pylint: disable=unused-argument
|
| 531 |
+
):
|
| 532 |
+
if is_main_process():
|
| 533 |
+
try:
|
| 534 |
+
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
| 535 |
+
artifact.add_file(local_path=self.axolotl_config_path)
|
| 536 |
+
wandb.run.log_artifact(artifact)
|
| 537 |
+
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
| 538 |
+
except (FileNotFoundError, ConnectionError) as err:
|
| 539 |
+
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
| 540 |
+
return control
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -30,6 +30,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|
| 30 |
from axolotl.utils.callbacks import (
|
| 31 |
EvalFirstStepCallback,
|
| 32 |
GPUStatsCallback,
|
|
|
|
| 33 |
SaveBetterTransformerModelCallback,
|
| 34 |
bench_eval_callback_factory,
|
| 35 |
log_prediction_callback_factory,
|
|
@@ -775,6 +776,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 775 |
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
| 776 |
trainer.add_callback(LogPredictionCallback(cfg))
|
| 777 |
|
|
|
|
|
|
|
|
|
|
| 778 |
if cfg.do_bench_eval:
|
| 779 |
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
| 780 |
|
|
|
|
| 30 |
from axolotl.utils.callbacks import (
|
| 31 |
EvalFirstStepCallback,
|
| 32 |
GPUStatsCallback,
|
| 33 |
+
SaveAxolotlConfigtoWandBCallback,
|
| 34 |
SaveBetterTransformerModelCallback,
|
| 35 |
bench_eval_callback_factory,
|
| 36 |
log_prediction_callback_factory,
|
|
|
|
| 776 |
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
| 777 |
trainer.add_callback(LogPredictionCallback(cfg))
|
| 778 |
|
| 779 |
+
if cfg.use_wandb:
|
| 780 |
+
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
|
| 781 |
+
|
| 782 |
if cfg.do_bench_eval:
|
| 783 |
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
| 784 |
|