integrate qlora? maybe?
Browse files- requirements.txt +1 -1
- src/axolotl/utils/models.py +32 -2
requirements.txt
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
peft @ git+https://github.com/huggingface/peft.git
|
| 2 |
transformers @ git+https://github.com/huggingface/transformers.git
|
|
|
|
| 3 |
attrdict
|
| 4 |
fire
|
| 5 |
PyYAML==6.0
|
| 6 |
black
|
| 7 |
-
bitsandbytes==0.37.2
|
| 8 |
datasets
|
| 9 |
accelerate>=0.19.0
|
| 10 |
sentencepiece
|
|
|
|
| 1 |
peft @ git+https://github.com/huggingface/peft.git
|
| 2 |
transformers @ git+https://github.com/huggingface/transformers.git
|
| 3 |
+
bitsandbytes @ git+https://github.com/TimDettmers/bitsandbytes.git
|
| 4 |
attrdict
|
| 5 |
fire
|
| 6 |
PyYAML==6.0
|
| 7 |
black
|
|
|
|
| 8 |
datasets
|
| 9 |
accelerate>=0.19.0
|
| 10 |
sentencepiece
|
src/axolotl/utils/models.py
CHANGED
|
@@ -6,11 +6,12 @@ from typing import Optional, Tuple, TYPE_CHECKING
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import transformers
|
|
|
|
| 9 |
from transformers import (
|
| 10 |
AutoModelForCausalLM,
|
| 11 |
AutoTokenizer,
|
| 12 |
PreTrainedModel,
|
| 13 |
-
AutoConfig,
|
| 14 |
)
|
| 15 |
|
| 16 |
try:
|
|
@@ -81,6 +82,16 @@ def load_model(
|
|
| 81 |
logging.exception(e)
|
| 82 |
raise e
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
try:
|
| 85 |
if cfg.load_4bit and is_llama_derived_model:
|
| 86 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
|
@@ -125,6 +136,7 @@ def load_model(
|
|
| 125 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 126 |
torch_dtype=torch_dtype,
|
| 127 |
device_map=cfg.device_map,
|
|
|
|
| 128 |
)
|
| 129 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
| 130 |
# This is a WIP, still an issue with the backward pass
|
|
@@ -159,6 +171,7 @@ def load_model(
|
|
| 159 |
torch_dtype=torch_dtype,
|
| 160 |
device_map=cfg.device_map,
|
| 161 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
|
|
|
| 162 |
)
|
| 163 |
else:
|
| 164 |
config = AutoConfig.from_pretrained(
|
|
@@ -172,6 +185,7 @@ def load_model(
|
|
| 172 |
torch_dtype=torch_dtype,
|
| 173 |
device_map=cfg.device_map,
|
| 174 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
|
|
|
| 175 |
)
|
| 176 |
except Exception as e:
|
| 177 |
logging.error(
|
|
@@ -184,8 +198,24 @@ def load_model(
|
|
| 184 |
torch_dtype=torch_dtype,
|
| 185 |
device_map=cfg.device_map,
|
| 186 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
|
|
|
| 187 |
)
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
if not tokenizer:
|
| 190 |
try:
|
| 191 |
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
|
@@ -270,7 +300,7 @@ def load_adapter(model, cfg, adapter):
|
|
| 270 |
|
| 271 |
if adapter is None:
|
| 272 |
return model, None
|
| 273 |
-
if adapter == "lora":
|
| 274 |
return load_lora(model, cfg)
|
| 275 |
if adapter == "llama-adapter":
|
| 276 |
return load_llama_adapter(model, cfg)
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import transformers
|
| 9 |
+
from torch import nn
|
| 10 |
from transformers import (
|
| 11 |
AutoModelForCausalLM,
|
| 12 |
AutoTokenizer,
|
| 13 |
PreTrainedModel,
|
| 14 |
+
AutoConfig, BitsAndBytesConfig,
|
| 15 |
)
|
| 16 |
|
| 17 |
try:
|
|
|
|
| 82 |
logging.exception(e)
|
| 83 |
raise e
|
| 84 |
|
| 85 |
+
model_kwargs = {}
|
| 86 |
+
if cfg.adapter == "qlora":
|
| 87 |
+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 88 |
+
load_in_4bit=True,
|
| 89 |
+
llm_int8_threshold=6.0,
|
| 90 |
+
llm_int8_has_fp16_weight=False,
|
| 91 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 92 |
+
bnb_4bit_use_double_quant=True,
|
| 93 |
+
bnb_4bit_quant_type="nf4",
|
| 94 |
+
)
|
| 95 |
try:
|
| 96 |
if cfg.load_4bit and is_llama_derived_model:
|
| 97 |
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
|
|
|
| 136 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 137 |
torch_dtype=torch_dtype,
|
| 138 |
device_map=cfg.device_map,
|
| 139 |
+
**model_kwargs,
|
| 140 |
)
|
| 141 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
| 142 |
# This is a WIP, still an issue with the backward pass
|
|
|
|
| 171 |
torch_dtype=torch_dtype,
|
| 172 |
device_map=cfg.device_map,
|
| 173 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
| 174 |
+
**model_kwargs,
|
| 175 |
)
|
| 176 |
else:
|
| 177 |
config = AutoConfig.from_pretrained(
|
|
|
|
| 185 |
torch_dtype=torch_dtype,
|
| 186 |
device_map=cfg.device_map,
|
| 187 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
| 188 |
+
**model_kwargs,
|
| 189 |
)
|
| 190 |
except Exception as e:
|
| 191 |
logging.error(
|
|
|
|
| 198 |
torch_dtype=torch_dtype,
|
| 199 |
device_map=cfg.device_map,
|
| 200 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
| 201 |
+
**model_kwargs,
|
| 202 |
)
|
| 203 |
|
| 204 |
+
"""### Post-processing on the model
|
| 205 |
+
Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
|
| 206 |
+
"""
|
| 207 |
+
if cfg.adapter == "qlora":
|
| 208 |
+
for param in model.parameters():
|
| 209 |
+
param.requires_grad = False # freeze the model - train adapters later
|
| 210 |
+
if param.ndim == 1:
|
| 211 |
+
# cast the small parameters (e.g. layernorm) to fp32 for stability
|
| 212 |
+
param.data = param.data.to(torch.float32)
|
| 213 |
+
class CastOutputToFloat(nn.Sequential):
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
return super().forward(x).to(torch.float32)
|
| 216 |
+
|
| 217 |
+
model.lm_head = CastOutputToFloat(model.lm_head)
|
| 218 |
+
|
| 219 |
if not tokenizer:
|
| 220 |
try:
|
| 221 |
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
|
|
|
| 300 |
|
| 301 |
if adapter is None:
|
| 302 |
return model, None
|
| 303 |
+
if adapter == "lora" or adapter == "qlora":
|
| 304 |
return load_lora(model, cfg)
|
| 305 |
if adapter == "llama-adapter":
|
| 306 |
return load_llama_adapter(model, cfg)
|