add axolotl trainer and quadratic warmup
Browse files- src/axolotl/utils/schedulers.py +59 -1
- src/axolotl/utils/trainer.py +35 -3
src/axolotl/utils/schedulers.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
"""Module for custom LRScheduler class"""
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
from torch.optim
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class InterpolatingLogScheduler(LRScheduler):
|
|
@@ -42,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler):
|
|
| 42 |
lrs = [self.max_lr for base_lr in self.base_lrs]
|
| 43 |
|
| 44 |
return lrs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Module for custom LRScheduler class"""
|
| 2 |
+
import math
|
| 3 |
+
from functools import partial
|
| 4 |
|
| 5 |
+
from torch.optim import Optimizer
|
| 6 |
+
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
| 7 |
|
| 8 |
|
| 9 |
class InterpolatingLogScheduler(LRScheduler):
|
|
|
|
| 45 |
lrs = [self.max_lr for base_lr in self.base_lrs]
|
| 46 |
|
| 47 |
return lrs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
|
| 51 |
+
current_step: int,
|
| 52 |
+
*,
|
| 53 |
+
num_warmup_steps: int,
|
| 54 |
+
num_training_steps: int,
|
| 55 |
+
num_cycles: float
|
| 56 |
+
):
|
| 57 |
+
if current_step < num_warmup_steps:
|
| 58 |
+
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
|
| 59 |
+
progress = float(current_step - num_warmup_steps) / float(
|
| 60 |
+
max(1, num_training_steps - num_warmup_steps)
|
| 61 |
+
)
|
| 62 |
+
return max(
|
| 63 |
+
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_cosine_schedule_with_quadratic_warmup(
|
| 68 |
+
optimizer: Optimizer,
|
| 69 |
+
num_warmup_steps: int,
|
| 70 |
+
num_training_steps: int,
|
| 71 |
+
num_cycles: float = 0.5,
|
| 72 |
+
last_epoch: int = -1,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 76 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
| 77 |
+
initial lr set in the optimizer.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 81 |
+
The optimizer for which to schedule the learning rate.
|
| 82 |
+
num_warmup_steps (`int`):
|
| 83 |
+
The number of steps for the warmup phase.
|
| 84 |
+
num_training_steps (`int`):
|
| 85 |
+
The total number of training steps.
|
| 86 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
| 87 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
| 88 |
+
following a half-cosine).
|
| 89 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 90 |
+
The index of the last epoch when resuming training.
|
| 91 |
+
|
| 92 |
+
Return:
|
| 93 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
lr_lambda = partial(
|
| 97 |
+
_get_cosine_schedule_with_quadratic_warmup_lr_lambda,
|
| 98 |
+
num_warmup_steps=num_warmup_steps,
|
| 99 |
+
num_training_steps=num_training_steps,
|
| 100 |
+
num_cycles=num_cycles,
|
| 101 |
+
)
|
| 102 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -17,10 +17,42 @@ from transformers import EarlyStoppingCallback, Trainer
|
|
| 17 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 18 |
|
| 19 |
from axolotl.utils.callbacks import SavePeftModelCallback
|
| 20 |
-
from axolotl.utils.schedulers import
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"""
|
| 25 |
Trainer subclass that uses the OneCycleLR scheduler
|
| 26 |
"""
|
|
@@ -259,7 +291,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 259 |
trainer_cls = (
|
| 260 |
OneCycleLRSchedulerTrainer
|
| 261 |
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
| 262 |
-
else
|
| 263 |
)
|
| 264 |
trainer = trainer_cls(
|
| 265 |
model=model,
|
|
|
|
| 17 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 18 |
|
| 19 |
from axolotl.utils.callbacks import SavePeftModelCallback
|
| 20 |
+
from axolotl.utils.schedulers import (
|
| 21 |
+
InterpolatingLogScheduler,
|
| 22 |
+
get_cosine_schedule_with_quadratic_warmup,
|
| 23 |
+
)
|
| 24 |
|
| 25 |
|
| 26 |
+
class AxolotlTrainer(Trainer):
|
| 27 |
+
"""
|
| 28 |
+
Extend the base Trainer for axolotl helpers
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def create_scheduler(
|
| 32 |
+
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
| 36 |
+
passed as an argument.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
num_training_steps (int): The number of training steps to do.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
if self.lr_scheduler is None: # pylint: disable=access-member-before-definition
|
| 43 |
+
"""# type: ignore"""
|
| 44 |
+
if self.args.lr_scheduler_type == "cosine_with_quadratic":
|
| 45 |
+
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
| 46 |
+
optimizer,
|
| 47 |
+
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
| 48 |
+
num_training_steps=num_training_steps,
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
return super().create_scheduler(num_training_steps, optimizer)
|
| 52 |
+
return self.lr_scheduler
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
| 56 |
"""
|
| 57 |
Trainer subclass that uses the OneCycleLR scheduler
|
| 58 |
"""
|
|
|
|
| 291 |
trainer_cls = (
|
| 292 |
OneCycleLRSchedulerTrainer
|
| 293 |
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
| 294 |
+
else AxolotlTrainer
|
| 295 |
)
|
| 296 |
trainer = trainer_cls(
|
| 297 |
model=model,
|