Respair commited on
Commit
e3b8436
·
verified ·
1 Parent(s): 2584aa9

Create optimizers.py

Browse files
Files changed (1) hide show
  1. optimizers.py +86 -0
optimizers.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+ import os, sys
3
+ import os.path as osp
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from torch.optim import Optimizer
8
+ from functools import reduce
9
+ from torch.optim import AdamW
10
+
11
+ class MultiOptimizer:
12
+ def __init__(self, optimizers={}, schedulers={}):
13
+ self.optimizers = optimizers
14
+ self.schedulers = schedulers
15
+ self.keys = list(optimizers.keys())
16
+ self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()])
17
+
18
+ def state_dict(self):
19
+ state_dicts = [(key, self.optimizers[key].state_dict())\
20
+ for key in self.keys]
21
+ return state_dicts
22
+
23
+ def load_state_dict(self, state_dict):
24
+ for key, val in state_dict:
25
+ try:
26
+ self.optimizers[key].load_state_dict(val)
27
+ except:
28
+ print("Unloaded %s" % key)
29
+
30
+
31
+ def step(self, key=None):
32
+ if key is not None:
33
+ self.optimizers[key].step()
34
+ else:
35
+ _ = [self.optimizers[key].step() for key in self.keys]
36
+
37
+ def zero_grad(self, key=None):
38
+ if key is not None:
39
+ self.optimizers[key].zero_grad()
40
+ else:
41
+ _ = [self.optimizers[key].zero_grad() for key in self.keys]
42
+
43
+ def scheduler(self, *args, key=None):
44
+ if key is not None:
45
+ self.schedulers[key].step(*args)
46
+ else:
47
+ _ = [self.schedulers[key].step(*args) for key in self.keys]
48
+
49
+
50
+ def build_optimizer(parameters):
51
+ optimizer, scheduler = _define_optimizer(parameters)
52
+ return optimizer, scheduler
53
+
54
+ def _define_optimizer(params):
55
+ optimizer_params = params['optimizer_params']
56
+ sch_params = params['scheduler_params']
57
+ optimizer = AdamW(
58
+ params['params'],
59
+ lr=optimizer_params.get('lr', 1e-4),
60
+ weight_decay=optimizer_params.get('weight_decay', 5e-4),
61
+ betas=(0.9, 0.98),
62
+ eps=1e-9)
63
+ scheduler = _define_scheduler(optimizer, sch_params)
64
+ return optimizer, scheduler
65
+
66
+ def _define_scheduler(optimizer, params):
67
+ print(params)
68
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
69
+ optimizer,
70
+ max_lr=params.get('max_lr', 5e-4),
71
+ epochs=params.get('epochs', 200),
72
+ steps_per_epoch=params.get('steps_per_epoch', 1000),
73
+ pct_start=params.get('pct_start', 0.0),
74
+ final_div_factor=5)
75
+
76
+ return scheduler
77
+
78
+ def build_multi_optimizer(parameters_dict, scheduler_params):
79
+ optim = dict([(key, AdamW(params, lr=1e-4, weight_decay=1e-6, betas=(0.9, 0.98), eps=1e-9))
80
+ for key, params in parameters_dict.items()])
81
+
82
+ schedulers = dict([(key, _define_scheduler(opt, scheduler_params)) \
83
+ for key, opt in optim.items()])
84
+
85
+ multi_optim = MultiOptimizer(optim, schedulers)
86
+ return multi_optim