WIP: Support table logging for mlflow, too (#1506)
Browse files* WIP: Support table logging for mlflow, too
Create a `LogPredictionCallback` for both "wandb" and "mlflow" if
specified.
In `log_prediction_callback_factory`, create a generic table and make it
specific only if the newly added `logger` argument is set to "wandb"
resp. "mlflow".
See https://github.com/OpenAccess-AI-Collective/axolotl/issues/1505
* chore: lint
* add additional clause for mlflow as it's optional
* Fix circular imports
---------
Co-authored-by: Dave Farago <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -36,6 +36,7 @@ from trl.trainer.utils import pad_to_length
|
|
| 36 |
from axolotl.loraplus import create_loraplus_optimizer
|
| 37 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
| 38 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|
|
|
| 39 |
from axolotl.utils.callbacks import (
|
| 40 |
EvalFirstStepCallback,
|
| 41 |
GPUStatsCallback,
|
|
@@ -71,10 +72,6 @@ except ImportError:
|
|
| 71 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
| 72 |
|
| 73 |
|
| 74 |
-
def is_mlflow_available():
|
| 75 |
-
return importlib.util.find_spec("mlflow") is not None
|
| 76 |
-
|
| 77 |
-
|
| 78 |
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
| 79 |
if isinstance(tag_names, str):
|
| 80 |
tag_names = [tag_names]
|
|
@@ -943,7 +940,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 943 |
callbacks = []
|
| 944 |
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
| 945 |
LogPredictionCallback = log_prediction_callback_factory(
|
| 946 |
-
trainer, self.tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 947 |
)
|
| 948 |
callbacks.append(LogPredictionCallback(self.cfg))
|
| 949 |
|
|
|
|
| 36 |
from axolotl.loraplus import create_loraplus_optimizer
|
| 37 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
| 38 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
| 39 |
+
from axolotl.utils import is_mlflow_available
|
| 40 |
from axolotl.utils.callbacks import (
|
| 41 |
EvalFirstStepCallback,
|
| 42 |
GPUStatsCallback,
|
|
|
|
| 72 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
| 73 |
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
| 76 |
if isinstance(tag_names, str):
|
| 77 |
tag_names = [tag_names]
|
|
|
|
| 940 |
callbacks = []
|
| 941 |
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
| 942 |
LogPredictionCallback = log_prediction_callback_factory(
|
| 943 |
+
trainer, self.tokenizer, "wandb"
|
| 944 |
+
)
|
| 945 |
+
callbacks.append(LogPredictionCallback(self.cfg))
|
| 946 |
+
if (
|
| 947 |
+
self.cfg.use_mlflow
|
| 948 |
+
and is_mlflow_available()
|
| 949 |
+
and self.cfg.eval_table_size > 0
|
| 950 |
+
):
|
| 951 |
+
LogPredictionCallback = log_prediction_callback_factory(
|
| 952 |
+
trainer, self.tokenizer, "mlflow"
|
| 953 |
)
|
| 954 |
callbacks.append(LogPredictionCallback(self.cfg))
|
| 955 |
|
src/axolotl/utils/__init__.py
CHANGED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic utils for Axolotl
|
| 3 |
+
"""
|
| 4 |
+
import importlib
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def is_mlflow_available():
|
| 8 |
+
return importlib.util.find_spec("mlflow") is not None
|
src/axolotl/utils/callbacks/__init__.py
CHANGED
|
@@ -6,7 +6,7 @@ import logging
|
|
| 6 |
import os
|
| 7 |
from shutil import copyfile
|
| 8 |
from tempfile import NamedTemporaryFile
|
| 9 |
-
from typing import TYPE_CHECKING, Dict, List
|
| 10 |
|
| 11 |
import evaluate
|
| 12 |
import numpy as np
|
|
@@ -27,7 +27,9 @@ from transformers import (
|
|
| 27 |
)
|
| 28 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
| 29 |
|
|
|
|
| 30 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
|
| 31 |
from axolotl.utils.distributed import (
|
| 32 |
barrier,
|
| 33 |
broadcast_dict,
|
|
@@ -540,7 +542,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|
| 540 |
return CausalLMBenchEvalCallback
|
| 541 |
|
| 542 |
|
| 543 |
-
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
| 544 |
class LogPredictionCallback(TrainerCallback):
|
| 545 |
"""Callback to log prediction values during each evaluation"""
|
| 546 |
|
|
@@ -597,15 +599,13 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|
| 597 |
return ranges
|
| 598 |
|
| 599 |
def log_table_from_dataloader(name: str, table_dataloader):
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
]
|
| 608 |
-
)
|
| 609 |
row_index = 0
|
| 610 |
|
| 611 |
for batch in tqdm(table_dataloader):
|
|
@@ -709,16 +709,29 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|
| 709 |
) in zip(
|
| 710 |
prompt_texts, completion_texts, predicted_texts, pred_step_texts
|
| 711 |
):
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
prediction_text
|
| 717 |
-
pred_step_text,
|
| 718 |
)
|
|
|
|
|
|
|
|
|
|
| 719 |
row_index += 1
|
| 720 |
-
|
| 721 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
|
| 723 |
if is_main_process():
|
| 724 |
log_table_from_dataloader("Eval", eval_dataloader)
|
|
|
|
| 6 |
import os
|
| 7 |
from shutil import copyfile
|
| 8 |
from tempfile import NamedTemporaryFile
|
| 9 |
+
from typing import TYPE_CHECKING, Any, Dict, List
|
| 10 |
|
| 11 |
import evaluate
|
| 12 |
import numpy as np
|
|
|
|
| 27 |
)
|
| 28 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
| 29 |
|
| 30 |
+
from axolotl.utils import is_mlflow_available
|
| 31 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 32 |
+
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
| 33 |
from axolotl.utils.distributed import (
|
| 34 |
barrier,
|
| 35 |
broadcast_dict,
|
|
|
|
| 542 |
return CausalLMBenchEvalCallback
|
| 543 |
|
| 544 |
|
| 545 |
+
def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
| 546 |
class LogPredictionCallback(TrainerCallback):
|
| 547 |
"""Callback to log prediction values during each evaluation"""
|
| 548 |
|
|
|
|
| 599 |
return ranges
|
| 600 |
|
| 601 |
def log_table_from_dataloader(name: str, table_dataloader):
|
| 602 |
+
table_data: Dict[str, List[Any]] = {
|
| 603 |
+
"id": [],
|
| 604 |
+
"Prompt": [],
|
| 605 |
+
"Correct Completion": [],
|
| 606 |
+
"Predicted Completion (model.generate)": [],
|
| 607 |
+
"Predicted Completion (trainer.prediction_step)": [],
|
| 608 |
+
}
|
|
|
|
|
|
|
| 609 |
row_index = 0
|
| 610 |
|
| 611 |
for batch in tqdm(table_dataloader):
|
|
|
|
| 709 |
) in zip(
|
| 710 |
prompt_texts, completion_texts, predicted_texts, pred_step_texts
|
| 711 |
):
|
| 712 |
+
table_data["id"].append(row_index)
|
| 713 |
+
table_data["Prompt"].append(prompt_text)
|
| 714 |
+
table_data["Correct Completion"].append(completion_text)
|
| 715 |
+
table_data["Predicted Completion (model.generate)"].append(
|
| 716 |
+
prediction_text
|
|
|
|
| 717 |
)
|
| 718 |
+
table_data[
|
| 719 |
+
"Predicted Completion (trainer.prediction_step)"
|
| 720 |
+
].append(pred_step_text)
|
| 721 |
row_index += 1
|
| 722 |
+
if logger == "wandb":
|
| 723 |
+
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
|
| 724 |
+
elif logger == "mlflow" and is_mlflow_available():
|
| 725 |
+
import mlflow
|
| 726 |
+
|
| 727 |
+
tracking_uri = AxolotlInputConfig(
|
| 728 |
+
**self.cfg.to_dict()
|
| 729 |
+
).mlflow_tracking_uri
|
| 730 |
+
mlflow.log_table(
|
| 731 |
+
data=table_data,
|
| 732 |
+
artifact_file="PredictionsVsGroundTruth.json",
|
| 733 |
+
tracking_uri=tracking_uri,
|
| 734 |
+
)
|
| 735 |
|
| 736 |
if is_main_process():
|
| 737 |
log_table_from_dataloader("Eval", eval_dataloader)
|