LISA (#1469)
Browse files* add lisa support
* fix default and fix attribute traversal for layers
* improve lisa callback logging
* fix LISA by ensuring params are not frozen during __init__
* example config for lisa
---------
Co-authored-by: Aman Karmani <[email protected]>
examples/llama-2/lisa.yml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_model: NousResearch/Llama-2-7b-hf
|
| 2 |
+
model_type: LlamaForCausalLM
|
| 3 |
+
tokenizer_type: LlamaTokenizer
|
| 4 |
+
|
| 5 |
+
load_in_8bit: false
|
| 6 |
+
load_in_4bit: false
|
| 7 |
+
strict: false
|
| 8 |
+
|
| 9 |
+
datasets:
|
| 10 |
+
- path: teknium/GPT4-LLM-Cleaned
|
| 11 |
+
type: alpaca
|
| 12 |
+
dataset_prepared_path: last_run_prepared
|
| 13 |
+
val_set_size: 0.05
|
| 14 |
+
output_dir: ./lisa-out
|
| 15 |
+
|
| 16 |
+
sequence_len: 4096
|
| 17 |
+
sample_packing: true
|
| 18 |
+
pad_to_sequence_len: true
|
| 19 |
+
|
| 20 |
+
adapter:
|
| 21 |
+
lora_model_dir:
|
| 22 |
+
lora_r:
|
| 23 |
+
lora_alpha:
|
| 24 |
+
lora_dropout:
|
| 25 |
+
lora_target_linear:
|
| 26 |
+
lora_fan_in_fan_out:
|
| 27 |
+
|
| 28 |
+
lisa_n_layers: 4
|
| 29 |
+
lisa_step_interval: 20
|
| 30 |
+
lisa_layers_attribute: model.layers
|
| 31 |
+
|
| 32 |
+
wandb_project:
|
| 33 |
+
wandb_entity:
|
| 34 |
+
wandb_watch:
|
| 35 |
+
wandb_name:
|
| 36 |
+
wandb_log_model:
|
| 37 |
+
|
| 38 |
+
gradient_accumulation_steps: 2
|
| 39 |
+
micro_batch_size: 1
|
| 40 |
+
num_epochs: 1
|
| 41 |
+
optimizer: adamw_bnb_8bit
|
| 42 |
+
lr_scheduler: cosine
|
| 43 |
+
learning_rate: 5e-5 # recommendation from lisa paper for 7b
|
| 44 |
+
|
| 45 |
+
train_on_inputs: false
|
| 46 |
+
group_by_length: false
|
| 47 |
+
bf16: auto
|
| 48 |
+
fp16:
|
| 49 |
+
tf32: false
|
| 50 |
+
|
| 51 |
+
gradient_checkpointing: true
|
| 52 |
+
early_stopping_patience:
|
| 53 |
+
resume_from_checkpoint:
|
| 54 |
+
local_rank:
|
| 55 |
+
logging_steps: 1
|
| 56 |
+
xformers_attention:
|
| 57 |
+
flash_attention: true
|
| 58 |
+
flash_attn_cross_entropy: false
|
| 59 |
+
flash_attn_rms_norm: true
|
| 60 |
+
flash_attn_fuse_qkv: false
|
| 61 |
+
flash_attn_fuse_mlp: true
|
| 62 |
+
|
| 63 |
+
warmup_steps: 100
|
| 64 |
+
evals_per_epoch: 4
|
| 65 |
+
eval_table_size:
|
| 66 |
+
saves_per_epoch: 1
|
| 67 |
+
debug:
|
| 68 |
+
deepspeed:
|
| 69 |
+
weight_decay: 0.1
|
| 70 |
+
fsdp:
|
| 71 |
+
fsdp_config:
|
| 72 |
+
special_tokens:
|
| 73 |
+
bos_token: "<s>"
|
| 74 |
+
eos_token: "</s>"
|
| 75 |
+
unk_token: "<unk>"
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -45,6 +45,7 @@ from axolotl.utils.callbacks import (
|
|
| 45 |
causal_lm_bench_eval_callback_factory,
|
| 46 |
log_prediction_callback_factory,
|
| 47 |
)
|
|
|
|
| 48 |
from axolotl.utils.collators import (
|
| 49 |
BatchSamplerDataCollatorForSeq2Seq,
|
| 50 |
DataCollatorForSeq2Seq,
|
|
@@ -200,6 +201,18 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
| 200 |
orpo_alpha: Optional[float] = field(
|
| 201 |
default=None,
|
| 202 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
class AxolotlTrainer(Trainer):
|
|
@@ -938,6 +951,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 938 |
)
|
| 939 |
callbacks.append(early_stop_cb)
|
| 940 |
|
|
|
|
|
|
|
| 941 |
return callbacks
|
| 942 |
|
| 943 |
def _get_trainer_cls(self):
|
|
@@ -1229,6 +1244,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 1229 |
"relora_prune_ratio"
|
| 1230 |
] = self.cfg.relora_prune_ratio
|
| 1231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1232 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
| 1233 |
training_arguments_kwargs
|
| 1234 |
)
|
|
|
|
| 45 |
causal_lm_bench_eval_callback_factory,
|
| 46 |
log_prediction_callback_factory,
|
| 47 |
)
|
| 48 |
+
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
| 49 |
from axolotl.utils.collators import (
|
| 50 |
BatchSamplerDataCollatorForSeq2Seq,
|
| 51 |
DataCollatorForSeq2Seq,
|
|
|
|
| 201 |
orpo_alpha: Optional[float] = field(
|
| 202 |
default=None,
|
| 203 |
)
|
| 204 |
+
lisa_n_layers: Optional[int] = field(
|
| 205 |
+
default=None,
|
| 206 |
+
metadata={"help": "the number of activate layers in LISA"},
|
| 207 |
+
)
|
| 208 |
+
lisa_step_interval: Optional[int] = field(
|
| 209 |
+
default=None,
|
| 210 |
+
metadata={"help": "how often to switch layers in LISA"},
|
| 211 |
+
)
|
| 212 |
+
lisa_layers_attribute: Optional[str] = field(
|
| 213 |
+
default=None,
|
| 214 |
+
metadata={"help": "path under the model to access the layers"},
|
| 215 |
+
)
|
| 216 |
|
| 217 |
|
| 218 |
class AxolotlTrainer(Trainer):
|
|
|
|
| 951 |
)
|
| 952 |
callbacks.append(early_stop_cb)
|
| 953 |
|
| 954 |
+
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
| 955 |
+
callbacks.append(lisa_callback_factory(trainer))
|
| 956 |
return callbacks
|
| 957 |
|
| 958 |
def _get_trainer_cls(self):
|
|
|
|
| 1244 |
"relora_prune_ratio"
|
| 1245 |
] = self.cfg.relora_prune_ratio
|
| 1246 |
|
| 1247 |
+
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
| 1248 |
+
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
| 1249 |
+
training_arguments_kwargs[
|
| 1250 |
+
"lisa_step_interval"
|
| 1251 |
+
] = self.cfg.lisa_step_interval
|
| 1252 |
+
training_arguments_kwargs[
|
| 1253 |
+
"lisa_layers_attribute"
|
| 1254 |
+
] = self.cfg.lisa_layers_attribute
|
| 1255 |
+
|
| 1256 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
| 1257 |
training_arguments_kwargs
|
| 1258 |
)
|
src/axolotl/utils/callbacks/lisa.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
module for LISA
|
| 3 |
+
|
| 4 |
+
Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl
|
| 5 |
+
Arxiv: https://arxiv.org/abs/2403.17919
|
| 6 |
+
License: Apache 2.0
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from functools import reduce
|
| 11 |
+
from typing import TYPE_CHECKING
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from transformers import TrainerCallback
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from axolotl.core.trainer_builder import AxolotlTrainer
|
| 18 |
+
|
| 19 |
+
LOG = logging.getLogger("axolotl.callbacks.lisa")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
| 23 |
+
class LISACallback(TrainerCallback):
|
| 24 |
+
"""trainer callback for lisa layer switching"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self, n_layers, step_interval, trainer, layers_attribute="model.layers"
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.n_layers = n_layers
|
| 31 |
+
self.step_interval = step_interval
|
| 32 |
+
self.layers_attribute = layers_attribute
|
| 33 |
+
self.trainer = trainer
|
| 34 |
+
|
| 35 |
+
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
| 36 |
+
|
| 37 |
+
self.total_layers = len(
|
| 38 |
+
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
| 39 |
+
)
|
| 40 |
+
self.active_layers_indices = []
|
| 41 |
+
|
| 42 |
+
layers = reduce(
|
| 43 |
+
getattr, self.layers_attribute.split("."), self.trainer.model
|
| 44 |
+
)
|
| 45 |
+
LOG.info(
|
| 46 |
+
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def freeze_all_layers(self):
|
| 50 |
+
layers = reduce(
|
| 51 |
+
getattr, self.layers_attribute.split("."), self.trainer.model
|
| 52 |
+
)
|
| 53 |
+
for layer in layers:
|
| 54 |
+
for param in layer.parameters():
|
| 55 |
+
param.requires_grad = False
|
| 56 |
+
|
| 57 |
+
def on_step_begin(
|
| 58 |
+
self, args, state, control, **kwargs
|
| 59 |
+
): # pylint: disable=unused-argument
|
| 60 |
+
# Check if it's time to switch active layers, including at step 0
|
| 61 |
+
if state.global_step % self.step_interval == 0 or state.global_step == 1:
|
| 62 |
+
self.switch_active_layers()
|
| 63 |
+
|
| 64 |
+
def switch_active_layers(self):
|
| 65 |
+
# First, disable gradients for all layers
|
| 66 |
+
self.freeze_all_layers()
|
| 67 |
+
|
| 68 |
+
# Randomly select n_layers to activate
|
| 69 |
+
layers = reduce(
|
| 70 |
+
getattr, self.layers_attribute.split("."), self.trainer.model
|
| 71 |
+
)
|
| 72 |
+
self.active_layers_indices = np.random.choice(
|
| 73 |
+
range(self.total_layers), self.n_layers, replace=False
|
| 74 |
+
)
|
| 75 |
+
LOG.info(
|
| 76 |
+
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Enable gradients only for the selected layers
|
| 80 |
+
for idx in self.active_layers_indices:
|
| 81 |
+
for param in layers[idx].parameters():
|
| 82 |
+
param.requires_grad = True
|
| 83 |
+
|
| 84 |
+
lisa_callback = LISACallback(
|
| 85 |
+
n_layers=trainer.args.lisa_n_layers,
|
| 86 |
+
step_interval=trainer.args.lisa_step_interval,
|
| 87 |
+
trainer=trainer,
|
| 88 |
+
layers_attribute=trainer.args.lisa_layers_attribute,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return lisa_callback
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -370,6 +370,23 @@ class MLFlowConfig(BaseModel):
|
|
| 370 |
hf_mlflow_log_artifacts: Optional[bool] = None
|
| 371 |
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
class WandbConfig(BaseModel):
|
| 374 |
"""wandb configuration subset"""
|
| 375 |
|
|
@@ -404,6 +421,7 @@ class AxolotlInputConfig(
|
|
| 404 |
HyperparametersConfig,
|
| 405 |
WandbConfig,
|
| 406 |
MLFlowConfig,
|
|
|
|
| 407 |
RemappedParameters,
|
| 408 |
DeprecatedParameters,
|
| 409 |
BaseModel,
|
|
|
|
| 370 |
hf_mlflow_log_artifacts: Optional[bool] = None
|
| 371 |
|
| 372 |
|
| 373 |
+
class LISAConfig(BaseModel):
|
| 374 |
+
"""LISA options"""
|
| 375 |
+
|
| 376 |
+
lisa_n_layers: Optional[int] = Field(
|
| 377 |
+
default=None,
|
| 378 |
+
metadata={"help": "the number of activate layers in LISA"},
|
| 379 |
+
)
|
| 380 |
+
lisa_step_interval: Optional[int] = Field(
|
| 381 |
+
default=None,
|
| 382 |
+
metadata={"help": "how often to switch layers in LISA"},
|
| 383 |
+
)
|
| 384 |
+
lisa_layers_attribute: Optional[str] = Field(
|
| 385 |
+
default="model.layers",
|
| 386 |
+
metadata={"help": "path under the model to access the layers"},
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
class WandbConfig(BaseModel):
|
| 391 |
"""wandb configuration subset"""
|
| 392 |
|
|
|
|
| 421 |
HyperparametersConfig,
|
| 422 |
WandbConfig,
|
| 423 |
MLFlowConfig,
|
| 424 |
+
LISAConfig,
|
| 425 |
RemappedParameters,
|
| 426 |
DeprecatedParameters,
|
| 427 |
BaseModel,
|