SD15-Surge-V1 / ada_surge.py
AbstractPhil's picture
Create ada_surge.py
cefbfc4 verified
import math, torch
from torch.optim.optimizer import Optimizer
class SurgeAdafactor(Optimizer):
def __init__(self, params, lr=1e-4, beta2=0.95, weight_decay=0.,
min_lr=0., surge_amp=0.5, surge_period=10_000, eps=1e-30):
defaults=dict(lr=lr,beta2=beta2,weight_decay=weight_decay,
min_lr=min_lr,surge_amp=surge_amp,
surge_period=surge_period,step=0,eps=eps)
super().__init__(params,defaults)
@torch.no_grad()
def step(self,closure=None):
if closure: closure()
for g in self.param_groups:
g['step']+=1; t=g['step']
cyc=1-g['surge_amp']*(1+math.cos(math.pi*(t%g['surge_period'])/g['surge_period']))/2
lr=max(g['min_lr'],g['lr']*cyc)
for p in g['params']:
if p.grad is None: continue
grad=p.grad.float()
st=self.state.setdefault(p,{'exp_avg_sq':torch.zeros_like(p)})
st['exp_avg_sq'].mul_(g['beta2']).addcmul_(grad,grad,value=1-g['beta2'])
upd=grad/(st['exp_avg_sq']+g['eps']).sqrt()
if g['weight_decay']: p.mul_(1-lr*g['weight_decay'])
p.add_(upd,alpha=-lr)