Feat: Warns to add to modules_to_save when adding tokens or switching special_tokens (#787)
Browse files* Feat: Auto add to modules_to_save when adding tokens
* fix: swap to error instead of warning
* feat: add check when special_tokens differ and add test
- src/axolotl/utils/config.py +14 -0
- src/axolotl/utils/models.py +17 -0
- tests/test_tokenizers.py +36 -0
- tests/test_validation.py +37 -0
src/axolotl/utils/config.py
CHANGED
|
@@ -448,6 +448,20 @@ def validate_config(cfg):
|
|
| 448 |
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
| 449 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
| 450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
# TODO
|
| 452 |
# MPT 7b
|
| 453 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 448 |
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
| 449 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
| 450 |
|
| 451 |
+
if (
|
| 452 |
+
cfg.adapter
|
| 453 |
+
and cfg.tokens
|
| 454 |
+
and (
|
| 455 |
+
not cfg.lora_modules_to_save
|
| 456 |
+
or not all(
|
| 457 |
+
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
| 458 |
+
)
|
| 459 |
+
)
|
| 460 |
+
):
|
| 461 |
+
raise ValueError(
|
| 462 |
+
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
# TODO
|
| 466 |
# MPT 7b
|
| 467 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/models.py
CHANGED
|
@@ -136,6 +136,23 @@ def load_tokenizer(cfg):
|
|
| 136 |
|
| 137 |
if cfg.special_tokens:
|
| 138 |
for k, val in cfg.special_tokens.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
tokenizer.add_special_tokens(
|
| 140 |
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
| 141 |
)
|
|
|
|
| 136 |
|
| 137 |
if cfg.special_tokens:
|
| 138 |
for k, val in cfg.special_tokens.items():
|
| 139 |
+
# check if new special token is not already in tokenizer and
|
| 140 |
+
# is adapter training to make sure lora_modules_to_save is set
|
| 141 |
+
if (
|
| 142 |
+
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
| 143 |
+
and cfg.adapter
|
| 144 |
+
and (
|
| 145 |
+
not cfg.lora_modules_to_save
|
| 146 |
+
or not all(
|
| 147 |
+
x in cfg.lora_modules_to_save
|
| 148 |
+
for x in ["embed_tokens", "lm_head"]
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
):
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
tokenizer.add_special_tokens(
|
| 157 |
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
| 158 |
)
|
tests/test_tokenizers.py
CHANGED
|
@@ -3,6 +3,8 @@ Test cases for the tokenizer loading
|
|
| 3 |
"""
|
| 4 |
import unittest
|
| 5 |
|
|
|
|
|
|
|
| 6 |
from axolotl.utils.dict import DictDefault
|
| 7 |
from axolotl.utils.models import load_tokenizer
|
| 8 |
|
|
@@ -31,6 +33,40 @@ class TestTokenizers(unittest.TestCase):
|
|
| 31 |
tokenizer = load_tokenizer(cfg)
|
| 32 |
assert "Fast" not in tokenizer.__class__.__name__
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
if __name__ == "__main__":
|
| 36 |
unittest.main()
|
|
|
|
| 3 |
"""
|
| 4 |
import unittest
|
| 5 |
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
from axolotl.utils.dict import DictDefault
|
| 9 |
from axolotl.utils.models import load_tokenizer
|
| 10 |
|
|
|
|
| 33 |
tokenizer = load_tokenizer(cfg)
|
| 34 |
assert "Fast" not in tokenizer.__class__.__name__
|
| 35 |
|
| 36 |
+
def test_special_tokens_modules_to_save(self):
|
| 37 |
+
# setting special_tokens to new token
|
| 38 |
+
cfg = DictDefault(
|
| 39 |
+
{
|
| 40 |
+
"tokenizer_config": "huggyllama/llama-7b",
|
| 41 |
+
"adapter": "lora",
|
| 42 |
+
"special_tokens": {"bos_token": "[INST]"},
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
with pytest.raises(
|
| 46 |
+
ValueError,
|
| 47 |
+
match=r".*Please set lora_modules_to_save*",
|
| 48 |
+
):
|
| 49 |
+
load_tokenizer(cfg)
|
| 50 |
+
|
| 51 |
+
# setting special_tokens but not changing from default
|
| 52 |
+
cfg = DictDefault(
|
| 53 |
+
{
|
| 54 |
+
"tokenizer_config": "huggyllama/llama-7b",
|
| 55 |
+
"adapter": "lora",
|
| 56 |
+
"special_tokens": {"bos_token": "<s>"},
|
| 57 |
+
}
|
| 58 |
+
)
|
| 59 |
+
load_tokenizer(cfg)
|
| 60 |
+
|
| 61 |
+
# non-adapter setting special_tokens
|
| 62 |
+
cfg = DictDefault(
|
| 63 |
+
{
|
| 64 |
+
"tokenizer_config": "huggyllama/llama-7b",
|
| 65 |
+
"special_tokens": {"bos_token": "[INST]"},
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
load_tokenizer(cfg)
|
| 69 |
+
|
| 70 |
|
| 71 |
if __name__ == "__main__":
|
| 72 |
unittest.main()
|
tests/test_validation.py
CHANGED
|
@@ -682,6 +682,43 @@ class ValidationTest(unittest.TestCase):
|
|
| 682 |
|
| 683 |
validate_config(cfg)
|
| 684 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
|
| 686 |
class ValidationWandbTest(ValidationTest):
|
| 687 |
"""
|
|
|
|
| 682 |
|
| 683 |
validate_config(cfg)
|
| 684 |
|
| 685 |
+
def test_add_tokens_adapter(self):
|
| 686 |
+
cfg = DictDefault(
|
| 687 |
+
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
with pytest.raises(
|
| 691 |
+
ValueError,
|
| 692 |
+
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
| 693 |
+
):
|
| 694 |
+
validate_config(cfg)
|
| 695 |
+
|
| 696 |
+
cfg = DictDefault(
|
| 697 |
+
{
|
| 698 |
+
"adapter": "qlora",
|
| 699 |
+
"load_in_4bit": True,
|
| 700 |
+
"tokens": ["<|imstart|>"],
|
| 701 |
+
"lora_modules_to_save": ["embed_tokens"],
|
| 702 |
+
}
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
with pytest.raises(
|
| 706 |
+
ValueError,
|
| 707 |
+
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
| 708 |
+
):
|
| 709 |
+
validate_config(cfg)
|
| 710 |
+
|
| 711 |
+
cfg = DictDefault(
|
| 712 |
+
{
|
| 713 |
+
"adapter": "qlora",
|
| 714 |
+
"load_in_4bit": True,
|
| 715 |
+
"tokens": ["<|imstart|>"],
|
| 716 |
+
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
| 717 |
+
}
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
validate_config(cfg)
|
| 721 |
+
|
| 722 |
|
| 723 |
class ValidationWandbTest(ValidationTest):
|
| 724 |
"""
|