TiM / tim /schedulers /transports.py
blanchon's picture
Update
3ed0796
import torch
from typing import Optional
class Transport:
def __init__(self, sigma_d, T_max, T_min, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0):
self.sigma_d = sigma_d
self.T_max = T_max
self.T_min = T_min
self.enhance_target = enhance_target
self.w_gt = w_gt
self.w_cond = w_cond
self.w_start = w_start
self.w_end = w_end
def sample_t(self, batch_size, dtype, device):
pass
def c_noise(self, t: torch.Tensor):
pass
def interpolant(self, t: torch.Tensor):
pass
def target(self, x_t: torch.Tensor, v_t: torch.Tensor, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor, r: torch.Tensor, dF_dv_dt: torch.Tensor, F_t_cond: torch.Tensor, F_t_uncond: torch.Tensor):
pass
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor):
pass
class OT_FM(Transport):
def __init__(self, P_mean=0.0, P_std=1.0, sigma_d=1.0, T_max=1.0, T_min=0.0, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0):
'''
Flow-matching with linear path formulation from the paper:
"SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers"
'''
self.P_mean = P_mean
self.P_std = P_std
super().__init__(sigma_d, T_max, T_min, enhance_target, w_gt, w_cond, w_start, w_end)
def interpolant(self, t: torch.Tensor):
alpha_t = 1 - t
sigma_t = t
d_alpha_t = -1
d_sigma_t = 1
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
def sample_t(self, batch_size, dtype, device):
rnd_normal = torch.randn((batch_size, ), dtype=dtype, device=device)
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
t = sigma / (1 + sigma) # [0, 1]
return t
def c_noise(self, t: torch.Tensor):
return t
def target(
self,
x_t: torch.Tensor,
v_t: torch.Tensor,
x: torch.Tensor,
z: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
dF_dv_dt: torch.Tensor,
F_t_cond: Optional[torch.Tensor] = 0.0,
F_t_uncond: Optional[torch.Tensor] = 0.0,
enhance_target = False,
):
if enhance_target:
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0)
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0)
v_t = w_gt * v_t + w_cond * F_t_cond + (1-w_gt-w_cond) * F_t_uncond
F_target = v_t - (t - r) * dF_dv_dt
return F_target
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
x_r = x_t - (t - r) * F
if s_ratio > 0.0:
z = x_t + (1-t) * F
epsilon = torch.randn_like(z)
dt = t-r
x_r = x_r - s_ratio * z * dt + torch.sqrt(s_ratio*2*t*dt) * epsilon
return x_r
class TrigFlow(Transport):
def __init__(self, P_mean=-1.0, P_std=1.6, sigma_d=0.5, T_max=1.57, T_min=0.0, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0):
'''
TrigFlow formulation from the paper:
"SIMPLIFYING, STABILIZING & SCALING CONTINUOUS-TIME CONSISTENCY MODELS"
'''
self.P_mean = P_mean
self.P_std = P_std
super().__init__(sigma_d, T_max, T_min, enhance_target, w_gt, w_cond, w_start, w_end)
def interpolant(self, t: torch.Tensor):
alpha_t = torch.cos(t)
sigma_t = torch.sin(t)
d_alpha_t = -torch.sin(t)
d_sigma_t = torch.cos(t)
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
def sample_t(self, batch_size, dtype, device):
rnd_normal = torch.randn((batch_size, ), dtype=dtype, device=device)
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
t = torch.atan(sigma) # [0, pi/2]
return t
def c_noise(self, t: torch.Tensor):
return t
def target(
self,
x_t: torch.Tensor,
v_t: torch.Tensor,
x: torch.Tensor,
z: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
dF_dv_dt: torch.Tensor,
F_t_cond: Optional[torch.Tensor] = 0.0,
F_t_uncond: Optional[torch.Tensor] = 0.0,
enhance_target = False,
):
if enhance_target:
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0)
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0)
v_t = w_gt * v_t + w_cond * F_t_cond + (1-w_gt-w_cond) * F_t_uncond
F_target = v_t - torch.tan(t - r) * (x_t + dF_dv_dt)
return F_target
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
x_r = torch.cos(t - r) * x_t - torch.sin(t - r) * F
return x_r
class EDM(Transport):
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_d=0.5, T_max=80.0, T_min=0.01, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0):
'''
EDM formulation from the paper:
"Elucidating the Design Space of Diffusion-Based Generative Models"
'''
self.P_mean = P_mean
self.P_std = P_std
super().__init__(sigma_d, T_max, T_min, enhance_target, w_gt, w_cond, w_start, w_end)
def interpolant(self, t: torch.Tensor):
'''
The d_alpha_t and d_sigma_t are easy to obtain:
# from sympy import *
# from scipy.stats import *
# t, sigma_d = symbols('t sigma_d')
# alpha_t = sigma_d * ((t**2 + sigma_d**2) ** (-0.5))
# sigma_t = t * ((t**2 + sigma_d**2) ** (-0.5))
# d_alpha_t = diff(alpha_t, t)
# d_sigma_t = diff(sigma_t, t)
# print(d_alpha_t)
# print(d_sigma_t)
'''
sigma_d = self.sigma_d
alpha_t = 1 / (t**2 + sigma_d**2).sqrt()
sigma_t = t / (t**2 + sigma_d**2).sqrt()
d_alpha_t = -t / ((sigma_d ** 2 + t ** 2) ** 1.5)
d_sigma_t = (sigma_d ** 2) / ((sigma_d ** 2 + t ** 2) ** 1.5)
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
def sample_t(self, batch_size, dtype, device):
rnd_normal = torch.randn((batch_size, ), dtype=dtype, device=device)
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
t = sigma # t > 0
return t
def c_noise(self, t: torch.Tensor):
return torch.log(t) / 4
def target(
self,
x_t: torch.Tensor,
v_t: torch.Tensor,
x: torch.Tensor,
z: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
dF_dv_dt: torch.Tensor,
F_t_cond: Optional[torch.Tensor] = 0.0,
F_t_uncond: Optional[torch.Tensor] = 0.0,
enhance_target = False,
):
sigma_d = self.sigma_d
alpha_hat_t = t / (sigma_d * (t**2 + sigma_d**2).sqrt())
sigma_hat_t = - sigma_d / (t**2 + sigma_d**2).sqrt()
d_alpha_hat_t = -t**2/(sigma_d*(sigma_d**2 + t**2)**(3/2)) + 1/(sigma_d*(sigma_d**2 + t**2).sqrt())
d_sigma_hat_t = sigma_d * t / ((sigma_d**2 + t**2)**(3/2))
diffusion_target = alpha_hat_t * x + sigma_hat_t * z
Bt_dv_dBt = (t - r) * (sigma_d**2 + t**2) * (sigma_d**3 + t**2) / (
2*t*(r - t)*(sigma_d**2 + t**2) - t*(r - t)*(sigma_d**3 + t**2) + (sigma_d**2 + t**2)*(sigma_d**3 + t**2)
)
if enhance_target:
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0)
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0)
diffusion_target = w_gt * diffusion_target + w_cond * F_t_cond + (1-w_gt-w_cond) * F_t_uncond
F_target = diffusion_target + Bt_dv_dBt * (d_alpha_hat_t*x + d_sigma_hat_t*z -dF_dv_dt)
return F_target
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
sigma_d = self.sigma_d
ratio = (t**2 + sigma_d**2).sqrt() / (r**2 + sigma_d**2).sqrt() / (sigma_d**3 + t**2)
A_t = (sigma_d**3 + t*r) * ratio
B_t = (sigma_d**2) * (t-r) * ratio
x_r = A_t * x_t + B_t * F
return x_r
class VP_SDE(Transport):
def __init__(self, beta_min=0.1, beta_d=19.9, epsilon_t=1e-5, T=1000, sigma_d=1.0, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0):
'''
Variance preserving (VP) formulation from the paper:
"Score-Based Generative Modeling through Stochastic Differential Equations".
'''
self.beta_min = beta_min
self.beta_d = beta_d
self.epsilon_t = epsilon_t
self.T = T
super().__init__(sigma_d, 1.0, epsilon_t, enhance_target, w_gt, w_cond, w_start, w_end)
def interpolant(self, t: torch.Tensor):
'''
The d_alpha_t and d_sigma_t are easy to obtain:
# from sympy import *
# from scipy.stats import *
# t, beta_d, beta_min = symbols('t beta_d beta_min')
# sigma = sqrt(exp(0.5 * beta_d * (t ** 2) + beta_min * t) - 1)
# d_sigma_d_t = diff(sigma, t)
# print(d_sigma_d_t)
# sigma = symbols('sigma')
# alpha_t = (sigma**2 + 1) ** (-0.5)
# sigma_t = sigma * (sigma**2 + 1) ** (-0.5)
# d_alpha_d_sigma = diff(alpha_t, sigma)
# print(d_alpha_d_sigma)
# d_sigma_d_sigma = diff(sigma_t, sigma)
# print(d_sigma_d_sigma)
'''
beta_t = self.beta(t)
alpha_t = 1 / torch.sqrt(beta_t**2 + 1)
sigma_t = beta_t / torch.sqrt(beta_t**2 + 1)
d_alpha_t = -0.5 * (self.beta_d * t + self.beta_min) / (beta_t**2 + 1).sqrt()
d_sigma_t = 0.5 * (self.beta_d * t + self.beta_min) / (beta_t * (beta_t**2 + 1).sqrt())
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
def beta(self, t: torch.Tensor):
return torch.sqrt((0.5 * self.beta_d * (t ** 2) + self. beta_min * t).exp() - 1)
def sample_t(self, batch_size, dtype, device):
rnd_uniform = torch.rand((batch_size, ), dtype=dtype, device=device)
t = 1 + rnd_uniform * (self.epsilon_t - 1) # [epsilon_t, 1]
return t
def c_noise(self, t: torch.Tensor):
return (self.T - 1) * t
def target(
self,
x_t: torch.Tensor,
v_t: torch.Tensor,
x: torch.Tensor,
z: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
dF_dv_dt: torch.Tensor,
F_t_cond: Optional[torch.Tensor] = 0.0,
F_t_uncond: Optional[torch.Tensor] = 0.0,
enhance_target = False,
):
if enhance_target:
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0)
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0)
z = w_gt * z + w_cond * F_t_cond + (1-w_gt-w_cond) * F_t_uncond
beta_t = self.beta(t)
beta_r = self.beta(r)
d_beta_t = (self.beta_d * t + self.beta_min) * (beta_t ** 2 + 1) / (2 * beta_t)
F_target = z - dF_dv_dt * (beta_t - beta_r) / d_beta_t
return F_target
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
beta_t = self.beta(t)
beta_r = self.beta(r)
A_t = (beta_t ** 2 + 1).sqrt() / (beta_r ** 2 + 1).sqrt()
B_t = (beta_r - beta_t) / (beta_r ** 2 + 1).sqrt()
x_r = A_t * x_t + B_t * F
return x_r
class VE_SDE(Transport):
def __init__(self, sigma_min=0.02, sigma_max=100, sigma_d=1.0, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0):
'''
Variance exploding (VE) formulation from the paper:
"Score-Based Generative Modeling through Stochastic Differential Equations".
'''
self.sigma_min = sigma_min
self.sigma_max = sigma_max
super().__init__(sigma_d, sigma_max, sigma_min, enhance_target, w_gt, w_cond, w_start, w_end)
def interpolant(self, t: torch.Tensor):
alpha_t = 1
sigma_t = t
d_alpha_t = 0
d_sigma_t = 1
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
def sample_t(self, batch_size, dtype, device):
rnd_uniform = torch.rand((batch_size, ), dtype=dtype, device=device)
t = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) # [sigma_min, sigma_max]
return t
def c_noise(self, t: torch.Tensor):
return torch.log(0.5 * t)
def target(
self,
x_t: torch.Tensor,
v_t: torch.Tensor,
x: torch.Tensor,
z: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
dF_dv_dt: torch.Tensor,
F_t_cond: Optional[torch.Tensor] = 0.0,
F_t_uncond: Optional[torch.Tensor] = 0.0,
enhance_target = False,
):
if enhance_target:
w_gt = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_gt, 1.0)
w_cond = torch.where((t>=self.w_start) & (t<=self.w_end), self.w_cond, 0.0)
z = w_gt * z + w_cond * (-F_t_cond) + (1-w_gt-w_cond) * (-F_t_uncond)
F_target = (r - t) * dF_dv_dt - z
return F_target
def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
x_r = x_t + (t - r) * F
return x_r