Spaces:
Runtime error
Runtime error
Dit-document-layout-analysis
/
unilm
/edgelm
/fairseq
/optim
/lr_scheduler
/polynomial_decay_schedule.py
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass, field | |
| from typing import Optional, List | |
| from omegaconf import II | |
| from fairseq.dataclass import FairseqDataclass | |
| from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler | |
| class PolynomialDecayLRScheduleConfig(FairseqDataclass): | |
| warmup_updates: int = field( | |
| default=0, | |
| metadata={"help": "warmup the learning rate linearly for the first N updates"}, | |
| ) | |
| force_anneal: Optional[int] = field( | |
| default=None, | |
| metadata={"help": "force annealing at specified epoch"}, | |
| ) | |
| end_learning_rate: float = field( | |
| default=0.0, | |
| metadata={"help": "learning rate to decay to"}, | |
| ) | |
| power: float = field( | |
| default=1.0, | |
| metadata={"help": "decay exponent"}, | |
| ) | |
| total_num_update: float = field( | |
| default=II("optimization.max_update"), | |
| metadata={"help": "total number of updates over which to decay learning rate"}, | |
| ) | |
| lr: List[float] = II("optimization.lr") | |
| class PolynomialDecayLRSchedule(FairseqLRScheduler): | |
| """Decay the LR on a fixed schedule.""" | |
| def __init__(self, cfg: PolynomialDecayLRScheduleConfig, optimizer): | |
| super().__init__(cfg, optimizer) | |
| assert cfg.total_num_update > 0 | |
| self.lr = cfg.lr[0] | |
| if cfg.warmup_updates > 0: | |
| self.warmup_factor = 1.0 / cfg.warmup_updates | |
| else: | |
| self.warmup_factor = 1 | |
| self.end_learning_rate = cfg.end_learning_rate | |
| self.total_num_update = cfg.total_num_update | |
| self.power = cfg.power | |
| self.optimizer.set_lr(self.warmup_factor * self.lr) | |
| def get_next_lr(self, epoch): | |
| lrs = self.cfg.lr | |
| if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal: | |
| # use fixed LR schedule | |
| next_lr = lrs[min(epoch, len(lrs) - 1)] | |
| else: | |
| # annneal based on lr_shrink | |
| next_lr = self.optimizer.get_lr() | |
| return next_lr | |
| def step_begin_epoch(self, epoch): | |
| """Update the learning rate at the beginning of the given epoch.""" | |
| self.lr = self.get_next_lr(epoch) | |
| self.optimizer.set_lr(self.warmup_factor * self.lr) | |
| return self.optimizer.get_lr() | |
| def step_update(self, num_updates): | |
| """Update the learning rate after each update.""" | |
| if self.cfg.warmup_updates > 0 and num_updates <= self.cfg.warmup_updates: | |
| self.warmup_factor = num_updates / float(self.cfg.warmup_updates) | |
| lr = self.warmup_factor * self.lr | |
| elif num_updates >= self.total_num_update: | |
| lr = self.end_learning_rate | |
| else: | |
| warmup = self.cfg.warmup_updates | |
| lr_range = self.lr - self.end_learning_rate | |
| pct_remaining = 1 - (num_updates - warmup) / ( | |
| self.total_num_update - warmup | |
| ) | |
| lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate | |
| self.optimizer.set_lr(lr) | |
| return self.optimizer.get_lr() | |