File size: 1,219 Bytes
cefbfc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import math, torch
from torch.optim.optimizer import Optimizer
class SurgeAdafactor(Optimizer):
    def __init__(self, params, lr=1e-4, beta2=0.95, weight_decay=0.,
                 min_lr=0., surge_amp=0.5, surge_period=10_000, eps=1e-30):
        defaults=dict(lr=lr,beta2=beta2,weight_decay=weight_decay,
                      min_lr=min_lr,surge_amp=surge_amp,
                      surge_period=surge_period,step=0,eps=eps)
        super().__init__(params,defaults)
    @torch.no_grad()
    def step(self,closure=None):
        if closure: closure()
        for g in self.param_groups:
            g['step']+=1; t=g['step']
            cyc=1-g['surge_amp']*(1+math.cos(math.pi*(t%g['surge_period'])/g['surge_period']))/2
            lr=max(g['min_lr'],g['lr']*cyc)
            for p in g['params']:
                if p.grad is None: continue
                grad=p.grad.float()
                st=self.state.setdefault(p,{'exp_avg_sq':torch.zeros_like(p)})
                st['exp_avg_sq'].mul_(g['beta2']).addcmul_(grad,grad,value=1-g['beta2'])
                upd=grad/(st['exp_avg_sq']+g['eps']).sqrt()
                if g['weight_decay']: p.mul_(1-lr*g['weight_decay'])
                p.add_(upd,alpha=-lr)