xzuyn
commited on
Add `layers_to_transform` for `lora_config` (#1118)
Browse files- README.md +2 -1
- src/axolotl/utils/config.py +5 -0
- src/axolotl/utils/models.py +1 -0
- tests/test_validation.py +15 -0
README.md
CHANGED
|
@@ -677,7 +677,8 @@ lora_target_modules:
|
|
| 677 |
# - gate_proj
|
| 678 |
# - down_proj
|
| 679 |
# - up_proj
|
| 680 |
-
lora_target_linear: # If true, will target all linear
|
|
|
|
| 681 |
|
| 682 |
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
| 683 |
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
|
|
|
| 677 |
# - gate_proj
|
| 678 |
# - down_proj
|
| 679 |
# - up_proj
|
| 680 |
+
lora_target_linear: # If true, will target all linear modules
|
| 681 |
+
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
|
| 682 |
|
| 683 |
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
| 684 |
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
src/axolotl/utils/config.py
CHANGED
|
@@ -257,6 +257,11 @@ def validate_config(cfg):
|
|
| 257 |
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
| 258 |
raise ValueError("Fused modules are not supported with LoRA")
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
if cfg.relora_steps:
|
| 261 |
if cfg.adapter not in ("lora", "qlora"):
|
| 262 |
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
|
|
|
| 257 |
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
| 258 |
raise ValueError("Fused modules are not supported with LoRA")
|
| 259 |
|
| 260 |
+
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
| 261 |
+
raise ValueError(
|
| 262 |
+
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
if cfg.relora_steps:
|
| 266 |
if cfg.adapter not in ("lora", "qlora"):
|
| 267 |
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
src/axolotl/utils/models.py
CHANGED
|
@@ -733,6 +733,7 @@ def load_lora(model, cfg, inference=False):
|
|
| 733 |
r=cfg.lora_r,
|
| 734 |
lora_alpha=cfg.lora_alpha,
|
| 735 |
target_modules=lora_target_modules,
|
|
|
|
| 736 |
lora_dropout=cfg.lora_dropout,
|
| 737 |
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
| 738 |
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
|
|
|
| 733 |
r=cfg.lora_r,
|
| 734 |
lora_alpha=cfg.lora_alpha,
|
| 735 |
target_modules=lora_target_modules,
|
| 736 |
+
layers_to_transform=cfg.peft_layers_to_transform,
|
| 737 |
lora_dropout=cfg.lora_dropout,
|
| 738 |
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
| 739 |
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
tests/test_validation.py
CHANGED
|
@@ -694,6 +694,21 @@ class ValidationTest(BaseValidation):
|
|
| 694 |
|
| 695 |
validate_config(cfg)
|
| 696 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
|
| 698 |
class ValidationCheckModelConfig(BaseValidation):
|
| 699 |
"""
|
|
|
|
| 694 |
|
| 695 |
validate_config(cfg)
|
| 696 |
|
| 697 |
+
def test_unfrozen_parameters_w_peft_layers_to_transform(self):
|
| 698 |
+
cfg = DictDefault(
|
| 699 |
+
{
|
| 700 |
+
"adapter": "lora",
|
| 701 |
+
"unfrozen_parameters": ["model.layers.2[0-9]+.block_sparse_moe.gate.*"],
|
| 702 |
+
"peft_layers_to_transform": [0, 1],
|
| 703 |
+
}
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
with pytest.raises(
|
| 707 |
+
ValueError,
|
| 708 |
+
match=r".*can have unexpected behavior*",
|
| 709 |
+
):
|
| 710 |
+
validate_config(cfg)
|
| 711 |
+
|
| 712 |
|
| 713 |
class ValidationCheckModelConfig(BaseValidation):
|
| 714 |
"""
|