ADD: warning if hub_model_id ist set but not any save strategy (#1202)
Browse files* warning if hub model id set but no save
* add warning
* move the warning
* add test
* allow more public methods for tests for now
* fix tests
---------
Co-authored-by: Wing Lian <[email protected]>
- src/axolotl/utils/config.py +5 -0
- tests/test_validation.py +17 -0
src/axolotl/utils/config.py
CHANGED
|
@@ -340,6 +340,11 @@ def validate_config(cfg):
|
|
| 340 |
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
| 341 |
)
|
| 342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
if cfg.gptq and cfg.model_revision:
|
| 344 |
raise ValueError(
|
| 345 |
"model_revision is not supported for GPTQ models. "
|
|
|
|
| 340 |
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
| 341 |
)
|
| 342 |
|
| 343 |
+
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
| 344 |
+
LOG.warning(
|
| 345 |
+
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
if cfg.gptq and cfg.model_revision:
|
| 349 |
raise ValueError(
|
| 350 |
"model_revision is not supported for GPTQ models. "
|
tests/test_validation.py
CHANGED
|
@@ -26,6 +26,7 @@ class BaseValidation(unittest.TestCase):
|
|
| 26 |
self._caplog = caplog
|
| 27 |
|
| 28 |
|
|
|
|
| 29 |
class ValidationTest(BaseValidation):
|
| 30 |
"""
|
| 31 |
Test the validation module
|
|
@@ -698,6 +699,22 @@ class ValidationTest(BaseValidation):
|
|
| 698 |
):
|
| 699 |
validate_config(cfg)
|
| 700 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
|
| 702 |
class ValidationCheckModelConfig(BaseValidation):
|
| 703 |
"""
|
|
|
|
| 26 |
self._caplog = caplog
|
| 27 |
|
| 28 |
|
| 29 |
+
# pylint: disable=too-many-public-methods
|
| 30 |
class ValidationTest(BaseValidation):
|
| 31 |
"""
|
| 32 |
Test the validation module
|
|
|
|
| 699 |
):
|
| 700 |
validate_config(cfg)
|
| 701 |
|
| 702 |
+
def test_hub_model_id_save_value_warns(self):
|
| 703 |
+
cfg = DictDefault({"hub_model_id": "test"})
|
| 704 |
+
|
| 705 |
+
with self._caplog.at_level(logging.WARNING):
|
| 706 |
+
validate_config(cfg)
|
| 707 |
+
assert (
|
| 708 |
+
"set without any models being saved" in self._caplog.records[0].message
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
def test_hub_model_id_save_value(self):
|
| 712 |
+
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
|
| 713 |
+
|
| 714 |
+
with self._caplog.at_level(logging.WARNING):
|
| 715 |
+
validate_config(cfg)
|
| 716 |
+
assert len(self._caplog.records) == 0
|
| 717 |
+
|
| 718 |
|
| 719 |
class ValidationCheckModelConfig(BaseValidation):
|
| 720 |
"""
|