FDSP + QLoRA (#1378)
Browse files* wip qlora + fsdp fixes
* more fixes
* make sure to load the lora :facepalm:
* only setup quantized meta on non-zero rank:
* only run setup_quantized_peft_meta_for_training for qlora+fsdp
* more fixes for qlora+fsdp
* chore: lint
* add example yml
* support mistral too
* fix for model_type and add mixtral support too
* set cpu_offload: false to reduce vram, constrain new accleerator logic to qlora + fsdp
* refactor for duplicate code
- examples/llama-2/qlora-fsdp.yml +70 -0
- examples/mistral/mixtral-qlora-fsdp.yml +74 -0
- requirements.txt +2 -1
- src/axolotl/core/policies/__init__.py +0 -0
- src/axolotl/core/policies/auto_wrap.py +55 -0
- src/axolotl/core/trainer_builder.py +62 -0
- src/axolotl/utils/bench.py +1 -1
- src/axolotl/utils/models.py +238 -7
examples/llama-2/qlora-fsdp.yml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_model: NousResearch/Llama-2-7b-hf
|
| 2 |
+
model_type: LlamaForCausalLM
|
| 3 |
+
tokenizer_type: LlamaTokenizer
|
| 4 |
+
|
| 5 |
+
load_in_8bit: false
|
| 6 |
+
load_in_4bit: true
|
| 7 |
+
strict: false
|
| 8 |
+
|
| 9 |
+
datasets:
|
| 10 |
+
- path: yahma/alpaca-cleaned
|
| 11 |
+
type: alpaca
|
| 12 |
+
dataset_prepared_path: last_run_prepared
|
| 13 |
+
val_set_size: 0.05
|
| 14 |
+
output_dir: ./qlora-out
|
| 15 |
+
|
| 16 |
+
adapter: qlora
|
| 17 |
+
lora_model_dir:
|
| 18 |
+
|
| 19 |
+
sequence_len: 512
|
| 20 |
+
sample_packing: false
|
| 21 |
+
pad_to_sequence_len: true
|
| 22 |
+
|
| 23 |
+
lora_r: 32
|
| 24 |
+
lora_alpha: 16
|
| 25 |
+
lora_dropout: 0.05
|
| 26 |
+
lora_target_modules:
|
| 27 |
+
lora_target_linear: true
|
| 28 |
+
lora_fan_in_fan_out:
|
| 29 |
+
|
| 30 |
+
wandb_project:
|
| 31 |
+
wandb_entity:
|
| 32 |
+
wandb_watch:
|
| 33 |
+
wandb_name:
|
| 34 |
+
wandb_log_model:
|
| 35 |
+
|
| 36 |
+
gradient_accumulation_steps: 4
|
| 37 |
+
micro_batch_size: 4
|
| 38 |
+
num_epochs: 4
|
| 39 |
+
optimizer: paged_adamw_8bit
|
| 40 |
+
lr_scheduler: cosine
|
| 41 |
+
learning_rate: 0.00001
|
| 42 |
+
|
| 43 |
+
train_on_inputs: false
|
| 44 |
+
group_by_length: false
|
| 45 |
+
bf16: auto
|
| 46 |
+
fp16:
|
| 47 |
+
tf32: false
|
| 48 |
+
|
| 49 |
+
gradient_checkpointing: true
|
| 50 |
+
gradient_checkpointing_kwargs:
|
| 51 |
+
use_reentrant: true
|
| 52 |
+
early_stopping_patience:
|
| 53 |
+
resume_from_checkpoint:
|
| 54 |
+
local_rank:
|
| 55 |
+
logging_steps: 1
|
| 56 |
+
xformers_attention:
|
| 57 |
+
flash_attention: true
|
| 58 |
+
|
| 59 |
+
warmup_steps: 10
|
| 60 |
+
evals_per_epoch: 4
|
| 61 |
+
eval_table_size:
|
| 62 |
+
saves_per_epoch: 1
|
| 63 |
+
debug:
|
| 64 |
+
deepspeed:
|
| 65 |
+
weight_decay: 0.0
|
| 66 |
+
fsdp:
|
| 67 |
+
- full_shard
|
| 68 |
+
fsdp_config:
|
| 69 |
+
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
| 70 |
+
special_tokens:
|
examples/mistral/mixtral-qlora-fsdp.yml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_model: mistralai/Mixtral-8x7B-v0.1
|
| 2 |
+
model_type: AutoModelForCausalLM
|
| 3 |
+
tokenizer_type: LlamaTokenizer
|
| 4 |
+
trust_remote_code: true
|
| 5 |
+
|
| 6 |
+
load_in_8bit: false
|
| 7 |
+
load_in_4bit: true
|
| 8 |
+
strict: false
|
| 9 |
+
|
| 10 |
+
datasets:
|
| 11 |
+
- path: tatsu-lab/alpaca
|
| 12 |
+
type: alpaca
|
| 13 |
+
dataset_prepared_path: last_run_prepared
|
| 14 |
+
val_set_size: 0.02
|
| 15 |
+
output_dir: ./qlora-out
|
| 16 |
+
|
| 17 |
+
model_config:
|
| 18 |
+
output_router_logits: true
|
| 19 |
+
|
| 20 |
+
adapter: qlora
|
| 21 |
+
lora_model_dir:
|
| 22 |
+
|
| 23 |
+
sequence_len: 1024
|
| 24 |
+
sample_packing: false
|
| 25 |
+
pad_to_sequence_len: false
|
| 26 |
+
|
| 27 |
+
lora_r: 32
|
| 28 |
+
lora_alpha: 16
|
| 29 |
+
lora_dropout: 0.05
|
| 30 |
+
lora_target_linear: true
|
| 31 |
+
lora_fan_in_fan_out:
|
| 32 |
+
|
| 33 |
+
wandb_project:
|
| 34 |
+
wandb_entity:
|
| 35 |
+
wandb_watch:
|
| 36 |
+
wandb_name:
|
| 37 |
+
wandb_log_model:
|
| 38 |
+
|
| 39 |
+
gradient_accumulation_steps: 4
|
| 40 |
+
micro_batch_size: 2
|
| 41 |
+
num_epochs: 1
|
| 42 |
+
optimizer: paged_adamw_8bit
|
| 43 |
+
lr_scheduler: cosine
|
| 44 |
+
learning_rate: 0.0002
|
| 45 |
+
|
| 46 |
+
train_on_inputs: false
|
| 47 |
+
group_by_length: false
|
| 48 |
+
bf16: auto
|
| 49 |
+
fp16:
|
| 50 |
+
tf32: false
|
| 51 |
+
|
| 52 |
+
gradient_checkpointing: true
|
| 53 |
+
early_stopping_patience:
|
| 54 |
+
resume_from_checkpoint:
|
| 55 |
+
local_rank:
|
| 56 |
+
logging_steps: 1
|
| 57 |
+
xformers_attention:
|
| 58 |
+
flash_attention: true
|
| 59 |
+
|
| 60 |
+
loss_watchdog_threshold: 5.0
|
| 61 |
+
loss_watchdog_patience: 3
|
| 62 |
+
|
| 63 |
+
warmup_steps: 10
|
| 64 |
+
evals_per_epoch: 4
|
| 65 |
+
eval_table_size:
|
| 66 |
+
eval_max_new_tokens: 128
|
| 67 |
+
saves_per_epoch: 1
|
| 68 |
+
debug:
|
| 69 |
+
weight_decay: 0.0
|
| 70 |
+
fsdp:
|
| 71 |
+
- full_shard
|
| 72 |
+
fsdp_config:
|
| 73 |
+
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
| 74 |
+
special_tokens:
|
requirements.txt
CHANGED
|
@@ -3,7 +3,7 @@ packaging==23.2
|
|
| 3 |
peft==0.9.0
|
| 4 |
transformers==4.38.2
|
| 5 |
tokenizers==0.15.0
|
| 6 |
-
bitsandbytes>=0.
|
| 7 |
accelerate==0.26.1
|
| 8 |
deepspeed==0.13.1
|
| 9 |
pydantic==2.6.3
|
|
@@ -40,3 +40,4 @@ gcsfs
|
|
| 40 |
# adlfs
|
| 41 |
|
| 42 |
trl>=0.7.9
|
|
|
|
|
|
| 3 |
peft==0.9.0
|
| 4 |
transformers==4.38.2
|
| 5 |
tokenizers==0.15.0
|
| 6 |
+
bitsandbytes>=0.43.0
|
| 7 |
accelerate==0.26.1
|
| 8 |
deepspeed==0.13.1
|
| 9 |
pydantic==2.6.3
|
|
|
|
| 40 |
# adlfs
|
| 41 |
|
| 42 |
trl>=0.7.9
|
| 43 |
+
fastcore>=1.5.29
|
src/axolotl/core/policies/__init__.py
ADDED
|
File without changes
|
src/axolotl/core/policies/auto_wrap.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""module for building the auto wrap policy for FSDP"""
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
|
| 5 |
+
from torch.distributed.fsdp.wrap import (
|
| 6 |
+
_or_policy,
|
| 7 |
+
lambda_auto_wrap_policy,
|
| 8 |
+
transformer_auto_wrap_policy,
|
| 9 |
+
)
|
| 10 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
| 11 |
+
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
| 12 |
+
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
| 13 |
+
|
| 14 |
+
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
|
| 15 |
+
"llama",
|
| 16 |
+
"mistral",
|
| 17 |
+
"mixtral",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_wrapping_policy_factory(model_type):
|
| 22 |
+
if model_type == "llama":
|
| 23 |
+
layer_to_wrap = LlamaDecoderLayer
|
| 24 |
+
elif model_type == "mistral":
|
| 25 |
+
layer_to_wrap = MistralDecoderLayer
|
| 26 |
+
elif model_type == "mixtral":
|
| 27 |
+
layer_to_wrap = MixtralDecoderLayer
|
| 28 |
+
|
| 29 |
+
def get_wrapping_policy():
|
| 30 |
+
"""This checks for lora layers (has weight and requires_grad)"""
|
| 31 |
+
|
| 32 |
+
def lambda_policy_fn(module):
|
| 33 |
+
return (
|
| 34 |
+
len(list(module.named_children())) == 0
|
| 35 |
+
and getattr(module, "weight", None) is not None
|
| 36 |
+
and module.weight.requires_grad
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
lambda_policy = functools.partial(
|
| 40 |
+
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
|
| 41 |
+
)
|
| 42 |
+
transformer_layer_name = layer_to_wrap
|
| 43 |
+
transformer_wrap_policy = functools.partial(
|
| 44 |
+
transformer_auto_wrap_policy,
|
| 45 |
+
transformer_layer_cls=(
|
| 46 |
+
PrefixEncoder,
|
| 47 |
+
PromptEncoder,
|
| 48 |
+
PromptEmbedding,
|
| 49 |
+
transformer_layer_name,
|
| 50 |
+
),
|
| 51 |
+
)
|
| 52 |
+
policies = [lambda_policy, transformer_wrap_policy]
|
| 53 |
+
return functools.partial(_or_policy, policies=policies)
|
| 54 |
+
|
| 55 |
+
return get_wrapping_policy
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -8,6 +8,7 @@ import importlib
|
|
| 8 |
import importlib.util
|
| 9 |
import logging
|
| 10 |
import math
|
|
|
|
| 11 |
import sys
|
| 12 |
from abc import abstractmethod
|
| 13 |
from dataclasses import dataclass, field
|
|
@@ -17,7 +18,10 @@ from typing import List, Optional, Type, Union
|
|
| 17 |
|
| 18 |
import torch
|
| 19 |
import transformers
|
|
|
|
|
|
|
| 20 |
from datasets import Dataset
|
|
|
|
| 21 |
from torch.optim.lr_scheduler import OneCycleLR
|
| 22 |
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
| 23 |
from transformers import (
|
|
@@ -30,6 +34,7 @@ from transformers.trainer_utils import seed_worker
|
|
| 30 |
from transformers.utils import is_sagemaker_mp_enabled
|
| 31 |
from trl import DPOTrainer
|
| 32 |
|
|
|
|
| 33 |
from axolotl.loraplus import create_loraplus_optimizer
|
| 34 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
| 35 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|
@@ -191,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
| 191 |
default=1e-6,
|
| 192 |
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
| 193 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
class AxolotlTrainer(Trainer):
|
|
@@ -468,6 +477,56 @@ class AxolotlTrainer(Trainer):
|
|
| 468 |
|
| 469 |
return super().push_to_hub(*args, **kwargs)
|
| 470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
| 473 |
"""
|
|
@@ -787,6 +846,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 787 |
if self.cfg.fsdp_config:
|
| 788 |
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
| 789 |
|
|
|
|
|
|
|
|
|
|
| 790 |
# deepspeed
|
| 791 |
if self.cfg.deepspeed:
|
| 792 |
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
|
|
|
| 8 |
import importlib.util
|
| 9 |
import logging
|
| 10 |
import math
|
| 11 |
+
import os
|
| 12 |
import sys
|
| 13 |
from abc import abstractmethod
|
| 14 |
from dataclasses import dataclass, field
|
|
|
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import transformers
|
| 21 |
+
from accelerate import FullyShardedDataParallelPlugin
|
| 22 |
+
from accelerate.utils import str_to_bool
|
| 23 |
from datasets import Dataset
|
| 24 |
+
from torch.distributed.fsdp import MixedPrecision
|
| 25 |
from torch.optim.lr_scheduler import OneCycleLR
|
| 26 |
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
| 27 |
from transformers import (
|
|
|
|
| 34 |
from transformers.utils import is_sagemaker_mp_enabled
|
| 35 |
from trl import DPOTrainer
|
| 36 |
|
| 37 |
+
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
| 38 |
from axolotl.loraplus import create_loraplus_optimizer
|
| 39 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
| 40 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|
|
|
| 196 |
default=1e-6,
|
| 197 |
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
| 198 |
)
|
| 199 |
+
qlora: bool = field(
|
| 200 |
+
default=False,
|
| 201 |
+
metadata={"help": "whether this is a qlora training"},
|
| 202 |
+
)
|
| 203 |
|
| 204 |
|
| 205 |
class AxolotlTrainer(Trainer):
|
|
|
|
| 477 |
|
| 478 |
return super().push_to_hub(*args, **kwargs)
|
| 479 |
|
| 480 |
+
@wraps(Trainer.create_accelerator_and_postprocess)
|
| 481 |
+
def create_accelerator_and_postprocess(self):
|
| 482 |
+
rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 483 |
+
res = super().create_accelerator_and_postprocess()
|
| 484 |
+
|
| 485 |
+
if self.args.qlora is False:
|
| 486 |
+
return res
|
| 487 |
+
|
| 488 |
+
# the rest of this method override is specific to fsdp + qlora (for now)
|
| 489 |
+
sync_module_states = (
|
| 490 |
+
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
mp_policy = None
|
| 494 |
+
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
|
| 495 |
+
if amp == "fp16":
|
| 496 |
+
mp_policy = MixedPrecision(
|
| 497 |
+
param_dtype=torch.float32,
|
| 498 |
+
reduce_dtype=torch.float32,
|
| 499 |
+
buffer_dtype=torch.float32,
|
| 500 |
+
)
|
| 501 |
+
elif amp == "bf16":
|
| 502 |
+
mp_policy = MixedPrecision(
|
| 503 |
+
param_dtype=torch.float32,
|
| 504 |
+
reduce_dtype=torch.float32,
|
| 505 |
+
buffer_dtype=torch.float32,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# If somehow we figure out how we want to parameterize we want to autocast buffers...
|
| 509 |
+
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
|
| 510 |
+
# load_param_skip_names = ['inv_freq']
|
| 511 |
+
|
| 512 |
+
if self.is_fsdp_enabled:
|
| 513 |
+
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
|
| 514 |
+
fsdp_plugin = FullyShardedDataParallelPlugin(
|
| 515 |
+
auto_wrap_policy=wrapping_policy(),
|
| 516 |
+
cpu_offload=False,
|
| 517 |
+
use_orig_params=False,
|
| 518 |
+
limit_all_gathers=True,
|
| 519 |
+
param_init_fn=lambda module: module.to_empty(
|
| 520 |
+
device=torch.device("cuda"), recurse=False
|
| 521 |
+
)
|
| 522 |
+
if (rank != 0 and sync_module_states)
|
| 523 |
+
else None,
|
| 524 |
+
mixed_precision_policy=mp_policy,
|
| 525 |
+
)
|
| 526 |
+
self.accelerator.state.fsdp_plugin = fsdp_plugin
|
| 527 |
+
|
| 528 |
+
return res
|
| 529 |
+
|
| 530 |
|
| 531 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
| 532 |
"""
|
|
|
|
| 846 |
if self.cfg.fsdp_config:
|
| 847 |
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
| 848 |
|
| 849 |
+
if self.cfg.adapter == "qlora":
|
| 850 |
+
training_arguments_kwargs["qlora"] = True
|
| 851 |
+
|
| 852 |
# deepspeed
|
| 853 |
if self.cfg.deepspeed:
|
| 854 |
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
src/axolotl/utils/bench.py
CHANGED
|
@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
|
|
| 24 |
or not torch.cuda.is_available()
|
| 25 |
or device == "auto"
|
| 26 |
or torch.device(device).type == "cpu"
|
|
|
|
| 27 |
):
|
| 28 |
return default_value
|
| 29 |
-
|
| 30 |
return func(*args, **kwargs)
|
| 31 |
|
| 32 |
return wrapper
|
|
|
|
| 24 |
or not torch.cuda.is_available()
|
| 25 |
or device == "auto"
|
| 26 |
or torch.device(device).type == "cpu"
|
| 27 |
+
or torch.device(device).type == "meta"
|
| 28 |
):
|
| 29 |
return default_value
|
|
|
|
| 30 |
return func(*args, **kwargs)
|
| 31 |
|
| 32 |
return wrapper
|
src/axolotl/utils/models.py
CHANGED
|
@@ -1,13 +1,20 @@
|
|
| 1 |
"""Module for models and model loading"""
|
|
|
|
|
|
|
| 2 |
import logging
|
| 3 |
import math
|
| 4 |
import os
|
| 5 |
-
|
|
|
|
| 6 |
|
| 7 |
import addict
|
| 8 |
import bitsandbytes as bnb
|
|
|
|
| 9 |
import torch
|
| 10 |
import transformers
|
|
|
|
|
|
|
|
|
|
| 11 |
from peft import (
|
| 12 |
LoftQConfig,
|
| 13 |
PeftConfig,
|
|
@@ -16,6 +23,7 @@ from peft import (
|
|
| 16 |
prepare_model_for_kbit_training,
|
| 17 |
)
|
| 18 |
from peft.tuners.lora import QuantLinear
|
|
|
|
| 19 |
from transformers import ( # noqa: F401
|
| 20 |
AddedToken,
|
| 21 |
AutoConfig,
|
|
@@ -27,7 +35,9 @@ from transformers import ( # noqa: F401
|
|
| 27 |
PreTrainedTokenizerBase,
|
| 28 |
)
|
| 29 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
|
|
| 30 |
|
|
|
|
| 31 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
| 32 |
from axolotl.monkeypatch.multipack import (
|
| 33 |
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
|
@@ -262,6 +272,117 @@ def load_tokenizer(cfg):
|
|
| 262 |
return tokenizer
|
| 263 |
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
def load_model(
|
| 266 |
cfg: DictDefault,
|
| 267 |
tokenizer: PreTrainedTokenizerBase,
|
|
@@ -394,7 +515,7 @@ def load_model(
|
|
| 394 |
|
| 395 |
if max_memory is not None:
|
| 396 |
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
|
| 397 |
-
from accelerate import infer_auto_device_map
|
| 398 |
|
| 399 |
with init_empty_weights():
|
| 400 |
model_canvas = AutoModelForCausalLM.from_config(model_config)
|
|
@@ -496,8 +617,78 @@ def load_model(
|
|
| 496 |
model_kwargs["attn_implementation"] = "eager"
|
| 497 |
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
| 498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
try:
|
| 500 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
model_config.model_type == "llama"
|
| 502 |
and not cfg.trust_remote_code
|
| 503 |
and not cfg.gptq
|
|
@@ -613,7 +804,7 @@ def load_model(
|
|
| 613 |
LOG.exception(err)
|
| 614 |
raise err
|
| 615 |
|
| 616 |
-
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
| 617 |
model = model.merge_and_unload()
|
| 618 |
|
| 619 |
embeddings_len = (
|
|
@@ -692,6 +883,9 @@ def load_model(
|
|
| 692 |
if cfg.adapter == "lora" and loftq_bits:
|
| 693 |
skip_prepare_model_for_kbit_training = True
|
| 694 |
|
|
|
|
|
|
|
|
|
|
| 695 |
if cfg.adapter in ["lora", "qlora"]:
|
| 696 |
if cfg.gradient_checkpointing:
|
| 697 |
model.gradient_checkpointing_enable()
|
|
@@ -706,7 +900,7 @@ def load_model(
|
|
| 706 |
|
| 707 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
| 708 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 709 |
-
if needs_fa2_dtype or cfg.flash_attention:
|
| 710 |
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
| 711 |
for name, module in model.named_modules():
|
| 712 |
if "norm" in name:
|
|
@@ -724,7 +918,12 @@ def load_model(
|
|
| 724 |
else:
|
| 725 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
| 726 |
|
| 727 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
# TODO revaldate this conditional
|
| 729 |
model.to(f"cuda:{cfg.local_rank}")
|
| 730 |
|
|
@@ -813,6 +1012,30 @@ def find_all_linear_names(model):
|
|
| 813 |
return list(lora_module_names)
|
| 814 |
|
| 815 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
def load_lora(model, cfg, inference=False, config_only=False):
|
| 817 |
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
| 818 |
|
|
@@ -849,6 +1072,11 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|
| 849 |
if config_only:
|
| 850 |
return None, lora_config
|
| 851 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
if cfg.lora_model_dir:
|
| 853 |
LOG.debug("Loading pretrained PEFT - LoRA")
|
| 854 |
model_kwargs: Any = {}
|
|
@@ -864,6 +1092,9 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|
| 864 |
else:
|
| 865 |
model = get_peft_model(model, lora_config)
|
| 866 |
|
| 867 |
-
|
|
|
|
|
|
|
|
|
|
| 868 |
|
| 869 |
return model, lora_config
|
|
|
|
| 1 |
"""Module for models and model loading"""
|
| 2 |
+
# pylint: disable=too-many-lines
|
| 3 |
+
|
| 4 |
import logging
|
| 5 |
import math
|
| 6 |
import os
|
| 7 |
+
import types
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
|
| 9 |
|
| 10 |
import addict
|
| 11 |
import bitsandbytes as bnb
|
| 12 |
+
import safetensors
|
| 13 |
import torch
|
| 14 |
import transformers
|
| 15 |
+
from accelerate import init_empty_weights
|
| 16 |
+
from bitsandbytes.nn import Linear4bit, Params4bit
|
| 17 |
+
from fastcore.parallel import parallel
|
| 18 |
from peft import (
|
| 19 |
LoftQConfig,
|
| 20 |
PeftConfig,
|
|
|
|
| 23 |
prepare_model_for_kbit_training,
|
| 24 |
)
|
| 25 |
from peft.tuners.lora import QuantLinear
|
| 26 |
+
from torch import Tensor, nn
|
| 27 |
from transformers import ( # noqa: F401
|
| 28 |
AddedToken,
|
| 29 |
AutoConfig,
|
|
|
|
| 35 |
PreTrainedTokenizerBase,
|
| 36 |
)
|
| 37 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 38 |
+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
| 39 |
|
| 40 |
+
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
| 41 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
| 42 |
from axolotl.monkeypatch.multipack import (
|
| 43 |
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
|
|
|
| 272 |
return tokenizer
|
| 273 |
|
| 274 |
|
| 275 |
+
def replace_linear(
|
| 276 |
+
model: nn.Module,
|
| 277 |
+
linear_replacement: Type[nn.Module],
|
| 278 |
+
quant_config: Union[dict, None] = None,
|
| 279 |
+
skip_modules=None,
|
| 280 |
+
**kwargs,
|
| 281 |
+
):
|
| 282 |
+
"""
|
| 283 |
+
Replace linear modules with a new Linear module.
|
| 284 |
+
Parameters:
|
| 285 |
+
model (`torch.nn.Module`):
|
| 286 |
+
Input model or `torch.nn.Module` as the function is run recursively.
|
| 287 |
+
linear_replacement (`torch.nn.Module`):
|
| 288 |
+
The linear module that replaces the old one. Only expects standard arguments.
|
| 289 |
+
If other arguments need to be passed, use a lambda.
|
| 290 |
+
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
| 291 |
+
List of modules names not to convert. Defaults to `lm_head`.
|
| 292 |
+
"""
|
| 293 |
+
if skip_modules is None:
|
| 294 |
+
skip_modules = ["lm_head"]
|
| 295 |
+
for name, module in model.named_children():
|
| 296 |
+
if len(list(module.children())) > 0:
|
| 297 |
+
replace_linear(
|
| 298 |
+
module, linear_replacement, quant_config, skip_modules, **kwargs
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
| 302 |
+
if issubclass(linear_replacement, Linear4bit):
|
| 303 |
+
model._modules[ # pylint: disable=protected-access
|
| 304 |
+
name
|
| 305 |
+
] = linear_replacement(
|
| 306 |
+
module.in_features,
|
| 307 |
+
module.out_features,
|
| 308 |
+
module.bias is not None,
|
| 309 |
+
**kwargs,
|
| 310 |
+
)
|
| 311 |
+
else:
|
| 312 |
+
raise ValueError(
|
| 313 |
+
f"Unsupported linear replacement: {type(linear_replacement)}"
|
| 314 |
+
)
|
| 315 |
+
return model
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def load_and_quantize(
|
| 319 |
+
module: nn.Module,
|
| 320 |
+
name: str,
|
| 321 |
+
value: Tensor,
|
| 322 |
+
device: torch.device = None,
|
| 323 |
+
dtype: torch.dtype = None,
|
| 324 |
+
skip_names: Optional[List[str]] = None,
|
| 325 |
+
is_meta_rank: bool = False,
|
| 326 |
+
low_memory: bool = True,
|
| 327 |
+
verbose: bool = False,
|
| 328 |
+
quant_method: str = "bnb",
|
| 329 |
+
):
|
| 330 |
+
"""
|
| 331 |
+
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
| 332 |
+
|
| 333 |
+
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
if skip_names is None:
|
| 337 |
+
skip_names = []
|
| 338 |
+
|
| 339 |
+
def place_on_device(value):
|
| 340 |
+
if is_meta_rank:
|
| 341 |
+
device = "meta"
|
| 342 |
+
elif low_memory:
|
| 343 |
+
device = "cpu"
|
| 344 |
+
else:
|
| 345 |
+
device = "cuda"
|
| 346 |
+
return value.to(device=device, dtype=dtype)
|
| 347 |
+
|
| 348 |
+
if any(skip_name in name for skip_name in skip_names):
|
| 349 |
+
if verbose:
|
| 350 |
+
print(f"Skipping {name} because it is in skip_names")
|
| 351 |
+
return
|
| 352 |
+
|
| 353 |
+
module_key, _, value_key = name.rpartition(".")
|
| 354 |
+
try:
|
| 355 |
+
submodule = module.get_submodule(module_key)
|
| 356 |
+
except AttributeError as exc:
|
| 357 |
+
print(f"Module {module_key} not found:\n{exc}")
|
| 358 |
+
return
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
if quant_method == "bnb":
|
| 362 |
+
param = submodule.get_parameter(value_key)
|
| 363 |
+
if isinstance(param, Params4bit):
|
| 364 |
+
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
| 365 |
+
# shape as the quantized Params4bit with an initialized quant_state. However,
|
| 366 |
+
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
| 367 |
+
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
| 368 |
+
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
| 369 |
+
value = type(param)(
|
| 370 |
+
value.to(device=device, dtype=dtype).data, **param.__dict__
|
| 371 |
+
).cuda(device)
|
| 372 |
+
if is_meta_rank:
|
| 373 |
+
value = type(param)(value.data.to("meta"), **value.__dict__)
|
| 374 |
+
elif low_memory:
|
| 375 |
+
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
| 376 |
+
else:
|
| 377 |
+
value = type(param)(place_on_device(value).data)
|
| 378 |
+
|
| 379 |
+
except AttributeError:
|
| 380 |
+
# it's a buffer
|
| 381 |
+
value = place_on_device(value)
|
| 382 |
+
|
| 383 |
+
setattr(submodule, value_key, value)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
def load_model(
|
| 387 |
cfg: DictDefault,
|
| 388 |
tokenizer: PreTrainedTokenizerBase,
|
|
|
|
| 515 |
|
| 516 |
if max_memory is not None:
|
| 517 |
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
|
| 518 |
+
from accelerate import infer_auto_device_map
|
| 519 |
|
| 520 |
with init_empty_weights():
|
| 521 |
model_canvas = AutoModelForCausalLM.from_config(model_config)
|
|
|
|
| 617 |
model_kwargs["attn_implementation"] = "eager"
|
| 618 |
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
| 619 |
|
| 620 |
+
qlora_fsdp = (
|
| 621 |
+
cfg.fsdp
|
| 622 |
+
and cfg.adapter == "qlora"
|
| 623 |
+
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
try:
|
| 627 |
+
if qlora_fsdp:
|
| 628 |
+
if cfg.bf16 or cfg.bfloat16:
|
| 629 |
+
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
|
| 630 |
+
elif cfg.fp16 or cfg.float16:
|
| 631 |
+
torch_dtype, compute_dtype = torch.float32, torch.float16
|
| 632 |
+
else:
|
| 633 |
+
torch_dtype, compute_dtype = torch.float32, torch.float16
|
| 634 |
+
|
| 635 |
+
with init_empty_weights():
|
| 636 |
+
LOG.info("Loading model with empty weights.")
|
| 637 |
+
model = AutoModelForCausalLM.from_config(model_config)
|
| 638 |
+
model.model = replace_linear(
|
| 639 |
+
model.model,
|
| 640 |
+
Linear4bit,
|
| 641 |
+
compute_dtype=compute_dtype,
|
| 642 |
+
quant_type="nf4",
|
| 643 |
+
quant_storage=torch_dtype,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
model.is_loaded_in_4bit = True
|
| 647 |
+
|
| 648 |
+
# Grab the safetensors files that hold the weights
|
| 649 |
+
try:
|
| 650 |
+
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
|
| 651 |
+
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
|
| 652 |
+
except OSError:
|
| 653 |
+
try:
|
| 654 |
+
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
| 655 |
+
files = []
|
| 656 |
+
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
|
| 657 |
+
except OSError as exc:
|
| 658 |
+
# This means the model probably doesn't have a safetensors file
|
| 659 |
+
raise exc
|
| 660 |
+
|
| 661 |
+
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
| 662 |
+
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
| 663 |
+
def load_and_quantize_parallel(name_param, model, **kwargs):
|
| 664 |
+
name, param = name_param
|
| 665 |
+
load_and_quantize(model, name, param, **kwargs)
|
| 666 |
+
|
| 667 |
+
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
| 668 |
+
for filename in files:
|
| 669 |
+
weights = safetensors.torch.load_file(filename)
|
| 670 |
+
quant_method = "bnb"
|
| 671 |
+
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
| 672 |
+
left = int(os.cpu_count() / torch.cuda.device_count())
|
| 673 |
+
right = int(
|
| 674 |
+
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
|
| 675 |
+
)
|
| 676 |
+
n_workers = min(left, right)
|
| 677 |
+
parallel(
|
| 678 |
+
load_and_quantize_parallel,
|
| 679 |
+
weights.items(),
|
| 680 |
+
n_workers=n_workers,
|
| 681 |
+
threadpool=True,
|
| 682 |
+
model=model,
|
| 683 |
+
dtype=torch_dtype,
|
| 684 |
+
device=cfg.local_rank,
|
| 685 |
+
skip_names=[],
|
| 686 |
+
is_meta_rank=(cfg.local_rank != 0),
|
| 687 |
+
verbose=False,
|
| 688 |
+
quant_method=quant_method,
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
elif (
|
| 692 |
model_config.model_type == "llama"
|
| 693 |
and not cfg.trust_remote_code
|
| 694 |
and not cfg.gptq
|
|
|
|
| 804 |
LOG.exception(err)
|
| 805 |
raise err
|
| 806 |
|
| 807 |
+
if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
|
| 808 |
model = model.merge_and_unload()
|
| 809 |
|
| 810 |
embeddings_len = (
|
|
|
|
| 883 |
if cfg.adapter == "lora" and loftq_bits:
|
| 884 |
skip_prepare_model_for_kbit_training = True
|
| 885 |
|
| 886 |
+
if qlora_fsdp:
|
| 887 |
+
skip_prepare_model_for_kbit_training = True
|
| 888 |
+
|
| 889 |
if cfg.adapter in ["lora", "qlora"]:
|
| 890 |
if cfg.gradient_checkpointing:
|
| 891 |
model.gradient_checkpointing_enable()
|
|
|
|
| 900 |
|
| 901 |
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
| 902 |
# convert them back to fp16/bf16 for flash-attn compatibility.
|
| 903 |
+
if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
|
| 904 |
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
| 905 |
for name, module in model.named_modules():
|
| 906 |
if "norm" in name:
|
|
|
|
| 918 |
else:
|
| 919 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
| 920 |
|
| 921 |
+
if (
|
| 922 |
+
cfg.ddp
|
| 923 |
+
and not load_in_8bit
|
| 924 |
+
and not (cfg.rl and cfg.load_in_4bit)
|
| 925 |
+
and not qlora_fsdp
|
| 926 |
+
):
|
| 927 |
# TODO revaldate this conditional
|
| 928 |
model.to(f"cuda:{cfg.local_rank}")
|
| 929 |
|
|
|
|
| 1012 |
return list(lora_module_names)
|
| 1013 |
|
| 1014 |
|
| 1015 |
+
def setup_quantized_meta_for_peft(model: nn.Module):
|
| 1016 |
+
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
|
| 1017 |
+
|
| 1018 |
+
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
|
| 1019 |
+
return self
|
| 1020 |
+
|
| 1021 |
+
for param in model.parameters():
|
| 1022 |
+
if isinstance(param, Params4bit):
|
| 1023 |
+
param.quant_state._orig_to = ( # pylint: disable=protected-access
|
| 1024 |
+
param.quant_state.to
|
| 1025 |
+
)
|
| 1026 |
+
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
def setup_quantized_peft_meta_for_training(model: nn.Module):
|
| 1030 |
+
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
|
| 1031 |
+
for param in model.parameters():
|
| 1032 |
+
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
|
| 1033 |
+
param.quant_state.to = (
|
| 1034 |
+
param.quant_state._orig_to # pylint: disable=protected-access
|
| 1035 |
+
)
|
| 1036 |
+
param.quant_state._orig_to = None # pylint: disable=protected-access
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
def load_lora(model, cfg, inference=False, config_only=False):
|
| 1040 |
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
| 1041 |
|
|
|
|
| 1072 |
if config_only:
|
| 1073 |
return None, lora_config
|
| 1074 |
|
| 1075 |
+
rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 1076 |
+
|
| 1077 |
+
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
|
| 1078 |
+
setup_quantized_meta_for_peft(model)
|
| 1079 |
+
|
| 1080 |
if cfg.lora_model_dir:
|
| 1081 |
LOG.debug("Loading pretrained PEFT - LoRA")
|
| 1082 |
model_kwargs: Any = {}
|
|
|
|
| 1092 |
else:
|
| 1093 |
model = get_peft_model(model, lora_config)
|
| 1094 |
|
| 1095 |
+
if rank == 0:
|
| 1096 |
+
model.print_trainable_parameters()
|
| 1097 |
+
elif cfg.fsdp and cfg.adapter == "qlora":
|
| 1098 |
+
setup_quantized_peft_meta_for_training(model)
|
| 1099 |
|
| 1100 |
return model, lora_config
|