Spaces:
Runtime error
Runtime error
| import math | |
| from argparse import ( | |
| ArgumentParser, | |
| Namespace, | |
| ) | |
| from typing import ( | |
| Dict, | |
| Iterable, | |
| Optional, | |
| Tuple, | |
| ) | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torchvision.utils import make_grid | |
| from torchvision.transforms import Resize | |
| #from optim import get_optimizer_class, OPTIMIZER_MAP | |
| from losses.regularize_noise import NoiseRegularizer | |
| from optim import RAdam | |
| from utils.misc import ( | |
| iterable_to_str, | |
| optional_string, | |
| ) | |
| class OptimizerArguments: | |
| def add_arguments(parser: ArgumentParser): | |
| parser.add_argument('--coarse_min', type=int, default=32) | |
| parser.add_argument('--wplus_step', type=int, nargs="+", default=[250, 750], help="#step for optimizing w_plus") | |
| #parser.add_argument('--lr_rampup', type=float, default=0.05) | |
| #parser.add_argument('--lr_rampdown', type=float, default=0.25) | |
| parser.add_argument('--lr', type=float, default=0.1) | |
| parser.add_argument('--noise_strength', type=float, default=.0) | |
| parser.add_argument('--noise_ramp', type=float, default=0.75) | |
| #parser.add_argument('--optimize_noise', action="store_true") | |
| parser.add_argument('--camera_lr', type=float, default=0.01) | |
| parser.add_argument("--log_dir", default="log/projector", help="tensorboard log directory") | |
| parser.add_argument("--log_freq", type=int, default=10, help="log frequency") | |
| parser.add_argument("--log_visual_freq", type=int, default=50, help="log frequency") | |
| def to_string(args: Namespace) -> str: | |
| return ( | |
| f"lr{args.lr}_{args.camera_lr}-c{args.coarse_min}" | |
| + f"-wp({iterable_to_str(args.wplus_step)})" | |
| + optional_string(args.noise_strength, f"-n{args.noise_strength}") | |
| ) | |
| class LatentNoiser(nn.Module): | |
| def __init__( | |
| self, generator: torch.nn, | |
| noise_ramp: float = 0.75, noise_strength: float = 0.05, | |
| n_mean_latent: int = 10000 | |
| ): | |
| super().__init__() | |
| self.noise_ramp = noise_ramp | |
| self.noise_strength = noise_strength | |
| with torch.no_grad(): | |
| # TODO: get 512 from generator | |
| noise_sample = torch.randn(n_mean_latent, 512, device=generator.device) | |
| latent_out = generator.style(noise_sample) | |
| latent_mean = latent_out.mean(0) | |
| self.latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 | |
| def forward(self, latent: torch.Tensor, t: float) -> torch.Tensor: | |
| strength = self.latent_std * self.noise_strength * max(0, 1 - t / self.noise_ramp) ** 2 | |
| noise = torch.randn_like(latent) * strength | |
| return latent + noise | |
| class Optimizer: | |
| def optimize( | |
| cls, | |
| generator: torch.nn, | |
| criterion: torch.nn, | |
| degrade: torch.nn, | |
| target: torch.Tensor, # only used in writer since it's mostly baked in criterion | |
| latent_init: torch.Tensor, | |
| noise_init: torch.Tensor, | |
| args: Namespace, | |
| writer: Optional[SummaryWriter] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # do not optimize generator | |
| generator = generator.eval() | |
| target = target.detach() | |
| # prepare parameters | |
| noises = [] | |
| for n in noise_init: | |
| noise = n.detach().clone() | |
| noise.requires_grad = True | |
| noises.append(noise) | |
| def create_parameters(latent_coarse): | |
| parameters = [ | |
| {'params': [latent_coarse], 'lr': args.lr}, | |
| {'params': noises, 'lr': args.lr}, | |
| {'params': degrade.parameters(), 'lr': args.camera_lr}, | |
| ] | |
| return parameters | |
| device = target.device | |
| # start optimize | |
| total_steps = np.sum(args.wplus_step) | |
| max_coarse_size = (2 ** (len(args.wplus_step) - 1)) * args.coarse_min | |
| noiser = LatentNoiser(generator, noise_ramp=args.noise_ramp, noise_strength=args.noise_strength).to(device) | |
| latent = latent_init.detach().clone() | |
| for coarse_level, steps in enumerate(args.wplus_step): | |
| if criterion.weights["contextual"] > 0: | |
| with torch.no_grad(): | |
| # synthesize new sibling image using the current optimization results | |
| # FIXME: update rgbs sibling | |
| sibling, _, _ = generator([latent], input_is_latent=True, randomize_noise=True) | |
| criterion.update_sibling(sibling) | |
| coarse_size = (2 ** coarse_level) * args.coarse_min | |
| latent_coarse, latent_fine = cls.split_latent( | |
| latent, generator.get_latent_size(coarse_size)) | |
| parameters = create_parameters(latent_coarse) | |
| optimizer = RAdam(parameters) | |
| print(f"Optimizing {coarse_size}x{coarse_size}") | |
| pbar = tqdm(range(steps)) | |
| for si in pbar: | |
| latent = torch.cat((latent_coarse, latent_fine), dim=1) | |
| niters = si + np.sum(args.wplus_step[:coarse_level]) | |
| latent_noisy = noiser(latent, niters / total_steps) | |
| img_gen, _, rgbs = generator([latent_noisy], input_is_latent=True, noise=noises) | |
| # TODO: use coarse_size instead of args.coarse_size for rgb_level | |
| loss, losses = criterion(img_gen, degrade=degrade, noises=noises, rgbs=rgbs) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| NoiseRegularizer.normalize(noises) | |
| # log | |
| pbar.set_description("; ".join([f"{k}: {v.item(): .3e}" for k, v in losses.items()])) | |
| if writer is not None and niters % args.log_freq == 0: | |
| cls.log_losses(writer, niters, loss, losses, criterion.weights) | |
| cls.log_parameters(writer, niters, degrade.named_parameters()) | |
| if writer is not None and niters % args.log_visual_freq == 0: | |
| cls.log_visuals(writer, niters, img_gen, target, degraded=degrade(img_gen), rgbs=rgbs) | |
| latent = torch.cat((latent_coarse, latent_fine), dim=1).detach() | |
| return latent, noises | |
| def split_latent(latent: torch.Tensor, coarse_latent_size: int): | |
| latent_coarse = latent[:, :coarse_latent_size] | |
| latent_coarse.requires_grad = True | |
| latent_fine = latent[:, coarse_latent_size:] | |
| latent_fine.requires_grad = False | |
| return latent_coarse, latent_fine | |
| def log_losses( | |
| writer: SummaryWriter, | |
| niters: int, | |
| loss_total: torch.Tensor, | |
| losses: Dict[str, torch.Tensor], | |
| weights: Optional[Dict[str, torch.Tensor]] = None | |
| ): | |
| writer.add_scalar("loss", loss_total.item(), niters) | |
| for name, loss in losses.items(): | |
| writer.add_scalar(name, loss.item(), niters) | |
| if weights is not None: | |
| writer.add_scalar(f"weighted_{name}", weights[name] * loss.item(), niters) | |
| def log_parameters( | |
| writer: SummaryWriter, | |
| niters: int, | |
| named_parameters: Iterable[Tuple[str, torch.nn.Parameter]], | |
| ): | |
| for name, para in named_parameters: | |
| writer.add_scalar(name, para.item(), niters) | |
| def log_visuals( | |
| cls, | |
| writer: SummaryWriter, | |
| niters: int, | |
| img: torch.Tensor, | |
| target: torch.Tensor, | |
| degraded=None, | |
| rgbs=None, | |
| ): | |
| if target.shape[-1] != img.shape[-1]: | |
| visual = make_grid(img, nrow=1, normalize=True, range=(-1, 1)) | |
| writer.add_image("pred", visual, niters) | |
| def resize(img): | |
| return F.interpolate(img, size=target.shape[2:], mode="area") | |
| vis = resize(img) | |
| if degraded is not None: | |
| vis = torch.cat((resize(degraded), vis), dim=-1) | |
| visual = make_grid(torch.cat((target.repeat(1, vis.shape[1] // target.shape[1], 1, 1), vis), dim=-1), nrow=1, normalize=True, range=(-1, 1)) | |
| writer.add_image("gnd[-degraded]-pred", visual, niters) | |
| # log to rgbs | |
| if rgbs is not None: | |
| cls.log_torgbs(writer, niters, rgbs) | |
| def log_torgbs(writer: SummaryWriter, niters: int, rgbs: Iterable[torch.Tensor], prefix: str = ""): | |
| for ri, rgb in enumerate(rgbs): | |
| scale = 2 ** (-(len(rgbs) - ri)) | |
| visual = make_grid(torch.cat((rgb, rgb / scale), dim=-1), nrow=1, normalize=True, range=(-1, 1)) | |
| writer.add_image(f"{prefix}to_rbg_{2 ** (ri + 2)}", visual, niters) | |