Update doc for grad_accu and add validation tests for batch size
Browse files- README.md +1 -0
- src/axolotl/utils/validation.py +6 -0
- tests/test_validation.py +19 -0
README.md
CHANGED
|
@@ -397,6 +397,7 @@ Add below flag to train command above
|
|
| 397 |
Please reduce any below
|
| 398 |
- `micro_batch_size`
|
| 399 |
- `eval_batch_size`
|
|
|
|
| 400 |
- `sequence_len`
|
| 401 |
|
| 402 |
> RuntimeError: expected scalar type Float but found Half
|
|
|
|
| 397 |
Please reduce any below
|
| 398 |
- `micro_batch_size`
|
| 399 |
- `eval_batch_size`
|
| 400 |
+
- `gradient_accumulation_steps`
|
| 401 |
- `sequence_len`
|
| 402 |
|
| 403 |
> RuntimeError: expected scalar type Float but found Half
|
src/axolotl/utils/validation.py
CHANGED
|
@@ -8,6 +8,12 @@ def validate_config(cfg):
|
|
| 8 |
raise ValueError(
|
| 9 |
"please set only one of gradient_accumulation_steps or batch_size"
|
| 10 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
if cfg.load_4bit:
|
| 12 |
raise ValueError(
|
| 13 |
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
|
|
|
| 8 |
raise ValueError(
|
| 9 |
"please set only one of gradient_accumulation_steps or batch_size"
|
| 10 |
)
|
| 11 |
+
if cfg.batch_size:
|
| 12 |
+
logging.warning(
|
| 13 |
+
"%s\n%s",
|
| 14 |
+
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
| 15 |
+
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
| 16 |
+
)
|
| 17 |
if cfg.load_4bit:
|
| 18 |
raise ValueError(
|
| 19 |
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
tests/test_validation.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
"""Module for testing the validation module"""
|
| 2 |
|
|
|
|
| 3 |
import unittest
|
|
|
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
|
|
@@ -13,6 +15,12 @@ class ValidationTest(unittest.TestCase):
|
|
| 13 |
Test the validation module
|
| 14 |
"""
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def test_load_4bit_deprecate(self):
|
| 17 |
cfg = DictDefault(
|
| 18 |
{
|
|
@@ -23,6 +31,17 @@ class ValidationTest(unittest.TestCase):
|
|
| 23 |
with pytest.raises(ValueError):
|
| 24 |
validate_config(cfg)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def test_qlora(self):
|
| 27 |
base_cfg = DictDefault(
|
| 28 |
{
|
|
|
|
| 1 |
"""Module for testing the validation module"""
|
| 2 |
|
| 3 |
+
import logging
|
| 4 |
import unittest
|
| 5 |
+
from typing import Optional
|
| 6 |
|
| 7 |
import pytest
|
| 8 |
|
|
|
|
| 15 |
Test the validation module
|
| 16 |
"""
|
| 17 |
|
| 18 |
+
_caplog: Optional[pytest.LogCaptureFixture] = None
|
| 19 |
+
|
| 20 |
+
@pytest.fixture(autouse=True)
|
| 21 |
+
def inject_fixtures(self, caplog):
|
| 22 |
+
self._caplog = caplog
|
| 23 |
+
|
| 24 |
def test_load_4bit_deprecate(self):
|
| 25 |
cfg = DictDefault(
|
| 26 |
{
|
|
|
|
| 31 |
with pytest.raises(ValueError):
|
| 32 |
validate_config(cfg)
|
| 33 |
|
| 34 |
+
def test_batch_size_unused_warning(self):
|
| 35 |
+
cfg = DictDefault(
|
| 36 |
+
{
|
| 37 |
+
"batch_size": 32,
|
| 38 |
+
}
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
with self._caplog.at_level(logging.WARNING):
|
| 42 |
+
validate_config(cfg)
|
| 43 |
+
assert "batch_size is not recommended" in self._caplog.records[0].message
|
| 44 |
+
|
| 45 |
def test_qlora(self):
|
| 46 |
base_cfg = DictDefault(
|
| 47 |
{
|