Lint trainer.py
Browse files- src/axolotl/utils/trainer.py +20 -11
src/axolotl/utils/trainer.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import importlib
|
| 2 |
import math
|
| 3 |
import os
|
|
@@ -17,12 +19,19 @@ from axolotl.utils.callbacks import SavePeftModelCallback
|
|
| 17 |
|
| 18 |
|
| 19 |
class OneCycleLRSchedulerTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def create_scheduler(
|
| 21 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
| 22 |
):
|
| 23 |
optimizer = self.optimizer if optimizer is None else optimizer
|
| 24 |
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
| 25 |
-
num_training_steps = num_training_steps
|
| 26 |
pct_start = num_warmup_steps / num_training_steps
|
| 27 |
|
| 28 |
self.lr_scheduler = OneCycleLR(
|
|
@@ -58,11 +67,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 58 |
training_arguments_kwargs["bf16_full_eval"] = True
|
| 59 |
else:
|
| 60 |
training_arguments_kwargs["bf16"] = cfg.bf16
|
| 61 |
-
training_arguments_kwargs["fp16"] =
|
| 62 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
| 63 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
| 64 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
| 65 |
-
if cfg.gradient_checkpointing
|
| 66 |
if cfg.gptq:
|
| 67 |
from alpaca_lora_4bit.gradient_checkpointing import (
|
| 68 |
apply_gradient_checkpointing,
|
|
@@ -112,13 +121,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 112 |
save_steps=save_steps,
|
| 113 |
output_dir=cfg.output_dir,
|
| 114 |
save_total_limit=3,
|
| 115 |
-
load_best_model_at_end=
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
| 123 |
group_by_length=cfg.group_by_length,
|
| 124 |
report_to="wandb" if cfg.use_wandb else None,
|
|
@@ -140,7 +149,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 140 |
if (
|
| 141 |
cfg.optimizer == "adamw_bnb_8bit"
|
| 142 |
and not cfg.gptq
|
| 143 |
-
and
|
| 144 |
and not cfg.fsdp
|
| 145 |
):
|
| 146 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
|
|
|
| 1 |
+
"""Module containing the Trainer class and related functions"""
|
| 2 |
+
|
| 3 |
import importlib
|
| 4 |
import math
|
| 5 |
import os
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class OneCycleLRSchedulerTrainer(Trainer):
|
| 22 |
+
"""
|
| 23 |
+
Trainer subclass that uses the OneCycleLR scheduler
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, *args, **kwargs):
|
| 27 |
+
super().__init__(*args, **kwargs)
|
| 28 |
+
self.lr_scheduler = None
|
| 29 |
+
|
| 30 |
def create_scheduler(
|
| 31 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
| 32 |
):
|
| 33 |
optimizer = self.optimizer if optimizer is None else optimizer
|
| 34 |
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
|
|
| 35 |
pct_start = num_warmup_steps / num_training_steps
|
| 36 |
|
| 37 |
self.lr_scheduler = OneCycleLR(
|
|
|
|
| 67 |
training_arguments_kwargs["bf16_full_eval"] = True
|
| 68 |
else:
|
| 69 |
training_arguments_kwargs["bf16"] = cfg.bf16
|
| 70 |
+
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
|
| 71 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
| 72 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
| 73 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
| 74 |
+
if cfg.gradient_checkpointing:
|
| 75 |
if cfg.gptq:
|
| 76 |
from alpaca_lora_4bit.gradient_checkpointing import (
|
| 77 |
apply_gradient_checkpointing,
|
|
|
|
| 121 |
save_steps=save_steps,
|
| 122 |
output_dir=cfg.output_dir,
|
| 123 |
save_total_limit=3,
|
| 124 |
+
load_best_model_at_end=(
|
| 125 |
+
cfg.val_set_size > 0
|
| 126 |
+
and save_steps
|
| 127 |
+
and save_steps % eval_steps == 0
|
| 128 |
+
and cfg.load_in_8bit is not True
|
| 129 |
+
)
|
| 130 |
+
or False,
|
| 131 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
| 132 |
group_by_length=cfg.group_by_length,
|
| 133 |
report_to="wandb" if cfg.use_wandb else None,
|
|
|
|
| 149 |
if (
|
| 150 |
cfg.optimizer == "adamw_bnb_8bit"
|
| 151 |
and not cfg.gptq
|
| 152 |
+
and "deepspeed" not in training_arguments_kwargs
|
| 153 |
and not cfg.fsdp
|
| 154 |
):
|
| 155 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|