Spaces:
Sleeping
Sleeping
| from abc import abstractmethod | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| from ...modules.diffusionmodules.util import make_beta_schedule | |
| from ...util import append_zero | |
| def generate_roughly_equally_spaced_steps( | |
| num_substeps: int, max_step: int | |
| ) -> np.ndarray: | |
| return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] | |
| class Discretization: | |
| def __call__(self, n, do_append_zero=True, device="cpu", flip=False): | |
| sigmas = self.get_sigmas(n, device=device) | |
| sigmas = append_zero(sigmas) if do_append_zero else sigmas | |
| return sigmas if not flip else torch.flip(sigmas, (0,)) | |
| def get_sigmas(self, n, device): | |
| pass | |
| class EDMDiscretization(Discretization): | |
| def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0): | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.rho = rho | |
| def get_sigmas(self, n, device="cpu"): | |
| ramp = torch.linspace(0, 1, n, device=device) | |
| min_inv_rho = self.sigma_min ** (1 / self.rho) | |
| max_inv_rho = self.sigma_max ** (1 / self.rho) | |
| sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho | |
| return sigmas | |
| class LegacyDDPMDiscretization(Discretization): | |
| def __init__( | |
| self, | |
| linear_start=0.00085, | |
| linear_end=0.0120, | |
| num_timesteps=1000, | |
| ): | |
| super().__init__() | |
| self.num_timesteps = num_timesteps | |
| betas = make_beta_schedule( | |
| "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end | |
| ) | |
| alphas = 1.0 - betas | |
| self.alphas_cumprod = np.cumprod(alphas, axis=0) | |
| self.to_torch = partial(torch.tensor, dtype=torch.float32) | |
| def get_sigmas(self, n, device="cpu"): | |
| if n < self.num_timesteps: | |
| timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) | |
| alphas_cumprod = self.alphas_cumprod[timesteps] | |
| elif n == self.num_timesteps: | |
| alphas_cumprod = self.alphas_cumprod | |
| else: | |
| raise ValueError | |
| to_torch = partial(torch.tensor, dtype=torch.float32, device=device) | |
| sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 | |
| return torch.flip(sigmas, (0,)) | |