# ================================================================== # V Q - V A E T R A I N I N G # ================================================================== # Author : Ashish Kumar Uchadiya # Created : November 3, 2024 # Description: This script implements the training of a VQ-VAE model for # image reconstruction. It uses LPIPS (Learned Perceptual Image Patch Similarity) # loss to capture perceptual differences and PatchGAN loss to enforce local # realism. The model maps images to a discrete latent space and reconstructs # high-fidelity outputs by minimizing these combined losses. # ================================================================== # I M P O R T S # ================================================================== import os # os.environ["CUDA_VISIBLE_DEVICES"] = "1" import torch import torch.nn as nn import numpy as np from collections import namedtuple import pandas as pd import torchvision as tv from torchvision.transforms import v2 from tqdm.auto import tqdm, trange import matplotlib.pyplot as plt from PIL import Image import re import glob import sys import yaml import random import datetime import torch.hub from torch.utils.data import Dataset, DataLoader from torchvision.utils import make_grid print("TIME:", datetime.datetime.now()) os.environ["CUDA_VISIBLE_DEVICES"] = f"{sys.argv[1]}" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("DEVICE:", device) # ================================================================== # H E L P E R S # ================================================================== from typing import Any from argparse import Namespace import typing class DotDict(Namespace): """A simple class that builds upon `argparse.Namespace` in order to make chained attributes possible.""" def __init__(self, temp=False, key=None, parent=None) -> None: self._temp = temp self._key = key self._parent = parent def __eq__(self, other): if not isinstance(other, DotDict): return NotImplemented return vars(self) == vars(other) def __getattr__(self, __name: str) -> Any: if __name not in self.__dict__ and not self._temp: self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self) else: del self._parent.__dict__[self._key] raise AttributeError("No attribute '%s'" % __name) return self.__dict__[__name] def __repr__(self) -> str: item_keys = [k for k in self.__dict__ if not k.startswith("_")] if len(item_keys) == 0: return "DotDict()" elif len(item_keys) == 1: key = item_keys[0] val = self.__dict__[key] return "DotDict(%s=%s)" % (key, repr(val)) else: return "DotDict(%s)" % ", ".join( "%s=%s" % (key, repr(val)) for key, val in self.__dict__.items() ) @classmethod def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict": """Create a DotDict from a (possibly nested) dict `original`. Warning: this method should not be used on very deeply nested inputs, since it's recursively traversing the nested dictionary values. """ dd = DotDict() for key, value in original.items(): if isinstance(value, typing.Mapping): value = cls.from_dict(value) setattr(dd, key, value) return dd # ================================================================== # L P I P S # ================================================================== class vgg16(nn.Module): def __init__(self): super(vgg16, self).__init__() vgg_pretrained_features = tv.models.vgg16( weights=tv.models.VGG16_Weights.IMAGENET1K_V1 ).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.eval() for param in self.parameters(): param.requires_grad = False def forward(self, X): h1 = self.slice1(X) h2 = self.slice2(h1) h3 = self.slice3(h2) h4 = self.slice4(h3) h5 = self.slice5(h4) vgg_outputs = namedtuple("VggOutputs", ['h1', 'h2', 'h3', 'h4', 'h5']) out = vgg_outputs(h1, h2, h3, h4, h5) return out def _spatial_average(in_tens, keepdim=True): return in_tens.mean([2, 3], keepdim=keepdim) def _normalize_tensor(in_feat, eps= 1e-8): norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True)) return in_feat / norm_factor class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() # Imagnet normalization for (0-1) # mean = [0.485, 0.456, 0.406] # std = [0.229, 0.224, 0.225] self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): ''' A single linear layer which does a 1x1 conv ''' def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(), ] if (use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class LPIPS(nn.Module): def __init__(self, net='vgg', version='0.1', use_dropout=True): super(LPIPS, self).__init__() self.version = version self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] self.L = len(self.chns) self.net = vgg16() self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = nn.ModuleList([self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]) # --- Orignal url -------------------- # weights_url = f"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v{version}/{net}.pth" # --- Orignal Forked url ------------- weights_url = f"https://github.com/akuresonite/PerceptualSimilarity-Forked/raw/master/lpips/weights/v{version}/{net}.pth" # --- Orignal torchmetric url -------- # weights_url = "https://github.com/Lightning-AI/torchmetrics/raw/master/src/torchmetrics/functional/image/lpips_models/vgg.pth" state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') self.load_state_dict(state_dict, strict=False) self.eval() for param in self.parameters(): param.requires_grad = False def forward(self, in0, in1, normalize=False): # Scale the inputs to -1 to +1 range if input in [0,1] if normalize: in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) # in0_input, in1_input = in0, in1 outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) diffs = {} for kk in range(self.L): feats0 = _normalize_tensor(outs0[kk]) feats1 = _normalize_tensor(outs1[kk]) diffs[kk] = (feats0 - feats1) ** 2 res = [_spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] val = sum(res) return val.reshape(-1) # ================================================================== # P A T C H - G A N - D I S C R I M I N A T O R # ================================================================== class Discriminator(nn.Module): r""" PatchGAN Discriminator. Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to 1 scalar value , we instead predict grid of values. Where each grid is prediction of how likely the discriminator thinks that the image patch corresponding to the grid cell is real """ def __init__( self, im_channels=3, conv_channels=[64, 128, 256], kernels=[4, 4, 4, 4], strides=[2, 2, 2, 1], paddings=[1, 1, 1, 1], ): super().__init__() self.im_channels = im_channels activation = nn.LeakyReLU(0.2) layers_dim = [self.im_channels] + conv_channels + [1] self.layers = nn.ModuleList( [ nn.Sequential( nn.Conv2d( layers_dim[i], layers_dim[i + 1], kernel_size=kernels[i], stride=strides[i], padding=paddings[i], bias=False if i != 0 else True, ), ( nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity() ), activation if i != len(layers_dim) - 2 else nn.Identity(), ) for i in range(len(layers_dim) - 1) ] ) def forward(self, x): out = x for layer in self.layers: out = layer(out) return out # ================================================================== # D O W E - B L O C K # ================================================================== class DownBlock(nn.Module): r""" Down conv block with attention. Sequence of following block 1. Resnet block with time embedding 2. Attention block 3. Downsample """ def __init__( self, in_channels, out_channels, t_emb_dim, down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None, ): super().__init__() self.num_layers = num_layers self.down_sample = down_sample self.attn = attn self.context_dim = context_dim self.cross_attn = cross_attn self.t_emb_dim = t_emb_dim self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d( in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1, ), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels)) for _ in range(num_layers) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.down_sample_conv = ( nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity() ) def forward(self, x, t_emb=None, context=None): out = x for i in range(self.num_layers): # Resnet block of Unet resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) if self.attn: # Attention block of Unet batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: assert ( context is not None ), "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Downsample out = self.down_sample_conv(out) return out # ================================================================== # M I D - B L O C K # ================================================================== class MidBlock(nn.Module): r""" Mid conv block with attention. Sequence of following blocks 1. Resnet block with time embedding 2. Attention block 3. Resnet block with time embedding """ def __init__( self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None, ): super().__init__() self.num_layers = num_layers self.t_emb_dim = t_emb_dim self.context_dim = context_dim self.cross_attn = cross_attn self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d( in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1, ), ) for i in range(num_layers + 1) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) for _ in range(num_layers + 1) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers + 1) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers + 1) ] ) def forward(self, x, t_emb=None, context=None): out = x # First resnet block resnet_input = out out = self.resnet_conv_first[0](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] out = self.resnet_conv_second[0](out) out = out + self.residual_input_conv[0](resnet_input) for i in range(self.num_layers): # Attention Block batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: assert ( context is not None ), "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Resnet Block resnet_input = out out = self.resnet_conv_first[i + 1](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] out = self.resnet_conv_second[i + 1](out) out = out + self.residual_input_conv[i + 1](resnet_input) return out # ================================================================== # U P - B L O C K # ================================================================== class UpBlock(nn.Module): r""" Up conv block with attention. Sequence of following blocks 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block """ def __init__( self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, attn, norm_channels, ): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.attn = attn self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d( in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1, ), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) for _ in range(num_layers) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.up_sample_conv = ( nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1) if self.up_sample else nn.Identity() ) def forward(self, x, out_down=None, t_emb=None): # Upsample x = self.up_sample_conv(x) # Concat with Downblock output if out_down is not None: x = torch.cat([x, out_down], dim=1) out = x for i in range(self.num_layers): # Resnet Block resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) # Self Attention if self.attn: batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn return out # ================================================================== # V Q - V A E # ================================================================== class VQVAE(nn.Module): def __init__(self, im_channels, model_config): super().__init__() self.down_channels = model_config.down_channels self.mid_channels = model_config.mid_channels self.down_sample = model_config.down_sample self.num_down_layers = model_config.num_down_layers self.num_mid_layers = model_config.num_mid_layers self.num_up_layers = model_config.num_up_layers # To disable attention in Downblock of Encoder and Upblock of Decoder self.attns = model_config.attn_down # Latent Dimension self.z_channels = model_config.z_channels self.codebook_size = model_config.codebook_size self.norm_channels = model_config.norm_channels self.num_heads = model_config.num_heads # Assertion to validate the channel information assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-1] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 # Wherever we use downsampling in encoder correspondingly use # upsampling in decoder self.up_sample = list(reversed(self.down_sample)) ##################### Encoder ###################### self.encoder_conv_in = nn.Conv2d( im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1) ) # Downblock + Midblock self.encoder_layers = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): self.encoder_layers.append( DownBlock( self.down_channels[i], self.down_channels[i + 1], t_emb_dim=None, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels, ) ) self.encoder_mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.encoder_mids.append( MidBlock( self.mid_channels[i], self.mid_channels[i + 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels, ) ) self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) self.encoder_conv_out = nn.Conv2d( self.down_channels[-1], self.z_channels, kernel_size=3, padding=1 ) # Pre Quantization Convolution self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) # Codebook self.embedding = nn.Embedding(self.codebook_size, self.z_channels) #################################################### ##################### Decoder ###################### # Post Quantization Convolution self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) self.decoder_conv_in = nn.Conv2d( self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1) ) # Midblock + Upblock self.decoder_mids = nn.ModuleList([]) for i in reversed(range(1, len(self.mid_channels))): self.decoder_mids.append( MidBlock( self.mid_channels[i], self.mid_channels[i - 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels, ) ) self.decoder_layers = nn.ModuleList([]) for i in reversed(range(1, len(self.down_channels))): self.decoder_layers.append( UpBlock( self.down_channels[i], self.down_channels[i - 1], t_emb_dim=None, up_sample=self.down_sample[i - 1], num_heads=self.num_heads, num_layers=self.num_up_layers, attn=self.attns[i - 1], norm_channels=self.norm_channels, ) ) self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) self.decoder_conv_out = nn.Conv2d( self.down_channels[0], im_channels, kernel_size=3, padding=1 ) def quantize(self, x): B, C, H, W = x.shape # B, C, H, W -> B, H, W, C x = x.permute(0, 2, 3, 1) # B, H, W, C -> B, H*W, C x = x.reshape(x.size(0), -1, x.size(-1)) # Find nearest embedding/codebook vector # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K) dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) # (B, H*W) min_encoding_indices = torch.argmin(dist, dim=-1) # Replace encoder output with nearest codebook # quant_out -> B*H*W, C quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) # x -> B*H*W, C x = x.reshape((-1, x.size(-1))) commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) codebook_loss = torch.mean((quant_out - x.detach()) ** 2) quantize_losses = {"codebook_loss": codebook_loss, "commitment_loss": commmitment_loss} # Straight through estimation quant_out = x + (quant_out - x).detach() # quant_out -> B, C, H, W quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) min_encoding_indices = min_encoding_indices.reshape( (-1, quant_out.size(-2), quant_out.size(-1)) ) return quant_out, quantize_losses, min_encoding_indices def encode(self, x): out = self.encoder_conv_in(x) for idx, down in enumerate(self.encoder_layers): out = down(out) for mid in self.encoder_mids: out = mid(out) out = self.encoder_norm_out(out) out = nn.SiLU()(out) out = self.encoder_conv_out(out) out = self.pre_quant_conv(out) out, quant_losses, _ = self.quantize(out) return out, quant_losses def decode(self, z): out = z out = self.post_quant_conv(out) out = self.decoder_conv_in(out) for mid in self.decoder_mids: out = mid(out) for idx, up in enumerate(self.decoder_layers): out = up(out) out = self.decoder_norm_out(out) out = nn.SiLU()(out) out = self.decoder_conv_out(out) return out def forward(self, x): '''out: [B, 3, 256, 256] z: [B, 3, 64, 64] quant_losses: { codebook_loss: 0.0681, commitment_loss: 0.0681 } ''' z, quant_losses = self.encode(x) out = self.decode(z) return out, z, quant_losses # ================================================================== # C O N F I G U R A T I O N # ================================================================== import pprint config_path = "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/config-VQVAE.yaml" with open(config_path, 'r') as file: Config = yaml.safe_load(file) pprint.pprint(Config, width=120) Config = DotDict.from_dict(Config) dataset_config = Config.dataset_params diffusion_config = Config.diffusion_params model_config = Config.model_params train_config = Config.train_params paths = Config.paths # ================================================================== # V A A N I - D A T A S E T # ================================================================== IMAGES_PATH = paths.images_dir def walkDIR(folder_path, include=None): file_list = [] for root, _, files in os.walk(folder_path): for file in files: if include is None or any(file.endswith(ext) for ext in include): file_list.append(os.path.join(root, file)) print("Files found:", len(file_list)) return file_list files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg']) df = pd.DataFrame(files, columns=['image_path']) class VaaniDataset(torch.utils.data.Dataset): def __init__(self, files_paths, im_size): self.files_paths = files_paths self.im_size = im_size def __len__(self): return len(self.files_paths) def __getitem__(self, idx): image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB) # image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB) image = v2.Resize((self.im_size,self.im_size))(image) image = v2.ToDtype(torch.float32, scale=True)(image) # image = 2*image - 1 return image dataset = VaaniDataset(files_paths=files, im_size=dataset_config.im_size) image = dataset[2] print('IMAGE SHAPE:', image.shape) # s = 0.001 # dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42)) # print("Length of Train dataset:", len(dataset)) dataloader = torch.utils.data.DataLoader( dataset, batch_size=train_config.autoencoder_batch_size, shuffle=True, num_workers=10, pin_memory=False, drop_last=True, persistent_workers=True ) images = next(iter(dataloader)) print('BATCH SHAPE:', images.shape) # ================================================================== # M O D E L - I N I T I L I Z A T I O N # ================================================================== dataset_config = Config.dataset_params autoencoder_config = Config.autoencoder_params train_config = Config.train_params model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device) # model_output = model(images.to(device)) # print('MODEL OUTPUT:') # print(model_output[0].shape, model_output[1].shape, model_output[2]) # from torchinfo import summary # summary(model=model, # input_data=images.to(device), # # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE), # col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"], # col_width=20, # row_settings=["var_names"], # depth = 6, # # device=device # ) # exit() # ================================================================== # V Q - V A E - T R A I N I N G # ================================================================== # python your_script.py 2>&1 > training.log import time def format_time(t1, t2): elapsed_time = t2 - t1 if elapsed_time < 60: return f"{elapsed_time:.2f} seconds" elif elapsed_time < 3600: minutes = elapsed_time // 60 seconds = elapsed_time % 60 return f"{minutes:.0f} minutes {seconds:.2f} seconds" elif elapsed_time < 86400: hours = elapsed_time // 3600 remainder = elapsed_time % 3600 minutes = remainder // 60 seconds = remainder % 60 return f"{hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" else: days = elapsed_time // 86400 remainder = elapsed_time % 86400 hours = remainder // 3600 remainder = remainder % 3600 minutes = remainder // 60 seconds = remainder % 60 return f"{days:.0f} days {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" def save_vae_checkpoint( total_steps, epoch, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path, logs, total_training_time ): checkpoint = { "total_steps": total_steps, "epoch": epoch, "model_state_dict": model.state_dict(), "discriminator_state_dict": discriminator.state_dict(), "optimizer_d_state_dict": optimizer_d.state_dict(), "optimizer_g_state_dict": optimizer_g.state_dict(), "metrics": metrics, "logs": logs, "total_training_time": total_training_time } checkpoint_file = f"{checkpoint_path}_epoch{epoch}.pt" torch.save(checkpoint, checkpoint_file) print(f"Checkpoint saved at {checkpoint_file}") all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") def extract_epoch(filename): match = re.search(r"_epoch(\d+)\.pt", filename) return int(match.group(1)) if match else -1 all_ckpts = sorted(all_ckpts, key=extract_epoch) for old_ckpt in all_ckpts[:-2]: os.remove(old_ckpt) print(f"Removed old checkpoint: {old_ckpt}") def load_vae_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g, device=device): all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") if not all_ckpts: print("No checkpoint found. Starting from scratch.") return 0, 0, None, [], 0 def extract_epoch(filename): match = re.search(r"_epoch(\d+)\.pt", filename) return int(match.group(1)) if match else -1 all_ckpts = sorted(all_ckpts, key=extract_epoch) latest_ckpt = all_ckpts[-1] if os.path.exists(latest_ckpt): checkpoint = torch.load(latest_ckpt, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"]) optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"]) total_steps = checkpoint["total_steps"] epoch = checkpoint["epoch"] metrics = checkpoint["metrics"] logs = checkpoint.get("logs", []) total_training_time = checkpoint.get("total_training_time", 0) print(f"Checkpoint loaded from {latest_ckpt}. Resuming from epoch {epoch + 1}, step {total_steps}") return total_steps, epoch + 1, metrics, logs, total_training_time def inference(model, dataset, save_path, epoch, device="cuda", sample_size=8): if not os.path.exists(save_path): os.makedirs(save_path) image_tensors = [] for i in range(sample_size): image_tensors.append(dataset[i].unsqueeze(0)) image_tensors = torch.cat(image_tensors, dim=0).to(device) with torch.no_grad(): outputs, _, _ = model(image_tensors) save_input = image_tensors.detach().cpu() save_output = outputs.detach().cpu() grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size, normalize=True, scale_each=True) np_img = (grid * 255).byte().numpy().transpose(1, 2, 0) combined_image = Image.fromarray(np_img) # combined_image = tv.transforms.ToPILImage()(grid) combined_image.save(os.path.join(save_path, f"reconstructed_images_EP-{epoch}_{sample_size}.png")) print(f"Reconstructed images saved at: {save_path}") def trainVAE(Config, dataloader): dataset_config = Config.dataset_params autoencoder_config = Config.autoencoder_params train_config = Config.train_params paths = Config.paths seed = train_config.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if device == "cuda": torch.cuda.manual_seed_all(seed) model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device) discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device) optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) checkpoint_path = os.path.join(train_config.task_name, "vqvae_ckpt") total_steps, start_epoch, metrics, logs, total_training_time = load_vae_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g) if not os.path.exists(train_config.task_name): os.mkdir(train_config.task_name) num_epochs = train_config.autoencoder_epochs recon_criterion = torch.nn.MSELoss() disc_criterion = torch.nn.MSELoss() lpips_model = LPIPS().eval().to(device) acc_steps = train_config.autoencoder_acc_steps disc_step_start = train_config.disc_start start_time_total = time.time() - total_training_time model.train() for epoch_idx in trange(start_epoch, num_epochs, colour='red', dynamic_ncols=True): start_time_epoch = time.time() epoch_log = [] optimizer_g.zero_grad() optimizer_d.zero_grad() for images in tqdm(dataloader, colour='green', dynamic_ncols=True): batch_start_time = time.time() total_steps += 1 images = images.to(device) model_output = model(images) output, z, quantize_losses = model_output recon_loss = recon_criterion(output, images) / acc_steps g_loss = ( recon_loss + (train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps) + (train_config.commitment_beta * quantize_losses["commitment_loss"] / acc_steps) ) if total_steps > disc_step_start: disc_fake_pred = discriminator(output) disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) g_loss += train_config.disc_weight * disc_fake_loss / acc_steps lpips_loss = torch.mean(lpips_model(output, images)) / acc_steps g_loss += train_config.perceptual_weight * lpips_loss g_loss.backward() if total_steps % acc_steps == 0: optimizer_g.step() optimizer_g.zero_grad() if total_steps > disc_step_start: disc_fake_pred = discriminator(output.detach()) disc_real_pred = discriminator(images) # disc_loss = (disc_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) + # disc_criterion(disc_real_pred, torch.ones_like(disc_real_pred))) / 2 / acc_steps disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device)) disc_real_loss = disc_criterion(disc_real_pred, torch.ones(disc_real_pred.shape, device=disc_real_pred.device)) disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 / acc_steps disc_loss.backward() if total_steps % acc_steps == 0: optimizer_d.step() optimizer_d.zero_grad() if total_steps % acc_steps == 0: optimizer_g.step() optimizer_g.zero_grad() batch_time = time.time() - batch_start_time epoch_log.append(format_time(0, batch_time)) optimizer_d.step() optimizer_d.zero_grad() optimizer_g.step() optimizer_g.zero_grad() epoch_time = time.time() - start_time_epoch logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log}) total_training_time = time.time() - start_time_total save_vae_checkpoint(total_steps, epoch_idx + 1, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path, logs, total_training_time) recon_save_path = os.path.join(train_config.task_name, 'vqvae_recon') inference(model, dataset, recon_save_path, epoch=epoch_idx, device=device, sample_size=16) # Checking to conntinue training train_continue = DotDict.from_dict(yaml.safe_load(open(config_path, 'r'))) if train_continue.training.continue_vqvae == False: print('VQVAE Training Stoped ...') break print("Training completed.") # ================================================================== # S T A R T I N G - T R A I N I N G # ================================================================== trainVAE(Config, dataloader) # python Vaani-VQVAE-Main.py | tee AE-training.log # python Vaani-VQVAE-Main.py > AE-training.log 2>&1 # huggingface-cli upload alpha31476/Vaani-Audio2Img-LDM . --commit-message "LDM-train-pass, checking results"