Spaces:
Running
on
L4
Running
on
L4
import numpy as np | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from tqdm import tqdm | |
# from seva.geometry import get_camera_dist | |
from typing import Union | |
def get_camera_dist( | |
source_c2ws: torch.Tensor, # N x 3 x 4 | |
target_c2ws: torch.Tensor, # M x 3 x 4 | |
mode: str = "translation", | |
): | |
if mode == "rotation": | |
dists = torch.acos( | |
( | |
( | |
torch.matmul( | |
source_c2ws[:, None, :3, :3], | |
target_c2ws[None, :, :3, :3].transpose(-1, -2), | |
) | |
.diagonal(offset=0, dim1=-2, dim2=-1) | |
.sum(-1) | |
- 1 | |
) | |
/ 2 | |
).clamp(-1, 1) | |
) * (180 / torch.pi) | |
elif mode == "translation": | |
dists = torch.norm( | |
source_c2ws[:, None, :3, 3] - target_c2ws[None, :, :3, 3], dim=-1 | |
) | |
else: | |
raise NotImplementedError( | |
f"Mode {mode} is not implemented for finding nearest source indices." | |
) | |
return dists | |
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError( | |
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
) | |
return x[(...,) + (None,) * dims_to_append] | |
def append_zero(x: torch.Tensor) -> torch.Tensor: | |
return torch.cat([x, x.new_zeros([1])]) | |
def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor: | |
return (x - denoised) / append_dims(sigma, x.ndim) | |
def make_betas( | |
num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2 | |
) -> np.ndarray: | |
betas = ( | |
torch.linspace( | |
linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64 | |
) | |
** 2 | |
) | |
return betas.numpy() | |
def generate_roughly_equally_spaced_steps( | |
num_substeps: int, max_step: int | |
) -> np.ndarray: | |
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] | |
class EpsScaling(object): | |
def __call__( | |
self, sigma: torch.Tensor | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
c_skip = torch.ones_like(sigma, device=sigma.device) | |
c_out = -sigma | |
c_in = 1 / (sigma**2 + 1.0) ** 0.5 | |
c_noise = sigma.clone() | |
return c_skip, c_out, c_in, c_noise | |
class DDPMDiscretization(object): | |
def __init__( | |
self, | |
linear_start: float = 5e-06, | |
linear_end: float = 0.012, | |
num_timesteps: int = 1000, | |
log_snr_shift: Union[float, None] = 2.4, | |
): | |
self.num_timesteps = num_timesteps | |
betas = make_betas( | |
num_timesteps, | |
linear_start=linear_start, | |
linear_end=linear_end, | |
) | |
self.log_snr_shift = log_snr_shift | |
alphas = 1.0 - betas # first alpha here is on data side | |
self.alphas_cumprod = np.cumprod(alphas, axis=0) | |
def get_sigmas(self, n: int, device: Union[str, torch.device] = "cpu") -> torch.Tensor: | |
if n < self.num_timesteps: | |
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) | |
alphas_cumprod = self.alphas_cumprod[timesteps] | |
elif n == self.num_timesteps: | |
alphas_cumprod = self.alphas_cumprod | |
else: | |
raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.") | |
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 | |
if self.log_snr_shift is not None: | |
sigmas = sigmas * np.exp(self.log_snr_shift) | |
return torch.flip( | |
torch.tensor(sigmas, dtype=torch.float32, device=device), (0,) | |
) | |
def __call__( | |
self, | |
n: int, | |
do_append_zero: bool = True, | |
flip: bool = False, | |
device: Union[str, torch.device] = "cpu", | |
) -> torch.Tensor: | |
sigmas = self.get_sigmas(n, device=device) | |
sigmas = append_zero(sigmas) if do_append_zero else sigmas | |
return sigmas if not flip else torch.flip(sigmas, (0,)) | |
class DiscreteDenoiser(object): | |
sigmas: torch.Tensor | |
def __init__( | |
self, | |
discretization: DDPMDiscretization, | |
num_idx: int = 1000, | |
device: Union[str, torch.device] = "cpu", | |
): | |
self.scaling = EpsScaling() | |
self.discretization = discretization | |
self.num_idx = num_idx | |
self.device = device | |
self.register_sigmas() | |
def register_sigmas(self): | |
self.sigmas = self.discretization( | |
self.num_idx, do_append_zero=False, flip=True, device=self.device | |
) | |
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: | |
dists = sigma - self.sigmas[:, None] | |
return dists.abs().argmin(dim=0).view(sigma.shape) | |
def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: | |
return self.sigmas[idx] | |
def __call__( | |
self, | |
network: nn.Module, | |
input: torch.Tensor, | |
sigma: torch.Tensor, | |
cond: dict, | |
**additional_model_inputs, | |
) -> torch.Tensor: | |
sigma = self.idx_to_sigma(self.sigma_to_idx(sigma)) | |
sigma_shape = sigma.shape | |
sigma = append_dims(sigma, input.ndim) | |
c_skip, c_out, c_in, c_noise = self.scaling(sigma) | |
c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape)) | |
if "replace" in cond: | |
x, mask = cond.get("replace").split((input.shape[1], 1), dim=1) | |
input = input * (1 - mask) + x * mask | |
return ( | |
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out | |
+ input * c_skip | |
) | |
class ConstantScaleRule(object): | |
def __call__(self, scale: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: | |
return scale | |
class MultiviewScaleRule(object): | |
def __init__(self, min_scale: float = 1.0): | |
self.min_scale = min_scale | |
def __call__( | |
self, | |
scale: Union[float, torch.Tensor], | |
c2w: torch.Tensor, | |
K: torch.Tensor, | |
input_frame_mask: torch.Tensor, | |
) -> torch.Tensor: | |
c2w_input = c2w[input_frame_mask] | |
rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values | |
translation_diff = ( | |
get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values | |
) | |
K_diff = ( | |
((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1) | |
) | |
close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff | |
if isinstance(scale, torch.Tensor): | |
scale = scale.clone() | |
scale[close_frame] = self.min_scale | |
elif isinstance(scale, float): | |
scale = torch.where(close_frame, self.min_scale, scale) | |
else: | |
raise ValueError(f"Invalid scale type {type(scale)}.") | |
return scale | |
class ConstantScaleSchedule(object): | |
def __call__( | |
self, sigma: Union[float, torch.Tensor], scale: Union[float, torch.Tensor] | |
) -> Union[float, torch.Tensor]: | |
if isinstance(sigma, float): | |
return scale | |
elif isinstance(sigma, torch.Tensor): | |
if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor): | |
sigma = append_dims(sigma, scale.ndim) | |
return scale * torch.ones_like(sigma) | |
else: | |
raise ValueError(f"Invalid sigma type {type(sigma)}.") | |
class ConstantGuidance(object): | |
def __call__( | |
self, | |
uncond: torch.Tensor, | |
cond: torch.Tensor, | |
scale: Union[float, torch.Tensor], | |
) -> torch.Tensor: | |
if isinstance(scale, torch.Tensor) and len(scale.shape) == 1: | |
scale = append_dims(scale, cond.ndim) | |
return uncond + scale * (cond - uncond) | |
class VanillaCFG(object): | |
def __init__(self): | |
self.scale_rule = ConstantScaleRule() | |
self.scale_schedule = ConstantScaleSchedule() | |
self.guidance = ConstantGuidance() | |
def __call__( | |
self, x: torch.Tensor, sigma: Union[float, torch.Tensor], scale: Union[float, torch.Tensor] | |
) -> torch.Tensor: | |
x_u, x_c = x.chunk(2) | |
scale = self.scale_rule(scale) | |
scale_value = self.scale_schedule(sigma, scale) | |
x_pred = self.guidance(x_u, x_c, scale_value) | |
return x_pred | |
def prepare_inputs( | |
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict | |
) -> tuple[torch.Tensor, torch.Tensor, dict]: | |
c_out = dict() | |
for k in c: | |
if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]: | |
c_out[k] = torch.cat((uc[k], c[k]), 0) | |
else: | |
assert c[k] == uc[k] | |
c_out[k] = c[k] | |
return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
class MultiviewCFG(VanillaCFG): | |
def __init__(self, cfg_min: float = 1.0): | |
self.scale_min = cfg_min | |
self.scale_rule = MultiviewScaleRule(min_scale=cfg_min) | |
self.scale_schedule = ConstantScaleSchedule() | |
self.guidance = ConstantGuidance() | |
def __call__( # type: ignore | |
self, | |
x: torch.Tensor, | |
sigma: Union[float, torch.Tensor], | |
scale: Union[float, torch.Tensor], | |
c2w: torch.Tensor, | |
K: torch.Tensor, | |
input_frame_mask: torch.Tensor, | |
) -> torch.Tensor: | |
x_u, x_c = x.chunk(2) | |
scale = self.scale_rule(scale, c2w, K, input_frame_mask) | |
scale_value = self.scale_schedule(sigma, scale) | |
x_pred = self.guidance(x_u, x_c, scale_value) | |
return x_pred | |
class MultiviewTemporalCFG(MultiviewCFG): | |
def __init__(self, num_frames: int, cfg_min: float = 1.0): | |
super().__init__(cfg_min=cfg_min) | |
self.num_frames = num_frames | |
distance_matrix = ( | |
torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None] | |
).abs() | |
self.distance_matrix = distance_matrix | |
def __call__( | |
self, | |
x: torch.Tensor, | |
sigma: Union[float, torch.Tensor], | |
scale: Union[float, torch.Tensor], | |
c2w: torch.Tensor, | |
K: torch.Tensor, | |
input_frame_mask: torch.Tensor, | |
) -> torch.Tensor: | |
input_frame_mask = rearrange( | |
input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames | |
) | |
min_distance = ( | |
self.distance_matrix[None].to(x.device) | |
+ (~input_frame_mask[:, None]) * self.num_frames | |
).min(-1)[0] | |
min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1) | |
scale = min_distance * (scale - self.scale_min) + self.scale_min | |
scale = rearrange(scale, "b t ... -> (b t) ...") | |
scale = append_dims(scale, x.ndim) | |
return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1)) | |
class EulerEDMSampler(object): | |
def __init__( | |
self, | |
discretization: DDPMDiscretization, | |
guider: Union[VanillaCFG, MultiviewCFG, MultiviewTemporalCFG], | |
num_steps: Union[int, None] = None, | |
verbose: bool = False, | |
device: Union[str, torch.device] = "cuda", | |
s_churn=0.0, | |
s_tmin=0.0, | |
s_tmax=float("inf"), | |
s_noise=1.0, | |
): | |
self.num_steps = num_steps | |
self.discretization = discretization | |
self.guider = guider | |
self.verbose = verbose | |
self.device = device | |
self.s_churn = s_churn | |
self.s_tmin = s_tmin | |
self.s_tmax = s_tmax | |
self.s_noise = s_noise | |
def prepare_sampling_loop( | |
self, x: torch.Tensor, cond: dict, uc: dict, num_steps: Union[int, None] = None | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]: | |
num_steps = num_steps or self.num_steps | |
assert num_steps is not None, "num_steps must be specified" | |
sigmas = self.discretization(num_steps, device=self.device) | |
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) | |
num_sigmas = len(sigmas) | |
s_in = x.new_ones([x.shape[0]]) | |
return x, s_in, sigmas, num_sigmas, cond, uc | |
def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> Union[range, tqdm]: | |
sigma_generator = range(num_sigmas - 1) | |
if self.verbose and verbose: | |
sigma_generator = tqdm( | |
sigma_generator, | |
total=num_sigmas - 1, | |
desc="Sampling", | |
leave=False, | |
) | |
return sigma_generator | |
def sampler_step( | |
self, | |
sigma: torch.Tensor, | |
next_sigma: torch.Tensor, | |
denoiser, | |
x: torch.Tensor, | |
scale: Union[float, torch.Tensor], | |
cond: dict, | |
uc: dict, | |
gamma: float = 0.0, | |
**guider_kwargs, | |
) -> torch.Tensor: | |
sigma_hat = sigma * (gamma + 1.0) + 1e-6 | |
eps = torch.randn_like(x) * self.s_noise | |
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 | |
denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc)) | |
denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs) | |
d = to_d(x, sigma_hat, denoised) | |
dt = append_dims(next_sigma - sigma_hat, x.ndim) | |
return x + dt * d | |
def __call__( | |
self, | |
denoiser, | |
x: torch.Tensor, | |
scale: Union[float, torch.Tensor], | |
cond: dict, | |
uc: Union[dict, None] = None, | |
num_steps: Union[int, None] = None, | |
verbose: bool = True, | |
**guider_kwargs, | |
) -> torch.Tensor: | |
uc = cond if uc is None else uc | |
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( | |
x, | |
cond, | |
uc, | |
num_steps, | |
) | |
for i in self.get_sigma_gen(num_sigmas, verbose=verbose): | |
gamma = ( | |
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) | |
if self.s_tmin <= sigmas[i] <= self.s_tmax | |
else 0.0 | |
) | |
x = self.sampler_step( | |
s_in * sigmas[i], | |
s_in * sigmas[i + 1], | |
denoiser, | |
x, | |
scale, | |
cond, | |
uc, | |
gamma, | |
**guider_kwargs, | |
) | |
return x | |
def create_samplers( | |
guider_types: Union[int, list[int]], | |
discretization, | |
num_frames: Union[list[int], None], | |
num_steps: int=50, | |
cfg_min: float = 1.2, | |
device: Union[str, torch.device] = "cuda" | |
): | |
guider_mapping = { | |
0: VanillaCFG, | |
1: MultiviewCFG, | |
2: MultiviewTemporalCFG, | |
} | |
samplers = [] | |
if not isinstance(guider_types, (list, tuple)): | |
guider_types = [guider_types] | |
for i, guider_type in enumerate(guider_types): | |
if guider_type not in guider_mapping: | |
raise ValueError( | |
f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}" | |
) | |
guider_cls = guider_mapping[guider_type] | |
guider_args = () | |
if guider_type > 0: | |
guider_args += (cfg_min,) | |
if guider_type == 2: | |
assert num_frames is not None | |
guider_args = (num_frames[i], cfg_min) | |
guider = guider_cls(*guider_args) | |
sampler = EulerEDMSampler( | |
discretization=discretization, | |
guider=guider, | |
num_steps=num_steps, | |
s_churn=0.0, | |
s_tmin=0.0, | |
s_tmax=999.0, | |
s_noise=1.0, | |
verbose=True, | |
device=device, | |
) | |
samplers.append(sampler) | |
return samplers | |