params are adam_*, not adamw_*
Browse files
src/axolotl/utils/validation.py
CHANGED
|
@@ -87,7 +87,7 @@ def validate_config(cfg):
|
|
| 87 |
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
| 88 |
)
|
| 89 |
|
| 90 |
-
if any([cfg.
|
| 91 |
not cfg.optimizer or "adamw" not in cfg.optimizer
|
| 92 |
):
|
| 93 |
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
|
|
|
|
| 87 |
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
| 88 |
)
|
| 89 |
|
| 90 |
+
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
| 91 |
not cfg.optimizer or "adamw" not in cfg.optimizer
|
| 92 |
):
|
| 93 |
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
|
tests/test_validation.py
CHANGED
|
@@ -268,7 +268,7 @@ class ValidationTest(unittest.TestCase):
|
|
| 268 |
cfg = DictDefault(
|
| 269 |
{
|
| 270 |
"optimizer": None,
|
| 271 |
-
"
|
| 272 |
}
|
| 273 |
)
|
| 274 |
|
|
@@ -283,7 +283,7 @@ class ValidationTest(unittest.TestCase):
|
|
| 283 |
cfg = DictDefault(
|
| 284 |
{
|
| 285 |
"optimizer": "adafactor",
|
| 286 |
-
"
|
| 287 |
}
|
| 288 |
)
|
| 289 |
|
|
@@ -298,9 +298,9 @@ class ValidationTest(unittest.TestCase):
|
|
| 298 |
cfg = DictDefault(
|
| 299 |
{
|
| 300 |
"optimizer": "adamw_bnb_8bit",
|
| 301 |
-
"
|
| 302 |
-
"
|
| 303 |
-
"
|
| 304 |
}
|
| 305 |
)
|
| 306 |
|
|
|
|
| 268 |
cfg = DictDefault(
|
| 269 |
{
|
| 270 |
"optimizer": None,
|
| 271 |
+
"adam_epsilon": 0.0001,
|
| 272 |
}
|
| 273 |
)
|
| 274 |
|
|
|
|
| 283 |
cfg = DictDefault(
|
| 284 |
{
|
| 285 |
"optimizer": "adafactor",
|
| 286 |
+
"adam_beta1": 0.0001,
|
| 287 |
}
|
| 288 |
)
|
| 289 |
|
|
|
|
| 298 |
cfg = DictDefault(
|
| 299 |
{
|
| 300 |
"optimizer": "adamw_bnb_8bit",
|
| 301 |
+
"adam_beta1": 0.9,
|
| 302 |
+
"adam_beta2": 0.99,
|
| 303 |
+
"adam_epsilon": 0.0001,
|
| 304 |
}
|
| 305 |
)
|
| 306 |
|