|  | import importlib | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import optim | 
					
						
						|  | import numpy as np | 
					
						
						|  |  | 
					
						
						|  | from inspect import isfunction | 
					
						
						|  | from PIL import Image, ImageDraw, ImageFont | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def log_txt_as_img(wh, xc, size=10): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | b = len(xc) | 
					
						
						|  | txts = list() | 
					
						
						|  | for bi in range(b): | 
					
						
						|  | txt = Image.new("RGB", wh, color="white") | 
					
						
						|  | draw = ImageDraw.Draw(txt) | 
					
						
						|  | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) | 
					
						
						|  | nc = int(40 * (wh[0] / 256)) | 
					
						
						|  | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | draw.text((0, 0), lines, fill="black", font=font) | 
					
						
						|  | except UnicodeEncodeError: | 
					
						
						|  | print("Cant encode string for logging. Skipping.") | 
					
						
						|  |  | 
					
						
						|  | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 | 
					
						
						|  | txts.append(txt) | 
					
						
						|  | txts = np.stack(txts) | 
					
						
						|  | txts = torch.tensor(txts) | 
					
						
						|  | return txts | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def ismap(x): | 
					
						
						|  | if not isinstance(x, torch.Tensor): | 
					
						
						|  | return False | 
					
						
						|  | return (len(x.shape) == 4) and (x.shape[1] > 3) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def isimage(x): | 
					
						
						|  | if not isinstance(x,torch.Tensor): | 
					
						
						|  | return False | 
					
						
						|  | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def exists(x): | 
					
						
						|  | return x is not None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def default(val, d): | 
					
						
						|  | if exists(val): | 
					
						
						|  | return val | 
					
						
						|  | return d() if isfunction(d) else d | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def mean_flat(tensor): | 
					
						
						|  | """ | 
					
						
						|  | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 | 
					
						
						|  | Take the mean over all non-batch dimensions. | 
					
						
						|  | """ | 
					
						
						|  | return tensor.mean(dim=list(range(1, len(tensor.shape)))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def count_params(model, verbose=False): | 
					
						
						|  | total_params = sum(p.numel() for p in model.parameters()) | 
					
						
						|  | if verbose: | 
					
						
						|  | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") | 
					
						
						|  | return total_params | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_from_config(config): | 
					
						
						|  | if not "target" in config: | 
					
						
						|  | if config == '__is_first_stage__': | 
					
						
						|  | return None | 
					
						
						|  | elif config == "__is_unconditional__": | 
					
						
						|  | return None | 
					
						
						|  | raise KeyError("Expected key `target` to instantiate.") | 
					
						
						|  | return get_obj_from_str(config["target"])(**config.get("params", dict())) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_obj_from_str(string, reload=False): | 
					
						
						|  | module, cls = string.rsplit(".", 1) | 
					
						
						|  | if reload: | 
					
						
						|  | module_imp = importlib.import_module(module) | 
					
						
						|  | importlib.reload(module_imp) | 
					
						
						|  | return getattr(importlib.import_module(module, package=None), cls) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AdamWwithEMAandWings(optim.Optimizer): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, | 
					
						
						|  | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, | 
					
						
						|  | ema_power=1., param_names=()): | 
					
						
						|  | """AdamW that saves EMA versions of the parameters.""" | 
					
						
						|  | if not 0.0 <= lr: | 
					
						
						|  | raise ValueError("Invalid learning rate: {}".format(lr)) | 
					
						
						|  | if not 0.0 <= eps: | 
					
						
						|  | raise ValueError("Invalid epsilon value: {}".format(eps)) | 
					
						
						|  | if not 0.0 <= betas[0] < 1.0: | 
					
						
						|  | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | 
					
						
						|  | if not 0.0 <= betas[1] < 1.0: | 
					
						
						|  | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | 
					
						
						|  | if not 0.0 <= weight_decay: | 
					
						
						|  | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | 
					
						
						|  | if not 0.0 <= ema_decay <= 1.0: | 
					
						
						|  | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) | 
					
						
						|  | defaults = dict(lr=lr, betas=betas, eps=eps, | 
					
						
						|  | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, | 
					
						
						|  | ema_power=ema_power, param_names=param_names) | 
					
						
						|  | super().__init__(params, defaults) | 
					
						
						|  |  | 
					
						
						|  | def __setstate__(self, state): | 
					
						
						|  | super().__setstate__(state) | 
					
						
						|  | for group in self.param_groups: | 
					
						
						|  | group.setdefault('amsgrad', False) | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def step(self, closure=None): | 
					
						
						|  | """Performs a single optimization step. | 
					
						
						|  | Args: | 
					
						
						|  | closure (callable, optional): A closure that reevaluates the model | 
					
						
						|  | and returns the loss. | 
					
						
						|  | """ | 
					
						
						|  | loss = None | 
					
						
						|  | if closure is not None: | 
					
						
						|  | with torch.enable_grad(): | 
					
						
						|  | loss = closure() | 
					
						
						|  |  | 
					
						
						|  | for group in self.param_groups: | 
					
						
						|  | params_with_grad = [] | 
					
						
						|  | grads = [] | 
					
						
						|  | exp_avgs = [] | 
					
						
						|  | exp_avg_sqs = [] | 
					
						
						|  | ema_params_with_grad = [] | 
					
						
						|  | state_sums = [] | 
					
						
						|  | max_exp_avg_sqs = [] | 
					
						
						|  | state_steps = [] | 
					
						
						|  | amsgrad = group['amsgrad'] | 
					
						
						|  | beta1, beta2 = group['betas'] | 
					
						
						|  | ema_decay = group['ema_decay'] | 
					
						
						|  | ema_power = group['ema_power'] | 
					
						
						|  |  | 
					
						
						|  | for p in group['params']: | 
					
						
						|  | if p.grad is None: | 
					
						
						|  | continue | 
					
						
						|  | params_with_grad.append(p) | 
					
						
						|  | if p.grad.is_sparse: | 
					
						
						|  | raise RuntimeError('AdamW does not support sparse gradients') | 
					
						
						|  | grads.append(p.grad) | 
					
						
						|  |  | 
					
						
						|  | state = self.state[p] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(state) == 0: | 
					
						
						|  | state['step'] = 0 | 
					
						
						|  |  | 
					
						
						|  | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) | 
					
						
						|  |  | 
					
						
						|  | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | 
					
						
						|  | if amsgrad: | 
					
						
						|  |  | 
					
						
						|  | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) | 
					
						
						|  |  | 
					
						
						|  | state['param_exp_avg'] = p.detach().float().clone() | 
					
						
						|  |  | 
					
						
						|  | exp_avgs.append(state['exp_avg']) | 
					
						
						|  | exp_avg_sqs.append(state['exp_avg_sq']) | 
					
						
						|  | ema_params_with_grad.append(state['param_exp_avg']) | 
					
						
						|  |  | 
					
						
						|  | if amsgrad: | 
					
						
						|  | max_exp_avg_sqs.append(state['max_exp_avg_sq']) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | state['step'] += 1 | 
					
						
						|  |  | 
					
						
						|  | state_steps.append(state['step']) | 
					
						
						|  |  | 
					
						
						|  | optim._functional.adamw(params_with_grad, | 
					
						
						|  | grads, | 
					
						
						|  | exp_avgs, | 
					
						
						|  | exp_avg_sqs, | 
					
						
						|  | max_exp_avg_sqs, | 
					
						
						|  | state_steps, | 
					
						
						|  | amsgrad=amsgrad, | 
					
						
						|  | beta1=beta1, | 
					
						
						|  | beta2=beta2, | 
					
						
						|  | lr=group['lr'], | 
					
						
						|  | weight_decay=group['weight_decay'], | 
					
						
						|  | eps=group['eps'], | 
					
						
						|  | maximize=False) | 
					
						
						|  |  | 
					
						
						|  | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) | 
					
						
						|  | for param, ema_param in zip(params_with_grad, ema_params_with_grad): | 
					
						
						|  | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) | 
					
						
						|  |  | 
					
						
						|  | return loss |