|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import functools |
|
import math |
|
from typing import Any, Callable, Iterator |
|
|
|
from torch.distributed.checkpoint.stateful import Stateful |
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler |
|
|
|
from torchtitan.components.optimizer import OptimizersContainer |
|
from torchtitan.config_manager import JobConfig |
|
from torchtitan.tools.logging import logger |
|
|
|
__all__ = [ |
|
"LRSchedulersContainer", |
|
"build_lr_schedulers", |
|
] |
|
|
|
|
|
class LRSchedulersContainer(Stateful): |
|
"""Container for multiple learning rate schedulers. |
|
|
|
This class is used to wrap multiple LRSchedulers into a single object that can be |
|
used to reduce the complexity of the training loop. This mimics the behavior of |
|
``torch.optim.lr_scheduler.LRScheduler``. The design concept is the same as |
|
``OptimizersContainer``. This class currently only supports ``LambdaLR``. |
|
|
|
**Note** |
|
Users who want to customize the lr_scheduler behavior can inherit from this class and |
|
extend the functionality as needed. The following methods must follow the same |
|
signature as ``torch.optim.lr_scheduler.LRScheduler`` class: ``step()``, ``state_dict()``, |
|
``load_state_dict()``. |
|
|
|
**Limitations** |
|
This class assumes all the lr schedulers are the same. There is no easy way to support |
|
resharding for multiple different LRSchedulers because LRScheduler.state_dict() is not |
|
resharding friendly. Therefore, the limitation is used to allow TorchTitan to support |
|
lr scheduler resharding. |
|
|
|
Args: |
|
optimizers (OptimizersContainer): The corresponding optimizers for the lr_schedulers. |
|
""" |
|
|
|
schedulers: list[LRScheduler] |
|
|
|
def __init__(self, optimizers: OptimizersContainer, lr_lambda: Callable) -> None: |
|
assert ( |
|
len(optimizers) > 0 |
|
), "Must have at least one optimizer to create LRScheduler" |
|
|
|
self.schedulers = [LambdaLR(optimizer, lr_lambda) for optimizer in optimizers] |
|
|
|
def __iter__(self) -> Iterator[LRScheduler]: |
|
return iter(self.schedulers) |
|
|
|
def __len__(self) -> int: |
|
return len(self.schedulers) |
|
|
|
def step(self) -> None: |
|
for scheduler in self.schedulers: |
|
scheduler.step() |
|
|
|
def state_dict(self) -> dict[str, Any]: |
|
|
|
|
|
|
|
return self.schedulers[0].state_dict() |
|
|
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
|
|
|
|
|
|
|
|
|
|
|
for scheduler in self.schedulers: |
|
scheduler.load_state_dict(copy.deepcopy(state_dict)) |
|
|
|
|
|
def build_lr_schedulers( |
|
optimizers: OptimizersContainer, job_config: JobConfig |
|
) -> LRSchedulersContainer: |
|
"""Create a LRSchedulerContainer for the given optimizers and job config. |
|
|
|
This function creates a ``LRSchedulersContainer`` for the given optimizers. |
|
``job_config`` should define the correct lr scheduler parameters. |
|
|
|
**Note** |
|
Users who want to customize the lr scheduler behavior can create their own |
|
``LRSchedulersContainer`` subclass and ``build_lr_scheduler``. Passing the |
|
customized ``build_lr_schedulers`` to ``TrainSpec`` will create the customized |
|
``LRSchedulersContainer``. |
|
|
|
|
|
Args: |
|
optimizers (OptimizersContainer): The corresponding optimizers for the |
|
lr_schedulers. |
|
""" |
|
training_steps = job_config.training.steps |
|
warmup_steps = int(job_config.lr_scheduler.warmup_steps) |
|
if job_config.lr_scheduler.decay_ratio is not None: |
|
decay_steps = round(training_steps * job_config.lr_scheduler.decay_ratio) |
|
if warmup_steps + decay_steps > training_steps: |
|
logger.warning( |
|
f"Warmup ({warmup_steps}) + decay ({decay_steps}) steps exceed " |
|
f"total training steps ({training_steps}). " |
|
f"Adjusting decay steps to {training_steps - warmup_steps}." |
|
) |
|
decay_steps = training_steps - warmup_steps |
|
else: |
|
decay_steps = training_steps - warmup_steps |
|
stable_steps = training_steps - warmup_steps - decay_steps |
|
lr_decay_type = job_config.lr_scheduler.decay_type |
|
lr_min = job_config.lr_scheduler.lr_min |
|
|
|
def linear_warmup_stable_decay( |
|
current_step: int, |
|
warmup_steps: int, |
|
stable_steps: int, |
|
decay_steps: int, |
|
lr_decay_type: str, |
|
lr_min: float, |
|
): |
|
""" |
|
Computes linear warmup followed by stable learning rate for a while, |
|
then some type of decay. |
|
|
|
Per LambdaLR requirement, this is accomplished by returning |
|
a multiplicative factor `curr_adjustment` ranging from 1 to 0 |
|
to adjust the learning rate to create the desired schedule. |
|
|
|
We offer three types of learning rate decay schedules: |
|
1. `linear`: decays linearly from 1 to 0 over the decay period. |
|
2. `sqrt`: decays as 1 minus the square root of the decay progress. |
|
3. `cosine`: follows a cosine curve, decaying according to the values of the half-period of the cosine function. |
|
|
|
If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` |
|
to ensure the learning rate does not drop below this minimum value. |
|
""" |
|
warmup_stable_steps = warmup_steps + stable_steps |
|
if current_step < warmup_steps: |
|
|
|
|
|
current_step += 1 |
|
curr_adjustment = float(current_step / (warmup_steps + 1)) |
|
elif current_step < warmup_stable_steps: |
|
curr_adjustment = 1.0 |
|
else: |
|
|
|
current_step += 1 |
|
progress = float(current_step - warmup_stable_steps) / (decay_steps + 1) |
|
|
|
if lr_decay_type == "linear": |
|
curr_adjustment = 1 - progress |
|
elif lr_decay_type == "sqrt": |
|
curr_adjustment = 1 - math.sqrt(progress) |
|
elif lr_decay_type == "cosine": |
|
curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment |
|
return curr_adjustment |
|
|
|
lr_lambda = functools.partial( |
|
linear_warmup_stable_decay, |
|
warmup_steps=warmup_steps, |
|
stable_steps=stable_steps, |
|
decay_steps=decay_steps, |
|
lr_decay_type=lr_decay_type, |
|
lr_min=lr_min, |
|
) |
|
return LRSchedulersContainer(optimizers, lr_lambda) |
|
|