Lint models.py
Browse files- src/axolotl/utils/models.py +34 -30
src/axolotl/utils/models.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import math
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import Optional, Tuple, TYPE_CHECKING
|
| 6 |
|
| 7 |
import bitsandbytes as bnb
|
| 8 |
import torch
|
| 9 |
import transformers
|
| 10 |
-
from transformers import (
|
| 11 |
AutoModelForCausalLM,
|
| 12 |
AutoTokenizer,
|
| 13 |
PreTrainedModel,
|
|
@@ -18,9 +21,8 @@ from transformers import (
|
|
| 18 |
try:
|
| 19 |
from transformers import (
|
| 20 |
LlamaForCausalLM,
|
| 21 |
-
LlamaTokenizer,
|
| 22 |
)
|
| 23 |
-
except:
|
| 24 |
logging.warning(
|
| 25 |
"This version of transformers does not support Llama. Consider upgrading."
|
| 26 |
)
|
|
@@ -28,9 +30,9 @@ except:
|
|
| 28 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
| 29 |
|
| 30 |
if TYPE_CHECKING:
|
| 31 |
-
from peft import
|
| 32 |
-
from axolotl.utils.dict import DictDefault
|
| 33 |
-
from transformers import PreTrainedTokenizer
|
| 34 |
|
| 35 |
|
| 36 |
def load_tokenizer(
|
|
@@ -62,8 +64,8 @@ def load_tokenizer(
|
|
| 62 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 63 |
|
| 64 |
if cfg.special_tokens:
|
| 65 |
-
for k,
|
| 66 |
-
tokenizer.add_special_tokens({k:
|
| 67 |
if cfg.tokens:
|
| 68 |
tokenizer.add_tokens(list(cfg.tokens))
|
| 69 |
|
|
@@ -80,6 +82,9 @@ def load_model(
|
|
| 80 |
inference=False,
|
| 81 |
):
|
| 82 |
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# TODO refactor as a kwarg
|
| 85 |
load_in_8bit = cfg.load_in_8bit
|
|
@@ -115,9 +120,9 @@ def load_model(
|
|
| 115 |
|
| 116 |
replace_peft_model_with_int4_lora_model()
|
| 117 |
from peft import prepare_model_for_int8_training
|
| 118 |
-
except Exception as
|
| 119 |
-
logging.exception(
|
| 120 |
-
raise
|
| 121 |
|
| 122 |
model_kwargs = {}
|
| 123 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
|
@@ -155,7 +160,7 @@ def load_model(
|
|
| 155 |
"unable to find a cached model file, this will likely fail..."
|
| 156 |
)
|
| 157 |
model_path = str(cache_model_path)
|
| 158 |
-
except:
|
| 159 |
model_path = cfg.base_model
|
| 160 |
model, _ = load_llama_model_4bit_low_ram(
|
| 161 |
base_model_config if base_model_config else base_model,
|
|
@@ -210,13 +215,13 @@ def load_model(
|
|
| 210 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 211 |
torch_dtype=torch_dtype,
|
| 212 |
device_map=cfg.device_map,
|
| 213 |
-
trust_remote_code=
|
| 214 |
**model_kwargs,
|
| 215 |
)
|
| 216 |
else:
|
| 217 |
config = AutoConfig.from_pretrained(
|
| 218 |
base_model,
|
| 219 |
-
trust_remote_code=
|
| 220 |
)
|
| 221 |
model = AutoModelForCausalLM.from_pretrained(
|
| 222 |
base_model,
|
|
@@ -225,30 +230,29 @@ def load_model(
|
|
| 225 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 226 |
torch_dtype=torch_dtype,
|
| 227 |
device_map=cfg.device_map,
|
| 228 |
-
trust_remote_code=
|
| 229 |
**model_kwargs,
|
| 230 |
)
|
| 231 |
-
except Exception as
|
| 232 |
logging.error(
|
| 233 |
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
| 234 |
)
|
| 235 |
-
logging.exception(
|
| 236 |
model = AutoModelForCausalLM.from_pretrained(
|
| 237 |
base_model,
|
| 238 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 239 |
torch_dtype=torch_dtype,
|
| 240 |
device_map=cfg.device_map,
|
| 241 |
-
trust_remote_code=
|
| 242 |
**model_kwargs,
|
| 243 |
)
|
| 244 |
|
| 245 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
| 246 |
model.resize_token_embeddings(embeddings_len)
|
| 247 |
|
| 248 |
-
if (
|
| 249 |
-
(
|
| 250 |
-
and
|
| 251 |
-
and (load_in_8bit or cfg.load_in_4bit)
|
| 252 |
):
|
| 253 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
| 254 |
model = prepare_model_for_int8_training(model)
|
|
@@ -261,14 +265,14 @@ def load_model(
|
|
| 261 |
if cfg.gptq:
|
| 262 |
# Scales to half
|
| 263 |
logging.info("Fitting 4bit scales and zeros to half")
|
| 264 |
-
for
|
| 265 |
-
if "Autograd4bitQuantLinear" in str(type(
|
| 266 |
-
type(
|
| 267 |
):
|
| 268 |
-
if hasattr(
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
|
| 273 |
if (
|
| 274 |
torch.cuda.device_count() > 1
|
|
|
|
| 1 |
+
"""Module for models and model loading"""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
import logging
|
| 5 |
import math
|
| 6 |
import os
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
|
| 9 |
|
| 10 |
import bitsandbytes as bnb
|
| 11 |
import torch
|
| 12 |
import transformers
|
| 13 |
+
from transformers import ( # noqa: F401
|
| 14 |
AutoModelForCausalLM,
|
| 15 |
AutoTokenizer,
|
| 16 |
PreTrainedModel,
|
|
|
|
| 21 |
try:
|
| 22 |
from transformers import (
|
| 23 |
LlamaForCausalLM,
|
|
|
|
| 24 |
)
|
| 25 |
+
except ImportError:
|
| 26 |
logging.warning(
|
| 27 |
"This version of transformers does not support Llama. Consider upgrading."
|
| 28 |
)
|
|
|
|
| 30 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
| 31 |
|
| 32 |
if TYPE_CHECKING:
|
| 33 |
+
from peft import PeftConfig # noqa: F401
|
| 34 |
+
from axolotl.utils.dict import DictDefault # noqa: F401
|
| 35 |
+
from transformers import PreTrainedTokenizer # noqa: F401
|
| 36 |
|
| 37 |
|
| 38 |
def load_tokenizer(
|
|
|
|
| 64 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 65 |
|
| 66 |
if cfg.special_tokens:
|
| 67 |
+
for k, val in cfg.special_tokens.items():
|
| 68 |
+
tokenizer.add_special_tokens({k: val})
|
| 69 |
if cfg.tokens:
|
| 70 |
tokenizer.add_tokens(list(cfg.tokens))
|
| 71 |
|
|
|
|
| 82 |
inference=False,
|
| 83 |
):
|
| 84 |
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
| 85 |
+
"""
|
| 86 |
+
Load a model from a base model and a model type.
|
| 87 |
+
"""
|
| 88 |
|
| 89 |
# TODO refactor as a kwarg
|
| 90 |
load_in_8bit = cfg.load_in_8bit
|
|
|
|
| 120 |
|
| 121 |
replace_peft_model_with_int4_lora_model()
|
| 122 |
from peft import prepare_model_for_int8_training
|
| 123 |
+
except Exception as err:
|
| 124 |
+
logging.exception(err)
|
| 125 |
+
raise err
|
| 126 |
|
| 127 |
model_kwargs = {}
|
| 128 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
|
|
|
| 160 |
"unable to find a cached model file, this will likely fail..."
|
| 161 |
)
|
| 162 |
model_path = str(cache_model_path)
|
| 163 |
+
except Exception: # pylint: disable=broad-exception-caught
|
| 164 |
model_path = cfg.base_model
|
| 165 |
model, _ = load_llama_model_4bit_low_ram(
|
| 166 |
base_model_config if base_model_config else base_model,
|
|
|
|
| 215 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 216 |
torch_dtype=torch_dtype,
|
| 217 |
device_map=cfg.device_map,
|
| 218 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
| 219 |
**model_kwargs,
|
| 220 |
)
|
| 221 |
else:
|
| 222 |
config = AutoConfig.from_pretrained(
|
| 223 |
base_model,
|
| 224 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
| 225 |
)
|
| 226 |
model = AutoModelForCausalLM.from_pretrained(
|
| 227 |
base_model,
|
|
|
|
| 230 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
| 231 |
torch_dtype=torch_dtype,
|
| 232 |
device_map=cfg.device_map,
|
| 233 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
| 234 |
**model_kwargs,
|
| 235 |
)
|
| 236 |
+
except Exception as err: # pylint: disable=broad-exception-caught
|
| 237 |
logging.error(
|
| 238 |
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
| 239 |
)
|
| 240 |
+
logging.exception(err)
|
| 241 |
model = AutoModelForCausalLM.from_pretrained(
|
| 242 |
base_model,
|
| 243 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
| 244 |
torch_dtype=torch_dtype,
|
| 245 |
device_map=cfg.device_map,
|
| 246 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
| 247 |
**model_kwargs,
|
| 248 |
)
|
| 249 |
|
| 250 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
| 251 |
model.resize_token_embeddings(embeddings_len)
|
| 252 |
|
| 253 |
+
if not cfg.gptq and (
|
| 254 |
+
(cfg.adapter == "lora" and load_in_8bit)
|
| 255 |
+
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
|
|
| 256 |
):
|
| 257 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
| 258 |
model = prepare_model_for_int8_training(model)
|
|
|
|
| 265 |
if cfg.gptq:
|
| 266 |
# Scales to half
|
| 267 |
logging.info("Fitting 4bit scales and zeros to half")
|
| 268 |
+
for _, module in model.named_modules():
|
| 269 |
+
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
| 270 |
+
type(module)
|
| 271 |
):
|
| 272 |
+
if hasattr(module, "is_v1_model") and module.is_v1_model:
|
| 273 |
+
module.zeros = module.zeros.half()
|
| 274 |
+
module.scales = module.scales.half()
|
| 275 |
+
module.bias = module.bias.half()
|
| 276 |
|
| 277 |
if (
|
| 278 |
torch.cuda.device_count() > 1
|