# ================================================================== # L A T E N T D I F F U S I O N M O D E L # ================================================================== # Author : Ashish Kumar Uchadiya # Created : May 11, 2025 # Description: This script implements the training of a VQ-VAE model for # image reconstruction, integrated with Latent Diffusion Models (LDMs) and # audio conditioning. The VQ-VAE maps images to a discrete latent space, # which is then modeled by the LDM for learning a diffusion process over the # compressed representation. Audio features are used as conditioning inputs # to guide the generation process. The training minimizes a combination of # LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual # fidelity and PatchGAN loss to enforce local realism. This setup enables # efficient and semantically-aware generation of high-quality images driven # by audio cues. # ================================================================== # I M P O R T S # ================================================================== import os import torch import torch.nn as nn import numpy as np from collections import namedtuple import pandas as pd import torchvision as tv from torchvision.transforms import v2 from tqdm import tqdm, trange import matplotlib.pyplot as plt import re import glob import sys import yaml import random import datetime import torch.hub from torch.utils.data import Dataset, DataLoader from torchvision.utils import make_grid print("TIME:", datetime.datetime.now()) # os.environ["CUDA_VISIBLE_DEVICES"] = f"{sys.argv[2]}" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("DEVICE:", device) # ================================================================== # H E L P E R S # ================================================================== from typing import Any from argparse import Namespace import typing class DotDict(Namespace): """A simple class that builds upon `argparse.Namespace` in order to make chained attributes possible.""" def __init__(self, temp=False, key=None, parent=None) -> None: self._temp = temp self._key = key self._parent = parent def __eq__(self, other): if not isinstance(other, DotDict): return NotImplemented return vars(self) == vars(other) def __getattr__(self, __name: str) -> Any: if __name not in self.__dict__ and not self._temp: self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self) else: del self._parent.__dict__[self._key] raise AttributeError("No attribute '%s'" % __name) return self.__dict__[__name] def __repr__(self) -> str: item_keys = [k for k in self.__dict__ if not k.startswith("_")] if len(item_keys) == 0: return "DotDict()" elif len(item_keys) == 1: key = item_keys[0] val = self.__dict__[key] return "DotDict(%s=%s)" % (key, repr(val)) else: return "DotDict(%s)" % ", ".join( "%s=%s" % (key, repr(val)) for key, val in self.__dict__.items() ) @classmethod def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict": """Create a DotDict from a (possibly nested) dict `original`. Warning: this method should not be used on very deeply nested inputs, since it's recursively traversing the nested dictionary values. """ dd = DotDict() for key, value in original.items(): if isinstance(value, typing.Mapping): value = cls.from_dict(value) setattr(dd, key, value) return dd # ================================================================== # L P I P S # ================================================================== class vgg16(nn.Module): def __init__(self): super(vgg16, self).__init__() vgg_pretrained_features = tv.models.vgg16( weights=tv.models.VGG16_Weights.IMAGENET1K_V1 ).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.eval() for param in self.parameters(): param.requires_grad = False def forward(self, X): h1 = self.slice1(X) h2 = self.slice2(h1) h3 = self.slice3(h2) h4 = self.slice4(h3) h5 = self.slice5(h4) vgg_outputs = namedtuple("VggOutputs", ['h1', 'h2', 'h3', 'h4', 'h5']) out = vgg_outputs(h1, h2, h3, h4, h5) return out def _spatial_average(in_tens, keepdim=True): return in_tens.mean([2, 3], keepdim=keepdim) def _normalize_tensor(in_feat, eps= 1e-8): norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True)) return in_feat / norm_factor class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() # Imagnet normalization for (0-1) # mean = [0.485, 0.456, 0.406] # std = [0.229, 0.224, 0.225] self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): ''' A single linear layer which does a 1x1 conv ''' def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(), ] if (use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class LPIPS(nn.Module): def __init__(self, net='vgg', version='0.1', use_dropout=True): super(LPIPS, self).__init__() self.version = version self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] self.L = len(self.chns) self.net = vgg16() self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = nn.ModuleList([self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]) # --- Orignal url -------------------- # weights_url = f"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v{version}/{net}.pth" # --- Orignal Forked url ------------- weights_url = f"https://github.com/akuresonite/PerceptualSimilarity-Forked/raw/master/lpips/weights/v{version}/{net}.pth" # --- Orignal torchmetric url -------- # weights_url = "https://github.com/Lightning-AI/torchmetrics/raw/master/src/torchmetrics/functional/image/lpips_models/vgg.pth" state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') self.load_state_dict(state_dict, strict=False) self.eval() for param in self.parameters(): param.requires_grad = False def forward(self, in0, in1, normalize=False): # Scale the inputs to -1 to +1 range if input in [0,1] if normalize: in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) # in0_input, in1_input = in0, in1 outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) diffs = {} for kk in range(self.L): feats0 = _normalize_tensor(outs0[kk]) feats1 = _normalize_tensor(outs1[kk]) diffs[kk] = (feats0 - feats1) ** 2 res = [_spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] val = sum(res) return val.reshape(-1) # ================================================================== # P A T C H - G A N - D I S C R I M I N A T O R # ================================================================== class Discriminator(nn.Module): r""" PatchGAN Discriminator. Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to 1 scalar value , we instead predict grid of values. Where each grid is prediction of how likely the discriminator thinks that the image patch corresponding to the grid cell is real """ def __init__( self, im_channels=3, conv_channels=[64, 128, 256], kernels=[4, 4, 4, 4], strides=[2, 2, 2, 1], paddings=[1, 1, 1, 1], ): super().__init__() self.im_channels = im_channels activation = nn.LeakyReLU(0.2) layers_dim = [self.im_channels] + conv_channels + [1] self.layers = nn.ModuleList( [ nn.Sequential( nn.Conv2d( layers_dim[i], layers_dim[i + 1], kernel_size=kernels[i], stride=strides[i], padding=paddings[i], bias=False if i != 0 else True, ), ( nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity() ), activation if i != len(layers_dim) - 2 else nn.Identity(), ) for i in range(len(layers_dim) - 1) ] ) def forward(self, x): out = x for layer in self.layers: out = layer(out) return out # ================================================================== # D O W E - B L O C K # ================================================================== class DownBlock(nn.Module): r""" Down conv block with attention. Sequence of following block 1. Resnet block with time embedding 2. Attention block 3. Downsample """ def __init__( self, in_channels, out_channels, t_emb_dim, down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None, ): super().__init__() self.num_layers = num_layers self.down_sample = down_sample self.attn = attn self.context_dim = context_dim self.cross_attn = cross_attn self.t_emb_dim = t_emb_dim self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d( in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1, ), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels)) for _ in range(num_layers) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.down_sample_conv = ( nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity() ) def forward(self, x, t_emb=None, context=None): out = x for i in range(self.num_layers): # Resnet block of Unet resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) if self.attn: # Attention block of Unet batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: assert ( context is not None ), "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Downsample out = self.down_sample_conv(out) return out # ================================================================== # M I D - B L O C K # ================================================================== class MidBlock(nn.Module): r""" Mid conv block with attention. Sequence of following blocks 1. Resnet block with time embedding 2. Attention block 3. Resnet block with time embedding """ def __init__( self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None, ): super().__init__() self.num_layers = num_layers self.t_emb_dim = t_emb_dim self.context_dim = context_dim self.cross_attn = cross_attn self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d( in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1, ), ) for i in range(num_layers + 1) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) for _ in range(num_layers + 1) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers + 1) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers + 1) ] ) def forward(self, x, t_emb=None, context=None): out = x # First resnet block resnet_input = out out = self.resnet_conv_first[0](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] out = self.resnet_conv_second[0](out) out = out + self.residual_input_conv[0](resnet_input) for i in range(self.num_layers): # Attention Block batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: assert ( context is not None ), "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Resnet Block resnet_input = out out = self.resnet_conv_first[i + 1](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] out = self.resnet_conv_second[i + 1](out) out = out + self.residual_input_conv[i + 1](resnet_input) return out # ================================================================== # U P - B L O C K # ================================================================== class UpBlock(nn.Module): r""" Up conv block with attention. Sequence of following blocks 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block """ def __init__( self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, attn, norm_channels, ): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.attn = attn self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d( in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1, ), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) for _ in range(num_layers) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.up_sample_conv = ( nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1) if self.up_sample else nn.Identity() ) def forward(self, x, out_down=None, t_emb=None): # Upsample x = self.up_sample_conv(x) # Concat with Downblock output if out_down is not None: x = torch.cat([x, out_down], dim=1) out = x for i in range(self.num_layers): # Resnet Block resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) # Self Attention if self.attn: batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn return out # ================================================================== # V Q - V A E # ================================================================== class VQVAE(nn.Module): def __init__(self, im_channels, model_config): super().__init__() self.down_channels = model_config.down_channels self.mid_channels = model_config.mid_channels self.down_sample = model_config.down_sample self.num_down_layers = model_config.num_down_layers self.num_mid_layers = model_config.num_mid_layers self.num_up_layers = model_config.num_up_layers # To disable attention in Downblock of Encoder and Upblock of Decoder self.attns = model_config.attn_down # Latent Dimension self.z_channels = model_config.z_channels self.codebook_size = model_config.codebook_size self.norm_channels = model_config.norm_channels self.num_heads = model_config.num_heads # Assertion to validate the channel information assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-1] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 # Wherever we use downsampling in encoder correspondingly use # upsampling in decoder self.up_sample = list(reversed(self.down_sample)) ##################### Encoder ###################### self.encoder_conv_in = nn.Conv2d( im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1) ) # Downblock + Midblock self.encoder_layers = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): self.encoder_layers.append( DownBlock( self.down_channels[i], self.down_channels[i + 1], t_emb_dim=None, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels, ) ) self.encoder_mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.encoder_mids.append( MidBlock( self.mid_channels[i], self.mid_channels[i + 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels, ) ) self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) self.encoder_conv_out = nn.Conv2d( self.down_channels[-1], self.z_channels, kernel_size=3, padding=1 ) # Pre Quantization Convolution self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) # Codebook self.embedding = nn.Embedding(self.codebook_size, self.z_channels) #################################################### ##################### Decoder ###################### # Post Quantization Convolution self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) self.decoder_conv_in = nn.Conv2d( self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1) ) # Midblock + Upblock self.decoder_mids = nn.ModuleList([]) for i in reversed(range(1, len(self.mid_channels))): self.decoder_mids.append( MidBlock( self.mid_channels[i], self.mid_channels[i - 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels, ) ) self.decoder_layers = nn.ModuleList([]) for i in reversed(range(1, len(self.down_channels))): self.decoder_layers.append( UpBlock( self.down_channels[i], self.down_channels[i - 1], t_emb_dim=None, up_sample=self.down_sample[i - 1], num_heads=self.num_heads, num_layers=self.num_up_layers, attn=self.attns[i - 1], norm_channels=self.norm_channels, ) ) self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) self.decoder_conv_out = nn.Conv2d( self.down_channels[0], im_channels, kernel_size=3, padding=1 ) def quantize(self, x): B, C, H, W = x.shape # B, C, H, W -> B, H, W, C x = x.permute(0, 2, 3, 1) # B, H, W, C -> B, H*W, C x = x.reshape(x.size(0), -1, x.size(-1)) # Find nearest embedding/codebook vector # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K) dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) # (B, H*W) min_encoding_indices = torch.argmin(dist, dim=-1) # Replace encoder output with nearest codebook # quant_out -> B*H*W, C quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) # x -> B*H*W, C x = x.reshape((-1, x.size(-1))) commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) codebook_loss = torch.mean((quant_out - x.detach()) ** 2) quantize_losses = {"codebook_loss": codebook_loss, "commitment_loss": commmitment_loss} # Straight through estimation quant_out = x + (quant_out - x).detach() # quant_out -> B, C, H, W quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) min_encoding_indices = min_encoding_indices.reshape( (-1, quant_out.size(-2), quant_out.size(-1)) ) return quant_out, quantize_losses, min_encoding_indices def encode(self, x): out = self.encoder_conv_in(x) for idx, down in enumerate(self.encoder_layers): out = down(out) for mid in self.encoder_mids: out = mid(out) out = self.encoder_norm_out(out) out = nn.SiLU()(out) out = self.encoder_conv_out(out) out = self.pre_quant_conv(out) out, quant_losses, _ = self.quantize(out) return out, quant_losses def decode(self, z): out = z out = self.post_quant_conv(out) out = self.decoder_conv_in(out) for mid in self.decoder_mids: out = mid(out) for idx, up in enumerate(self.decoder_layers): out = up(out) out = self.decoder_norm_out(out) out = nn.SiLU()(out) out = self.decoder_conv_out(out) return out def forward(self, x): '''out: [B, 3, 256, 256] z: [B, 3, 64, 64] quant_losses: { codebook_loss: 0.0681, commitment_loss: 0.0681 } ''' z, quant_losses = self.encode(x) out = self.decode(z) return out, z, quant_losses # ================================================================== # C O N F I G U R A T I O N # ================================================================== import pprint config_path = "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/config-LDM-High-Pre.yaml" # config_path = sys.argv[1] with open(config_path, 'r') as file: Config = yaml.safe_load(file) pprint.pprint(Config, width=120) Config = DotDict.from_dict(Config) dataset_config = Config.dataset_params diffusion_config = Config.diffusion_params model_config = Config.model_params train_config = Config.train_params paths = Config.paths # ================================================================== # V A A N I - D A T A S E T # ================================================================== IMAGES_PATH = paths.images_dir def walkDIR(folder_path, include=None): file_list = [] for root, _, files in os.walk(folder_path): for file in files: if include is None or any(file.endswith(ext) for ext in include): file_list.append(os.path.join(root, file)) print("Files found:", len(file_list)) return file_list files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg']) df = pd.DataFrame(files, columns=['image_path']) class VaaniDataset(torch.utils.data.Dataset): def __init__(self, files_paths, im_size): self.files_paths = files_paths self.im_size = im_size def __len__(self): return len(self.files_paths) def __getitem__(self, idx): image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB) # image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB) image = v2.Resize((self.im_size,self.im_size))(image) image = v2.ToDtype(torch.float32, scale=True)(image) # image = 2*image - 1 return image dataset = VaaniDataset(files_paths=files, im_size=dataset_config.im_size) image = dataset[2] print('IMAGE SHAPE:', image.shape) if train_config.debug: s = 0.001 dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42)) print("Length of Train dataset:", len(dataset)) if sys.argv[1] == "train_vae": BATCH_SIZE = train_config.autoencoder_batch_size elif sys.argv[1] == "train_ldm": BATCH_SIZE = train_config.ldm_batch_size dataloader = torch.utils.data.DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=48, pin_memory=True, drop_last=True, persistent_workers=True ) images = next(iter(dataloader)) print('BATCH SHAPE:', images.shape) # ================================================================== # M O D E L - I N I T I L I Z A T I O N # ================================================================== dataset_config = Config.dataset_params autoencoder_config = Config.autoencoder_params train_config = Config.train_params # model = VQVAE(im_channels=dataset_config.im_channels, # model_config=autoencoder_config).to(device) # model_output = model(images.to(device)) # print('MODEL OUTPUT:') # print(model_output[0].shape, model_output[1].shape, model_output[2]) # from torchinfo import summary # summary(model=model, # input_data=images.to(device), # # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE), # col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"], # col_width=20, # row_settings=["var_names"], # depth = 6, # # device=device # ) # exit() # ================================================================== # V Q - V A E - T R A I N I N G # ================================================================== # python your_script.py 2>&1 > training.log import time def format_time(t1, t2): elapsed_time = t2 - t1 if elapsed_time < 60: return f"{elapsed_time:.2f} seconds" elif elapsed_time < 3600: minutes = elapsed_time // 60 seconds = elapsed_time % 60 return f"{minutes:.0f} minutes {seconds:.2f} seconds" elif elapsed_time < 86400: hours = elapsed_time // 3600 remainder = elapsed_time % 3600 minutes = remainder // 60 seconds = remainder % 60 return f"{hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" else: days = elapsed_time // 86400 remainder = elapsed_time % 86400 hours = remainder // 3600 remainder = remainder % 3600 minutes = remainder // 60 seconds = remainder % 60 return f"{days:.0f} days {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" def find_checkpoints(checkpoint_path): directory = os.path.dirname(checkpoint_path) prefix = os.path.basename(checkpoint_path) pattern = re.compile(rf"{re.escape(prefix)}_epoch(\d+)\.pt$") try: files = os.listdir(directory) except FileNotFoundError: return [] return [ os.path.join(directory, f) for f in files if pattern.match(f) ] def save_vae_checkpoint( total_steps, epoch, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path, logs, total_training_time ): checkpoint = { "total_steps": total_steps, "epoch": epoch, "model_state_dict": model.state_dict(), "discriminator_state_dict": discriminator.state_dict(), "optimizer_d_state_dict": optimizer_d.state_dict(), "optimizer_g_state_dict": optimizer_g.state_dict(), "metrics": metrics, "logs": logs, "total_training_time": total_training_time } checkpoint_file = f"{checkpoint_path}_epoch{epoch}.pt" torch.save(checkpoint, checkpoint_file) print(f"VQVAE Checkpoint saved at {checkpoint_file}") all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") # all_ckpts = find_checkpoints(checkpoint_path) def extract_epoch(filename): match = re.search(r"_epoch(\d+)\.pt", filename) return int(match.group(1)) if match else -1 all_ckpts = sorted(all_ckpts, key=extract_epoch) for old_ckpt in all_ckpts[:-2]: os.remove(old_ckpt) print(f"Removed old VQVAE checkpoint: {old_ckpt}") def load_vae_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g, device=device): all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") # all_ckpts = find_checkpoints(checkpoint_path) if not all_ckpts: print("No VQVAE checkpoint found. Starting from scratch.") return 0, 0, None, [], 0 def extract_epoch(filename): match = re.search(r"_epoch(\d+)\.pt", filename) return int(match.group(1)) if match else -1 all_ckpts = sorted(all_ckpts, key=extract_epoch) latest_ckpt = all_ckpts[-1] if os.path.exists(latest_ckpt): checkpoint = torch.load(latest_ckpt, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"]) optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"]) total_steps = checkpoint["total_steps"] epoch = checkpoint["epoch"] metrics = checkpoint["metrics"] logs = checkpoint.get("logs", []) total_training_time = checkpoint.get("total_training_time", 0) print(f"VQVAE Checkpoint loaded from {latest_ckpt}. Resuming from epoch {epoch + 1}, step {total_steps}") return total_steps, epoch + 1, metrics, logs, total_training_time else: print("No VQVAE checkpoint found. Starting from scratch.") return 0, 0, None, [], 0 from PIL import Image def inference(model, dataset, save_path, epoch, device="cuda", sample_size=8): if not os.path.exists(save_path): os.makedirs(save_path) image_tensors = [] for i in range(sample_size): image_tensors.append(dataset[i].unsqueeze(0)) image_tensors = torch.cat(image_tensors, dim=0).to(device) with torch.no_grad(): outputs, _, _ = model(image_tensors) save_input = image_tensors.detach().cpu() save_output = outputs.detach().cpu() grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) np_img = (grid * 255).byte().numpy().transpose(1, 2, 0) combined_image = Image.fromarray(np_img) combined_image.save("output_image.png") # combined_image = tv.transforms.ToPILImage()(grid) combined_image.save(os.path.join(save_path, f"reconstructed_images_EP-{epoch}_{sample_size}.png")) print(f"Reconstructed images saved at: {save_path}") def trainVAE(Config, dataloader): dataset_config = Config.dataset_params autoencoder_config = Config.autoencoder_params train_config = Config.train_params paths = Config.paths seed = train_config.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if device == "cuda": torch.cuda.manual_seed_all(seed) model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device) discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device) optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) checkpoint_path = os.path.join(train_config.task_name, "vqvae_ckpt.pth") (total_steps, start_epoch, metrics, logs, total_training_time) = load_vae_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g) if not os.path.exists(train_config.task_name): os.mkdir(train_config.task_name) num_epochs = train_config.autoencoder_epochs recon_criterion = torch.nn.MSELoss() disc_criterion = torch.nn.MSELoss() lpips_model = LPIPS().eval().to(device) acc_steps = train_config.autoencoder_acc_steps disc_step_start = train_config.disc_start start_time_total = time.time() - total_training_time for epoch_idx in trange(start_epoch, num_epochs, colour='red', dynamic_ncols=True): start_time_epoch = time.time() epoch_log = [] for images in tqdm(dataloader, colour='green', dynamic_ncols=True): batch_start_time = time.time() total_steps += 1 images = images.to(device) model_output = model(images) output, z, quantize_losses = model_output recon_loss = recon_criterion(output, images) / acc_steps g_loss = ( recon_loss + (train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps) + (train_config.commitment_beta * quantize_losses["commitment_loss"] / acc_steps) ) if total_steps > disc_step_start: disc_fake_pred = discriminator(output) disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) g_loss += train_config.disc_weight * disc_fake_loss / acc_steps lpips_loss = torch.mean(lpips_model(output, images)) / acc_steps g_loss += train_config.perceptual_weight * lpips_loss g_loss.backward() if total_steps % acc_steps == 0: optimizer_g.step() optimizer_g.zero_grad() if total_steps > disc_step_start: disc_fake_pred = discriminator(output.detach()) disc_real_pred = discriminator(images) # disc_loss = (disc_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) + # disc_criterion(disc_real_pred, torch.ones_like(disc_real_pred))) / 2 / acc_steps disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device)) disc_real_loss = disc_criterion(disc_real_pred, torch.ones(disc_real_pred.shape, device=disc_real_pred.device)) disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 / acc_steps disc_loss.backward() if total_steps % acc_steps == 0: optimizer_d.step() optimizer_d.zero_grad() if total_steps % acc_steps == 0: optimizer_g.step() optimizer_g.zero_grad() batch_time = time.time() - batch_start_time epoch_log.append(format_time(0, batch_time)) optimizer_d.step() optimizer_d.zero_grad() optimizer_g.step() optimizer_g.zero_grad() epoch_time = time.time() - start_time_epoch logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log}) total_training_time = time.time() - start_time_total save_vae_checkpoint(total_steps, epoch_idx + 1, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path, logs, total_training_time) recon_save_path = os.path.join(train_config.task_name, 'vqvae_recon') inference(model, dataset, recon_save_path, epoch=epoch_idx, device=device, sample_size=16) print("Training completed.") # ================================================================== # S T A R T I N G - V Q - V A E - T R A I N I N G # ================================================================== # trainVAE(Config, dataloader) # python Vaani-VQVAE-Main.py | tee AE-training.log # python Vaani-VQVAE-Main.py > AE-training.log 2>&1 # ================================================================== # L I N E A R - N O I S E - S C H E D U L E R # ================================================================== class LinearNoiseScheduler: r""" Class for the linear noise scheduler that is used in DDPM. """ def __init__(self, num_timesteps, beta_start, beta_end): self.num_timesteps = num_timesteps self.beta_start = beta_start self.beta_end = beta_end # Mimicking how compvis repo creates schedule self.betas = ( torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2 ) self.alphas = 1. - self.betas self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) def add_noise(self, original, noise, t): r""" Forward method for diffusion :param original: Image on which noise is to be applied :param noise: Random Noise Tensor (from normal dist) :param t: timestep of the forward process of shape -> (B,) :return: """ original_shape = original.shape batch_size = original_shape[0] sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W) for _ in range(len(original_shape) - 1): sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) for _ in range(len(original_shape) - 1): sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) # Apply and Return Forward process equation return (sqrt_alpha_cum_prod.to(original.device) * original + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise) def sample_prev_timestep(self, xt, noise_pred, t): r""" Use the noise prediction by model to get xt-1 using xt and the nosie predicted :param xt: current timestep sample :param noise_pred: model noise prediction :param t: current timestep we are at :return: """ x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) x0 = torch.clamp(x0, -1., 1.) mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) if t == 0: return mean, x0 else: variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) variance = variance * self.betas.to(xt.device)[t] sigma = variance ** 0.5 z = torch.randn(xt.shape).to(xt.device) # OR # variance = self.betas[t] # sigma = variance ** 0.5 # z = torch.randn(xt.shape).to(xt.device) return mean + sigma * z, x0 # ================================================================== # T I M E - E M B E D D I N G # ================================================================== def get_time_embedding(time_steps, temb_dim): r""" Convert time steps tensor into an embedding using the sinusoidal time embedding formula :param time_steps: 1D tensor of length batch size :param temb_dim: Dimension of the embedding :return: BxD embedding representation of B time steps """ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" # factor = 10000^(2i/d_model) factor = 10000 ** ((torch.arange( start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) ) # pos / factor # timesteps B -> B, 1 -> B, temb_dim t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) return t_emb # ================================================================== # L D M - U N E T - U P - B L O C K # ================================================================== class UpBlockUnet(nn.Module): r""" Up conv block with attention. Sequence of following blocks 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block """ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.cross_attn = cross_attn self.context_dim = context_dim self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential( nn.SiLU(), nn.Linear(t_emb_dim, out_channels) ) for _ in range(num_layers) ]) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) self.attention_norms = nn.ModuleList( [ nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers) ] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) \ if self.up_sample else nn.Identity() def forward(self, x, out_down=None, t_emb=None, context=None): x = self.up_sample_conv(x) if out_down is not None: x = torch.cat([x, out_down], dim=1) out = x for i in range(self.num_layers): # --- Resnet -------------------- resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) # --- Self Attention ------------ batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # --- Cross Attention ----------- if self.cross_attn: assert context is not None, "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert len(context.shape) == 3, \ "Context shape does not match B,_,CONTEXT_DIM" assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\ "Context shape does not match B,_,CONTEXT_DIM" context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn return out # ================================================================== # L D M - U N E T # ================================================================== class Unet(nn.Module): r""" Unet model comprising Down blocks, Midblocks and Uplocks """ def __init__(self, im_channels, model_config): super().__init__() self.down_channels = model_config.down_channels self.mid_channels = model_config.mid_channels self.t_emb_dim = model_config.time_emb_dim self.down_sample = model_config.down_sample self.num_down_layers = model_config.num_down_layers self.num_mid_layers = model_config.num_mid_layers self.num_up_layers = model_config.num_up_layers self.attns = model_config.attn_down self.norm_channels = model_config.norm_channels self.num_heads = model_config.num_heads self.conv_out_channels = model_config.conv_out_channels assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-2] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 self.condition_config = model_config.condition_config self.cond = condition_types = self.condition_config.condition_types if 'audio' in condition_types: self.audio_cond = True self.audio_embed_dim = self.condition_config.audio_condition_config.audio_embed_dim # Initial projection from sinusoidal time embedding self.t_proj = nn.Sequential( nn.Linear(self.t_emb_dim, self.t_emb_dim), nn.SiLU(), nn.Linear(self.t_emb_dim, self.t_emb_dim), ) # Context projection for whisper Encoder last hidden state # [B, 1500, 1280] -> [B, 1280] self.context_projector = nn.Sequential( nn.Linear(self.audio_embed_dim, 320), nn.SiLU(), nn.Linear(320, 1) ) self.up_sample = list(reversed(self.down_sample)) self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) # --::----- D O W N - B l O C K S ----------------::--------------::---------------- self.downs = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): # Cross Attention and Context Dim only needed if text or audio condition is present self.downs.append( DownBlock( self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels, cross_attn=self.audio_cond, context_dim=self.audio_embed_dim ) ) # --::----- M I D - B l O C K S ----------------::--------------::---------------- self.mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.mids.append( MidBlock( self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels, cross_attn=self.audio_cond, context_dim=self.audio_embed_dim ) ) # --::----- U P - B l O C K S ----------------::--------------::---------------- self.ups = nn.ModuleList([]) for i in reversed(range(len(self.down_channels) - 1)): self.ups.append( UpBlockUnet( self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels, self.t_emb_dim, up_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_up_layers, norm_channels=self.norm_channels, cross_attn=self.audio_cond, context_dim=self.audio_embed_dim ) ) self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) def forward(self, x, t, cond_input=None): # Shapes assuming downblocks are [C1, C2, C3, C4] # Shapes assuming midblocks are [C4, C4, C3] # Shapes assuming downsamples are [True, True, False] # B x C x H x W out = self.conv_in(x) # B x C1 x H x W # t_emb -> B x t_emb_dim t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) t_emb = self.t_proj(t_emb) # --- Conditioning --------------- if self.audio_cond: # context_hidden_states = cond_input # print(self.audio_cond, cond_input.shape) last_hidden_state = cond_input weights = self.context_projector(last_hidden_state) weights = torch.softmax(weights, dim=1) # Normalize across time pooled_embedding = (last_hidden_state * weights).sum(dim=1) # [1, 512] context_hidden_states = pooled_embedding.unsqueeze(1) # print(context_hidden_states.shape) # exit() # --- Down Pass ------------------ down_outs = [] for idx, down in enumerate(self.downs): down_outs.append(out) out = down(out, t_emb, context_hidden_states) # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4] # out B x C4 x H/4 x W/4 # --- Mid Pass ------------------ for mid in self.mids: out = mid(out, t_emb, context_hidden_states) # out B x C3 x H/4 x W/4 # --- Up Pass ------------------ for up in self.ups: down_out = down_outs.pop() out = up(out, down_out, t_emb, context_hidden_states) # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W] out = self.norm_out(out) out = nn.SiLU()(out) out = self.conv_out(out) # out B x C x H x W return out # ================================================================== # L D M - T R A I N I N G # ================================================================== def find_checkpoints(checkpoint_path): directory = os.path.dirname(checkpoint_path) prefix = os.path.basename(checkpoint_path) pattern = re.compile(rf"{re.escape(prefix)}_epoch(\d+)\.pt$") try: files = os.listdir(directory) except FileNotFoundError: return [] return [ os.path.join(directory, f) for f in files if pattern.match(f) ] def save_ldm_checkpoint(checkpoint_path, total_steps, epoch, model, optimizer, metrics, logs, total_training_time ): checkpoint = { "total_steps": total_steps, "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "metrics": metrics, "logs": logs, "total_training_time": total_training_time } checkpoint_file = f"{checkpoint_path}_epoch{epoch}.pt" torch.save(checkpoint, checkpoint_file) print(f"LDM Checkpoint saved at {checkpoint_file}") all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") # all_ckpts = find_checkpoints(checkpoint_path) def extract_epoch(filename): match = re.search(r"_epoch(\d+)\.pt", filename) return int(match.group(1)) if match else -1 all_ckpts = sorted(all_ckpts, key=extract_epoch) for old_ckpt in all_ckpts[:-2]: os.remove(old_ckpt) print(f"Removed old LDM checkpoint: {old_ckpt}") def load_ldm_checkpoint(checkpoint_path, model, optimizer): all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") # all_ckpts = find_checkpoints(checkpoint_path) if not all_ckpts: print("No LDM checkpoint found. Starting from scratch.") return 0, 0, None, [], 0 def extract_epoch(filename): match = re.search(r"_epoch(\d+)\.pt", filename) return int(match.group(1)) if match else -1 all_ckpts = sorted(all_ckpts, key=extract_epoch) latest_ckpt = all_ckpts[-1] if os.path.exists(latest_ckpt): checkpoint = torch.load(latest_ckpt, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) total_steps = checkpoint["total_steps"] epoch = checkpoint["epoch"] metrics = checkpoint["metrics"] logs = checkpoint.get("logs", []) total_training_time = checkpoint.get("total_training_time", 0) print(f"LDM Checkpoint loaded from {latest_ckpt}. Resuming from epoch {epoch + 1}, step {total_steps}") return total_steps, epoch + 1, metrics, logs, total_training_time else: print("No LDM checkpoint found. Starting from scratch.") return 0, 0, None, [], 0 def load_ldm_vae_checkpoint(checkpoint_path, vae, device=device): # all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") all_ckpts = find_checkpoints(checkpoint_path) if not all_ckpts: print("No VQVAE checkpoint found.") return 0, 0, None, [], 0 def extract_epoch(filename): match = re.search(r"_epoch(\d+)\.pt", filename) return int(match.group(1)) if match else -1 all_ckpts = sorted(all_ckpts, key=extract_epoch) latest_ckpt = all_ckpts[-1] if os.path.exists(latest_ckpt): checkpoint = torch.load(latest_ckpt, map_location=device) vae.load_state_dict(checkpoint["model_state_dict"]) total_steps = checkpoint["total_steps"] epoch = checkpoint["epoch"] print(f"VQVAE Checkpoint loaded from {latest_ckpt} at epoch {epoch + 1} & step {total_steps}") def trainLDM(Config, dataloader): diffusion_config = Config.diffusion_params dataset_config = Config.dataset_params diffusion_model_config = Config.ldm_params autoencoder_model_config = Config.autoencoder_params train_config = Config.train_params condition_config = diffusion_model_config.condition_config seed = train_config.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if device == "cuda": torch.cuda.manual_seed_all(seed) vqvae_device = "cuda:1" ldm_device = "cuda:0" # ldm_device = vqvae_device = device # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps) scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, beta_start=diffusion_config.beta_start, beta_end=diffusion_config.beta_end) if not train_config.ldm_pretraining: if condition_config is not None: condition_types = condition_config.condition_types if 'audio' in condition_types: from msclap import CLAP # type: ignore audio_model = CLAP(version = '2023', use_cuda=(True if "cuda" in device else False)) model = Unet(im_channels=autoencoder_model_config.z_channels, model_config=diffusion_model_config).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=train_config.ldm_lr) criterion = torch.nn.MSELoss() num_epochs = train_config.ldm_epochs checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "ldmH_ckpt") (total_steps, start_epoch, metrics, logs, total_training_time) = load_ldm_checkpoint(checkpoint_path, model, optimizer) vae = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_model_config).eval().to(vqvae_device) vae_checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "vqvae_ckpt") load_ldm_vae_checkpoint(vae_checkpoint_path, vae, vqvae_device) for param in vae.parameters(): param.requires_grad = False vae.eval() if not os.path.exists(train_config.task_name): os.makedirs(train_config.task_name, exist_ok=True) acc_steps = train_config.ldm_acc_steps disc_step_start = train_config.disc_start start_time_total = time.time() - total_training_time model.train() optimizer.zero_grad() for epoch_idx in trange(start_epoch, num_epochs, desc=f"{device}-LDM Epoch", colour='red', dynamic_ncols=True): start_time_epoch = time.time() losses = [] epoch_log = [] # Load latest vqvae checkpoints vae_checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "vqvae_ckpt") load_ldm_vae_checkpoint(vae_checkpoint_path, vae, vqvae_device) for param in vae.parameters(): param.requires_grad = False vae.eval() # for images, cond_input in tqdm(dataloader, colour='green', dynamic_ncols=True): for images in tqdm(dataloader, colour='green', dynamic_ncols=True): cond_input = None batch_start_time = time.time() total_steps += 1 batch_size = images.shape[0] # images = images.to(device) with torch.no_grad(): images, _ = vae.encode(images.to(vqvae_device)) images = images.to(ldm_device) # Conditional Input audio_embed_dim = condition_config.audio_condition_config.audio_embed_dim # empty_audio_embedding = torch.zeros(audio_embed_dim, device=device).float().unsqueeze(0).repeat(batch_size, 1).unsqueeze(1) # empty_audio_embedding = torch.zeros((1500,1280), device=device).float().unsqueeze(0).repeat(batch_size, 1).unsqueeze(1) empty_audio_embedding = torch.zeros((batch_size, 1500, 1280), device=device).float() if not train_config.ldm_pretraining: if 'audio' in condition_types: with torch.no_grad(): audio_embeddings = audio_model.get_audio_embeddings(cond_input) text_drop_prob = condition_config.audio_condition_config.cond_drop_prob text_drop_mask = torch.zeros((images.shape[0]), device=images.device).float().uniform_(0, 1) < text_drop_prob audio_embeddings[text_drop_mask, :, :] = empty_audio_embedding[0] else: audio_embeddings = empty_audio_embedding # Sample random noise noise = torch.randn_like(images).to(device) # Sample timestep t = torch.randint(0, diffusion_config.num_timesteps, (images.shape[0],)).to(device) # Add noise to images according to timestep noisy_images = scheduler.add_noise(images, noise, t) noise_pred = model(noisy_images, t, cond_input=audio_embeddings) loss = criterion(noise_pred, noise) losses.append(loss.item()) loss = loss / acc_steps loss.backward() if total_steps % acc_steps == 0: optimizer.step() optimizer.zero_grad() if total_steps % acc_steps == 0: optimizer.step() optimizer.zero_grad() print(f'Finished epoch:{epoch_idx + 1}/{num_epochs} | Loss : {np.mean(losses):.4f}') epoch_time = time.time() - start_time_epoch logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log}) total_training_time = time.time() - start_time_total save_ldm_checkpoint(checkpoint_path, total_steps, epoch_idx + 1, model, optimizer, metrics, logs, total_training_time) infer(Config) # Checking to conntinue training train_continue = DotDict.from_dict(yaml.safe_load(open(config_path, 'r'))) if train_continue.training.continue_ldm == False: print('LDM Training Stoped ...') break print('Done Training ...') # ================================================================== # L D M - S A M P L I N G # ================================================================== def sample(model, scheduler, train_config, diffusion_model_config, autoencoder_model_config, diffusion_config, dataset_config, vae, audio_model ): r""" Sample stepwise by going backward one timestep at a time. We save the x0 predictions """ im_size = dataset_config.im_size // 2**sum(autoencoder_model_config.down_sample) xt = torch.randn((train_config.num_samples, autoencoder_model_config.z_channels, im_size, im_size)).to(device) audio_embed_dim = diffusion_model_config.condition_config.audio_condition_config.audio_embed_dim # empty_audio_embedding = torch.zeros(audio_embed_dim, device=device).float() # empty_audio_embedding = torch.zeros(audio_embed_dim, device=device).float().unsqueeze(0) # empty_audio_embedding = empty_audio_embedding.repeat(train_config.num_samples, 1).unsqueeze(1) empty_audio_embedding = torch.zeros((train_config.num_samples, 1500, 1280), device=device).float() if not train_config.ldm_pretraining: # Create Conditional input pass else: audio_embeddings = empty_audio_embedding uncond_input = empty_audio_embedding cond_input = audio_embeddings save_count = 0 for i in tqdm(reversed(range(diffusion_config.num_timesteps)), total=diffusion_config.num_timesteps, colour='blue', desc="Sampling", dynamic_ncols=True): # Get prediction of noise t = (torch.ones((xt.shape[0],)) * i).long().to(device) # t = torch.as_tensor(i).unsqueeze(0).to(device) noise_pred_cond = model(xt, t, cond_input) cf_guidance_scale = train_config.cf_guidance_scale if cf_guidance_scale > 1: noise_pred_uncond = model(xt, t, uncond_input) noise_pred = noise_pred_uncond + cf_guidance_scale * (noise_pred_cond - noise_pred_uncond) else: noise_pred = noise_pred_cond # Use scheduler to get x0 and xt-1 xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) # Save x0 #ims = torch.clamp(xt, -1., 1.).detach().cpu() if i == 0: # Decode ONLY the final iamge to save time ims = vae.decode(xt) else: # ims = xt ims = x0_pred ims = torch.clamp(ims, -1., 1.).detach().cpu() ims = (ims + 1) / 2 grid = make_grid(ims, nrow=train_config.num_grid_rows) # img = tv.transforms.ToPILImage()(grid) np_img = (grid * 255).byte().numpy().transpose(1, 2, 0) img = Image.fromarray(np_img) if not os.path.exists(os.path.join(train_config.task_name, 'samplesH')): os.makedirs(os.path.join(train_config.task_name, 'samplesH'), exist_ok=True) img.save(os.path.join(train_config.task_name, 'samplesH', 'x0_{}.png'.format(i))) img.close() def infer(Config): diffusion_config = Config.diffusion_params dataset_config = Config.dataset_params diffusion_model_config = Config.ldm_params autoencoder_model_config = Config.autoencoder_params train_config = Config.train_params # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps) scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, beta_start=diffusion_config.beta_start, beta_end=diffusion_config.beta_end) model = Unet(im_channels=autoencoder_model_config.z_channels, model_config=diffusion_model_config).eval().to(device) vae = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_model_config).eval().to(device) if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)): checkpoint_path = os.path.join(train_config.task_name, train_config.ldm_ckpt_name) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) vae.load_state_dict(checkpoint["vae_state_dict"]) print('Loaded unet & vae checkpoint') # Create output directories if not os.path.exists(train_config.task_name): os.makedirs(train_config.task_name, exist_ok=True) with torch.no_grad(): sample(model, scheduler, train_config, diffusion_model_config, autoencoder_model_config, diffusion_config, dataset_config, vae, None) # ================================================================== # S T A R T I N G - L D M - T R A I N I N G # ================================================================== # trainLDM(Config, dataloader) if sys.argv[1] == 'train_vae': trainVAE(Config, dataloader) elif sys.argv[1] == 'train_ldm': trainLDM(Config, dataloader) else: infer(Config) # git add . && git commit -m "LDM" && git push -u origin master # huggingface-cli upload alpha31476/Vaani-Audio2Img-LDM . --commit-message "SDFT"