|
import math |
|
import os |
|
import random |
|
import glob |
|
import pickle |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
from torch.optim import Adam |
|
from torchvision.utils import make_grid |
|
from PIL import Image |
|
from transformers import ( |
|
DistilBertModel, |
|
DistilBertTokenizer, |
|
CLIPTokenizer, |
|
CLIPTextModel, |
|
) |
|
|
|
dataset_params = { |
|
"image_path": "data/CelebAMask-HQ", |
|
"image_channels": 3, |
|
"image_size": 256, |
|
"name": "celebhq", |
|
} |
|
|
|
diffusion_params = { |
|
"num_timesteps": 1000, |
|
"beta_start": 0.00085, |
|
"beta_end": 0.012, |
|
} |
|
|
|
ldm_params = { |
|
"down_channels": [256, 384, 512, 768], |
|
"mid_channels": [768, 512], |
|
"down_sample": [True, True, True], |
|
"attn_down": [True, True, True], |
|
"time_emb_dim": 512, |
|
"norm_channels": 32, |
|
"num_heads": 16, |
|
"conv_out_channels": 128, |
|
"num_down_layers": 2, |
|
"num_mid_layers": 2, |
|
"num_up_layers": 2, |
|
"condition_config": { |
|
"condition_types": ["text", "image"], |
|
"text_condition_config": { |
|
"text_embed_model": "clip", |
|
"train_text_embed_model": False, |
|
"text_embed_dim": 512, |
|
"cond_drop_prob": 0.1, |
|
}, |
|
"image_condition_config": { |
|
"image_condition_input_channels": 18, |
|
"image_condition_output_channels": 3, |
|
"image_condition_h": 512, |
|
"image_condition_w": 512, |
|
"cond_drop_prob": 0.1, |
|
}, |
|
}, |
|
} |
|
|
|
autoencoder_params = { |
|
"z_channels": 4, |
|
"codebook_size": 8192, |
|
"down_channels": [64, 128, 256, 256], |
|
"mid_channels": [256, 256], |
|
"down_sample": [True, True, True], |
|
"attn_down": [False, False, False], |
|
"norm_channels": 32, |
|
"num_heads": 4, |
|
"num_down_layers": 2, |
|
"num_mid_layers": 2, |
|
"num_up_layers": 2, |
|
} |
|
|
|
train_params = { |
|
"task_name": "celebhq", |
|
"num_samples": 1, |
|
"num_grid_rows": 1, |
|
"cf_guidance_scale": 1.0, |
|
"ldm_ckpt_name": "ddpm_ckpt_class_cond.pth", |
|
"vqvae_autoencoder_ckpt_name": "vqvae_autoencoder_ckpt.pth", |
|
"vqvae_latent_dir_name": "vqvae_latents", |
|
} |
|
|
|
|
|
def get_config_value(config, key, default_value): |
|
return config[key] if key in config else default_value |
|
|
|
|
|
def spatial_average(in_tens, keepdim=True): |
|
return in_tens.mean([2, 3], keepdim=keepdim) |
|
|
|
|
|
class LinearNoiseScheduler: |
|
def __init__(self, num_timesteps, beta_start, beta_end): |
|
self.num_timesteps = num_timesteps |
|
self.beta_start = beta_start |
|
self.beta_end = beta_end |
|
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2 |
|
self.alphas = 1.0 - self.betas |
|
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) |
|
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) |
|
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) |
|
|
|
def add_noise(self, original, noise, t): |
|
|
|
batch_size = original.shape[0] |
|
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].view( |
|
batch_size, 1, 1, 1 |
|
) |
|
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to( |
|
original.device |
|
)[t].view(batch_size, 1, 1, 1) |
|
return sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise |
|
|
|
def sample_prev_timestep(self, xt, noise_pred, t): |
|
batch_size = xt.shape[0] |
|
alpha_cum_prod_t = self.alpha_cum_prod.to(xt.device)[t].view( |
|
batch_size, 1, 1, 1 |
|
) |
|
sqrt_one_minus_alpha_cum_prod_t = self.sqrt_one_minus_alpha_cum_prod.to( |
|
xt.device |
|
)[t].view(batch_size, 1, 1, 1) |
|
x0 = (xt - sqrt_one_minus_alpha_cum_prod_t * noise_pred) / torch.sqrt( |
|
alpha_cum_prod_t |
|
) |
|
x0 = torch.clamp(x0, -1.0, 1.0) |
|
betas_t = self.betas.to(xt.device)[t].view(batch_size, 1, 1, 1) |
|
mean = ( |
|
xt - betas_t / sqrt_one_minus_alpha_cum_prod_t * noise_pred |
|
) / torch.sqrt(self.alphas.to(xt.device)[t].view(batch_size, 1, 1, 1)) |
|
if t[0] == 0: |
|
return mean, x0 |
|
else: |
|
prev_alpha_cum_prod = self.alpha_cum_prod.to(xt.device)[ |
|
(t - 1).clamp(min=0) |
|
].view(batch_size, 1, 1, 1) |
|
variance = (1 - prev_alpha_cum_prod) / (1 - alpha_cum_prod_t) * betas_t |
|
sigma = variance.sqrt() |
|
z = torch.randn_like(xt) |
|
return mean + sigma * z, x0 |
|
|
|
|
|
def get_tokenizer_and_model(model_type, device, eval_mode=True): |
|
assert model_type in ( |
|
"bert", |
|
"clip", |
|
), "Text model can only be one of 'clip' or 'bert'" |
|
if model_type == "bert": |
|
text_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") |
|
text_model = DistilBertModel.from_pretrained("distilbert-base-uncased").to( |
|
device |
|
) |
|
else: |
|
text_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16") |
|
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch16").to( |
|
device |
|
) |
|
if eval_mode: |
|
text_model.eval() |
|
return text_tokenizer, text_model |
|
|
|
|
|
def get_text_representation(text, text_tokenizer, text_model, device, max_length=77): |
|
token_output = text_tokenizer( |
|
text, |
|
truncation=True, |
|
padding="max_length", |
|
return_attention_mask=True, |
|
max_length=max_length, |
|
) |
|
tokens_tensor = torch.tensor(token_output["input_ids"]).to(device) |
|
mask_tensor = torch.tensor(token_output["attention_mask"]).to(device) |
|
text_embed = text_model(tokens_tensor, attention_mask=mask_tensor).last_hidden_state |
|
return text_embed |
|
|
|
|
|
def get_time_embedding(time_steps, temb_dim): |
|
""" |
|
Convert time steps tensor into an embedding using the sinusoidal time embedding formula |
|
time_steps: 1D tensor of length batch size |
|
temb_dim: Dimension of the embedding |
|
""" |
|
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" |
|
|
|
|
|
factor = 10000 ** ( |
|
( |
|
torch.arange( |
|
start=0, |
|
end=temb_dim // 2, |
|
dtype=torch.float32, |
|
device=time_steps.device, |
|
) |
|
/ (temb_dim // 2) |
|
) |
|
) |
|
|
|
t_emb = time_steps.unsqueeze(dim=-1).repeat(1, temb_dim // 2) / factor |
|
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) |
|
|
|
return t_emb |
|
|
|
|
|
class DownBlock(nn.Module): |
|
""" |
|
Down conv block with attention. |
|
1. Resnet block with time embedding |
|
2. Attention block |
|
3. Downsample |
|
|
|
in_channels: Number of channels in the input feature map. |
|
out_channels: Number of channels produced by this block. |
|
t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. |
|
down_sample: Whether to apply downsampling at the end. |
|
num_heads: Number of attention heads (used if attention is enabled). |
|
num_layers: How many sub-blocks to apply in sequence. |
|
attn: Whether to apply self-attention |
|
norm_channels: Number of groups for GroupNorm. |
|
cross_attn: Whether to apply cross-attention. |
|
context_dim: If performing cross-attention, provide a context_dim for extra conditioning context. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
t_emb_dim, |
|
down_sample, |
|
num_heads, |
|
num_layers, |
|
attn, |
|
norm_channels, |
|
cross_attn=False, |
|
context_dim=None, |
|
): |
|
super().__init__() |
|
|
|
self.num_layers = num_layers |
|
self.down_sample = down_sample |
|
self.attn = attn |
|
self.context_dim = context_dim |
|
self.cross_attn = cross_attn |
|
self.t_emb_dim = t_emb_dim |
|
|
|
self.resnet_conv_first = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm( |
|
norm_channels, in_channels if i == 0 else out_channels |
|
), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels=(in_channels if i == 0 else out_channels), |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
self.t_emb_layers = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear( |
|
in_features=self.t_emb_dim, out_features=out_channels |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.resnet_conv_second = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.residual_input_conv = nn.ModuleList( |
|
[ |
|
nn.Conv2d( |
|
in_channels=(in_channels if i == 0 else out_channels), |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
if self.attn: |
|
self.attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] |
|
) |
|
|
|
self.attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention( |
|
embed_dim=out_channels, num_heads=num_heads, batch_first=True |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
|
|
if self.cross_attn: |
|
assert ( |
|
context_dim is not None |
|
), "Context Dimension must be passed for cross attention" |
|
|
|
self.cross_attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] |
|
) |
|
|
|
self.cross_attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention( |
|
embed_dim=out_channels, num_heads=num_heads, batch_first=True |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.context_proj = nn.ModuleList( |
|
[ |
|
nn.Linear(in_features=context_dim, out_features=out_channels) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
|
|
self.down_sample_conv = ( |
|
nn.Conv2d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
) |
|
if self.down_sample |
|
else nn.Identity() |
|
) |
|
|
|
def forward(self, x, t_emb=None, context=None): |
|
out = x |
|
for i in range(self.num_layers): |
|
|
|
resnet_input = out |
|
|
|
out = self.resnet_conv_first[i](out) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
|
|
out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze( |
|
dim=-1 |
|
) |
|
|
|
out = self.resnet_conv_second[i]( |
|
out |
|
) |
|
|
|
|
|
out = out + self.residual_input_conv[i]( |
|
resnet_input |
|
) |
|
|
|
|
|
if self.attn: |
|
|
|
batch_size, channels, h, w = ( |
|
out.shape |
|
) |
|
|
|
in_attn = out.reshape( |
|
batch_size, channels, h * w |
|
) |
|
in_attn = self.attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
|
|
|
|
out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) |
|
out_attn = out_attn.transpose(1, 2).reshape( |
|
batch_size, channels, h, w |
|
) |
|
|
|
|
|
out = out + out_attn |
|
|
|
if self.cross_attn: |
|
assert ( |
|
context is not None |
|
), "context cannot be None if cross attention layers are used" |
|
|
|
batch_size, channels, h, w = ( |
|
out.shape |
|
) |
|
|
|
in_attn = out.reshape( |
|
batch_size, channels, h * w |
|
) |
|
in_attn = self.cross_attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
|
|
assert ( |
|
context.shape[0] == x.shape[0] |
|
and context.shape[-1] == self.context_dim |
|
) |
|
context_proj = self.context_proj[i]( |
|
context |
|
) |
|
|
|
|
|
out_attn, attn_weights = self.cross_attentions[i]( |
|
in_attn, context_proj, context_proj |
|
) |
|
out_attn = out_attn.transpose(1, 2).reshape( |
|
batch_size, channels, h, w |
|
) |
|
|
|
|
|
out = out + out_attn |
|
|
|
|
|
out = self.down_sample_conv(out) |
|
return out |
|
|
|
|
|
class MidBlock(nn.Module): |
|
""" |
|
Mid conv block with attention. |
|
1. Resnet block with time embedding |
|
2. Attention block |
|
3. Resnet block with time embedding |
|
|
|
in_channels: Number of channels in the input feature map. |
|
out_channels: Number of channels produced by this block. |
|
t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. |
|
num_heads: Number of attention heads (used if attention is enabled). |
|
num_layers: How many sub-blocks to apply in sequence. |
|
norm_channels: Number of groups for GroupNorm. |
|
cross_attn: Whether to apply cross-attention. |
|
context_dim: If performing cross-attention, provide a context_dim for extra conditioning context. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
t_emb_dim, |
|
num_heads, |
|
num_layers, |
|
norm_channels, |
|
cross_attn=None, |
|
context_dim=None, |
|
): |
|
super().__init__() |
|
|
|
self.num_layers = num_layers |
|
self.t_emb_dim = t_emb_dim |
|
self.context_dim = context_dim |
|
self.cross_attn = cross_attn |
|
|
|
self.resnet_conv_first = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm( |
|
norm_channels, in_channels if i == 0 else out_channels |
|
), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels=(in_channels if i == 0 else out_channels), |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers + 1) |
|
] |
|
) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
self.t_emb_layers = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear( |
|
in_features=self.t_emb_dim, out_features=out_channels |
|
), |
|
) |
|
for i in range(num_layers + 1) |
|
] |
|
) |
|
|
|
self.resnet_conv_second = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers + 1) |
|
] |
|
) |
|
|
|
self.residual_input_conv = nn.ModuleList( |
|
[ |
|
nn.Conv2d( |
|
in_channels=(in_channels if i == 0 else out_channels), |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
for i in range(num_layers + 1) |
|
] |
|
) |
|
|
|
self.attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] |
|
) |
|
|
|
self.attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention( |
|
embed_dim=out_channels, num_heads=num_heads, batch_first=True |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
|
|
if self.cross_attn: |
|
assert ( |
|
context_dim is not None |
|
), "Context Dimension must be passed for cross attention" |
|
|
|
self.cross_attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] |
|
) |
|
|
|
self.cross_attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention( |
|
embed_dim=out_channels, num_heads=num_heads, batch_first=True |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.context_proj = nn.ModuleList( |
|
[ |
|
nn.Linear(in_features=context_dim, out_features=out_channels) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
def forward(self, x, t_emb=None, context=None): |
|
out = x |
|
|
|
|
|
resnet_input = out |
|
out = self.resnet_conv_first[0](out) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
|
|
out = out + self.t_emb_layers[0](t_emb).unsqueeze(dim=-1).unsqueeze( |
|
dim=-1 |
|
) |
|
|
|
out = self.resnet_conv_second[0]( |
|
out |
|
) |
|
|
|
|
|
out = out + self.residual_input_conv[0]( |
|
resnet_input |
|
) |
|
|
|
for i in range(self.num_layers): |
|
|
|
batch_size, channels, h, w = out.shape |
|
|
|
|
|
in_attn = out.reshape( |
|
batch_size, channels, h * w |
|
) |
|
in_attn = self.attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
|
|
|
|
out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) |
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
|
|
|
|
out = out + out_attn |
|
|
|
if self.cross_attn: |
|
assert ( |
|
context is not None |
|
), "context cannot be None if cross attention layers are used" |
|
batch_size, channels, h, w = out.shape |
|
|
|
in_attn = out.reshape( |
|
batch_size, channels, h * w |
|
) |
|
in_attn = self.cross_attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
|
|
assert ( |
|
context.shape[0] == x.shape[0] |
|
and context.shape[-1] == self.context_dim |
|
) |
|
context_proj = self.context_proj[i]( |
|
context |
|
) |
|
|
|
|
|
out_attn, attn_weights = self.cross_attentions[i]( |
|
in_attn, context_proj, context_proj |
|
) |
|
out_attn = out_attn.transpose(1, 2).reshape( |
|
batch_size, channels, h, w |
|
) |
|
|
|
|
|
out = out + out_attn |
|
|
|
|
|
resnet_input = out |
|
out = self.resnet_conv_first[i + 1]( |
|
out |
|
) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
|
|
out = out + self.t_emb_layers[i + 1](t_emb).unsqueeze(dim=-1).unsqueeze( |
|
dim=-1 |
|
) |
|
|
|
out = self.resnet_conv_second[i + 1]( |
|
out |
|
) |
|
|
|
|
|
out = out + self.residual_input_conv[i + 1]( |
|
resnet_input |
|
) |
|
|
|
return out |
|
|
|
|
|
class UpBlock(nn.Module): |
|
""" |
|
Up conv block with attention. |
|
1. Upsample |
|
1. Concatenate Down block output |
|
2. Resnet block with time embedding |
|
3. Attention Block |
|
|
|
in_channels: Number of channels in the input feature map. |
|
out_channels: Number of channels produced by this block. |
|
t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. |
|
up_sample: Whether to apply upsampling at the end. |
|
num_heads: Number of attention heads (used if attention is enabled). |
|
num_layers: How many sub-blocks to apply in sequence. |
|
attn: Whether to apply self-attention |
|
norm_channels: Number of groups for GroupNorm. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
t_emb_dim, |
|
up_sample, |
|
num_heads, |
|
num_layers, |
|
attn, |
|
norm_channels, |
|
): |
|
super().__init__() |
|
|
|
self.num_layers = num_layers |
|
self.up_sample = up_sample |
|
self.t_emb_dim = t_emb_dim |
|
self.attn = attn |
|
|
|
|
|
self.up_sample_conv = ( |
|
nn.ConvTranspose2d( |
|
in_channels=in_channels, |
|
out_channels=in_channels, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
) |
|
if self.up_sample |
|
else nn.Identity() |
|
) |
|
|
|
self.resnet_conv_first = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm( |
|
norm_channels, in_channels if i == 0 else out_channels |
|
), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels=(in_channels if i == 0 else out_channels), |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
self.t_emb_layers = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear( |
|
in_features=self.t_emb_dim, out_features=out_channels |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.resnet_conv_second = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.residual_input_conv = nn.ModuleList( |
|
[ |
|
nn.Conv2d( |
|
in_channels=(in_channels if i == 0 else out_channels), |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
if self.attn: |
|
self.attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] |
|
) |
|
|
|
self.attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention( |
|
embed_dim=out_channels, num_heads=num_heads, batch_first=True |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
def forward(self, x, out_down=None, t_emb=None): |
|
|
|
|
|
|
|
x = self.up_sample_conv( |
|
x |
|
) |
|
|
|
|
|
|
|
if out_down is not None: |
|
x = torch.cat( |
|
[x, out_down], dim=1 |
|
) |
|
|
|
out = x |
|
|
|
for i in range(self.num_layers): |
|
|
|
resnet_input = out |
|
out = self.resnet_conv_first[i]( |
|
out |
|
) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
|
|
out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze( |
|
dim=-1 |
|
) |
|
|
|
out = self.resnet_conv_second[i]( |
|
out |
|
) |
|
|
|
|
|
out = out + self.residual_input_conv[i]( |
|
resnet_input |
|
) |
|
|
|
|
|
if self.attn: |
|
|
|
batch_size, channels, h, w = out.shape |
|
|
|
in_attn = out.reshape( |
|
batch_size, channels, h * w |
|
) |
|
in_attn = self.attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose( |
|
1, 2 |
|
) |
|
|
|
|
|
out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) |
|
out_attn = out_attn.transpose(1, 2).reshape( |
|
batch_size, channels, h, w |
|
) |
|
|
|
|
|
out = out + out_attn |
|
|
|
return out |
|
|
|
|
|
class UpBlockUNet(nn.Module): |
|
""" |
|
Up conv block with attention. |
|
1. Upsample |
|
1. Concatenate Down block output |
|
2. Resnet block with time embedding |
|
3. Attention Block |
|
|
|
in_channels: Number of channels in the input feature map. (It is passed in multiplied by 2 for concatenation with DownBlock output) |
|
out_channels: Number of channels produced by this block. |
|
t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. |
|
up_sample: Whether to apply upsampling at the end. |
|
num_heads: Number of attention heads (used if attention is enabled). |
|
num_layers: How many sub-blocks to apply in sequence. |
|
norm_channels: Number of groups for GroupNorm. |
|
cross_attn: Whether to apply cross-attention. |
|
context_dim: If performing cross-attention, provide a context_dim for extra conditioning context. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
t_emb_dim, |
|
up_sample, |
|
num_heads, |
|
num_layers, |
|
norm_channels, |
|
cross_attn=False, |
|
context_dim=None, |
|
): |
|
super().__init__() |
|
|
|
self.num_layers = num_layers |
|
self.up_sample = up_sample |
|
self.t_emb_dim = t_emb_dim |
|
self.cross_attn = cross_attn |
|
self.context_dim = context_dim |
|
|
|
self.up_sample_conv = ( |
|
nn.ConvTranspose2d( |
|
in_channels=(in_channels // 2), |
|
out_channels=(in_channels // 2), |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
) |
|
if self.up_sample |
|
else nn.Identity() |
|
) |
|
|
|
self.resnet_conv_first = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm( |
|
norm_channels, in_channels if i == 0 else out_channels |
|
), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels=(in_channels if i == 0 else out_channels), |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
|
|
|
|
if self.t_emb_dim is not None: |
|
self.t_emb_layers = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear( |
|
in_features=self.t_emb_dim, out_features=out_channels |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.resnet_conv_second = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.residual_input_conv = nn.ModuleList( |
|
[ |
|
nn.Conv2d( |
|
in_channels=(in_channels if i == 0 else out_channels), |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
for i in range( |
|
num_layers |
|
) |
|
] |
|
) |
|
|
|
self.attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] |
|
) |
|
|
|
self.attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention( |
|
embed_dim=out_channels, num_heads=num_heads, batch_first=True |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
|
|
if self.cross_attn: |
|
assert ( |
|
context_dim is not None |
|
), "Context Dimension must be passed for cross attention" |
|
|
|
self.cross_attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] |
|
) |
|
|
|
self.cross_attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention( |
|
embed_dim=out_channels, num_heads=num_heads, batch_first=True |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
self.context_proj = nn.ModuleList( |
|
[ |
|
nn.Linear(in_features=context_dim, out_features=out_channels) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
def forward(self, x, out_down=None, t_emb=None, context=None): |
|
|
|
|
|
|
|
x = self.up_sample_conv( |
|
x |
|
) |
|
|
|
|
|
if out_down is not None: |
|
x = torch.cat( |
|
[x, out_down], dim=1 |
|
) |
|
|
|
out = x |
|
for i in range(self.num_layers): |
|
|
|
resnet_input = out |
|
|
|
out = self.resnet_conv_first[i]( |
|
out |
|
) |
|
|
|
if self.t_emb_dim is not None: |
|
|
|
out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze( |
|
dim=-1 |
|
) |
|
|
|
out = self.resnet_conv_second[i]( |
|
out |
|
) |
|
|
|
|
|
out = out + self.residual_input_conv[i]( |
|
resnet_input |
|
) |
|
|
|
|
|
batch_size, channels, h, w = ( |
|
out.shape |
|
) |
|
|
|
in_attn = out.reshape( |
|
batch_size, channels, h * w |
|
) |
|
in_attn = self.attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
|
|
|
|
out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) |
|
out_attn = out_attn.transpose(1, 2).reshape( |
|
batch_size, channels, h, w |
|
) |
|
|
|
|
|
out = out + out_attn |
|
|
|
if self.cross_attn: |
|
assert ( |
|
context is not None |
|
), "context cannot be None if cross attention layers are used" |
|
batch_size, channels, h, w = out.shape |
|
|
|
in_attn = out.reshape( |
|
batch_size, channels, h * w |
|
) |
|
in_attn = self.cross_attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose( |
|
1, 2 |
|
) |
|
|
|
assert ( |
|
len(context.shape) == 3 |
|
), "Context shape does not match batch_size, _, context_dim" |
|
|
|
assert ( |
|
context.shape[0] == x.shape[0] |
|
and context.shape[-1] == self.context_dim |
|
), "Context shape does not match batch_size, _, context_dim" |
|
context_proj = self.context_proj[i]( |
|
context |
|
) |
|
|
|
|
|
out_attn, attn_weights = self.cross_attentions[i]( |
|
in_attn, context_proj, context_proj |
|
) |
|
out_attn = out_attn.transpose(1, 2).reshape( |
|
batch_size, channels, h, w |
|
) |
|
|
|
|
|
out = out + out_attn |
|
|
|
return out |
|
|
|
|
|
class VQVAE(nn.Module): |
|
def __init__(self, image_channels, model_config): |
|
super().__init__() |
|
|
|
self.down_channels = model_config["down_channels"] |
|
self.mid_channels = model_config["mid_channels"] |
|
self.down_sample = model_config["down_sample"] |
|
self.num_down_layers = model_config["num_down_layers"] |
|
self.num_mid_layers = model_config["num_mid_layers"] |
|
self.num_up_layers = model_config["num_up_layers"] |
|
|
|
|
|
self.attns = model_config["attn_down"] |
|
|
|
|
|
self.z_channels = model_config[ |
|
"z_channels" |
|
] |
|
self.codebook_size = model_config[ |
|
"codebook_size" |
|
] |
|
self.norm_channels = model_config["norm_channels"] |
|
self.num_heads = model_config["num_heads"] |
|
|
|
assert self.mid_channels[0] == self.down_channels[-1] |
|
assert self.mid_channels[-1] == self.down_channels[-1] |
|
assert len(self.down_sample) == len(self.down_channels) - 1 |
|
assert len(self.attns) == len(self.down_channels) - 1 |
|
|
|
|
|
self.up_sample = list(reversed(self.down_sample)) |
|
|
|
|
|
self.encoder_conv_in = nn.Conv2d( |
|
in_channels=image_channels, |
|
out_channels=self.down_channels[0], |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
|
|
|
|
self.encoder_layers = nn.ModuleList([]) |
|
for i in range(len(self.down_channels) - 1): |
|
self.encoder_layers.append( |
|
DownBlock( |
|
in_channels=self.down_channels[i], |
|
out_channels=self.down_channels[i + 1], |
|
t_emb_dim=None, |
|
down_sample=self.down_sample[i], |
|
num_heads=self.num_heads, |
|
num_layers=self.num_down_layers, |
|
attn=self.attns[i], |
|
norm_channels=self.norm_channels, |
|
) |
|
) |
|
|
|
self.encoder_mids = nn.ModuleList([]) |
|
for i in range(len(self.mid_channels) - 1): |
|
self.encoder_mids.append( |
|
MidBlock( |
|
in_channels=self.mid_channels[i], |
|
out_channels=self.mid_channels[i + 1], |
|
t_emb_dim=None, |
|
num_heads=self.num_heads, |
|
num_layers=self.num_mid_layers, |
|
norm_channels=self.norm_channels, |
|
) |
|
) |
|
|
|
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) |
|
|
|
self.encoder_conv_out = nn.Conv2d( |
|
in_channels=self.down_channels[-1], |
|
out_channels=self.z_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
|
|
|
|
self.pre_quant_conv = nn.Conv2d( |
|
in_channels=self.z_channels, |
|
out_channels=self.z_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
|
|
|
|
self.embedding = nn.Embedding( |
|
self.codebook_size, self.z_channels |
|
) |
|
|
|
|
|
|
|
|
|
self.post_quant_conv = nn.Conv2d( |
|
in_channels=self.z_channels, |
|
out_channels=self.z_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
) |
|
|
|
self.decoder_conv_in = nn.Conv2d( |
|
in_channels=self.z_channels, |
|
out_channels=self.mid_channels[-1], |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
|
|
|
|
self.decoder_mids = nn.ModuleList([]) |
|
for i in reversed(range(1, len(self.mid_channels))): |
|
self.decoder_mids.append( |
|
MidBlock( |
|
in_channels=self.mid_channels[i], |
|
out_channels=self.mid_channels[i - 1], |
|
t_emb_dim=None, |
|
num_heads=self.num_heads, |
|
num_layers=self.num_mid_layers, |
|
norm_channels=self.norm_channels, |
|
) |
|
) |
|
|
|
self.decoder_layers = nn.ModuleList([]) |
|
for i in reversed(range(1, len(self.down_channels))): |
|
self.decoder_layers.append( |
|
UpBlock( |
|
in_channels=self.down_channels[i], |
|
out_channels=self.down_channels[i - 1], |
|
t_emb_dim=None, |
|
up_sample=self.down_sample[i - 1], |
|
num_heads=self.num_heads, |
|
num_layers=self.num_up_layers, |
|
attn=self.attns[i - 1], |
|
norm_channels=self.norm_channels, |
|
) |
|
) |
|
|
|
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) |
|
|
|
self.decoder_conv_out = nn.Conv2d( |
|
in_channels=self.down_channels[0], |
|
out_channels=image_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
|
|
def quantize(self, x): |
|
batch_size, c, h, w = x.shape |
|
|
|
x = x.permute( |
|
0, 2, 3, 1 |
|
) |
|
x = x.reshape( |
|
batch_size, -1, c |
|
) |
|
|
|
|
|
dist = torch.cdist( |
|
x, self.embedding.weight.unsqueeze(dim=0).repeat((batch_size, 1, 1)) |
|
) |
|
|
|
|
|
min_encoding_indices = torch.argmin(dist, dim=-1) |
|
|
|
|
|
quant_out = torch.index_select( |
|
self.embedding.weight, 0, min_encoding_indices.view(-1) |
|
) |
|
|
|
x = x.reshape((-1, c)) |
|
|
|
|
|
commitment_loss = torch.mean((quant_out.detach() - x) ** 2) |
|
codebook_loss = torch.mean((quant_out - x.detach()) ** 2) |
|
|
|
quantize_losses = { |
|
"codebook_loss": codebook_loss, |
|
"commitment_loss": commitment_loss, |
|
} |
|
|
|
|
|
quant_out = x + (quant_out - x).detach() |
|
|
|
quant_out = quant_out.reshape(batch_size, h, w, c).permute( |
|
0, 3, 1, 2 |
|
) |
|
min_encoding_indices = min_encoding_indices.reshape( |
|
(-1, h, w) |
|
) |
|
|
|
return quant_out, quantize_losses, min_encoding_indices |
|
|
|
def encode(self, x): |
|
out = self.encoder_conv_in(x) |
|
|
|
|
|
for idx, down in enumerate(self.encoder_layers): |
|
out = down(out) |
|
|
|
|
|
for mid in self.encoder_mids: |
|
out = mid(out) |
|
|
|
out = self.encoder_norm_out(out) |
|
out = F.silu(out) |
|
|
|
out = self.encoder_conv_out( |
|
out |
|
) |
|
out = self.pre_quant_conv( |
|
out |
|
) |
|
|
|
out, quant_losses, min_encoding_indices = self.quantize( |
|
out |
|
) |
|
return out, quant_losses |
|
|
|
def decode(self, z): |
|
out = z |
|
out = self.post_quant_conv( |
|
out |
|
) |
|
out = self.decoder_conv_in( |
|
out |
|
) |
|
|
|
|
|
for mid in self.decoder_mids: |
|
out = mid(out) |
|
|
|
|
|
for idx, up in enumerate(self.decoder_layers): |
|
out = up(out) |
|
|
|
out = self.decoder_norm_out(out) |
|
out = F.silu(out) |
|
|
|
out = self.decoder_conv_out( |
|
out |
|
) |
|
return out |
|
|
|
def forward(self, x): |
|
|
|
|
|
z, quant_losses = self.encode( |
|
x |
|
) |
|
out = self.decode(z) |
|
|
|
return out, z, quant_losses |
|
|
|
|
|
def validate_image_conditional_input(cond_input, x): |
|
assert ( |
|
"image" in cond_input |
|
), "Model initialized with image conditioning but cond_input has no image information" |
|
assert ( |
|
cond_input["image"].shape[0] == x.shape[0] |
|
), "Batch size mismatch of image condition and input" |
|
assert ( |
|
cond_input["image"].shape[2] % x.shape[2] == 0 |
|
), "Height/Width of image condition must be divisible by latent input" |
|
|
|
|
|
def validate_class_conditional_input(cond_input, x, num_classes): |
|
assert ( |
|
"class" in cond_input |
|
), "Model initialized with class conditioning but cond_input has no class information" |
|
assert cond_input["class"].shape == ( |
|
x.shape[0], |
|
num_classes, |
|
), "Shape of class condition input must match (Batch Size, )" |
|
|
|
|
|
def get_config_value(config, key, default_value): |
|
return config[key] if key in config else default_value |
|
|
|
|
|
class UNet(nn.Module): |
|
""" |
|
Unet model comprising |
|
Down blocks, Midblocks and Uplocks |
|
""" |
|
|
|
def __init__(self, image_channels, model_config): |
|
super().__init__() |
|
|
|
self.down_channels = model_config["down_channels"] |
|
self.mid_channels = model_config["mid_channels"] |
|
self.t_emb_dim = model_config["time_emb_dim"] |
|
self.down_sample = model_config["down_sample"] |
|
self.num_down_layers = model_config["num_down_layers"] |
|
self.num_mid_layers = model_config["num_mid_layers"] |
|
self.num_up_layers = model_config["num_up_layers"] |
|
self.attns = model_config["attn_down"] |
|
self.norm_channels = model_config["norm_channels"] |
|
self.num_heads = model_config["num_heads"] |
|
self.conv_out_channels = model_config["conv_out_channels"] |
|
|
|
assert self.mid_channels[0] == self.down_channels[-1] |
|
assert self.mid_channels[-1] == self.down_channels[-2] |
|
assert len(self.down_sample) == len(self.down_channels) - 1 |
|
assert len(self.attns) == len(self.down_channels) - 1 |
|
|
|
|
|
self.class_cond = False |
|
self.text_cond = False |
|
self.image_cond = False |
|
self.text_embed_dim = None |
|
self.condition_config = get_config_value( |
|
model_config, "condition_config", None |
|
) |
|
|
|
if self.condition_config is not None: |
|
assert ( |
|
"condition_types" in self.condition_config |
|
), "Condition Type not provided in model config" |
|
condition_types = self.condition_config["condition_types"] |
|
|
|
|
|
if "class" in condition_types: |
|
self.class_cond = True |
|
self.num_classes = self.condition_config["class_condition_config"][ |
|
"num_classes" |
|
] |
|
|
|
if "text" in condition_types: |
|
self.text_cond = True |
|
self.text_embed_dim = self.condition_config["text_condition_config"][ |
|
"text_embed_dim" |
|
] |
|
|
|
if "image" in condition_types: |
|
self.image_cond = True |
|
self.image_cond_input_channels = self.condition_config[ |
|
"image_condition_config" |
|
]["image_condition_input_channels"] |
|
self.image_cond_output_channels = self.condition_config[ |
|
"image_condition_config" |
|
]["image_condition_output_channels"] |
|
|
|
if self.class_cond: |
|
|
|
self.class_emb = nn.Embedding( |
|
self.num_classes, self.t_emb_dim |
|
) |
|
|
|
if self.image_cond: |
|
|
|
self.cond_conv_in = nn.Conv2d( |
|
in_channels=self.image_cond_input_channels, |
|
out_channels=self.image_cond_output_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
) |
|
|
|
self.conv_in_concat = nn.Conv2d( |
|
in_channels=(image_channels + self.image_cond_output_channels), |
|
out_channels=self.down_channels[0], |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
else: |
|
self.conv_in = nn.Conv2d( |
|
in_channels=image_channels, |
|
out_channels=self.down_channels[0], |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
|
|
self.cond = self.text_cond or self.image_cond or self.class_cond |
|
|
|
|
|
self.t_proj = nn.Sequential( |
|
nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim), |
|
nn.SiLU(), |
|
nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim), |
|
) |
|
|
|
self.up_sample = list(reversed(self.down_sample)) |
|
|
|
self.downs = nn.ModuleList([]) |
|
for i in range(len(self.down_channels) - 1): |
|
|
|
self.downs.append( |
|
DownBlock( |
|
in_channels=self.down_channels[i], |
|
out_channels=self.down_channels[i + 1], |
|
t_emb_dim=self.t_emb_dim, |
|
down_sample=self.down_sample[i], |
|
num_heads=self.num_heads, |
|
num_layers=self.num_down_layers, |
|
attn=self.attns[i], |
|
norm_channels=self.norm_channels, |
|
cross_attn=self.text_cond, |
|
context_dim=self.text_embed_dim, |
|
) |
|
) |
|
|
|
self.mids = nn.ModuleList([]) |
|
for i in range(len(self.mid_channels) - 1): |
|
|
|
self.mids.append( |
|
MidBlock( |
|
in_channels=self.mid_channels[i], |
|
out_channels=self.mid_channels[i + 1], |
|
t_emb_dim=self.t_emb_dim, |
|
num_heads=self.num_heads, |
|
num_layers=self.num_mid_layers, |
|
norm_channels=self.norm_channels, |
|
cross_attn=self.text_cond, |
|
context_dim=self.text_embed_dim, |
|
) |
|
) |
|
|
|
self.ups = nn.ModuleList([]) |
|
for i in reversed(range(len(self.down_channels) - 1)): |
|
|
|
self.ups.append( |
|
UpBlockUNet( |
|
in_channels=(self.down_channels[i] * 2), |
|
out_channels=( |
|
self.down_channels[i - 1] if i != 0 else self.conv_out_channels |
|
), |
|
t_emb_dim=self.t_emb_dim, |
|
up_sample=self.down_sample[i], |
|
num_heads=self.num_heads, |
|
num_layers=self.num_up_layers, |
|
norm_channels=self.norm_channels, |
|
cross_attn=self.text_cond, |
|
context_dim=self.text_embed_dim, |
|
) |
|
) |
|
|
|
self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) |
|
|
|
self.conv_out = nn.Conv2d( |
|
in_channels=self.conv_out_channels, |
|
out_channels=image_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
|
|
def forward(self, x, t, cond_input=None): |
|
|
|
|
|
|
|
|
|
if self.cond: |
|
assert ( |
|
cond_input is not None |
|
), "Model initialized with conditioning so cond_input cannot be None" |
|
|
|
if self.image_cond: |
|
|
|
validate_image_conditional_input(cond_input, x) |
|
image_cond = cond_input["image"] |
|
image_cond = F.interpolate(image_cond, size=x.shape[-2:]) |
|
image_cond = self.cond_conv_in(image_cond) |
|
assert image_cond.shape[-2:] == x.shape[-2:] |
|
|
|
x = torch.cat( |
|
[x, image_cond], dim=1 |
|
) |
|
out = self.conv_in_concat(x) |
|
else: |
|
out = self.conv_in(x) |
|
|
|
t_emb = get_time_embedding( |
|
torch.as_tensor(t).long(), self.t_emb_dim |
|
) |
|
t_emb = self.t_proj(t_emb) |
|
|
|
|
|
if self.class_cond: |
|
validate_class_conditional_input(cond_input, x, self.num_classes) |
|
|
|
|
|
class_embed = torch.matmul( |
|
cond_input["class"].float(), self.class_emb.weight |
|
) |
|
t_emb += class_embed |
|
|
|
context_hidden_states = None |
|
|
|
|
|
if self.text_cond: |
|
assert ( |
|
"text" in cond_input |
|
), "Model initialized with text conditioning but cond_input has no text information" |
|
context_hidden_states = cond_input["text"] |
|
|
|
down_outs = [] |
|
for idx, down in enumerate(self.downs): |
|
down_outs.append(out) |
|
out = down( |
|
out, t_emb, context_hidden_states |
|
) |
|
|
|
|
|
for mid in self.mids: |
|
out = mid(out, t_emb, context_hidden_states) |
|
|
|
|
|
for up in self.ups: |
|
down_out = down_outs.pop() |
|
out = up(out, down_out, t_emb, context_hidden_states) |
|
|
|
|
|
out = F.silu(self.norm_out(out)) |
|
out = self.conv_out( |
|
out |
|
) |
|
|
|
return out |
|
|
|
|
|
def sample_ddpm_inference( |
|
text_prompt, mask_image_pil=None, guidance_scale=1.0, device=torch.device("cpu") |
|
): |
|
""" |
|
Given a text prompt and (optionally) an image condition (as a PIL image), |
|
sample from the diffusion model and return a generated image (PIL image). |
|
""" |
|
|
|
scheduler = LinearNoiseScheduler( |
|
num_timesteps=diffusion_params["num_timesteps"], |
|
beta_start=diffusion_params["beta_start"], |
|
beta_end=diffusion_params["beta_end"], |
|
) |
|
|
|
condition_config = ldm_params.get("condition_config", None) |
|
condition_types = ( |
|
condition_config.get("condition_types", []) |
|
if condition_config is not None |
|
else [] |
|
) |
|
|
|
|
|
text_model_type = condition_config["text_condition_config"]["text_embed_model"] |
|
text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device) |
|
|
|
|
|
empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device) |
|
|
|
|
|
text_prompt_embed = get_text_representation( |
|
[text_prompt], text_tokenizer, text_model, device |
|
) |
|
|
|
|
|
|
|
if "image" in condition_types: |
|
if mask_image_pil is not None: |
|
mask_transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
( |
|
ldm_params["condition_config"]["image_condition_config"][ |
|
"image_condition_h" |
|
], |
|
ldm_params["condition_config"]["image_condition_config"][ |
|
"image_condition_w" |
|
], |
|
) |
|
), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
mask_tensor = ( |
|
mask_transform(mask_image_pil).unsqueeze(0).to(device) |
|
) |
|
else: |
|
|
|
ic = ldm_params["condition_config"]["image_condition_config"][ |
|
"image_condition_input_channels" |
|
] |
|
H = ldm_params["condition_config"]["image_condition_config"][ |
|
"image_condition_h" |
|
] |
|
W = ldm_params["condition_config"]["image_condition_config"][ |
|
"image_condition_w" |
|
] |
|
mask_tensor = torch.zeros((1, ic, H, W), device=device) |
|
else: |
|
mask_tensor = None |
|
|
|
|
|
|
|
uncond_input = {} |
|
cond_input = {} |
|
if "text" in condition_types: |
|
uncond_input["text"] = empty_text_embed |
|
cond_input["text"] = text_prompt_embed |
|
if "image" in condition_types: |
|
|
|
uncond_input["image"] = torch.zeros_like(mask_tensor) |
|
cond_input["image"] = mask_tensor |
|
|
|
|
|
unet = UNet( |
|
image_channels=autoencoder_params["z_channels"], model_config=ldm_params |
|
).to(device) |
|
ldm_checkpoint_path = os.path.join( |
|
train_params["task_name"], train_params["ldm_ckpt_name"] |
|
) |
|
if os.path.exists(ldm_checkpoint_path): |
|
checkpoint = torch.load(ldm_checkpoint_path, map_location=device) |
|
unet.load_state_dict(checkpoint["model_state_dict"]) |
|
unet.eval() |
|
|
|
|
|
vae = VQVAE( |
|
image_channels=dataset_params["image_channels"], model_config=autoencoder_params |
|
).to(device) |
|
vae_checkpoint_path = os.path.join( |
|
train_params["task_name"], train_params["vqvae_autoencoder_ckpt_name"] |
|
) |
|
if os.path.exists(vae_checkpoint_path): |
|
checkpoint = torch.load(vae_checkpoint_path, map_location=device) |
|
vae.load_state_dict(checkpoint["model_state_dict"]) |
|
vae.eval() |
|
|
|
|
|
|
|
latent_size = dataset_params["image_size"] // ( |
|
2 ** sum(autoencoder_params["down_sample"]) |
|
) |
|
batch = train_params["num_samples"] |
|
z_channels = autoencoder_params["z_channels"] |
|
|
|
|
|
xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device) |
|
|
|
|
|
T = diffusion_params["num_timesteps"] |
|
for i in reversed(range(T)): |
|
t = torch.full((batch,), i, dtype=torch.long, device=device) |
|
|
|
noise_pred_cond = unet(xt, t, cond_input) |
|
if guidance_scale > 1: |
|
noise_pred_uncond = unet(xt, t, uncond_input) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_cond - noise_pred_uncond |
|
) |
|
else: |
|
noise_pred = noise_pred_cond |
|
xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t) |
|
|
|
|
|
with torch.no_grad(): |
|
generated = vae.decode(xt) |
|
generated = torch.clamp(generated, -1, 1) |
|
generated = (generated + 1) / 2 |
|
grid = make_grid(generated, nrow=1) |
|
pil_img = transforms.ToPILImage()(grid.cpu()) |
|
return pil_img |
|
|