new validation for mpt w grad checkpoints
Browse files- src/axolotl/utils/validation.py +5 -0
- tests/test_validation.py +14 -0
src/axolotl/utils/validation.py
CHANGED
|
@@ -57,6 +57,11 @@ def validate_config(cfg):
|
|
| 57 |
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
| 58 |
raise ValueError("FSDP is not supported for falcon models")
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# TODO
|
| 61 |
# MPT 7b
|
| 62 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 57 |
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
| 58 |
raise ValueError("FSDP is not supported for falcon models")
|
| 59 |
|
| 60 |
+
if (
|
| 61 |
+
cfg.base_model and "mpt" in cfg.base_model.lower()
|
| 62 |
+
) and cfg.gradient_checkpointing:
|
| 63 |
+
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
| 64 |
+
|
| 65 |
# TODO
|
| 66 |
# MPT 7b
|
| 67 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
tests/test_validation.py
CHANGED
|
@@ -198,3 +198,17 @@ class ValidationTest(unittest.TestCase):
|
|
| 198 |
)
|
| 199 |
|
| 200 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
)
|
| 199 |
|
| 200 |
validate_config(cfg)
|
| 201 |
+
|
| 202 |
+
def test_mpt_gradient_checkpointing(self):
|
| 203 |
+
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
|
| 204 |
+
|
| 205 |
+
# Check for lower-case
|
| 206 |
+
cfg = DictDefault(
|
| 207 |
+
{
|
| 208 |
+
"base_model": "mosaicml/mpt-7b",
|
| 209 |
+
"gradient_checkpointing": True,
|
| 210 |
+
}
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
with pytest.raises(ValueError, match=regex_exp):
|
| 214 |
+
validate_config(cfg)
|