cleanup, prep for 4bit quant support
Browse files- README.md +21 -1
- scripts/finetune.py +18 -6
- setup.cfg +3 -0
README.md
CHANGED
|
@@ -30,4 +30,24 @@ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
|
|
| 30 |
|
| 31 |
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
| 32 |
- Install python dependencies `pip3 install -r requirements.txt`
|
| 33 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
| 32 |
- Install python dependencies `pip3 install -r requirements.txt`
|
| 33 |
+
- Configure accelerate `accelerate launch` or update `~/.cache/huggingface/accelerate/default_config.yaml`
|
| 34 |
+
|
| 35 |
+
```yaml
|
| 36 |
+
compute_environment: LOCAL_MACHINE
|
| 37 |
+
distributed_type: MULTI_GPU
|
| 38 |
+
downcast_bf16: 'no'
|
| 39 |
+
gpu_ids: all
|
| 40 |
+
machine_rank: 0
|
| 41 |
+
main_training_function: main
|
| 42 |
+
mixed_precision: bf16
|
| 43 |
+
num_machines: 1
|
| 44 |
+
num_processes: 4
|
| 45 |
+
rdzv_backend: static
|
| 46 |
+
same_network: true
|
| 47 |
+
tpu_env: []
|
| 48 |
+
tpu_use_cluster: false
|
| 49 |
+
tpu_use_sudo: false
|
| 50 |
+
use_cpu: false
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
- Train! `accelerate launch scripts/finetune.py`, make sure to choose the correct YAML config file
|
scripts/finetune.py
CHANGED
|
@@ -68,26 +68,27 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|
| 68 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
| 69 |
replace_llama_attn_with_flash_attn()
|
| 70 |
|
|
|
|
| 71 |
try:
|
| 72 |
if "llama" in base_model:
|
| 73 |
model = LlamaForCausalLM.from_pretrained(
|
| 74 |
base_model,
|
| 75 |
load_in_8bit=cfg.load_in_8bit,
|
| 76 |
-
torch_dtype=
|
| 77 |
device_map=cfg.device_map,
|
| 78 |
)
|
| 79 |
else:
|
| 80 |
model = getattr(transformers, model_type).from_pretrained(
|
| 81 |
base_model,
|
| 82 |
load_in_8bit=cfg.load_in_8bit,
|
| 83 |
-
torch_dtype=
|
| 84 |
device_map=cfg.device_map,
|
| 85 |
)
|
| 86 |
except:
|
| 87 |
model = AutoModelForCausalLM.from_pretrained(
|
| 88 |
base_model,
|
| 89 |
load_in_8bit=cfg.load_in_8bit,
|
| 90 |
-
torch_dtype=
|
| 91 |
device_map=cfg.device_map,
|
| 92 |
)
|
| 93 |
|
|
@@ -235,7 +236,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 235 |
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
| 236 |
|
| 237 |
training_arguments_kwargs = {}
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
| 239 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
| 240 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
| 241 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
|
@@ -256,10 +260,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 256 |
group_by_length=cfg.group_by_length,
|
| 257 |
report_to="wandb" if cfg.use_wandb else None,
|
| 258 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
|
|
|
| 259 |
**training_arguments_kwargs,
|
| 260 |
)
|
| 261 |
|
| 262 |
-
trainer_kwargs = {}
|
| 263 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
| 264 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
| 265 |
optimizer_grouped_parameters = [
|
|
@@ -282,13 +286,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 282 |
lr=training_args.learning_rate,
|
| 283 |
)
|
| 284 |
|
|
|
|
| 285 |
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
| 286 |
adam_bnb_optim,
|
| 287 |
training_args.warmup_steps,
|
| 288 |
total_num_steps,
|
| 289 |
)
|
| 290 |
-
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
|
| 291 |
|
|
|
|
| 292 |
if cfg.early_stopping_patience:
|
| 293 |
early_stop_cb = EarlyStoppingCallback(
|
| 294 |
cfg.early_stopping_patience,
|
|
@@ -300,6 +305,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 300 |
train_dataset=train_dataset,
|
| 301 |
eval_dataset=eval_dataset,
|
| 302 |
args=training_args,
|
|
|
|
| 303 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
| 304 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
| 305 |
),
|
|
@@ -342,6 +348,12 @@ def train(
|
|
| 342 |
cfg.gradient_accumulation_steps // cfg.world_size
|
| 343 |
)
|
| 344 |
setup_wandb_env_vars(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
# Load the model and tokenizer
|
| 347 |
model, tokenizer, lora_config = load_model(
|
|
|
|
| 68 |
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
| 69 |
replace_llama_attn_with_flash_attn()
|
| 70 |
|
| 71 |
+
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
| 72 |
try:
|
| 73 |
if "llama" in base_model:
|
| 74 |
model = LlamaForCausalLM.from_pretrained(
|
| 75 |
base_model,
|
| 76 |
load_in_8bit=cfg.load_in_8bit,
|
| 77 |
+
torch_dtype=torch_dtype,
|
| 78 |
device_map=cfg.device_map,
|
| 79 |
)
|
| 80 |
else:
|
| 81 |
model = getattr(transformers, model_type).from_pretrained(
|
| 82 |
base_model,
|
| 83 |
load_in_8bit=cfg.load_in_8bit,
|
| 84 |
+
torch_dtype=torch_dtype,
|
| 85 |
device_map=cfg.device_map,
|
| 86 |
)
|
| 87 |
except:
|
| 88 |
model = AutoModelForCausalLM.from_pretrained(
|
| 89 |
base_model,
|
| 90 |
load_in_8bit=cfg.load_in_8bit,
|
| 91 |
+
torch_dtype=torch_dtype,
|
| 92 |
device_map=cfg.device_map,
|
| 93 |
)
|
| 94 |
|
|
|
|
| 236 |
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
| 237 |
|
| 238 |
training_arguments_kwargs = {}
|
| 239 |
+
if cfg.bf16 == "full":
|
| 240 |
+
training_arguments_kwargs["bf16_full_eval"] = True
|
| 241 |
+
else:
|
| 242 |
+
training_arguments_kwargs["bf16"] = cfg.bf16
|
| 243 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
| 244 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
| 245 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
|
|
|
| 260 |
group_by_length=cfg.group_by_length,
|
| 261 |
report_to="wandb" if cfg.use_wandb else None,
|
| 262 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
| 263 |
+
gradient_checkpointing=cfg.gradient_checkpointing,
|
| 264 |
**training_arguments_kwargs,
|
| 265 |
)
|
| 266 |
|
|
|
|
| 267 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
| 268 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
| 269 |
optimizer_grouped_parameters = [
|
|
|
|
| 286 |
lr=training_args.learning_rate,
|
| 287 |
)
|
| 288 |
|
| 289 |
+
# TODO optionally use torch.optim.OneCycleLR
|
| 290 |
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
| 291 |
adam_bnb_optim,
|
| 292 |
training_args.warmup_steps,
|
| 293 |
total_num_steps,
|
| 294 |
)
|
|
|
|
| 295 |
|
| 296 |
+
trainer_kwargs = {}
|
| 297 |
if cfg.early_stopping_patience:
|
| 298 |
early_stop_cb = EarlyStoppingCallback(
|
| 299 |
cfg.early_stopping_patience,
|
|
|
|
| 305 |
train_dataset=train_dataset,
|
| 306 |
eval_dataset=eval_dataset,
|
| 307 |
args=training_args,
|
| 308 |
+
optimizers=(adam_bnb_optim, lr_scheduler),
|
| 309 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
| 310 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
| 311 |
),
|
|
|
|
| 348 |
cfg.gradient_accumulation_steps // cfg.world_size
|
| 349 |
)
|
| 350 |
setup_wandb_env_vars(cfg)
|
| 351 |
+
if cfg.device == "mps":
|
| 352 |
+
cfg.load_in_8bit = False
|
| 353 |
+
cfg.tf32 = False
|
| 354 |
+
if cfg.bf16:
|
| 355 |
+
cfg.fp16 = True
|
| 356 |
+
cfg.bf16 = False
|
| 357 |
|
| 358 |
# Load the model and tokenizer
|
| 359 |
model, tokenizer, lora_config = load_model(
|
setup.cfg
CHANGED
|
@@ -28,3 +28,6 @@ install_requires =
|
|
| 28 |
[options.packages.find]
|
| 29 |
where = src
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
[options.packages.find]
|
| 29 |
where = src
|
| 30 |
|
| 31 |
+
[options.extras_require]
|
| 32 |
+
gptq_cuda = alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]
|
| 33 |
+
gptq_triton = alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]
|