Cosine learning rate schedule - minimum learning rate (#1062)
Browse files* Cosine min lr
* Cosine min lr - warn if using deepspeed
* cosine_min_lr_ratio readme
* chore: lint
---------
Co-authored-by: Wing Lian <[email protected]>
- README.md +1 -0
- src/axolotl/core/trainer_builder.py +20 -1
- src/axolotl/utils/schedulers.py +40 -0
README.md
CHANGED
|
@@ -755,6 +755,7 @@ early_stopping_patience: 3
|
|
| 755 |
# Specify a scheduler and kwargs to use with the optimizer
|
| 756 |
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
| 757 |
lr_scheduler_kwargs:
|
|
|
|
| 758 |
|
| 759 |
# For one_cycle optim
|
| 760 |
lr_div_factor: # Learning rate div factor
|
|
|
|
| 755 |
# Specify a scheduler and kwargs to use with the optimizer
|
| 756 |
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
| 757 |
lr_scheduler_kwargs:
|
| 758 |
+
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
| 759 |
|
| 760 |
# For one_cycle optim
|
| 761 |
lr_div_factor: # Learning rate div factor
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -38,7 +38,10 @@ from axolotl.utils.collators import (
|
|
| 38 |
MambaDataCollator,
|
| 39 |
)
|
| 40 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 41 |
-
from axolotl.utils.schedulers import
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
try:
|
| 44 |
import torch._dynamo # pylint: disable=ungrouped-imports
|
|
@@ -120,6 +123,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
| 120 |
default=None,
|
| 121 |
metadata={"help": "prefetch_factor argument to the dataloader"},
|
| 122 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
class AxolotlTrainer(Trainer):
|
|
@@ -159,6 +166,17 @@ class AxolotlTrainer(Trainer):
|
|
| 159 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
| 160 |
num_training_steps=num_training_steps,
|
| 161 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
else:
|
| 163 |
return super().create_scheduler(num_training_steps, optimizer)
|
| 164 |
return self.lr_scheduler
|
|
@@ -745,6 +763,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 745 |
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
| 746 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
| 747 |
)
|
|
|
|
| 748 |
training_arguments_kwargs["weight_decay"] = (
|
| 749 |
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
| 750 |
)
|
|
|
|
| 38 |
MambaDataCollator,
|
| 39 |
)
|
| 40 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 41 |
+
from axolotl.utils.schedulers import (
|
| 42 |
+
get_cosine_schedule_with_min_lr,
|
| 43 |
+
get_cosine_schedule_with_quadratic_warmup,
|
| 44 |
+
)
|
| 45 |
|
| 46 |
try:
|
| 47 |
import torch._dynamo # pylint: disable=ungrouped-imports
|
|
|
|
| 123 |
default=None,
|
| 124 |
metadata={"help": "prefetch_factor argument to the dataloader"},
|
| 125 |
)
|
| 126 |
+
cosine_min_lr_ratio: Optional[float] = field(
|
| 127 |
+
default=None,
|
| 128 |
+
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
| 129 |
+
)
|
| 130 |
|
| 131 |
|
| 132 |
class AxolotlTrainer(Trainer):
|
|
|
|
| 166 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
| 167 |
num_training_steps=num_training_steps,
|
| 168 |
)
|
| 169 |
+
elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None:
|
| 170 |
+
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
| 171 |
+
if self.args.deepspeed:
|
| 172 |
+
LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \
|
| 173 |
+
in the deepspeed JSON")
|
| 174 |
+
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
| 175 |
+
optimizer,
|
| 176 |
+
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
| 177 |
+
num_training_steps=num_training_steps,
|
| 178 |
+
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
| 179 |
+
)
|
| 180 |
else:
|
| 181 |
return super().create_scheduler(num_training_steps, optimizer)
|
| 182 |
return self.lr_scheduler
|
|
|
|
| 763 |
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
| 764 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
| 765 |
)
|
| 766 |
+
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
| 767 |
training_arguments_kwargs["weight_decay"] = (
|
| 768 |
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
| 769 |
)
|
src/axolotl/utils/schedulers.py
CHANGED
|
@@ -100,3 +100,43 @@ def get_cosine_schedule_with_quadratic_warmup(
|
|
| 100 |
num_cycles=num_cycles,
|
| 101 |
)
|
| 102 |
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
num_cycles=num_cycles,
|
| 101 |
)
|
| 102 |
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _get_cosine_schedule_with_min_lr_lambda(
|
| 106 |
+
current_step: int,
|
| 107 |
+
*,
|
| 108 |
+
num_warmup_steps: int,
|
| 109 |
+
num_training_steps: int,
|
| 110 |
+
min_lr_ratio: float
|
| 111 |
+
):
|
| 112 |
+
# Warm up
|
| 113 |
+
if current_step < num_warmup_steps:
|
| 114 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 115 |
+
|
| 116 |
+
# Cosine learning rate decay
|
| 117 |
+
progress = float(current_step - num_warmup_steps) / float(
|
| 118 |
+
max(1, num_training_steps - num_warmup_steps)
|
| 119 |
+
)
|
| 120 |
+
scaling = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 121 |
+
return (1 - min_lr_ratio) * scaling + min_lr_ratio
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_cosine_schedule_with_min_lr(
|
| 125 |
+
optimizer: Optimizer,
|
| 126 |
+
num_warmup_steps: int,
|
| 127 |
+
num_training_steps: int,
|
| 128 |
+
min_lr_ratio: float = 0.0,
|
| 129 |
+
):
|
| 130 |
+
"""
|
| 131 |
+
Create a learning rate schedule which has:
|
| 132 |
+
- linear warmup from 0 -> `max_lr` over `num_warmup_steps`
|
| 133 |
+
- cosine learning rate annealing from `max_lr` -> `min_lr` over `num_training_steps`
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
lr_lambda = partial(
|
| 137 |
+
_get_cosine_schedule_with_min_lr_lambda,
|
| 138 |
+
num_warmup_steps=num_warmup_steps,
|
| 139 |
+
num_training_steps=num_training_steps,
|
| 140 |
+
min_lr_ratio=min_lr_ratio,
|
| 141 |
+
)
|
| 142 |
+
return LambdaLR(optimizer, lr_lambda)
|