Spaces:
Configuration error
Configuration error
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import torch | |
| import torchsde | |
| from tqdm.auto import trange | |
| from video_to_video.utils.logger import get_logger | |
| logger = get_logger() | |
| def get_ancestral_step(sigma_from, sigma_to, eta=1.): | |
| """ | |
| Calculates the noise level (sigma_down) to step down to and the amount | |
| of noise to add (sigma_up) when doing an ancestral sampling step. | |
| """ | |
| if not eta: | |
| return sigma_to, 0. | |
| sigma_up = min( | |
| sigma_to, | |
| eta * ( | |
| sigma_to**2 * # noqa | |
| (sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5) | |
| sigma_down = (sigma_to**2 - sigma_up**2)**0.5 | |
| return sigma_down, sigma_up | |
| def get_scalings(sigma): | |
| c_out = -sigma | |
| c_in = 1 / (sigma**2 + 1.**2)**0.5 | |
| return c_out, c_in | |
| def sample_heun(noise, | |
| model, | |
| sigmas, | |
| s_churn=0., | |
| s_tmin=0., | |
| s_tmax=float('inf'), | |
| s_noise=1., | |
| show_progress=True): | |
| """ | |
| Implements Algorithm 2 (Heun steps) from Karras et al. (2022). | |
| """ | |
| x = noise * sigmas[0] | |
| for i in trange(len(sigmas) - 1, disable=not show_progress): | |
| gamma = 0. | |
| if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'): | |
| gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) | |
| eps = torch.randn_like(x) * s_noise | |
| sigma_hat = sigmas[i] * (gamma + 1) | |
| if gamma > 0: | |
| x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5 | |
| if sigmas[i] == float('inf'): | |
| # Euler method | |
| denoised = model(noise, sigma_hat) | |
| x = denoised + sigmas[i + 1] * (gamma + 1) * noise | |
| else: | |
| _, c_in = get_scalings(sigma_hat) | |
| denoised = model(x * c_in, sigma_hat) | |
| d = (x - denoised) / sigma_hat | |
| dt = sigmas[i + 1] - sigma_hat | |
| if sigmas[i + 1] == 0: | |
| # Euler method | |
| x = x + d * dt | |
| else: | |
| # Heun's method | |
| x_2 = x + d * dt | |
| _, c_in = get_scalings(sigmas[i + 1]) | |
| denoised_2 = model(x_2 * c_in, sigmas[i + 1]) | |
| d_2 = (x_2 - denoised_2) / sigmas[i + 1] | |
| d_prime = (d + d_2) / 2 | |
| x = x + d_prime * dt | |
| return x | |
| class BatchedBrownianTree: | |
| """ | |
| A wrapper around torchsde.BrownianTree that enables batches of entropy. | |
| """ | |
| def __init__(self, x, t0, t1, seed=None, **kwargs): | |
| t0, t1, self.sign = self.sort(t0, t1) | |
| w0 = kwargs.get('w0', torch.zeros_like(x)) | |
| if seed is None: | |
| seed = torch.randint(0, 2**63 - 1, []).item() | |
| self.batched = True | |
| try: | |
| assert len(seed) == x.shape[0] | |
| w0 = w0[0] | |
| except TypeError: | |
| seed = [seed] | |
| self.batched = False | |
| self.trees = [ | |
| torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) | |
| for s in seed | |
| ] | |
| def sort(a, b): | |
| return (a, b, 1) if a < b else (b, a, -1) | |
| def __call__(self, t0, t1): | |
| t0, t1, sign = self.sort(t0, t1) | |
| w = torch.stack([tree(t0, t1) for tree in self.trees]) * ( | |
| self.sign * sign) | |
| return w if self.batched else w[0] | |
| class BrownianTreeNoiseSampler: | |
| """ | |
| A noise sampler backed by a torchsde.BrownianTree. | |
| Args: | |
| x (Tensor): The tensor whose shape, device and dtype to use to generate | |
| random samples. | |
| sigma_min (float): The low end of the valid interval. | |
| sigma_max (float): The high end of the valid interval. | |
| seed (int or List[int]): The random seed. If a list of seeds is | |
| supplied instead of a single integer, then the noise sampler will | |
| use one BrownianTree per batch item, each with its own seed. | |
| transform (callable): A function that maps sigma to the sampler's | |
| internal timestep. | |
| """ | |
| def __init__(self, | |
| x, | |
| sigma_min, | |
| sigma_max, | |
| seed=None, | |
| transform=lambda x: x): | |
| self.transform = transform | |
| t0 = self.transform(torch.as_tensor(sigma_min)) | |
| t1 = self.transform(torch.as_tensor(sigma_max)) | |
| self.tree = BatchedBrownianTree(x, t0, t1, seed) | |
| def __call__(self, sigma, sigma_next): | |
| t0 = self.transform(torch.as_tensor(sigma)) | |
| t1 = self.transform(torch.as_tensor(sigma_next)) | |
| return self.tree(t0, t1) / (t1 - t0).abs().sqrt() | |
| def sample_dpmpp_2m_sde(noise, | |
| model, | |
| sigmas, | |
| eta=1., | |
| s_noise=1., | |
| solver_type='midpoint', | |
| show_progress=True, | |
| variant_info=None): | |
| """ | |
| DPM-Solver++ (2M) SDE. | |
| """ | |
| assert solver_type in {'heun', 'midpoint'} | |
| x = noise * sigmas[0] | |
| sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[ | |
| sigmas < float('inf')].max() | |
| noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) | |
| old_denoised = None | |
| h_last = None | |
| for i in trange(len(sigmas) - 1, disable=not show_progress): | |
| logger.info(f'step: {i}') | |
| if sigmas[i] == float('inf'): | |
| # Euler method | |
| denoised = model(noise, sigmas[i], variant_info=variant_info) | |
| x = denoised + sigmas[i + 1] * noise | |
| else: | |
| _, c_in = get_scalings(sigmas[i]) | |
| denoised = model(x * c_in, sigmas[i], variant_info=variant_info) | |
| if sigmas[i + 1] == 0: | |
| # Denoising step | |
| x = denoised | |
| else: | |
| # DPM-Solver++(2M) SDE | |
| t, s = -sigmas[i].log(), -sigmas[i + 1].log() | |
| h = s - t | |
| eta_h = eta * h | |
| x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \ | |
| (-h - eta_h).expm1().neg() * denoised | |
| if old_denoised is not None: | |
| r = h_last / h | |
| if solver_type == 'heun': | |
| x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \ | |
| (1 / r) * (denoised - old_denoised) | |
| elif solver_type == 'midpoint': | |
| x = x + 0.5 * (-h - eta_h).expm1().neg() * \ | |
| (1 / r) * (denoised - old_denoised) | |
| x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[ | |
| i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise | |
| old_denoised = denoised | |
| h_last = h | |
| if variant_info is not None and variant_info.get('type') == 'variant1': | |
| x_long, x_short = x.chunk(2, dim=0) | |
| x = x_long * (1-variant_info['alpha']) + x_short * variant_info['alpha'] | |
| return x |