Vaani-Audio2Img-LDM / Vaani /_7.1_Vaani-LDM-High-Pre.py
alpha31476's picture
SDFT
cb656a6 verified
# ==================================================================
# L A T E N T D I F F U S I O N M O D E L
# ==================================================================
# Author : Ashish Kumar Uchadiya
# Created : May 11, 2025
# Description: This script implements the training of a VQ-VAE model for
# image reconstruction, integrated with Latent Diffusion Models (LDMs) and
# audio conditioning. The VQ-VAE maps images to a discrete latent space,
# which is then modeled by the LDM for learning a diffusion process over the
# compressed representation. Audio features are used as conditioning inputs
# to guide the generation process. The training minimizes a combination of
# LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual
# fidelity and PatchGAN loss to enforce local realism. This setup enables
# efficient and semantically-aware generation of high-quality images driven
# by audio cues.
# ==================================================================
# I M P O R T S
# ==================================================================
import os
import torch
import torch.nn as nn
import numpy as np
from collections import namedtuple
import pandas as pd
import torchvision as tv
from torchvision.transforms import v2
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import re
import glob
import sys
import yaml
import random
import datetime
import torch.hub
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
print("TIME:", datetime.datetime.now())
# os.environ["CUDA_VISIBLE_DEVICES"] = f"{sys.argv[2]}"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("DEVICE:", device)
# ==================================================================
# H E L P E R S
# ==================================================================
from typing import Any
from argparse import Namespace
import typing
class DotDict(Namespace):
"""A simple class that builds upon `argparse.Namespace`
in order to make chained attributes possible."""
def __init__(self, temp=False, key=None, parent=None) -> None:
self._temp = temp
self._key = key
self._parent = parent
def __eq__(self, other):
if not isinstance(other, DotDict):
return NotImplemented
return vars(self) == vars(other)
def __getattr__(self, __name: str) -> Any:
if __name not in self.__dict__ and not self._temp:
self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self)
else:
del self._parent.__dict__[self._key]
raise AttributeError("No attribute '%s'" % __name)
return self.__dict__[__name]
def __repr__(self) -> str:
item_keys = [k for k in self.__dict__ if not k.startswith("_")]
if len(item_keys) == 0:
return "DotDict()"
elif len(item_keys) == 1:
key = item_keys[0]
val = self.__dict__[key]
return "DotDict(%s=%s)" % (key, repr(val))
else:
return "DotDict(%s)" % ", ".join(
"%s=%s" % (key, repr(val)) for key, val in self.__dict__.items()
)
@classmethod
def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict":
"""Create a DotDict from a (possibly nested) dict `original`.
Warning: this method should not be used on very deeply nested inputs,
since it's recursively traversing the nested dictionary values.
"""
dd = DotDict()
for key, value in original.items():
if isinstance(value, typing.Mapping):
value = cls.from_dict(value)
setattr(dd, key, value)
return dd
# ==================================================================
# L P I P S
# ==================================================================
class vgg16(nn.Module):
def __init__(self):
super(vgg16, self).__init__()
vgg_pretrained_features = tv.models.vgg16(
weights=tv.models.VGG16_Weights.IMAGENET1K_V1
).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h1 = self.slice1(X)
h2 = self.slice2(h1)
h3 = self.slice3(h2)
h4 = self.slice4(h3)
h5 = self.slice5(h4)
vgg_outputs = namedtuple("VggOutputs", ['h1', 'h2', 'h3', 'h4', 'h5'])
out = vgg_outputs(h1, h2, h3, h4, h5)
return out
def _spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
def _normalize_tensor(in_feat, eps= 1e-8):
norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / norm_factor
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
# Imagnet normalization for (0-1)
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if (use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class LPIPS(nn.Module):
def __init__(self, net='vgg', version='0.1', use_dropout=True):
super(LPIPS, self).__init__()
self.version = version
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512]
self.L = len(self.chns)
self.net = vgg16()
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = nn.ModuleList([self.lin0, self.lin1, self.lin2, self.lin3, self.lin4])
# --- Orignal url --------------------
# weights_url = f"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v{version}/{net}.pth"
# --- Orignal Forked url -------------
weights_url = f"https://github.com/akuresonite/PerceptualSimilarity-Forked/raw/master/lpips/weights/v{version}/{net}.pth"
# --- Orignal torchmetric url --------
# weights_url = "https://github.com/Lightning-AI/torchmetrics/raw/master/src/torchmetrics/functional/image/lpips_models/vgg.pth"
state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu')
self.load_state_dict(state_dict, strict=False)
self.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, in0, in1, normalize=False):
# Scale the inputs to -1 to +1 range if input in [0,1]
if normalize:
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
# in0_input, in1_input = in0, in1
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
diffs = {}
for kk in range(self.L):
feats0 = _normalize_tensor(outs0[kk])
feats1 = _normalize_tensor(outs1[kk])
diffs[kk] = (feats0 - feats1) ** 2
res = [_spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
val = sum(res)
return val.reshape(-1)
# ==================================================================
# P A T C H - G A N - D I S C R I M I N A T O R
# ==================================================================
class Discriminator(nn.Module):
r"""
PatchGAN Discriminator.
Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to
1 scalar value , we instead predict grid of values.
Where each grid is prediction of how likely
the discriminator thinks that the image patch corresponding
to the grid cell is real
"""
def __init__(
self,
im_channels=3,
conv_channels=[64, 128, 256],
kernels=[4, 4, 4, 4],
strides=[2, 2, 2, 1],
paddings=[1, 1, 1, 1],
):
super().__init__()
self.im_channels = im_channels
activation = nn.LeakyReLU(0.2)
layers_dim = [self.im_channels] + conv_channels + [1]
self.layers = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(
layers_dim[i],
layers_dim[i + 1],
kernel_size=kernels[i],
stride=strides[i],
padding=paddings[i],
bias=False if i != 0 else True,
),
(
nn.BatchNorm2d(layers_dim[i + 1])
if i != len(layers_dim) - 2 and i != 0
else nn.Identity()
),
activation if i != len(layers_dim) - 2 else nn.Identity(),
)
for i in range(len(layers_dim) - 1)
]
)
def forward(self, x):
out = x
for layer in self.layers:
out = layer(out)
return out
# ==================================================================
# D O W E - B L O C K
# ==================================================================
class DownBlock(nn.Module):
r"""
Down conv block with attention.
Sequence of following block
1. Resnet block with time embedding
2. Attention block
3. Downsample
"""
def __init__(
self,
in_channels,
out_channels,
t_emb_dim,
down_sample,
num_heads,
num_layers,
attn,
norm_channels,
cross_attn=False,
context_dim=None,
):
super().__init__()
self.num_layers = num_layers
self.down_sample = down_sample
self.attn = attn
self.context_dim = context_dim
self.cross_attn = cross_attn
self.t_emb_dim = t_emb_dim
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(
in_channels if i == 0 else out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels))
for _ in range(num_layers)
]
)
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
if self.attn:
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
if self.cross_attn:
assert context_dim is not None, "Context Dimension must be passed for cross attention"
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels) for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.down_sample_conv = (
nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity()
)
def forward(self, x, t_emb=None, context=None):
out = x
for i in range(self.num_layers):
# Resnet block of Unet
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
if self.attn:
# Attention block of Unet
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
if self.cross_attn:
assert (
context is not None
), "context cannot be None if cross attention layers are used"
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Downsample
out = self.down_sample_conv(out)
return out
# ==================================================================
# M I D - B L O C K
# ==================================================================
class MidBlock(nn.Module):
r"""
Mid conv block with attention.
Sequence of following blocks
1. Resnet block with time embedding
2. Attention block
3. Resnet block with time embedding
"""
def __init__(
self,
in_channels,
out_channels,
t_emb_dim,
num_heads,
num_layers,
norm_channels,
cross_attn=None,
context_dim=None,
):
super().__init__()
self.num_layers = num_layers
self.t_emb_dim = t_emb_dim
self.context_dim = context_dim
self.cross_attn = cross_attn
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(
in_channels if i == 0 else out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(num_layers + 1)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
for _ in range(num_layers + 1)
]
)
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers + 1)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
if self.cross_attn:
assert context_dim is not None, "Context Dimension must be passed for cross attention"
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels) for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers + 1)
]
)
def forward(self, x, t_emb=None, context=None):
out = x
# First resnet block
resnet_input = out
out = self.resnet_conv_first[0](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
out = self.resnet_conv_second[0](out)
out = out + self.residual_input_conv[0](resnet_input)
for i in range(self.num_layers):
# Attention Block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
if self.cross_attn:
assert (
context is not None
), "context cannot be None if cross attention layers are used"
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Resnet Block
resnet_input = out
out = self.resnet_conv_first[i + 1](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i + 1](out)
out = out + self.residual_input_conv[i + 1](resnet_input)
return out
# ==================================================================
# U P - B L O C K
# ==================================================================
class UpBlock(nn.Module):
r"""
Up conv block with attention.
Sequence of following blocks
1. Upsample
1. Concatenate Down block output
2. Resnet block with time embedding
3. Attention Block
"""
def __init__(
self,
in_channels,
out_channels,
t_emb_dim,
up_sample,
num_heads,
num_layers,
attn,
norm_channels,
):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.t_emb_dim = t_emb_dim
self.attn = attn
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(
in_channels if i == 0 else out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))
for _ in range(num_layers)
]
)
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
if self.attn:
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.up_sample_conv = (
nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1)
if self.up_sample
else nn.Identity()
)
def forward(self, x, out_down=None, t_emb=None):
# Upsample
x = self.up_sample_conv(x)
# Concat with Downblock output
if out_down is not None:
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
# Resnet Block
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
# Self Attention
if self.attn:
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out
# ==================================================================
# V Q - V A E
# ==================================================================
class VQVAE(nn.Module):
def __init__(self, im_channels, model_config):
super().__init__()
self.down_channels = model_config.down_channels
self.mid_channels = model_config.mid_channels
self.down_sample = model_config.down_sample
self.num_down_layers = model_config.num_down_layers
self.num_mid_layers = model_config.num_mid_layers
self.num_up_layers = model_config.num_up_layers
# To disable attention in Downblock of Encoder and Upblock of Decoder
self.attns = model_config.attn_down
# Latent Dimension
self.z_channels = model_config.z_channels
self.codebook_size = model_config.codebook_size
self.norm_channels = model_config.norm_channels
self.num_heads = model_config.num_heads
# Assertion to validate the channel information
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-1]
assert len(self.down_sample) == len(self.down_channels) - 1
assert len(self.attns) == len(self.down_channels) - 1
# Wherever we use downsampling in encoder correspondingly use
# upsampling in decoder
self.up_sample = list(reversed(self.down_sample))
##################### Encoder ######################
self.encoder_conv_in = nn.Conv2d(
im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)
)
# Downblock + Midblock
self.encoder_layers = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
self.encoder_layers.append(
DownBlock(
self.down_channels[i],
self.down_channels[i + 1],
t_emb_dim=None,
down_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_down_layers,
attn=self.attns[i],
norm_channels=self.norm_channels,
)
)
self.encoder_mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.encoder_mids.append(
MidBlock(
self.mid_channels[i],
self.mid_channels[i + 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels,
)
)
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
self.encoder_conv_out = nn.Conv2d(
self.down_channels[-1], self.z_channels, kernel_size=3, padding=1
)
# Pre Quantization Convolution
self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
# Codebook
self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
####################################################
##################### Decoder ######################
# Post Quantization Convolution
self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
self.decoder_conv_in = nn.Conv2d(
self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)
)
# Midblock + Upblock
self.decoder_mids = nn.ModuleList([])
for i in reversed(range(1, len(self.mid_channels))):
self.decoder_mids.append(
MidBlock(
self.mid_channels[i],
self.mid_channels[i - 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels,
)
)
self.decoder_layers = nn.ModuleList([])
for i in reversed(range(1, len(self.down_channels))):
self.decoder_layers.append(
UpBlock(
self.down_channels[i],
self.down_channels[i - 1],
t_emb_dim=None,
up_sample=self.down_sample[i - 1],
num_heads=self.num_heads,
num_layers=self.num_up_layers,
attn=self.attns[i - 1],
norm_channels=self.norm_channels,
)
)
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
self.decoder_conv_out = nn.Conv2d(
self.down_channels[0], im_channels, kernel_size=3, padding=1
)
def quantize(self, x):
B, C, H, W = x.shape
# B, C, H, W -> B, H, W, C
x = x.permute(0, 2, 3, 1)
# B, H, W, C -> B, H*W, C
x = x.reshape(x.size(0), -1, x.size(-1))
# Find nearest embedding/codebook vector
# dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
# (B, H*W)
min_encoding_indices = torch.argmin(dist, dim=-1)
# Replace encoder output with nearest codebook
# quant_out -> B*H*W, C
quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
# x -> B*H*W, C
x = x.reshape((-1, x.size(-1)))
commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
quantize_losses = {"codebook_loss": codebook_loss, "commitment_loss": commmitment_loss}
# Straight through estimation
quant_out = x + (quant_out - x).detach()
# quant_out -> B, C, H, W
quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
min_encoding_indices = min_encoding_indices.reshape(
(-1, quant_out.size(-2), quant_out.size(-1))
)
return quant_out, quantize_losses, min_encoding_indices
def encode(self, x):
out = self.encoder_conv_in(x)
for idx, down in enumerate(self.encoder_layers):
out = down(out)
for mid in self.encoder_mids:
out = mid(out)
out = self.encoder_norm_out(out)
out = nn.SiLU()(out)
out = self.encoder_conv_out(out)
out = self.pre_quant_conv(out)
out, quant_losses, _ = self.quantize(out)
return out, quant_losses
def decode(self, z):
out = z
out = self.post_quant_conv(out)
out = self.decoder_conv_in(out)
for mid in self.decoder_mids:
out = mid(out)
for idx, up in enumerate(self.decoder_layers):
out = up(out)
out = self.decoder_norm_out(out)
out = nn.SiLU()(out)
out = self.decoder_conv_out(out)
return out
def forward(self, x):
'''out: [B, 3, 256, 256]
z: [B, 3, 64, 64]
quant_losses: {
codebook_loss: 0.0681,
commitment_loss: 0.0681
}
'''
z, quant_losses = self.encode(x)
out = self.decode(z)
return out, z, quant_losses
# ==================================================================
# C O N F I G U R A T I O N
# ==================================================================
import pprint
config_path = "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/config-LDM-High-Pre.yaml"
# config_path = sys.argv[1]
with open(config_path, 'r') as file:
Config = yaml.safe_load(file)
pprint.pprint(Config, width=120)
Config = DotDict.from_dict(Config)
dataset_config = Config.dataset_params
diffusion_config = Config.diffusion_params
model_config = Config.model_params
train_config = Config.train_params
paths = Config.paths
# ==================================================================
# V A A N I - D A T A S E T
# ==================================================================
IMAGES_PATH = paths.images_dir
def walkDIR(folder_path, include=None):
file_list = []
for root, _, files in os.walk(folder_path):
for file in files:
if include is None or any(file.endswith(ext) for ext in include):
file_list.append(os.path.join(root, file))
print("Files found:", len(file_list))
return file_list
files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg'])
df = pd.DataFrame(files, columns=['image_path'])
class VaaniDataset(torch.utils.data.Dataset):
def __init__(self, files_paths, im_size):
self.files_paths = files_paths
self.im_size = im_size
def __len__(self):
return len(self.files_paths)
def __getitem__(self, idx):
image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)
# image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)
image = v2.Resize((self.im_size,self.im_size))(image)
image = v2.ToDtype(torch.float32, scale=True)(image)
# image = 2*image - 1
return image
dataset = VaaniDataset(files_paths=files, im_size=dataset_config.im_size)
image = dataset[2]
print('IMAGE SHAPE:', image.shape)
if train_config.debug:
s = 0.001
dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42))
print("Length of Train dataset:", len(dataset))
if sys.argv[1] == "train_vae":
BATCH_SIZE = train_config.autoencoder_batch_size
elif sys.argv[1] == "train_ldm":
BATCH_SIZE = train_config.ldm_batch_size
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=48,
pin_memory=True,
drop_last=True,
persistent_workers=True
)
images = next(iter(dataloader))
print('BATCH SHAPE:', images.shape)
# ==================================================================
# M O D E L - I N I T I L I Z A T I O N
# ==================================================================
dataset_config = Config.dataset_params
autoencoder_config = Config.autoencoder_params
train_config = Config.train_params
# model = VQVAE(im_channels=dataset_config.im_channels,
# model_config=autoencoder_config).to(device)
# model_output = model(images.to(device))
# print('MODEL OUTPUT:')
# print(model_output[0].shape, model_output[1].shape, model_output[2])
# from torchinfo import summary
# summary(model=model,
# input_data=images.to(device),
# # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),
# col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"],
# col_width=20,
# row_settings=["var_names"],
# depth = 6,
# # device=device
# )
# exit()
# ==================================================================
# V Q - V A E - T R A I N I N G
# ==================================================================
# python your_script.py 2>&1 > training.log
import time
def format_time(t1, t2):
elapsed_time = t2 - t1
if elapsed_time < 60:
return f"{elapsed_time:.2f} seconds"
elif elapsed_time < 3600:
minutes = elapsed_time // 60
seconds = elapsed_time % 60
return f"{minutes:.0f} minutes {seconds:.2f} seconds"
elif elapsed_time < 86400:
hours = elapsed_time // 3600
remainder = elapsed_time % 3600
minutes = remainder // 60
seconds = remainder % 60
return f"{hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds"
else:
days = elapsed_time // 86400
remainder = elapsed_time % 86400
hours = remainder // 3600
remainder = remainder % 3600
minutes = remainder // 60
seconds = remainder % 60
return f"{days:.0f} days {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds"
def find_checkpoints(checkpoint_path):
directory = os.path.dirname(checkpoint_path)
prefix = os.path.basename(checkpoint_path)
pattern = re.compile(rf"{re.escape(prefix)}_epoch(\d+)\.pt$")
try:
files = os.listdir(directory)
except FileNotFoundError:
return []
return [
os.path.join(directory, f)
for f in files if pattern.match(f)
]
def save_vae_checkpoint(
total_steps, epoch, model, discriminator, optimizer_d,
optimizer_g, metrics, checkpoint_path, logs, total_training_time
):
checkpoint = {
"total_steps": total_steps,
"epoch": epoch,
"model_state_dict": model.state_dict(),
"discriminator_state_dict": discriminator.state_dict(),
"optimizer_d_state_dict": optimizer_d.state_dict(),
"optimizer_g_state_dict": optimizer_g.state_dict(),
"metrics": metrics,
"logs": logs,
"total_training_time": total_training_time
}
checkpoint_file = f"{checkpoint_path}_epoch{epoch}.pt"
torch.save(checkpoint, checkpoint_file)
print(f"VQVAE Checkpoint saved at {checkpoint_file}")
all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt")
# all_ckpts = find_checkpoints(checkpoint_path)
def extract_epoch(filename):
match = re.search(r"_epoch(\d+)\.pt", filename)
return int(match.group(1)) if match else -1
all_ckpts = sorted(all_ckpts, key=extract_epoch)
for old_ckpt in all_ckpts[:-2]:
os.remove(old_ckpt)
print(f"Removed old VQVAE checkpoint: {old_ckpt}")
def load_vae_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g, device=device):
all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt")
# all_ckpts = find_checkpoints(checkpoint_path)
if not all_ckpts:
print("No VQVAE checkpoint found. Starting from scratch.")
return 0, 0, None, [], 0
def extract_epoch(filename):
match = re.search(r"_epoch(\d+)\.pt", filename)
return int(match.group(1)) if match else -1
all_ckpts = sorted(all_ckpts, key=extract_epoch)
latest_ckpt = all_ckpts[-1]
if os.path.exists(latest_ckpt):
checkpoint = torch.load(latest_ckpt, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"])
optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"])
total_steps = checkpoint["total_steps"]
epoch = checkpoint["epoch"]
metrics = checkpoint["metrics"]
logs = checkpoint.get("logs", [])
total_training_time = checkpoint.get("total_training_time", 0)
print(f"VQVAE Checkpoint loaded from {latest_ckpt}. Resuming from epoch {epoch + 1}, step {total_steps}")
return total_steps, epoch + 1, metrics, logs, total_training_time
else:
print("No VQVAE checkpoint found. Starting from scratch.")
return 0, 0, None, [], 0
from PIL import Image
def inference(model, dataset, save_path, epoch, device="cuda", sample_size=8):
if not os.path.exists(save_path):
os.makedirs(save_path)
image_tensors = []
for i in range(sample_size):
image_tensors.append(dataset[i].unsqueeze(0))
image_tensors = torch.cat(image_tensors, dim=0).to(device)
with torch.no_grad():
outputs, _, _ = model(image_tensors)
save_input = image_tensors.detach().cpu()
save_output = outputs.detach().cpu()
grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
np_img = (grid * 255).byte().numpy().transpose(1, 2, 0)
combined_image = Image.fromarray(np_img)
combined_image.save("output_image.png")
# combined_image = tv.transforms.ToPILImage()(grid)
combined_image.save(os.path.join(save_path, f"reconstructed_images_EP-{epoch}_{sample_size}.png"))
print(f"Reconstructed images saved at: {save_path}")
def trainVAE(Config, dataloader):
dataset_config = Config.dataset_params
autoencoder_config = Config.autoencoder_params
train_config = Config.train_params
paths = Config.paths
seed = train_config.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if device == "cuda":
torch.cuda.manual_seed_all(seed)
model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device)
discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
checkpoint_path = os.path.join(train_config.task_name, "vqvae_ckpt.pth")
(total_steps, start_epoch,
metrics, logs, total_training_time) = load_vae_checkpoint(checkpoint_path,
model, discriminator,
optimizer_d, optimizer_g)
if not os.path.exists(train_config.task_name):
os.mkdir(train_config.task_name)
num_epochs = train_config.autoencoder_epochs
recon_criterion = torch.nn.MSELoss()
disc_criterion = torch.nn.MSELoss()
lpips_model = LPIPS().eval().to(device)
acc_steps = train_config.autoencoder_acc_steps
disc_step_start = train_config.disc_start
start_time_total = time.time() - total_training_time
for epoch_idx in trange(start_epoch, num_epochs, colour='red', dynamic_ncols=True):
start_time_epoch = time.time()
epoch_log = []
for images in tqdm(dataloader, colour='green', dynamic_ncols=True):
batch_start_time = time.time()
total_steps += 1
images = images.to(device)
model_output = model(images)
output, z, quantize_losses = model_output
recon_loss = recon_criterion(output, images) / acc_steps
g_loss = (
recon_loss
+ (train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps)
+ (train_config.commitment_beta * quantize_losses["commitment_loss"] / acc_steps)
)
if total_steps > disc_step_start:
disc_fake_pred = discriminator(output)
disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
g_loss += train_config.disc_weight * disc_fake_loss / acc_steps
lpips_loss = torch.mean(lpips_model(output, images)) / acc_steps
g_loss += train_config.perceptual_weight * lpips_loss
g_loss.backward()
if total_steps % acc_steps == 0:
optimizer_g.step()
optimizer_g.zero_grad()
if total_steps > disc_step_start:
disc_fake_pred = discriminator(output.detach())
disc_real_pred = discriminator(images)
# disc_loss = (disc_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) +
# disc_criterion(disc_real_pred, torch.ones_like(disc_real_pred))) / 2 / acc_steps
disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device))
disc_real_loss = disc_criterion(disc_real_pred, torch.ones(disc_real_pred.shape, device=disc_real_pred.device))
disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 / acc_steps
disc_loss.backward()
if total_steps % acc_steps == 0:
optimizer_d.step()
optimizer_d.zero_grad()
if total_steps % acc_steps == 0:
optimizer_g.step()
optimizer_g.zero_grad()
batch_time = time.time() - batch_start_time
epoch_log.append(format_time(0, batch_time))
optimizer_d.step()
optimizer_d.zero_grad()
optimizer_g.step()
optimizer_g.zero_grad()
epoch_time = time.time() - start_time_epoch
logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log})
total_training_time = time.time() - start_time_total
save_vae_checkpoint(total_steps, epoch_idx + 1, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path, logs, total_training_time)
recon_save_path = os.path.join(train_config.task_name, 'vqvae_recon')
inference(model, dataset, recon_save_path, epoch=epoch_idx, device=device, sample_size=16)
print("Training completed.")
# ==================================================================
# S T A R T I N G - V Q - V A E - T R A I N I N G
# ==================================================================
# trainVAE(Config, dataloader)
# python Vaani-VQVAE-Main.py | tee AE-training.log
# python Vaani-VQVAE-Main.py > AE-training.log 2>&1
# ==================================================================
# L I N E A R - N O I S E - S C H E D U L E R
# ==================================================================
class LinearNoiseScheduler:
r"""
Class for the linear noise scheduler that is used in DDPM.
"""
def __init__(self, num_timesteps, beta_start, beta_end):
self.num_timesteps = num_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
# Mimicking how compvis repo creates schedule
self.betas = (
torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2
)
self.alphas = 1. - self.betas
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
def add_noise(self, original, noise, t):
r"""
Forward method for diffusion
:param original: Image on which noise is to be applied
:param noise: Random Noise Tensor (from normal dist)
:param t: timestep of the forward process of shape -> (B,)
:return:
"""
original_shape = original.shape
batch_size = original_shape[0]
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
# Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
for _ in range(len(original_shape) - 1):
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
for _ in range(len(original_shape) - 1):
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
# Apply and Return Forward process equation
return (sqrt_alpha_cum_prod.to(original.device) * original
+ sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
def sample_prev_timestep(self, xt, noise_pred, t):
r"""
Use the noise prediction by model to get
xt-1 using xt and the nosie predicted
:param xt: current timestep sample
:param noise_pred: model noise prediction
:param t: current timestep we are at
:return:
"""
x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
x0 = torch.clamp(x0, -1., 1.)
mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
if t == 0:
return mean, x0
else:
variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
variance = variance * self.betas.to(xt.device)[t]
sigma = variance ** 0.5
z = torch.randn(xt.shape).to(xt.device)
# OR
# variance = self.betas[t]
# sigma = variance ** 0.5
# z = torch.randn(xt.shape).to(xt.device)
return mean + sigma * z, x0
# ==================================================================
# T I M E - E M B E D D I N G
# ==================================================================
def get_time_embedding(time_steps, temb_dim):
r"""
Convert time steps tensor into an embedding using the
sinusoidal time embedding formula
:param time_steps: 1D tensor of length batch size
:param temb_dim: Dimension of the embedding
:return: BxD embedding representation of B time steps
"""
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
# factor = 10000^(2i/d_model)
factor = 10000 ** ((torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
)
# pos / factor
# timesteps B -> B, 1 -> B, temb_dim
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
# ==================================================================
# L D M - U N E T - U P - B L O C K
# ==================================================================
class UpBlockUnet(nn.Module):
r"""
Up conv block with attention.
Sequence of following blocks
1. Upsample
1. Concatenate Down block output
2. Resnet block with time embedding
3. Attention Block
"""
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.t_emb_dim = t_emb_dim
self.cross_attn = cross_attn
self.context_dim = context_dim
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
if self.cross_attn:
assert context_dim is not None, "Context Dimension must be passed for cross attention"
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
4, 2, 1) \
if self.up_sample else nn.Identity()
def forward(self, x, out_down=None, t_emb=None, context=None):
x = self.up_sample_conv(x)
if out_down is not None:
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
# --- Resnet --------------------
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
# --- Self Attention ------------
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# --- Cross Attention -----------
if self.cross_attn:
assert context is not None, "context cannot be None if cross attention layers are used"
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert len(context.shape) == 3, \
"Context shape does not match B,_,CONTEXT_DIM"
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\
"Context shape does not match B,_,CONTEXT_DIM"
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out
# ==================================================================
# L D M - U N E T
# ==================================================================
class Unet(nn.Module):
r"""
Unet model comprising
Down blocks, Midblocks and Uplocks
"""
def __init__(self, im_channels, model_config):
super().__init__()
self.down_channels = model_config.down_channels
self.mid_channels = model_config.mid_channels
self.t_emb_dim = model_config.time_emb_dim
self.down_sample = model_config.down_sample
self.num_down_layers = model_config.num_down_layers
self.num_mid_layers = model_config.num_mid_layers
self.num_up_layers = model_config.num_up_layers
self.attns = model_config.attn_down
self.norm_channels = model_config.norm_channels
self.num_heads = model_config.num_heads
self.conv_out_channels = model_config.conv_out_channels
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-2]
assert len(self.down_sample) == len(self.down_channels) - 1
assert len(self.attns) == len(self.down_channels) - 1
self.condition_config = model_config.condition_config
self.cond = condition_types = self.condition_config.condition_types
if 'audio' in condition_types:
self.audio_cond = True
self.audio_embed_dim = self.condition_config.audio_condition_config.audio_embed_dim
# Initial projection from sinusoidal time embedding
self.t_proj = nn.Sequential(
nn.Linear(self.t_emb_dim, self.t_emb_dim),
nn.SiLU(),
nn.Linear(self.t_emb_dim, self.t_emb_dim),
)
# Context projection for whisper Encoder last hidden state
# [B, 1500, 1280] -> [B, 1280]
self.context_projector = nn.Sequential(
nn.Linear(self.audio_embed_dim, 320),
nn.SiLU(),
nn.Linear(320, 1)
)
self.up_sample = list(reversed(self.down_sample))
self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)
# --::----- D O W N - B l O C K S ----------------::--------------::----------------
self.downs = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
# Cross Attention and Context Dim only needed if text or audio condition is present
self.downs.append(
DownBlock(
self.down_channels[i],
self.down_channels[i + 1],
self.t_emb_dim,
down_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_down_layers,
attn=self.attns[i],
norm_channels=self.norm_channels,
cross_attn=self.audio_cond,
context_dim=self.audio_embed_dim
)
)
# --::----- M I D - B l O C K S ----------------::--------------::----------------
self.mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.mids.append(
MidBlock(
self.mid_channels[i],
self.mid_channels[i + 1],
self.t_emb_dim,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels,
cross_attn=self.audio_cond,
context_dim=self.audio_embed_dim
)
)
# --::----- U P - B l O C K S ----------------::--------------::----------------
self.ups = nn.ModuleList([])
for i in reversed(range(len(self.down_channels) - 1)):
self.ups.append(
UpBlockUnet(
self.down_channels[i] * 2,
self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
self.t_emb_dim,
up_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_up_layers,
norm_channels=self.norm_channels,
cross_attn=self.audio_cond,
context_dim=self.audio_embed_dim
)
)
self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1)
def forward(self, x, t, cond_input=None):
# Shapes assuming downblocks are [C1, C2, C3, C4]
# Shapes assuming midblocks are [C4, C4, C3]
# Shapes assuming downsamples are [True, True, False]
# B x C x H x W
out = self.conv_in(x)
# B x C1 x H x W
# t_emb -> B x t_emb_dim
t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
t_emb = self.t_proj(t_emb)
# --- Conditioning ---------------
if self.audio_cond:
# context_hidden_states = cond_input
# print(self.audio_cond, cond_input.shape)
last_hidden_state = cond_input
weights = self.context_projector(last_hidden_state)
weights = torch.softmax(weights, dim=1) # Normalize across time
pooled_embedding = (last_hidden_state * weights).sum(dim=1) # [1, 512]
context_hidden_states = pooled_embedding.unsqueeze(1)
# print(context_hidden_states.shape)
# exit()
# --- Down Pass ------------------
down_outs = []
for idx, down in enumerate(self.downs):
down_outs.append(out)
out = down(out, t_emb, context_hidden_states)
# down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
# out B x C4 x H/4 x W/4
# --- Mid Pass ------------------
for mid in self.mids:
out = mid(out, t_emb, context_hidden_states)
# out B x C3 x H/4 x W/4
# --- Up Pass ------------------
for up in self.ups:
down_out = down_outs.pop()
out = up(out, down_out, t_emb, context_hidden_states)
# out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
out = self.norm_out(out)
out = nn.SiLU()(out)
out = self.conv_out(out)
# out B x C x H x W
return out
# ==================================================================
# L D M - T R A I N I N G
# ==================================================================
def find_checkpoints(checkpoint_path):
directory = os.path.dirname(checkpoint_path)
prefix = os.path.basename(checkpoint_path)
pattern = re.compile(rf"{re.escape(prefix)}_epoch(\d+)\.pt$")
try:
files = os.listdir(directory)
except FileNotFoundError:
return []
return [
os.path.join(directory, f)
for f in files if pattern.match(f)
]
def save_ldm_checkpoint(checkpoint_path,
total_steps, epoch, model, optimizer,
metrics, logs, total_training_time
):
checkpoint = {
"total_steps": total_steps,
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"metrics": metrics,
"logs": logs,
"total_training_time": total_training_time
}
checkpoint_file = f"{checkpoint_path}_epoch{epoch}.pt"
torch.save(checkpoint, checkpoint_file)
print(f"LDM Checkpoint saved at {checkpoint_file}")
all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt")
# all_ckpts = find_checkpoints(checkpoint_path)
def extract_epoch(filename):
match = re.search(r"_epoch(\d+)\.pt", filename)
return int(match.group(1)) if match else -1
all_ckpts = sorted(all_ckpts, key=extract_epoch)
for old_ckpt in all_ckpts[:-2]:
os.remove(old_ckpt)
print(f"Removed old LDM checkpoint: {old_ckpt}")
def load_ldm_checkpoint(checkpoint_path, model, optimizer):
all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt")
# all_ckpts = find_checkpoints(checkpoint_path)
if not all_ckpts:
print("No LDM checkpoint found. Starting from scratch.")
return 0, 0, None, [], 0
def extract_epoch(filename):
match = re.search(r"_epoch(\d+)\.pt", filename)
return int(match.group(1)) if match else -1
all_ckpts = sorted(all_ckpts, key=extract_epoch)
latest_ckpt = all_ckpts[-1]
if os.path.exists(latest_ckpt):
checkpoint = torch.load(latest_ckpt, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
total_steps = checkpoint["total_steps"]
epoch = checkpoint["epoch"]
metrics = checkpoint["metrics"]
logs = checkpoint.get("logs", [])
total_training_time = checkpoint.get("total_training_time", 0)
print(f"LDM Checkpoint loaded from {latest_ckpt}. Resuming from epoch {epoch + 1}, step {total_steps}")
return total_steps, epoch + 1, metrics, logs, total_training_time
else:
print("No LDM checkpoint found. Starting from scratch.")
return 0, 0, None, [], 0
def load_ldm_vae_checkpoint(checkpoint_path, vae, device=device):
# all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt")
all_ckpts = find_checkpoints(checkpoint_path)
if not all_ckpts:
print("No VQVAE checkpoint found.")
return 0, 0, None, [], 0
def extract_epoch(filename):
match = re.search(r"_epoch(\d+)\.pt", filename)
return int(match.group(1)) if match else -1
all_ckpts = sorted(all_ckpts, key=extract_epoch)
latest_ckpt = all_ckpts[-1]
if os.path.exists(latest_ckpt):
checkpoint = torch.load(latest_ckpt, map_location=device)
vae.load_state_dict(checkpoint["model_state_dict"])
total_steps = checkpoint["total_steps"]
epoch = checkpoint["epoch"]
print(f"VQVAE Checkpoint loaded from {latest_ckpt} at epoch {epoch + 1} & step {total_steps}")
def trainLDM(Config, dataloader):
diffusion_config = Config.diffusion_params
dataset_config = Config.dataset_params
diffusion_model_config = Config.ldm_params
autoencoder_model_config = Config.autoencoder_params
train_config = Config.train_params
condition_config = diffusion_model_config.condition_config
seed = train_config.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if device == "cuda":
torch.cuda.manual_seed_all(seed)
vqvae_device = "cuda:1"
ldm_device = "cuda:0"
# ldm_device = vqvae_device = device
# scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps)
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps,
beta_start=diffusion_config.beta_start,
beta_end=diffusion_config.beta_end)
if not train_config.ldm_pretraining:
if condition_config is not None:
condition_types = condition_config.condition_types
if 'audio' in condition_types:
from msclap import CLAP # type: ignore
audio_model = CLAP(version = '2023', use_cuda=(True if "cuda" in device else False))
model = Unet(im_channels=autoencoder_model_config.z_channels, model_config=diffusion_model_config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=train_config.ldm_lr)
criterion = torch.nn.MSELoss()
num_epochs = train_config.ldm_epochs
checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "ldmH_ckpt")
(total_steps, start_epoch, metrics, logs,
total_training_time) = load_ldm_checkpoint(checkpoint_path, model, optimizer)
vae = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_model_config).eval().to(vqvae_device)
vae_checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "vqvae_ckpt")
load_ldm_vae_checkpoint(vae_checkpoint_path, vae, vqvae_device)
for param in vae.parameters():
param.requires_grad = False
vae.eval()
if not os.path.exists(train_config.task_name):
os.makedirs(train_config.task_name, exist_ok=True)
acc_steps = train_config.ldm_acc_steps
disc_step_start = train_config.disc_start
start_time_total = time.time() - total_training_time
model.train()
optimizer.zero_grad()
for epoch_idx in trange(start_epoch, num_epochs, desc=f"{device}-LDM Epoch", colour='red', dynamic_ncols=True):
start_time_epoch = time.time()
losses = []
epoch_log = []
# Load latest vqvae checkpoints
vae_checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "vqvae_ckpt")
load_ldm_vae_checkpoint(vae_checkpoint_path, vae, vqvae_device)
for param in vae.parameters():
param.requires_grad = False
vae.eval()
# for images, cond_input in tqdm(dataloader, colour='green', dynamic_ncols=True):
for images in tqdm(dataloader, colour='green', dynamic_ncols=True):
cond_input = None
batch_start_time = time.time()
total_steps += 1
batch_size = images.shape[0]
# images = images.to(device)
with torch.no_grad():
images, _ = vae.encode(images.to(vqvae_device))
images = images.to(ldm_device)
# Conditional Input
audio_embed_dim = condition_config.audio_condition_config.audio_embed_dim
# empty_audio_embedding = torch.zeros(audio_embed_dim, device=device).float().unsqueeze(0).repeat(batch_size, 1).unsqueeze(1)
# empty_audio_embedding = torch.zeros((1500,1280), device=device).float().unsqueeze(0).repeat(batch_size, 1).unsqueeze(1)
empty_audio_embedding = torch.zeros((batch_size, 1500, 1280), device=device).float()
if not train_config.ldm_pretraining:
if 'audio' in condition_types:
with torch.no_grad():
audio_embeddings = audio_model.get_audio_embeddings(cond_input)
text_drop_prob = condition_config.audio_condition_config.cond_drop_prob
text_drop_mask = torch.zeros((images.shape[0]), device=images.device).float().uniform_(0, 1) < text_drop_prob
audio_embeddings[text_drop_mask, :, :] = empty_audio_embedding[0]
else:
audio_embeddings = empty_audio_embedding
# Sample random noise
noise = torch.randn_like(images).to(device)
# Sample timestep
t = torch.randint(0, diffusion_config.num_timesteps, (images.shape[0],)).to(device)
# Add noise to images according to timestep
noisy_images = scheduler.add_noise(images, noise, t)
noise_pred = model(noisy_images, t, cond_input=audio_embeddings)
loss = criterion(noise_pred, noise)
losses.append(loss.item())
loss = loss / acc_steps
loss.backward()
if total_steps % acc_steps == 0:
optimizer.step()
optimizer.zero_grad()
if total_steps % acc_steps == 0:
optimizer.step()
optimizer.zero_grad()
print(f'Finished epoch:{epoch_idx + 1}/{num_epochs} | Loss : {np.mean(losses):.4f}')
epoch_time = time.time() - start_time_epoch
logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log})
total_training_time = time.time() - start_time_total
save_ldm_checkpoint(checkpoint_path, total_steps, epoch_idx + 1, model, optimizer, metrics, logs, total_training_time)
infer(Config)
# Checking to conntinue training
train_continue = DotDict.from_dict(yaml.safe_load(open(config_path, 'r')))
if train_continue.training.continue_ldm == False:
print('LDM Training Stoped ...')
break
print('Done Training ...')
# ==================================================================
# L D M - S A M P L I N G
# ==================================================================
def sample(model, scheduler, train_config, diffusion_model_config,
autoencoder_model_config, diffusion_config, dataset_config,
vae, audio_model
):
r"""
Sample stepwise by going backward one timestep at a time.
We save the x0 predictions
"""
im_size = dataset_config.im_size // 2**sum(autoencoder_model_config.down_sample)
xt = torch.randn((train_config.num_samples,
autoencoder_model_config.z_channels,
im_size,
im_size)).to(device)
audio_embed_dim = diffusion_model_config.condition_config.audio_condition_config.audio_embed_dim
# empty_audio_embedding = torch.zeros(audio_embed_dim, device=device).float()
# empty_audio_embedding = torch.zeros(audio_embed_dim, device=device).float().unsqueeze(0)
# empty_audio_embedding = empty_audio_embedding.repeat(train_config.num_samples, 1).unsqueeze(1)
empty_audio_embedding = torch.zeros((train_config.num_samples, 1500, 1280), device=device).float()
if not train_config.ldm_pretraining:
# Create Conditional input
pass
else:
audio_embeddings = empty_audio_embedding
uncond_input = empty_audio_embedding
cond_input = audio_embeddings
save_count = 0
for i in tqdm(reversed(range(diffusion_config.num_timesteps)),
total=diffusion_config.num_timesteps,
colour='blue', desc="Sampling", dynamic_ncols=True):
# Get prediction of noise
t = (torch.ones((xt.shape[0],)) * i).long().to(device)
# t = torch.as_tensor(i).unsqueeze(0).to(device)
noise_pred_cond = model(xt, t, cond_input)
cf_guidance_scale = train_config.cf_guidance_scale
if cf_guidance_scale > 1:
noise_pred_uncond = model(xt, t, uncond_input)
noise_pred = noise_pred_uncond + cf_guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
# Use scheduler to get x0 and xt-1
xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
# Save x0
#ims = torch.clamp(xt, -1., 1.).detach().cpu()
if i == 0:
# Decode ONLY the final iamge to save time
ims = vae.decode(xt)
else:
# ims = xt
ims = x0_pred
ims = torch.clamp(ims, -1., 1.).detach().cpu()
ims = (ims + 1) / 2
grid = make_grid(ims, nrow=train_config.num_grid_rows)
# img = tv.transforms.ToPILImage()(grid)
np_img = (grid * 255).byte().numpy().transpose(1, 2, 0)
img = Image.fromarray(np_img)
if not os.path.exists(os.path.join(train_config.task_name, 'samplesH')):
os.makedirs(os.path.join(train_config.task_name, 'samplesH'), exist_ok=True)
img.save(os.path.join(train_config.task_name, 'samplesH', 'x0_{}.png'.format(i)))
img.close()
def infer(Config):
diffusion_config = Config.diffusion_params
dataset_config = Config.dataset_params
diffusion_model_config = Config.ldm_params
autoencoder_model_config = Config.autoencoder_params
train_config = Config.train_params
# scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps)
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps,
beta_start=diffusion_config.beta_start,
beta_end=diffusion_config.beta_end)
model = Unet(im_channels=autoencoder_model_config.z_channels,
model_config=diffusion_model_config).eval().to(device)
vae = VQVAE(im_channels=dataset_config.im_channels,
model_config=autoencoder_model_config).eval().to(device)
if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)):
checkpoint_path = os.path.join(train_config.task_name, train_config.ldm_ckpt_name)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
vae.load_state_dict(checkpoint["vae_state_dict"])
print('Loaded unet & vae checkpoint')
# Create output directories
if not os.path.exists(train_config.task_name):
os.makedirs(train_config.task_name, exist_ok=True)
with torch.no_grad():
sample(model, scheduler, train_config, diffusion_model_config,
autoencoder_model_config, diffusion_config, dataset_config, vae, None)
# ==================================================================
# S T A R T I N G - L D M - T R A I N I N G
# ==================================================================
# trainLDM(Config, dataloader)
if sys.argv[1] == 'train_vae':
trainVAE(Config, dataloader)
elif sys.argv[1] == 'train_ldm':
trainLDM(Config, dataloader)
else:
infer(Config)
# git add . && git commit -m "LDM" && git push -u origin master
# huggingface-cli upload alpha31476/Vaani-Audio2Img-LDM . --commit-message "SDFT"