import torch from random import randrange import torch.nn.functional as F def noise_regularization( e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls ): for _outer in range(num_reg_steps): if lambda_kl > 0: _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True) l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal) l_kld.backward() _grad = _var.grad.detach() _grad = torch.clip(_grad, -100, 100) e_t = e_t - lambda_kl * _grad if lambda_ac > 0: for _inner in range(num_ac_rolls): _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True) l_ac = auto_corr_loss(_var) l_ac.backward() _grad = _var.grad.detach() / num_ac_rolls e_t = e_t - lambda_ac * _grad e_t = e_t.detach() return e_t def auto_corr_loss(x, random_shift=True): B, C, H, W = x.shape assert B == 1 x = x.squeeze(0) # x must be shape [C,H,W] now reg_loss = 0.0 for ch_idx in range(x.shape[0]): noise = x[ch_idx][None, None, :, :] while True: if random_shift: roll_amount = randrange(noise.shape[2] // 2) else: roll_amount = 1 reg_loss += ( noise * torch.roll(noise, shifts=roll_amount, dims=2) ).mean() ** 2 reg_loss += ( noise * torch.roll(noise, shifts=roll_amount, dims=3) ).mean() ** 2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) return reg_loss def patchify_latents_kl_divergence(x0, x1, patch_size=4, num_channels=4): def patchify_tensor(input_tensor): patches = ( input_tensor.unfold(1, patch_size, patch_size) .unfold(2, patch_size, patch_size) .unfold(3, patch_size, patch_size) ) patches = patches.contiguous().view(-1, num_channels, patch_size, patch_size) return patches x0 = patchify_tensor(x0) x1 = patchify_tensor(x1) kl = latents_kl_divergence(x0, x1).sum() return kl def latents_kl_divergence(x0, x1): EPSILON = 1e-6 x0 = x0.view(x0.shape[0], x0.shape[1], -1) x1 = x1.view(x1.shape[0], x1.shape[1], -1) mu0 = x0.mean(dim=-1) mu1 = x1.mean(dim=-1) var0 = x0.var(dim=-1) var1 = x1.var(dim=-1) kl = ( torch.log((var1 + EPSILON) / (var0 + EPSILON)) + (var0 + (mu0 - mu1) ** 2) / (var1 + EPSILON) - 1 ) kl = torch.abs(kl).sum(dim=-1) return kl