|
import torch |
|
from typing import Optional |
|
|
|
class Transport: |
|
def __init__(self, sigma_d, T_max, T_min, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0): |
|
self.sigma_d = sigma_d |
|
self.T_max = T_max |
|
self.T_min = T_min |
|
self.enhance_target = enhance_target |
|
self.w_gt = w_gt |
|
self.w_cond = w_cond |
|
self.w_start = w_start |
|
self.w_end = w_end |
|
|
|
def sample_t(self, batch_size, dtype, device): |
|
pass |
|
def c_noise(self, t: torch.Tensor): |
|
pass |
|
def interpolant(self, t: torch.Tensor): |
|
pass |
|
def target(self, x_t: torch.Tensor, v_t: torch.Tensor, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor, r: torch.Tensor, dF_dv_dt: torch.Tensor, F_t_cond: torch.Tensor, F_t_uncond: torch.Tensor): |
|
pass |
|
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor): |
|
pass |
|
|
|
class OT_FM(Transport): |
|
def __init__(self, P_mean=0.0, P_std=1.0, sigma_d=1.0, T_max=1.0, T_min=0.0, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0): |
|
''' |
|
Flow-matching with linear path formulation from the paper: |
|
"SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers" |
|
''' |
|
self.P_mean = P_mean |
|
self.P_std = P_std |
|
super().__init__(sigma_d, T_max, T_min, enhance_target, w_gt, w_cond, w_start, w_end) |
|
|
|
def interpolant(self, t: torch.Tensor): |
|
alpha_t = 1 - t |
|
sigma_t = t |
|
d_alpha_t = -1 |
|
d_sigma_t = 1 |
|
return alpha_t, sigma_t, d_alpha_t, d_sigma_t |
|
|
|
def sample_t(self, batch_size, dtype, device): |
|
rnd_normal = torch.randn((batch_size, ), dtype=dtype, device=device) |
|
sigma = (rnd_normal * self.P_std + self.P_mean).exp() |
|
t = sigma / (1 + sigma) |
|
return t |
|
|
|
def c_noise(self, t: torch.Tensor): |
|
return t |
|
|
|
def target( |
|
self, |
|
x_t: torch.Tensor, |
|
v_t: torch.Tensor, |
|
x: torch.Tensor, |
|
z: torch.Tensor, |
|
t: torch.Tensor, |
|
r: torch.Tensor, |
|
dF_dv_dt: torch.Tensor, |
|
F_t_cond: Optional[torch.Tensor] = 0.0, |
|
F_t_uncond: Optional[torch.Tensor] = 0.0, |
|
enhance_target = False, |
|
): |
|
if enhance_target: |
|
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0) |
|
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0) |
|
v_t = w_gt * v_t + w_cond * F_t_cond + (1-w_gt-w_cond) * F_t_uncond |
|
F_target = v_t - (t - r) * dF_dv_dt |
|
return F_target |
|
|
|
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0): |
|
x_r = x_t - (t - r) * F |
|
if s_ratio > 0.0: |
|
z = x_t + (1-t) * F |
|
epsilon = torch.randn_like(z) |
|
dt = t-r |
|
x_r = x_r - s_ratio * z * dt + torch.sqrt(s_ratio*2*t*dt) * epsilon |
|
return x_r |
|
|
|
|
|
|
|
|
|
class TrigFlow(Transport): |
|
def __init__(self, P_mean=-1.0, P_std=1.6, sigma_d=0.5, T_max=1.57, T_min=0.0, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0): |
|
''' |
|
TrigFlow formulation from the paper: |
|
"SIMPLIFYING, STABILIZING & SCALING CONTINUOUS-TIME CONSISTENCY MODELS" |
|
''' |
|
self.P_mean = P_mean |
|
self.P_std = P_std |
|
super().__init__(sigma_d, T_max, T_min, enhance_target, w_gt, w_cond, w_start, w_end) |
|
|
|
def interpolant(self, t: torch.Tensor): |
|
alpha_t = torch.cos(t) |
|
sigma_t = torch.sin(t) |
|
d_alpha_t = -torch.sin(t) |
|
d_sigma_t = torch.cos(t) |
|
return alpha_t, sigma_t, d_alpha_t, d_sigma_t |
|
|
|
def sample_t(self, batch_size, dtype, device): |
|
rnd_normal = torch.randn((batch_size, ), dtype=dtype, device=device) |
|
sigma = (rnd_normal * self.P_std + self.P_mean).exp() |
|
t = torch.atan(sigma) |
|
return t |
|
|
|
def c_noise(self, t: torch.Tensor): |
|
return t |
|
|
|
def target( |
|
self, |
|
x_t: torch.Tensor, |
|
v_t: torch.Tensor, |
|
x: torch.Tensor, |
|
z: torch.Tensor, |
|
t: torch.Tensor, |
|
r: torch.Tensor, |
|
dF_dv_dt: torch.Tensor, |
|
F_t_cond: Optional[torch.Tensor] = 0.0, |
|
F_t_uncond: Optional[torch.Tensor] = 0.0, |
|
enhance_target = False, |
|
): |
|
if enhance_target: |
|
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0) |
|
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0) |
|
v_t = w_gt * v_t + w_cond * F_t_cond + (1-w_gt-w_cond) * F_t_uncond |
|
F_target = v_t - torch.tan(t - r) * (x_t + dF_dv_dt) |
|
return F_target |
|
|
|
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0): |
|
x_r = torch.cos(t - r) * x_t - torch.sin(t - r) * F |
|
return x_r |
|
|
|
|
|
class EDM(Transport): |
|
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_d=0.5, T_max=80.0, T_min=0.01, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0): |
|
''' |
|
EDM formulation from the paper: |
|
"Elucidating the Design Space of Diffusion-Based Generative Models" |
|
''' |
|
self.P_mean = P_mean |
|
self.P_std = P_std |
|
super().__init__(sigma_d, T_max, T_min, enhance_target, w_gt, w_cond, w_start, w_end) |
|
|
|
def interpolant(self, t: torch.Tensor): |
|
''' |
|
The d_alpha_t and d_sigma_t are easy to obtain: |
|
# from sympy import * |
|
# from scipy.stats import * |
|
# t, sigma_d = symbols('t sigma_d') |
|
# alpha_t = sigma_d * ((t**2 + sigma_d**2) ** (-0.5)) |
|
# sigma_t = t * ((t**2 + sigma_d**2) ** (-0.5)) |
|
# d_alpha_t = diff(alpha_t, t) |
|
# d_sigma_t = diff(sigma_t, t) |
|
# print(d_alpha_t) |
|
# print(d_sigma_t) |
|
''' |
|
sigma_d = self.sigma_d |
|
alpha_t = 1 / (t**2 + sigma_d**2).sqrt() |
|
sigma_t = t / (t**2 + sigma_d**2).sqrt() |
|
d_alpha_t = -t / ((sigma_d ** 2 + t ** 2) ** 1.5) |
|
d_sigma_t = (sigma_d ** 2) / ((sigma_d ** 2 + t ** 2) ** 1.5) |
|
return alpha_t, sigma_t, d_alpha_t, d_sigma_t |
|
|
|
def sample_t(self, batch_size, dtype, device): |
|
rnd_normal = torch.randn((batch_size, ), dtype=dtype, device=device) |
|
sigma = (rnd_normal * self.P_std + self.P_mean).exp() |
|
t = sigma |
|
return t |
|
|
|
def c_noise(self, t: torch.Tensor): |
|
return torch.log(t) / 4 |
|
|
|
def target( |
|
self, |
|
x_t: torch.Tensor, |
|
v_t: torch.Tensor, |
|
x: torch.Tensor, |
|
z: torch.Tensor, |
|
t: torch.Tensor, |
|
r: torch.Tensor, |
|
dF_dv_dt: torch.Tensor, |
|
F_t_cond: Optional[torch.Tensor] = 0.0, |
|
F_t_uncond: Optional[torch.Tensor] = 0.0, |
|
enhance_target = False, |
|
): |
|
sigma_d = self.sigma_d |
|
alpha_hat_t = t / (sigma_d * (t**2 + sigma_d**2).sqrt()) |
|
sigma_hat_t = - sigma_d / (t**2 + sigma_d**2).sqrt() |
|
d_alpha_hat_t = -t**2/(sigma_d*(sigma_d**2 + t**2)**(3/2)) + 1/(sigma_d*(sigma_d**2 + t**2).sqrt()) |
|
d_sigma_hat_t = sigma_d * t / ((sigma_d**2 + t**2)**(3/2)) |
|
diffusion_target = alpha_hat_t * x + sigma_hat_t * z |
|
Bt_dv_dBt = (t - r) * (sigma_d**2 + t**2) * (sigma_d**3 + t**2) / ( |
|
2*t*(r - t)*(sigma_d**2 + t**2) - t*(r - t)*(sigma_d**3 + t**2) + (sigma_d**2 + t**2)*(sigma_d**3 + t**2) |
|
) |
|
if enhance_target: |
|
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0) |
|
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0) |
|
diffusion_target = w_gt * diffusion_target + w_cond * F_t_cond + (1-w_gt-w_cond) * F_t_uncond |
|
F_target = diffusion_target + Bt_dv_dBt * (d_alpha_hat_t*x + d_sigma_hat_t*z -dF_dv_dt) |
|
return F_target |
|
|
|
|
|
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0): |
|
sigma_d = self.sigma_d |
|
ratio = (t**2 + sigma_d**2).sqrt() / (r**2 + sigma_d**2).sqrt() / (sigma_d**3 + t**2) |
|
A_t = (sigma_d**3 + t*r) * ratio |
|
B_t = (sigma_d**2) * (t-r) * ratio |
|
x_r = A_t * x_t + B_t * F |
|
return x_r |
|
|
|
|
|
class VP_SDE(Transport): |
|
def __init__(self, beta_min=0.1, beta_d=19.9, epsilon_t=1e-5, T=1000, sigma_d=1.0, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0): |
|
''' |
|
Variance preserving (VP) formulation from the paper: |
|
"Score-Based Generative Modeling through Stochastic Differential Equations". |
|
''' |
|
self.beta_min = beta_min |
|
self.beta_d = beta_d |
|
self.epsilon_t = epsilon_t |
|
self.T = T |
|
super().__init__(sigma_d, 1.0, epsilon_t, enhance_target, w_gt, w_cond, w_start, w_end) |
|
|
|
def interpolant(self, t: torch.Tensor): |
|
''' |
|
The d_alpha_t and d_sigma_t are easy to obtain: |
|
# from sympy import * |
|
# from scipy.stats import * |
|
# t, beta_d, beta_min = symbols('t beta_d beta_min') |
|
# sigma = sqrt(exp(0.5 * beta_d * (t ** 2) + beta_min * t) - 1) |
|
# d_sigma_d_t = diff(sigma, t) |
|
# print(d_sigma_d_t) |
|
# sigma = symbols('sigma') |
|
# alpha_t = (sigma**2 + 1) ** (-0.5) |
|
# sigma_t = sigma * (sigma**2 + 1) ** (-0.5) |
|
# d_alpha_d_sigma = diff(alpha_t, sigma) |
|
# print(d_alpha_d_sigma) |
|
# d_sigma_d_sigma = diff(sigma_t, sigma) |
|
# print(d_sigma_d_sigma) |
|
''' |
|
beta_t = self.beta(t) |
|
alpha_t = 1 / torch.sqrt(beta_t**2 + 1) |
|
sigma_t = beta_t / torch.sqrt(beta_t**2 + 1) |
|
d_alpha_t = -0.5 * (self.beta_d * t + self.beta_min) / (beta_t**2 + 1).sqrt() |
|
d_sigma_t = 0.5 * (self.beta_d * t + self.beta_min) / (beta_t * (beta_t**2 + 1).sqrt()) |
|
return alpha_t, sigma_t, d_alpha_t, d_sigma_t |
|
|
|
def beta(self, t: torch.Tensor): |
|
return torch.sqrt((0.5 * self.beta_d * (t ** 2) + self. beta_min * t).exp() - 1) |
|
|
|
def sample_t(self, batch_size, dtype, device): |
|
rnd_uniform = torch.rand((batch_size, ), dtype=dtype, device=device) |
|
t = 1 + rnd_uniform * (self.epsilon_t - 1) |
|
return t |
|
|
|
def c_noise(self, t: torch.Tensor): |
|
return (self.T - 1) * t |
|
|
|
def target( |
|
self, |
|
x_t: torch.Tensor, |
|
v_t: torch.Tensor, |
|
x: torch.Tensor, |
|
z: torch.Tensor, |
|
t: torch.Tensor, |
|
r: torch.Tensor, |
|
dF_dv_dt: torch.Tensor, |
|
F_t_cond: Optional[torch.Tensor] = 0.0, |
|
F_t_uncond: Optional[torch.Tensor] = 0.0, |
|
enhance_target = False, |
|
): |
|
if enhance_target: |
|
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0) |
|
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0) |
|
z = w_gt * z + w_cond * F_t_cond + (1-w_gt-w_cond) * F_t_uncond |
|
beta_t = self.beta(t) |
|
beta_r = self.beta(r) |
|
d_beta_t = (self.beta_d * t + self.beta_min) * (beta_t ** 2 + 1) / (2 * beta_t) |
|
F_target = z - dF_dv_dt * (beta_t - beta_r) / d_beta_t |
|
return F_target |
|
|
|
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0): |
|
beta_t = self.beta(t) |
|
beta_r = self.beta(r) |
|
A_t = (beta_t ** 2 + 1).sqrt() / (beta_r ** 2 + 1).sqrt() |
|
B_t = (beta_r - beta_t) / (beta_r ** 2 + 1).sqrt() |
|
x_r = A_t * x_t + B_t * F |
|
return x_r |
|
|
|
|
|
|
|
|
|
class VE_SDE(Transport): |
|
def __init__(self, sigma_min=0.02, sigma_max=100, sigma_d=1.0, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0): |
|
''' |
|
Variance exploding (VE) formulation from the paper: |
|
"Score-Based Generative Modeling through Stochastic Differential Equations". |
|
''' |
|
self.sigma_min = sigma_min |
|
self.sigma_max = sigma_max |
|
super().__init__(sigma_d, sigma_max, sigma_min, enhance_target, w_gt, w_cond, w_start, w_end) |
|
|
|
def interpolant(self, t: torch.Tensor): |
|
alpha_t = 1 |
|
sigma_t = t |
|
d_alpha_t = 0 |
|
d_sigma_t = 1 |
|
return alpha_t, sigma_t, d_alpha_t, d_sigma_t |
|
|
|
def sample_t(self, batch_size, dtype, device): |
|
rnd_uniform = torch.rand((batch_size, ), dtype=dtype, device=device) |
|
t = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) |
|
return t |
|
|
|
def c_noise(self, t: torch.Tensor): |
|
return torch.log(0.5 * t) |
|
|
|
def target( |
|
self, |
|
x_t: torch.Tensor, |
|
v_t: torch.Tensor, |
|
x: torch.Tensor, |
|
z: torch.Tensor, |
|
t: torch.Tensor, |
|
r: torch.Tensor, |
|
dF_dv_dt: torch.Tensor, |
|
F_t_cond: Optional[torch.Tensor] = 0.0, |
|
F_t_uncond: Optional[torch.Tensor] = 0.0, |
|
enhance_target = False, |
|
): |
|
if enhance_target: |
|
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0) |
|
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0) |
|
z = w_gt * z + w_cond * (-F_t_cond) + (1-w_gt-w_cond) * (-F_t_uncond) |
|
F_target = (r - t) * dF_dv_dt - z |
|
return F_target |
|
|
|
|
|
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0): |
|
x_r = x_t + (t - r) * F |
|
return x_r |
|
|