ADD: push checkpoints to mlflow artifact registry (#1295) [skip ci]
Browse files* Add checkpoint logging to mlflow artifact registry
* clean up
* Update README.md
Co-authored-by: NanoCode012 <[email protected]>
* update pydantic config from rebase
---------
Co-authored-by: NanoCode012 <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
README.md
CHANGED
|
@@ -763,6 +763,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
|
| 763 |
# mlflow configuration if you're using it
|
| 764 |
mlflow_tracking_uri: # URI to mlflow
|
| 765 |
mlflow_experiment_name: # Your experiment name
|
|
|
|
| 766 |
|
| 767 |
# Where to save the full-finetuned model to
|
| 768 |
output_dir: ./completed-model
|
|
|
|
| 763 |
# mlflow configuration if you're using it
|
| 764 |
mlflow_tracking_uri: # URI to mlflow
|
| 765 |
mlflow_experiment_name: # Your experiment name
|
| 766 |
+
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
| 767 |
|
| 768 |
# Where to save the full-finetuned model to
|
| 769 |
output_dir: ./completed-model
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -305,6 +305,7 @@ class MLFlowConfig(BaseModel):
|
|
| 305 |
use_mlflow: Optional[str] = None
|
| 306 |
mlflow_tracking_uri: Optional[str] = None
|
| 307 |
mlflow_experiment_name: Optional[str] = None
|
|
|
|
| 308 |
|
| 309 |
|
| 310 |
class WandbConfig(BaseModel):
|
|
|
|
| 305 |
use_mlflow: Optional[str] = None
|
| 306 |
mlflow_tracking_uri: Optional[str] = None
|
| 307 |
mlflow_experiment_name: Optional[str] = None
|
| 308 |
+
hf_mlflow_log_artifacts: Optional[bool] = None
|
| 309 |
|
| 310 |
|
| 311 |
class WandbConfig(BaseModel):
|
src/axolotl/utils/mlflow_.py
CHANGED
|
@@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault
|
|
| 7 |
|
| 8 |
def setup_mlflow_env_vars(cfg: DictDefault):
|
| 9 |
for key in cfg.keys():
|
| 10 |
-
if key.startswith("mlflow_"):
|
| 11 |
value = cfg.get(key, "")
|
| 12 |
|
| 13 |
if value and isinstance(value, str) and len(value) > 0:
|
|
|
|
| 7 |
|
| 8 |
def setup_mlflow_env_vars(cfg: DictDefault):
|
| 9 |
for key in cfg.keys():
|
| 10 |
+
if key.startswith("mlflow_") or key.startswith("hf_mlflow_"):
|
| 11 |
value = cfg.get(key, "")
|
| 12 |
|
| 13 |
if value and isinstance(value, str) and len(value) > 0:
|