|
|
import math, torch |
|
|
from torch.optim.optimizer import Optimizer |
|
|
class SurgeAdafactor(Optimizer): |
|
|
def __init__(self, params, lr=1e-4, beta2=0.95, weight_decay=0., |
|
|
min_lr=0., surge_amp=0.5, surge_period=10_000, eps=1e-30): |
|
|
defaults=dict(lr=lr,beta2=beta2,weight_decay=weight_decay, |
|
|
min_lr=min_lr,surge_amp=surge_amp, |
|
|
surge_period=surge_period,step=0,eps=eps) |
|
|
super().__init__(params,defaults) |
|
|
@torch.no_grad() |
|
|
def step(self,closure=None): |
|
|
if closure: closure() |
|
|
for g in self.param_groups: |
|
|
g['step']+=1; t=g['step'] |
|
|
cyc=1-g['surge_amp']*(1+math.cos(math.pi*(t%g['surge_period'])/g['surge_period']))/2 |
|
|
lr=max(g['min_lr'],g['lr']*cyc) |
|
|
for p in g['params']: |
|
|
if p.grad is None: continue |
|
|
grad=p.grad.float() |
|
|
st=self.state.setdefault(p,{'exp_avg_sq':torch.zeros_like(p)}) |
|
|
st['exp_avg_sq'].mul_(g['beta2']).addcmul_(grad,grad,value=1-g['beta2']) |
|
|
upd=grad/(st['exp_avg_sq']+g['eps']).sqrt() |
|
|
if g['weight_decay']: p.mul_(1-lr*g['weight_decay']) |
|
|
p.add_(upd,alpha=-lr) |
|
|
|