be more robust about checking embedding modules for lora finetunes (#1074) [skip ci]
Browse files* be more robust about checking embedding modules for lora finetunes
* update dynamic error message
- src/axolotl/utils/config.py +4 -14
- src/axolotl/utils/lora_embeddings.py +12 -0
- src/axolotl/utils/models.py +27 -7
- tests/test_validation.py +61 -9
src/axolotl/utils/config.py
CHANGED
|
@@ -151,6 +151,10 @@ def normalize_config(cfg):
|
|
| 151 |
|
| 152 |
|
| 153 |
def validate_config(cfg):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
if is_torch_bf16_gpu_available():
|
| 155 |
if not cfg.bf16 and not cfg.bfloat16:
|
| 156 |
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
|
@@ -443,20 +447,6 @@ def validate_config(cfg):
|
|
| 443 |
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
| 444 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
| 445 |
|
| 446 |
-
if (
|
| 447 |
-
cfg.adapter
|
| 448 |
-
and cfg.tokens
|
| 449 |
-
and (
|
| 450 |
-
not cfg.lora_modules_to_save
|
| 451 |
-
or not all(
|
| 452 |
-
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
| 453 |
-
)
|
| 454 |
-
)
|
| 455 |
-
):
|
| 456 |
-
raise ValueError(
|
| 457 |
-
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
| 458 |
-
)
|
| 459 |
-
|
| 460 |
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
| 461 |
raise ValueError(
|
| 462 |
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
def validate_config(cfg):
|
| 154 |
+
"""
|
| 155 |
+
This is a "pre-validation" step that handles the yaml configuration before we have any
|
| 156 |
+
information about the model architecture
|
| 157 |
+
"""
|
| 158 |
if is_torch_bf16_gpu_available():
|
| 159 |
if not cfg.bf16 and not cfg.bfloat16:
|
| 160 |
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
|
|
|
| 447 |
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
| 448 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
| 449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
| 451 |
raise ValueError(
|
| 452 |
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
src/axolotl/utils/lora_embeddings.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
helpers for lora embeddings
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_linear_embedding_layers(model_type):
|
| 7 |
+
"""
|
| 8 |
+
returns the linear embedding layers needed for loras, dependent on the model arch
|
| 9 |
+
"""
|
| 10 |
+
if model_type == "phi-msft":
|
| 11 |
+
return ["embd", "lm_head.linear"]
|
| 12 |
+
return ["lm_head", "embed_tokens"]
|
src/axolotl/utils/models.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
import logging
|
| 3 |
import math
|
| 4 |
import os
|
| 5 |
-
from typing import Any, Optional, Tuple # noqa: F401
|
| 6 |
|
| 7 |
import addict
|
| 8 |
import bitsandbytes as bnb
|
|
@@ -28,12 +28,16 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|
| 28 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 29 |
from axolotl.utils.chat_templates import chat_templates
|
| 30 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 31 |
|
| 32 |
LOG = logging.getLogger("axolotl")
|
| 33 |
|
| 34 |
|
| 35 |
-
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
| 36 |
-
quant_config_exists =
|
|
|
|
|
|
|
|
|
|
| 37 |
quant_config_method_is_gptq = (
|
| 38 |
quant_config_exists
|
| 39 |
and "quant_method" in model_config.quantization_config
|
|
@@ -52,6 +56,20 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|
| 52 |
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
| 53 |
)
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def load_model_config(cfg):
|
| 57 |
model_config_name = cfg.base_model_config or cfg.base_model
|
|
@@ -139,6 +157,7 @@ def load_tokenizer(cfg):
|
|
| 139 |
setattr(tokenizer, attr_name, "<|endoftext|>")
|
| 140 |
|
| 141 |
if cfg.special_tokens:
|
|
|
|
| 142 |
for k, val in cfg.special_tokens.items():
|
| 143 |
# check if new special token is not already in tokenizer and
|
| 144 |
# is adapter training to make sure lora_modules_to_save is set
|
|
@@ -149,14 +168,15 @@ def load_tokenizer(cfg):
|
|
| 149 |
and (
|
| 150 |
not cfg.lora_modules_to_save
|
| 151 |
or not all(
|
| 152 |
-
x in cfg.lora_modules_to_save
|
| 153 |
-
for x in ["embed_tokens", "lm_head"]
|
| 154 |
)
|
| 155 |
)
|
| 156 |
-
and (model_config.model_type in ["llama", "mistral", "mixtral"])
|
| 157 |
):
|
|
|
|
|
|
|
|
|
|
| 158 |
raise ValueError(
|
| 159 |
-
"Please set lora_modules_to_save to
|
| 160 |
)
|
| 161 |
|
| 162 |
tokenizer.add_special_tokens(
|
|
|
|
| 2 |
import logging
|
| 3 |
import math
|
| 4 |
import os
|
| 5 |
+
from typing import Any, Optional, Tuple, Union # noqa: F401
|
| 6 |
|
| 7 |
import addict
|
| 8 |
import bitsandbytes as bnb
|
|
|
|
| 28 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 29 |
from axolotl.utils.chat_templates import chat_templates
|
| 30 |
from axolotl.utils.dict import DictDefault
|
| 31 |
+
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
| 32 |
|
| 33 |
LOG = logging.getLogger("axolotl")
|
| 34 |
|
| 35 |
|
| 36 |
+
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
| 37 |
+
quant_config_exists = (
|
| 38 |
+
hasattr(model_config, "quantization_config")
|
| 39 |
+
and model_config.quantization_config
|
| 40 |
+
)
|
| 41 |
quant_config_method_is_gptq = (
|
| 42 |
quant_config_exists
|
| 43 |
and "quant_method" in model_config.quantization_config
|
|
|
|
| 56 |
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
| 57 |
)
|
| 58 |
|
| 59 |
+
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
| 60 |
+
if (
|
| 61 |
+
cfg.adapter
|
| 62 |
+
and cfg.tokens
|
| 63 |
+
and (
|
| 64 |
+
not cfg.lora_modules_to_save
|
| 65 |
+
or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)
|
| 66 |
+
)
|
| 67 |
+
):
|
| 68 |
+
lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save))
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"`lora_modules_to_save` not properly set when adding new tokens. Please include {lora_modules_to_save} in `lora_modules_to_save`."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
|
| 74 |
def load_model_config(cfg):
|
| 75 |
model_config_name = cfg.base_model_config or cfg.base_model
|
|
|
|
| 157 |
setattr(tokenizer, attr_name, "<|endoftext|>")
|
| 158 |
|
| 159 |
if cfg.special_tokens:
|
| 160 |
+
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
| 161 |
for k, val in cfg.special_tokens.items():
|
| 162 |
# check if new special token is not already in tokenizer and
|
| 163 |
# is adapter training to make sure lora_modules_to_save is set
|
|
|
|
| 168 |
and (
|
| 169 |
not cfg.lora_modules_to_save
|
| 170 |
or not all(
|
| 171 |
+
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
|
|
|
| 172 |
)
|
| 173 |
)
|
|
|
|
| 174 |
):
|
| 175 |
+
lora_modules_to_save = ", ".join(
|
| 176 |
+
[f"`{x}`" for x in lora_modules_to_save]
|
| 177 |
+
)
|
| 178 |
raise ValueError(
|
| 179 |
+
f"Please set lora_modules_to_save to {lora_modules_to_save} when using an adapter and changing the special tokens."
|
| 180 |
)
|
| 181 |
|
| 182 |
tokenizer.add_special_tokens(
|
tests/test_validation.py
CHANGED
|
@@ -10,12 +10,13 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|
| 10 |
|
| 11 |
from axolotl.utils.config import validate_config
|
| 12 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 13 |
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
| 14 |
|
| 15 |
|
| 16 |
-
class
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
"""
|
| 20 |
|
| 21 |
_caplog: Optional[pytest.LogCaptureFixture] = None
|
|
@@ -24,6 +25,12 @@ class ValidationTest(unittest.TestCase):
|
|
| 24 |
def inject_fixtures(self, caplog):
|
| 25 |
self._caplog = caplog
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def test_load_4bit_deprecate(self):
|
| 28 |
cfg = DictDefault(
|
| 29 |
{
|
|
@@ -687,16 +694,23 @@ class ValidationTest(unittest.TestCase):
|
|
| 687 |
|
| 688 |
validate_config(cfg)
|
| 689 |
|
| 690 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
cfg = DictDefault(
|
| 692 |
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
| 693 |
)
|
|
|
|
| 694 |
|
| 695 |
with pytest.raises(
|
| 696 |
ValueError,
|
| 697 |
-
match=r"
|
| 698 |
):
|
| 699 |
-
|
| 700 |
|
| 701 |
cfg = DictDefault(
|
| 702 |
{
|
|
@@ -709,9 +723,9 @@ class ValidationTest(unittest.TestCase):
|
|
| 709 |
|
| 710 |
with pytest.raises(
|
| 711 |
ValueError,
|
| 712 |
-
match=r"
|
| 713 |
):
|
| 714 |
-
|
| 715 |
|
| 716 |
cfg = DictDefault(
|
| 717 |
{
|
|
@@ -722,10 +736,48 @@ class ValidationTest(unittest.TestCase):
|
|
| 722 |
}
|
| 723 |
)
|
| 724 |
|
| 725 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
|
| 727 |
|
| 728 |
-
class ValidationWandbTest(
|
| 729 |
"""
|
| 730 |
Validation test for wandb
|
| 731 |
"""
|
|
|
|
| 10 |
|
| 11 |
from axolotl.utils.config import validate_config
|
| 12 |
from axolotl.utils.dict import DictDefault
|
| 13 |
+
from axolotl.utils.models import check_model_config
|
| 14 |
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
| 15 |
|
| 16 |
|
| 17 |
+
class BaseValidation(unittest.TestCase):
|
| 18 |
"""
|
| 19 |
+
Base validation module to setup the log capture
|
| 20 |
"""
|
| 21 |
|
| 22 |
_caplog: Optional[pytest.LogCaptureFixture] = None
|
|
|
|
| 25 |
def inject_fixtures(self, caplog):
|
| 26 |
self._caplog = caplog
|
| 27 |
|
| 28 |
+
|
| 29 |
+
class ValidationTest(BaseValidation):
|
| 30 |
+
"""
|
| 31 |
+
Test the validation module
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
def test_load_4bit_deprecate(self):
|
| 35 |
cfg = DictDefault(
|
| 36 |
{
|
|
|
|
| 694 |
|
| 695 |
validate_config(cfg)
|
| 696 |
|
| 697 |
+
|
| 698 |
+
class ValidationCheckModelConfig(BaseValidation):
|
| 699 |
+
"""
|
| 700 |
+
Test the validation for the config when the model config is available
|
| 701 |
+
"""
|
| 702 |
+
|
| 703 |
+
def test_llama_add_tokens_adapter(self):
|
| 704 |
cfg = DictDefault(
|
| 705 |
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
| 706 |
)
|
| 707 |
+
model_config = DictDefault({"model_type": "llama"})
|
| 708 |
|
| 709 |
with pytest.raises(
|
| 710 |
ValueError,
|
| 711 |
+
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
| 712 |
):
|
| 713 |
+
check_model_config(cfg, model_config)
|
| 714 |
|
| 715 |
cfg = DictDefault(
|
| 716 |
{
|
|
|
|
| 723 |
|
| 724 |
with pytest.raises(
|
| 725 |
ValueError,
|
| 726 |
+
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
| 727 |
):
|
| 728 |
+
check_model_config(cfg, model_config)
|
| 729 |
|
| 730 |
cfg = DictDefault(
|
| 731 |
{
|
|
|
|
| 736 |
}
|
| 737 |
)
|
| 738 |
|
| 739 |
+
check_model_config(cfg, model_config)
|
| 740 |
+
|
| 741 |
+
def test_phi2_add_tokens_adapter(self):
|
| 742 |
+
cfg = DictDefault(
|
| 743 |
+
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
| 744 |
+
)
|
| 745 |
+
model_config = DictDefault({"model_type": "phi-msft"})
|
| 746 |
+
|
| 747 |
+
with pytest.raises(
|
| 748 |
+
ValueError,
|
| 749 |
+
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
| 750 |
+
):
|
| 751 |
+
check_model_config(cfg, model_config)
|
| 752 |
+
|
| 753 |
+
cfg = DictDefault(
|
| 754 |
+
{
|
| 755 |
+
"adapter": "qlora",
|
| 756 |
+
"load_in_4bit": True,
|
| 757 |
+
"tokens": ["<|imstart|>"],
|
| 758 |
+
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
| 759 |
+
}
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
with pytest.raises(
|
| 763 |
+
ValueError,
|
| 764 |
+
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
|
| 765 |
+
):
|
| 766 |
+
check_model_config(cfg, model_config)
|
| 767 |
+
|
| 768 |
+
cfg = DictDefault(
|
| 769 |
+
{
|
| 770 |
+
"adapter": "qlora",
|
| 771 |
+
"load_in_4bit": True,
|
| 772 |
+
"tokens": ["<|imstart|>"],
|
| 773 |
+
"lora_modules_to_save": ["embd", "lm_head.linear"],
|
| 774 |
+
}
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
check_model_config(cfg, model_config)
|
| 778 |
|
| 779 |
|
| 780 |
+
class ValidationWandbTest(BaseValidation):
|
| 781 |
"""
|
| 782 |
Validation test for wandb
|
| 783 |
"""
|