recast loralayer, norm, lmhead + embed token weights per original qlora (#393)
Browse files* recast loralayer, norm, lmhead + embed token weights per original qlora
* try again for the fix
* refactor torch dtype picking
* linter fixes
* missing import for LoraLayer
* fix install for tests now that peft is involved
- .github/workflows/tests.yml +1 -1
- setup.py +3 -0
- src/axolotl/utils/config.py +7 -0
- src/axolotl/utils/models.py +22 -21
.github/workflows/tests.yml
CHANGED
|
@@ -24,7 +24,7 @@ jobs:
|
|
| 24 |
|
| 25 |
- name: Install dependencies
|
| 26 |
run: |
|
| 27 |
-
pip install -e .
|
| 28 |
pip install -r requirements-tests.txt
|
| 29 |
|
| 30 |
- name: Run tests
|
|
|
|
| 24 |
|
| 25 |
- name: Install dependencies
|
| 26 |
run: |
|
| 27 |
+
pip install -e .[peft]
|
| 28 |
pip install -r requirements-tests.txt
|
| 29 |
|
| 30 |
- name: Run tests
|
setup.py
CHANGED
|
@@ -32,5 +32,8 @@ setup(
|
|
| 32 |
"extras": [
|
| 33 |
"deepspeed",
|
| 34 |
],
|
|
|
|
|
|
|
|
|
|
| 35 |
},
|
| 36 |
)
|
|
|
|
| 32 |
"extras": [
|
| 33 |
"deepspeed",
|
| 34 |
],
|
| 35 |
+
"peft": [
|
| 36 |
+
"peft @ git+https://github.com/huggingface/peft.git",
|
| 37 |
+
],
|
| 38 |
},
|
| 39 |
)
|
src/axolotl/utils/config.py
CHANGED
|
@@ -62,6 +62,13 @@ def normalize_config(cfg):
|
|
| 62 |
else:
|
| 63 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
| 66 |
|
| 67 |
|
|
|
|
| 62 |
else:
|
| 63 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
| 64 |
|
| 65 |
+
if cfg.bf16 or cfg.bfloat16:
|
| 66 |
+
cfg.torch_dtype = torch.bfloat16
|
| 67 |
+
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
| 68 |
+
cfg.torch_dtype = torch.float16
|
| 69 |
+
else:
|
| 70 |
+
cfg.torch_dtype = torch.float32
|
| 71 |
+
|
| 72 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
| 73 |
|
| 74 |
|
src/axolotl/utils/models.py
CHANGED
|
@@ -11,6 +11,7 @@ import bitsandbytes as bnb
|
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
from optimum.bettertransformer import BetterTransformer
|
|
|
|
| 14 |
from transformers import ( # noqa: F401
|
| 15 |
AutoConfig,
|
| 16 |
AutoModelForCausalLM,
|
|
@@ -146,12 +147,6 @@ def load_model(
|
|
| 146 |
LOG.info("patching _expand_mask")
|
| 147 |
hijack_expand_mask()
|
| 148 |
|
| 149 |
-
if cfg.bf16 or cfg.bfloat16:
|
| 150 |
-
torch_dtype = torch.bfloat16
|
| 151 |
-
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
| 152 |
-
torch_dtype = torch.float16
|
| 153 |
-
else:
|
| 154 |
-
torch_dtype = torch.float32
|
| 155 |
try:
|
| 156 |
if cfg.gptq:
|
| 157 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
|
@@ -183,7 +178,7 @@ def load_model(
|
|
| 183 |
load_in_4bit=True,
|
| 184 |
llm_int8_threshold=6.0,
|
| 185 |
llm_int8_has_fp16_weight=False,
|
| 186 |
-
bnb_4bit_compute_dtype=torch_dtype,
|
| 187 |
bnb_4bit_use_double_quant=True,
|
| 188 |
bnb_4bit_quant_type="nf4",
|
| 189 |
)
|
|
@@ -242,7 +237,7 @@ def load_model(
|
|
| 242 |
device_map=cfg.device_map,
|
| 243 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 244 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 245 |
-
torch_dtype=torch_dtype,
|
| 246 |
**model_kwargs,
|
| 247 |
)
|
| 248 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
|
@@ -277,7 +272,7 @@ def load_model(
|
|
| 277 |
device_map=cfg.device_map,
|
| 278 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 279 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 280 |
-
torch_dtype=torch_dtype,
|
| 281 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 282 |
**model_kwargs,
|
| 283 |
)
|
|
@@ -308,7 +303,7 @@ def load_model(
|
|
| 308 |
device_map=cfg.device_map,
|
| 309 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 310 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 311 |
-
torch_dtype=torch_dtype,
|
| 312 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 313 |
**model_kwargs,
|
| 314 |
)
|
|
@@ -322,7 +317,7 @@ def load_model(
|
|
| 322 |
device_map=cfg.device_map,
|
| 323 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 324 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 325 |
-
torch_dtype=torch_dtype,
|
| 326 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 327 |
**model_kwargs,
|
| 328 |
)
|
|
@@ -356,16 +351,6 @@ def load_model(
|
|
| 356 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
| 357 |
)
|
| 358 |
|
| 359 |
-
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
| 360 |
-
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 361 |
-
if cfg.flash_attention and cfg.is_llama_derived_model:
|
| 362 |
-
for name, module in model.named_modules():
|
| 363 |
-
if "norm" in name:
|
| 364 |
-
module.to(torch_dtype)
|
| 365 |
-
if "lm_head" in name or "embed_tokens" in name:
|
| 366 |
-
if hasattr(module, "weight"):
|
| 367 |
-
module.to(torch_dtype)
|
| 368 |
-
|
| 369 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
| 370 |
|
| 371 |
if cfg.ddp and not load_in_8bit:
|
|
@@ -509,6 +494,22 @@ def load_lora(model, cfg):
|
|
| 509 |
else:
|
| 510 |
model = get_peft_model(model, lora_config)
|
| 511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
model.print_trainable_parameters()
|
| 513 |
|
| 514 |
return model, lora_config
|
|
|
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
from optimum.bettertransformer import BetterTransformer
|
| 14 |
+
from peft.tuners.lora import LoraLayer
|
| 15 |
from transformers import ( # noqa: F401
|
| 16 |
AutoConfig,
|
| 17 |
AutoModelForCausalLM,
|
|
|
|
| 147 |
LOG.info("patching _expand_mask")
|
| 148 |
hijack_expand_mask()
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
try:
|
| 151 |
if cfg.gptq:
|
| 152 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
|
|
|
| 178 |
load_in_4bit=True,
|
| 179 |
llm_int8_threshold=6.0,
|
| 180 |
llm_int8_has_fp16_weight=False,
|
| 181 |
+
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
| 182 |
bnb_4bit_use_double_quant=True,
|
| 183 |
bnb_4bit_quant_type="nf4",
|
| 184 |
)
|
|
|
|
| 237 |
device_map=cfg.device_map,
|
| 238 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 239 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 240 |
+
torch_dtype=cfg.torch_dtype,
|
| 241 |
**model_kwargs,
|
| 242 |
)
|
| 243 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
|
|
|
| 272 |
device_map=cfg.device_map,
|
| 273 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 274 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 275 |
+
torch_dtype=cfg.torch_dtype,
|
| 276 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 277 |
**model_kwargs,
|
| 278 |
)
|
|
|
|
| 303 |
device_map=cfg.device_map,
|
| 304 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 305 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 306 |
+
torch_dtype=cfg.torch_dtype,
|
| 307 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 308 |
**model_kwargs,
|
| 309 |
)
|
|
|
|
| 317 |
device_map=cfg.device_map,
|
| 318 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 319 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 320 |
+
torch_dtype=cfg.torch_dtype,
|
| 321 |
trust_remote_code=cfg.trust_remote_code or False,
|
| 322 |
**model_kwargs,
|
| 323 |
)
|
|
|
|
| 351 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
| 352 |
)
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
| 355 |
|
| 356 |
if cfg.ddp and not load_in_8bit:
|
|
|
|
| 494 |
else:
|
| 495 |
model = get_peft_model(model, lora_config)
|
| 496 |
|
| 497 |
+
for name, module in model.named_modules():
|
| 498 |
+
if isinstance(module, LoraLayer):
|
| 499 |
+
module = module.to(cfg.torch_dtype)
|
| 500 |
+
if "norm" in name:
|
| 501 |
+
module = module.to(torch.float32)
|
| 502 |
+
if "lm_head" in name or "embed_tokens" in name:
|
| 503 |
+
if hasattr(module, "weight"):
|
| 504 |
+
module = module.to(cfg.torch_dtype)
|
| 505 |
+
|
| 506 |
+
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
| 507 |
+
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 508 |
+
if cfg.flash_attention and cfg.is_llama_derived_model:
|
| 509 |
+
for name, module in model.named_modules():
|
| 510 |
+
if "norm" in name:
|
| 511 |
+
module = module.to(cfg.torch_dtype)
|
| 512 |
+
|
| 513 |
model.print_trainable_parameters()
|
| 514 |
|
| 515 |
return model, lora_config
|