Fix: Fail bf16 check when running on cpu during merge (#631)
Browse files- src/axolotl/utils/config.py +1 -1
- tests/test_validation.py +23 -0
src/axolotl/utils/config.py
CHANGED
|
@@ -94,7 +94,7 @@ def validate_config(cfg):
|
|
| 94 |
if not cfg.bf16 and not cfg.bfloat16:
|
| 95 |
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
| 96 |
else:
|
| 97 |
-
if cfg.bf16 or cfg.bfloat16:
|
| 98 |
raise ValueError(
|
| 99 |
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
| 100 |
)
|
|
|
|
| 94 |
if not cfg.bf16 and not cfg.bfloat16:
|
| 95 |
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
| 96 |
else:
|
| 97 |
+
if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16):
|
| 98 |
raise ValueError(
|
| 99 |
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
| 100 |
)
|
tests/test_validation.py
CHANGED
|
@@ -351,3 +351,26 @@ class ValidationTest(unittest.TestCase):
|
|
| 351 |
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
| 352 |
with pytest.raises(ValueError, match=regex_exp):
|
| 353 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
| 352 |
with pytest.raises(ValueError, match=regex_exp):
|
| 353 |
validate_config(cfg)
|
| 354 |
+
|
| 355 |
+
def test_merge_lora_no_bf16_fail(self):
|
| 356 |
+
"""
|
| 357 |
+
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
cfg = DictDefault(
|
| 361 |
+
{
|
| 362 |
+
"bf16": True,
|
| 363 |
+
}
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
|
| 367 |
+
validate_config(cfg)
|
| 368 |
+
|
| 369 |
+
cfg = DictDefault(
|
| 370 |
+
{
|
| 371 |
+
"bf16": True,
|
| 372 |
+
"merge_lora": True,
|
| 373 |
+
}
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
validate_config(cfg)
|