{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "490e66c1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Script Execution Time: 2025-05-10 02:33:55.745055\n", "cuda\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/IITB/ai-at-ieor/23m1521/.conda/envs/flash2/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "# ==================================================================\n", "# V Q - V A E T R A I N I N G\n", "# ==================================================================\n", "# Author : Ashish Kumar Uchadiya\n", "# Created : November 3, 2024\n", "# Description: This script implements the training of a VQ-VAE model for\n", "# image reconstruction. It uses LPIPS (Learned Perceptual Image Patch Similarity)\n", "# loss to capture perceptual differences and PatchGAN loss to enforce local\n", "# realism. The model maps images to a discrete latent space and reconstructs\n", "# high-fidelity outputs by minimizing these combined losses.\n", "# ==================================================================\n", "# I M P O R T S\n", "# ==================================================================\n", "\n", "\n", "import os\n", "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", "\n", "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "from collections import namedtuple\n", "\n", "import pandas as pd\n", "import torchvision as tv\n", "from torchvision.transforms import v2\n", "from tqdm import tqdm, trange\n", "import matplotlib.pyplot as plt\n", "\n", "import yaml\n", "import random\n", "import datetime\n", "import torch.hub\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision.utils import make_grid\n", "\n", "from accelerate import Accelerator\n", "\n", "import datetime\n", "print(\"Script Execution Time:\", datetime.datetime.now())\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(device)" ] }, { "cell_type": "code", "execution_count": 2, "id": "adca2c76", "metadata": {}, "outputs": [], "source": [ "\n", "# ==================================================================\n", "# H E L P E R S\n", "# ==================================================================\n", "from typing import Any\n", "from argparse import Namespace\n", "import typing\n", "\n", "\n", "class DotDict(Namespace):\n", " \"\"\"A simple class that builds upon `argparse.Namespace`\n", " in order to make chained attributes possible.\"\"\"\n", "\n", " def __init__(self, temp=False, key=None, parent=None) -> None:\n", " self._temp = temp\n", " self._key = key\n", " self._parent = parent\n", "\n", " def __eq__(self, other):\n", " if not isinstance(other, DotDict):\n", " return NotImplemented\n", " return vars(self) == vars(other)\n", "\n", " def __getattr__(self, __name: str) -> Any:\n", " if __name not in self.__dict__ and not self._temp:\n", " self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self)\n", " else:\n", " del self._parent.__dict__[self._key]\n", " raise AttributeError(\"No attribute '%s'\" % __name)\n", " return self.__dict__[__name]\n", "\n", " def __repr__(self) -> str:\n", " item_keys = [k for k in self.__dict__ if not k.startswith(\"_\")]\n", "\n", " if len(item_keys) == 0:\n", " return \"DotDict()\"\n", " elif len(item_keys) == 1:\n", " key = item_keys[0]\n", " val = self.__dict__[key]\n", " return \"DotDict(%s=%s)\" % (key, repr(val))\n", " else:\n", " return \"DotDict(%s)\" % \", \".join(\n", " \"%s=%s\" % (key, repr(val)) for key, val in self.__dict__.items()\n", " )\n", "\n", " @classmethod\n", " def from_dict(cls, original: typing.Mapping[str, any]) -> \"DotDict\":\n", " \"\"\"Create a DotDict from a (possibly nested) dict `original`.\n", " Warning: this method should not be used on very deeply nested inputs,\n", " since it's recursively traversing the nested dictionary values.\n", " \"\"\"\n", " dd = DotDict()\n", " for key, value in original.items():\n", " if isinstance(value, typing.Mapping):\n", " value = cls.from_dict(value)\n", " setattr(dd, key, value)\n", " return dd\n", " \n", " \n", "# ==================================================================\n", "# L P I P S \n", "# ==================================================================\n", "class vgg16(nn.Module):\n", " def __init__(self):\n", " super(vgg16, self).__init__()\n", " vgg_pretrained_features = tv.models.vgg16(\n", " weights=tv.models.VGG16_Weights.IMAGENET1K_V1\n", " ).features\n", " self.slice1 = torch.nn.Sequential()\n", " self.slice2 = torch.nn.Sequential()\n", " self.slice3 = torch.nn.Sequential()\n", " self.slice4 = torch.nn.Sequential()\n", " self.slice5 = torch.nn.Sequential()\n", " self.N_slices = 5\n", " for x in range(4):\n", " self.slice1.add_module(str(x), vgg_pretrained_features[x])\n", " for x in range(4, 9):\n", " self.slice2.add_module(str(x), vgg_pretrained_features[x])\n", " for x in range(9, 16):\n", " self.slice3.add_module(str(x), vgg_pretrained_features[x])\n", " for x in range(16, 23):\n", " self.slice4.add_module(str(x), vgg_pretrained_features[x])\n", " for x in range(23, 30):\n", " self.slice5.add_module(str(x), vgg_pretrained_features[x])\n", " \n", " self.eval()\n", " for param in self.parameters():\n", " param.requires_grad = False\n", "\n", " def forward(self, X):\n", " h1 = self.slice1(X)\n", " h2 = self.slice2(h1)\n", " h3 = self.slice3(h2)\n", " h4 = self.slice4(h3)\n", " h5 = self.slice5(h4)\n", " vgg_outputs = namedtuple(\"VggOutputs\", ['h1', 'h2', 'h3', 'h4', 'h5'])\n", " out = vgg_outputs(h1, h2, h3, h4, h5)\n", " return out\n", "\n", "\n", "def _spatial_average(in_tens, keepdim=True):\n", " return in_tens.mean([2, 3], keepdim=keepdim)\n", "\n", "\n", "def _normalize_tensor(in_feat, eps= 1e-8):\n", " norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True))\n", " return in_feat / norm_factor\n", "\n", "\n", "class ScalingLayer(nn.Module):\n", " def __init__(self):\n", " super(ScalingLayer, self).__init__()\n", " # Imagnet normalization for (0-1)\n", " # mean = [0.485, 0.456, 0.406]\n", " # std = [0.229, 0.224, 0.225]\n", "\n", " self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])\n", " self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])\n", "\n", " def forward(self, inp):\n", " return (inp - self.shift) / self.scale\n", "\n", "\n", "class NetLinLayer(nn.Module):\n", " ''' A single linear layer which does a 1x1 conv '''\n", " def __init__(self, chn_in, chn_out=1, use_dropout=False):\n", " super(NetLinLayer, self).__init__()\n", " layers = [nn.Dropout(), ] if (use_dropout) else []\n", " layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]\n", " self.model = nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " return self.model(x)\n", "\n", "\n", "class LPIPS(nn.Module):\n", " def __init__(self, net='vgg', version='0.1', use_dropout=True):\n", " super(LPIPS, self).__init__()\n", " self.version = version\n", " self.scaling_layer = ScalingLayer()\n", " self.chns = [64, 128, 256, 512, 512]\n", " self.L = len(self.chns)\n", " self.net = vgg16()\n", " self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)\n", " self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)\n", " self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)\n", " self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)\n", " self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)\n", " self.lins = nn.ModuleList([self.lin0, self.lin1, self.lin2, self.lin3, self.lin4])\n", "\n", " # --- Orignal url --------------------\n", " # weights_url = f\"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v{version}/{net}.pth\"\n", " \n", " # --- Orignal Forked url -------------\n", " weights_url = f\"https://github.com/akuresonite/PerceptualSimilarity-Forked/raw/master/lpips/weights/v{version}/{net}.pth\"\n", " \n", " # --- Orignal torchmetric url --------\n", " # weights_url = \"https://github.com/Lightning-AI/torchmetrics/raw/master/src/torchmetrics/functional/image/lpips_models/vgg.pth\"\n", " \n", " state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu')\n", " self.load_state_dict(state_dict, strict=False)\n", " \n", " self.eval()\n", " for param in self.parameters():\n", " param.requires_grad = False\n", "\n", " def forward(self, in0, in1, normalize=False):\n", " # Scale the inputs to -1 to +1 range if input in [0,1]\n", " if normalize:\n", " in0 = 2 * in0 - 1\n", " in1 = 2 * in1 - 1\n", "\n", " in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)\n", " # in0_input, in1_input = in0, in1\n", " \n", " outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)\n", " \n", " diffs = {}\n", " for kk in range(self.L):\n", " feats0 = _normalize_tensor(outs0[kk])\n", " feats1 = _normalize_tensor(outs1[kk])\n", " diffs[kk] = (feats0 - feats1) ** 2\n", " \n", " res = [_spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]\n", " val = sum(res)\n", " return val.reshape(-1)\n", "\n", "\n", "# ==================================================================\n", "# P A T C H - G A N - D I S C R I M I N A T O R\n", "# ==================================================================\n", "class Discriminator(nn.Module):\n", " r\"\"\"\n", " PatchGAN Discriminator.\n", " Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to\n", " 1 scalar value , we instead predict grid of values.\n", " Where each grid is prediction of how likely\n", " the discriminator thinks that the image patch corresponding\n", " to the grid cell is real\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " im_channels=3,\n", " conv_channels=[64, 128, 256],\n", " kernels=[4, 4, 4, 4],\n", " strides=[2, 2, 2, 1],\n", " paddings=[1, 1, 1, 1],\n", " ):\n", " super().__init__()\n", " self.im_channels = im_channels\n", " activation = nn.LeakyReLU(0.2)\n", " layers_dim = [self.im_channels] + conv_channels + [1]\n", " self.layers = nn.ModuleList(\n", " [\n", " nn.Sequential(\n", " nn.Conv2d(\n", " layers_dim[i],\n", " layers_dim[i + 1],\n", " kernel_size=kernels[i],\n", " stride=strides[i],\n", " padding=paddings[i],\n", " bias=False if i != 0 else True,\n", " ),\n", " (\n", " nn.BatchNorm2d(layers_dim[i + 1])\n", " if i != len(layers_dim) - 2 and i != 0\n", " else nn.Identity()\n", " ),\n", " activation if i != len(layers_dim) - 2 else nn.Identity(),\n", " )\n", " for i in range(len(layers_dim) - 1)\n", " ]\n", " )\n", "\n", " def forward(self, x):\n", " out = x\n", " for layer in self.layers:\n", " out = layer(out)\n", " return out\n", "\n", "\n", "\n", "# ==================================================================\n", "# D O W E - B L O C K\n", "# ==================================================================\n", "class DownBlock(nn.Module):\n", " r\"\"\"\n", " Down conv block with attention.\n", " Sequence of following block\n", " 1. Resnet block with time embedding\n", " 2. Attention block\n", " 3. Downsample\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " in_channels,\n", " out_channels,\n", " t_emb_dim,\n", " down_sample,\n", " num_heads,\n", " num_layers,\n", " attn,\n", " norm_channels,\n", " cross_attn=False,\n", " context_dim=None,\n", " ):\n", " super().__init__()\n", " self.num_layers = num_layers\n", " self.down_sample = down_sample\n", " self.attn = attn\n", " self.context_dim = context_dim\n", " self.cross_attn = cross_attn\n", " self.t_emb_dim = t_emb_dim\n", " self.resnet_conv_first = nn.ModuleList(\n", " [\n", " nn.Sequential(\n", " nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),\n", " nn.SiLU(),\n", " nn.Conv2d(\n", " in_channels if i == 0 else out_channels,\n", " out_channels,\n", " kernel_size=3,\n", " stride=1,\n", " padding=1,\n", " ),\n", " )\n", " for i in range(num_layers)\n", " ]\n", " )\n", " if self.t_emb_dim is not None:\n", " self.t_emb_layers = nn.ModuleList(\n", " [\n", " nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels))\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " self.resnet_conv_second = nn.ModuleList(\n", " [\n", " nn.Sequential(\n", " nn.GroupNorm(norm_channels, out_channels),\n", " nn.SiLU(),\n", " nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),\n", " )\n", " for _ in range(num_layers)\n", " ]\n", " )\n", "\n", " if self.attn:\n", " self.attention_norms = nn.ModuleList(\n", " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n", " )\n", "\n", " self.attentions = nn.ModuleList(\n", " [\n", " nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True)\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " if self.cross_attn:\n", " assert context_dim is not None, \"Context Dimension must be passed for cross attention\"\n", " self.cross_attention_norms = nn.ModuleList(\n", " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n", " )\n", " self.cross_attentions = nn.ModuleList(\n", " [\n", " nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True)\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " self.context_proj = nn.ModuleList(\n", " [nn.Linear(context_dim, out_channels) for _ in range(num_layers)]\n", " )\n", " self.residual_input_conv = nn.ModuleList(\n", " [\n", " nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)\n", " for i in range(num_layers)\n", " ]\n", " )\n", " self.down_sample_conv = (\n", " nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity()\n", " )\n", "\n", " def forward(self, x, t_emb=None, context=None):\n", " out = x\n", " for i in range(self.num_layers):\n", " # Resnet block of Unet\n", "\n", " resnet_input = out\n", " out = self.resnet_conv_first[i](out)\n", " if self.t_emb_dim is not None:\n", " out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]\n", " out = self.resnet_conv_second[i](out)\n", " out = out + self.residual_input_conv[i](resnet_input)\n", "\n", " if self.attn:\n", " # Attention block of Unet\n", "\n", " batch_size, channels, h, w = out.shape\n", " in_attn = out.reshape(batch_size, channels, h * w)\n", " in_attn = self.attention_norms[i](in_attn)\n", " in_attn = in_attn.transpose(1, 2)\n", " out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)\n", " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n", " out = out + out_attn\n", " if self.cross_attn:\n", " assert (\n", " context is not None\n", " ), \"context cannot be None if cross attention layers are used\"\n", " batch_size, channels, h, w = out.shape\n", " in_attn = out.reshape(batch_size, channels, h * w)\n", " in_attn = self.cross_attention_norms[i](in_attn)\n", " in_attn = in_attn.transpose(1, 2)\n", " assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim\n", " context_proj = self.context_proj[i](context)\n", " out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)\n", " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n", " out = out + out_attn\n", " # Downsample\n", "\n", " out = self.down_sample_conv(out)\n", " return out\n", "\n", "\n", "\n", "# ==================================================================\n", "# M I D - B L O C K\n", "# ==================================================================\n", "class MidBlock(nn.Module):\n", " r\"\"\"\n", " Mid conv block with attention.\n", " Sequence of following blocks\n", " 1. Resnet block with time embedding\n", " 2. Attention block\n", " 3. Resnet block with time embedding\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " in_channels,\n", " out_channels,\n", " t_emb_dim,\n", " num_heads,\n", " num_layers,\n", " norm_channels,\n", " cross_attn=None,\n", " context_dim=None,\n", " ):\n", " super().__init__()\n", " self.num_layers = num_layers\n", " self.t_emb_dim = t_emb_dim\n", " self.context_dim = context_dim\n", " self.cross_attn = cross_attn\n", " self.resnet_conv_first = nn.ModuleList(\n", " [\n", " nn.Sequential(\n", " nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),\n", " nn.SiLU(),\n", " nn.Conv2d(\n", " in_channels if i == 0 else out_channels,\n", " out_channels,\n", " kernel_size=3,\n", " stride=1,\n", " padding=1,\n", " ),\n", " )\n", " for i in range(num_layers + 1)\n", " ]\n", " )\n", "\n", " if self.t_emb_dim is not None:\n", " self.t_emb_layers = nn.ModuleList(\n", " [\n", " nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))\n", " for _ in range(num_layers + 1)\n", " ]\n", " )\n", " self.resnet_conv_second = nn.ModuleList(\n", " [\n", " nn.Sequential(\n", " nn.GroupNorm(norm_channels, out_channels),\n", " nn.SiLU(),\n", " nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),\n", " )\n", " for _ in range(num_layers + 1)\n", " ]\n", " )\n", "\n", " self.attention_norms = nn.ModuleList(\n", " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n", " )\n", "\n", " self.attentions = nn.ModuleList(\n", " [\n", " nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True)\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " if self.cross_attn:\n", " assert context_dim is not None, \"Context Dimension must be passed for cross attention\"\n", " self.cross_attention_norms = nn.ModuleList(\n", " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n", " )\n", " self.cross_attentions = nn.ModuleList(\n", " [\n", " nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True)\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " self.context_proj = nn.ModuleList(\n", " [nn.Linear(context_dim, out_channels) for _ in range(num_layers)]\n", " )\n", " self.residual_input_conv = nn.ModuleList(\n", " [\n", " nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)\n", " for i in range(num_layers + 1)\n", " ]\n", " )\n", "\n", " def forward(self, x, t_emb=None, context=None):\n", " out = x\n", "\n", " # First resnet block\n", "\n", " resnet_input = out\n", " out = self.resnet_conv_first[0](out)\n", " if self.t_emb_dim is not None:\n", " out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]\n", " out = self.resnet_conv_second[0](out)\n", " out = out + self.residual_input_conv[0](resnet_input)\n", "\n", " for i in range(self.num_layers):\n", " # Attention Block\n", "\n", " batch_size, channels, h, w = out.shape\n", " in_attn = out.reshape(batch_size, channels, h * w)\n", " in_attn = self.attention_norms[i](in_attn)\n", " in_attn = in_attn.transpose(1, 2)\n", " out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)\n", " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n", " out = out + out_attn\n", "\n", " if self.cross_attn:\n", " assert (\n", " context is not None\n", " ), \"context cannot be None if cross attention layers are used\"\n", " batch_size, channels, h, w = out.shape\n", " in_attn = out.reshape(batch_size, channels, h * w)\n", " in_attn = self.cross_attention_norms[i](in_attn)\n", " in_attn = in_attn.transpose(1, 2)\n", " assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim\n", " context_proj = self.context_proj[i](context)\n", " out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)\n", " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n", " out = out + out_attn\n", " # Resnet Block\n", "\n", " resnet_input = out\n", " out = self.resnet_conv_first[i + 1](out)\n", " if self.t_emb_dim is not None:\n", " out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]\n", " out = self.resnet_conv_second[i + 1](out)\n", " out = out + self.residual_input_conv[i + 1](resnet_input)\n", " return out\n", "\n", "\n", "# ==================================================================\n", "# U P - B L O C K\n", "# ==================================================================\n", "class UpBlock(nn.Module):\n", " r\"\"\"\n", " Up conv block with attention.\n", " Sequence of following blocks\n", " 1. Upsample\n", " 1. Concatenate Down block output\n", " 2. Resnet block with time embedding\n", " 3. Attention Block\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " in_channels,\n", " out_channels,\n", " t_emb_dim,\n", " up_sample,\n", " num_heads,\n", " num_layers,\n", " attn,\n", " norm_channels,\n", " ):\n", " super().__init__()\n", " self.num_layers = num_layers\n", " self.up_sample = up_sample\n", " self.t_emb_dim = t_emb_dim\n", " self.attn = attn\n", " self.resnet_conv_first = nn.ModuleList(\n", " [\n", " nn.Sequential(\n", " nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),\n", " nn.SiLU(),\n", " nn.Conv2d(\n", " in_channels if i == 0 else out_channels,\n", " out_channels,\n", " kernel_size=3,\n", " stride=1,\n", " padding=1,\n", " ),\n", " )\n", " for i in range(num_layers)\n", " ]\n", " )\n", "\n", " if self.t_emb_dim is not None:\n", " self.t_emb_layers = nn.ModuleList(\n", " [\n", " nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels))\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " self.resnet_conv_second = nn.ModuleList(\n", " [\n", " nn.Sequential(\n", " nn.GroupNorm(norm_channels, out_channels),\n", " nn.SiLU(),\n", " nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),\n", " )\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " if self.attn:\n", " self.attention_norms = nn.ModuleList(\n", " [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)]\n", " )\n", "\n", " self.attentions = nn.ModuleList(\n", " [\n", " nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True)\n", " for _ in range(num_layers)\n", " ]\n", " )\n", " self.residual_input_conv = nn.ModuleList(\n", " [\n", " nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)\n", " for i in range(num_layers)\n", " ]\n", " )\n", " self.up_sample_conv = (\n", " nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1)\n", " if self.up_sample\n", " else nn.Identity()\n", " )\n", "\n", " def forward(self, x, out_down=None, t_emb=None):\n", " # Upsample\n", "\n", " x = self.up_sample_conv(x)\n", "\n", " # Concat with Downblock output\n", "\n", " if out_down is not None:\n", " x = torch.cat([x, out_down], dim=1)\n", " out = x\n", " for i in range(self.num_layers):\n", " # Resnet Block\n", "\n", " resnet_input = out\n", " out = self.resnet_conv_first[i](out)\n", " if self.t_emb_dim is not None:\n", " out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]\n", " out = self.resnet_conv_second[i](out)\n", " out = out + self.residual_input_conv[i](resnet_input)\n", "\n", " # Self Attention\n", "\n", " if self.attn:\n", " batch_size, channels, h, w = out.shape\n", " in_attn = out.reshape(batch_size, channels, h * w)\n", " in_attn = self.attention_norms[i](in_attn)\n", " in_attn = in_attn.transpose(1, 2)\n", " out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)\n", " out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)\n", " out = out + out_attn\n", " return out\n", "\n", "\n", "# ==================================================================\n", "# V Q - V A E\n", "# ==================================================================\n", "class VQVAE(nn.Module):\n", " def __init__(self, im_channels, model_config):\n", " super().__init__()\n", " self.down_channels = model_config.down_channels\n", " self.mid_channels = model_config.mid_channels\n", " self.down_sample = model_config.down_sample\n", " self.num_down_layers = model_config.num_down_layers\n", " self.num_mid_layers = model_config.num_mid_layers\n", " self.num_up_layers = model_config.num_up_layers\n", "\n", " # To disable attention in Downblock of Encoder and Upblock of Decoder\n", " self.attns = model_config.attn_down\n", "\n", " # Latent Dimension\n", " self.z_channels = model_config.z_channels\n", " self.codebook_size = model_config.codebook_size\n", " self.norm_channels = model_config.norm_channels\n", " self.num_heads = model_config.num_heads\n", "\n", " # Assertion to validate the channel information\n", " assert self.mid_channels[0] == self.down_channels[-1]\n", " assert self.mid_channels[-1] == self.down_channels[-1]\n", " assert len(self.down_sample) == len(self.down_channels) - 1\n", " assert len(self.attns) == len(self.down_channels) - 1\n", "\n", " # Wherever we use downsampling in encoder correspondingly use\n", " # upsampling in decoder\n", " self.up_sample = list(reversed(self.down_sample))\n", "\n", " ##################### Encoder ######################\n", " self.encoder_conv_in = nn.Conv2d(\n", " im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)\n", " )\n", "\n", " # Downblock + Midblock\n", " self.encoder_layers = nn.ModuleList([])\n", " for i in range(len(self.down_channels) - 1):\n", " self.encoder_layers.append(\n", " DownBlock(\n", " self.down_channels[i],\n", " self.down_channels[i + 1],\n", " t_emb_dim=None,\n", " down_sample=self.down_sample[i],\n", " num_heads=self.num_heads,\n", " num_layers=self.num_down_layers,\n", " attn=self.attns[i],\n", " norm_channels=self.norm_channels,\n", " )\n", " )\n", " self.encoder_mids = nn.ModuleList([])\n", " for i in range(len(self.mid_channels) - 1):\n", " self.encoder_mids.append(\n", " MidBlock(\n", " self.mid_channels[i],\n", " self.mid_channels[i + 1],\n", " t_emb_dim=None,\n", " num_heads=self.num_heads,\n", " num_layers=self.num_mid_layers,\n", " norm_channels=self.norm_channels,\n", " )\n", " )\n", " self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])\n", " self.encoder_conv_out = nn.Conv2d(\n", " self.down_channels[-1], self.z_channels, kernel_size=3, padding=1\n", " )\n", "\n", " # Pre Quantization Convolution\n", " self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)\n", "\n", " # Codebook\n", " self.embedding = nn.Embedding(self.codebook_size, self.z_channels)\n", " ####################################################\n", "\n", " ##################### Decoder ######################\n", "\n", " # Post Quantization Convolution\n", " self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)\n", " self.decoder_conv_in = nn.Conv2d(\n", " self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)\n", " )\n", "\n", " # Midblock + Upblock\n", " self.decoder_mids = nn.ModuleList([])\n", " for i in reversed(range(1, len(self.mid_channels))):\n", " self.decoder_mids.append(\n", " MidBlock(\n", " self.mid_channels[i],\n", " self.mid_channels[i - 1],\n", " t_emb_dim=None,\n", " num_heads=self.num_heads,\n", " num_layers=self.num_mid_layers,\n", " norm_channels=self.norm_channels,\n", " )\n", " )\n", " self.decoder_layers = nn.ModuleList([])\n", " for i in reversed(range(1, len(self.down_channels))):\n", " self.decoder_layers.append(\n", " UpBlock(\n", " self.down_channels[i],\n", " self.down_channels[i - 1],\n", " t_emb_dim=None,\n", " up_sample=self.down_sample[i - 1],\n", " num_heads=self.num_heads,\n", " num_layers=self.num_up_layers,\n", " attn=self.attns[i - 1],\n", " norm_channels=self.norm_channels,\n", " )\n", " )\n", " self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])\n", " self.decoder_conv_out = nn.Conv2d(\n", " self.down_channels[0], im_channels, kernel_size=3, padding=1\n", " )\n", "\n", " def quantize(self, x):\n", " B, C, H, W = x.shape\n", "\n", " # B, C, H, W -> B, H, W, C\n", " x = x.permute(0, 2, 3, 1)\n", "\n", " # B, H, W, C -> B, H*W, C\n", " x = x.reshape(x.size(0), -1, x.size(-1))\n", "\n", " # Find nearest embedding/codebook vector\n", " # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)\n", " dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))\n", " # (B, H*W)\n", " min_encoding_indices = torch.argmin(dist, dim=-1)\n", "\n", " # Replace encoder output with nearest codebook\n", " # quant_out -> B*H*W, C\n", " quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))\n", "\n", " # x -> B*H*W, C\n", " x = x.reshape((-1, x.size(-1)))\n", " commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)\n", " codebook_loss = torch.mean((quant_out - x.detach()) ** 2)\n", " quantize_losses = {\"codebook_loss\": codebook_loss, \"commitment_loss\": commmitment_loss}\n", " # Straight through estimation\n", " quant_out = x + (quant_out - x).detach()\n", "\n", " # quant_out -> B, C, H, W\n", " quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)\n", " min_encoding_indices = min_encoding_indices.reshape(\n", " (-1, quant_out.size(-2), quant_out.size(-1))\n", " )\n", " return quant_out, quantize_losses, min_encoding_indices\n", "\n", " def encode(self, x):\n", " out = self.encoder_conv_in(x)\n", " for idx, down in enumerate(self.encoder_layers):\n", " out = down(out)\n", " for mid in self.encoder_mids:\n", " out = mid(out)\n", " out = self.encoder_norm_out(out)\n", " out = nn.SiLU()(out)\n", " out = self.encoder_conv_out(out)\n", " out = self.pre_quant_conv(out)\n", " out, quant_losses, _ = self.quantize(out)\n", " return out, quant_losses\n", "\n", " def decode(self, z):\n", " out = z\n", " out = self.post_quant_conv(out)\n", " out = self.decoder_conv_in(out)\n", " for mid in self.decoder_mids:\n", " out = mid(out)\n", " for idx, up in enumerate(self.decoder_layers):\n", " out = up(out)\n", " out = self.decoder_norm_out(out)\n", " out = nn.SiLU()(out)\n", " out = self.decoder_conv_out(out)\n", " return out\n", "\n", " def forward(self, x):\n", " '''out: [B, 3, 256, 256]\n", " z: [B, 3, 64, 64]\n", " quant_losses: {\n", " codebook_loss: 0.0681,\n", " commitment_loss: 0.0681\n", " }\n", " '''\n", " z, quant_losses = self.encode(x)\n", " out = self.decode(z)\n", " return out, z, quant_losses" ] }, { "cell_type": "code", "execution_count": 3, "id": "27943460", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'autoencoder_params': {'attn_down': [False, True],\n", " 'codebook_size': 20,\n", " 'down_channels': [32, 64, 128],\n", " 'down_sample': [True, True],\n", " 'mid_channels': [128, 128],\n", " 'norm_channels': 32,\n", " 'num_down_layers': 4,\n", " 'num_heads': 8,\n", " 'num_mid_layers': 4,\n", " 'num_up_layers': 4,\n", " 'z_channels': 3},\n", " 'dataset_params': {'im_channels': 3, 'im_size': 128},\n", " 'diffusion_params': {'beta_end': 0.0195, 'beta_start': 0.0015, 'num_timesteps': 1000},\n", " 'ldm_params': {'attn_down': [True, True, True],\n", " 'conv_out_channels': 128,\n", " 'down_channels': [128, 256, 256, 256],\n", " 'down_sample': [False, False, False],\n", " 'mid_channels': [256, 256],\n", " 'norm_channels': 32,\n", " 'num_down_layers': 2,\n", " 'num_heads': 16,\n", " 'num_mid_layers': 2,\n", " 'num_up_layers': 2,\n", " 'time_emb_dim': 256},\n", " 'paths': {'images_dir': '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images'},\n", " 'train_params': {'autoencoder_acc_steps': 1,\n", " 'autoencoder_batch_size': 8,\n", " 'autoencoder_epochs': 30,\n", " 'autoencoder_img_save_steps': 8,\n", " 'autoencoder_lr': 0.0001,\n", " 'codebook_weight': 1,\n", " 'commitment_beta': 0.2,\n", " 'disc_start': 1000,\n", " 'disc_weight': 0.5,\n", " 'kl_weight': 5e-06,\n", " 'ldm_batch_size': 1,\n", " 'ldm_ckpt_name': 'ddpm_ckpt_Acc.pth',\n", " 'ldm_epochs': 10,\n", " 'ldm_lr': 1e-05,\n", " 'num_grid_rows': 3,\n", " 'num_samples': 9,\n", " 'perceptual_weight': 1,\n", " 'save_latents': True,\n", " 'seed': 4422,\n", " 'task_name': 'VaaniLDM_Acc',\n", " 'vqvae_ckpt_name': 'vqvaq_ckpt_Acc.pth',\n", " 'vqvae_latent_dir_name': 'vqvae_latents'},\n", " 'training': {'_continue_': True}}\n" ] } ], "source": [ "# ==================================================================\n", "# C O N F I G U R A T I O N\n", "# ==================================================================\n", "import pprint\n", "# config_path = \"/home/23m1521/ashish/MTP/Vaani/config-Acc.yaml\"\n", "config_path = \"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/config-Acc.yaml\"\n", "with open(config_path, 'r') as file:\n", " Config = yaml.safe_load(file)\n", " pprint.pprint(Config, width=120)\n", "\n", "Config = DotDict.from_dict(Config)\n", "dataset_config = Config.dataset_params\n", "diffusion_config = Config.diffusion_params\n", "model_config = Config.model_params\n", "train_config = Config.train_params\n", "paths = Config.paths" ] }, { "cell_type": "code", "execution_count": 4, "id": "06b8c833", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files found: 128807\n", "IMAGE SHAPE: torch.Size([3, 128, 128])\n" ] } ], "source": [ "# ==================================================================\n", "# V A A N I - D A T A S E T\n", "# ==================================================================\n", "IMAGES_PATH = paths.images_dir\n", "\n", "def walkDIR(folder_path, include=None):\n", " file_list = []\n", " for root, _, files in os.walk(folder_path):\n", " for file in files:\n", " if include is None or any(file.endswith(ext) for ext in include):\n", " file_list.append(os.path.join(root, file))\n", " print(\"Files found:\", len(file_list))\n", " return file_list\n", "\n", "files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg'])\n", "df = pd.DataFrame(files, columns=['image_path'])\n", "\n", "class VaaniDataset(torch.utils.data.Dataset):\n", " def __init__(self, files_paths, im_size):\n", " self.files_paths = files_paths\n", " self.im_size = im_size\n", "\n", " def __len__(self):\n", " return len(self.files_paths)\n", "\n", " def __getitem__(self, idx):\n", " image = tv.io.decode_image(self.files_paths[idx], mode='RGB')\n", " image = v2.Resize((self.im_size,self.im_size))(image)\n", " image = v2.ToDtype(torch.float32, scale=True)(image)\n", " # image = 2*image - 1\n", " return image\n", "\n", "dataset = VaaniDataset(files_paths=files, im_size=dataset_config.im_size)\n", "image = dataset[2]\n", "print('IMAGE SHAPE:', image.shape)" ] }, { "cell_type": "markdown", "id": "e1a1b642", "metadata": {}, "source": [ "### *VQVAE* Inference" ] }, { "cell_type": "code", "execution_count": 10, "id": "7306dd05", "metadata": {}, "outputs": [], "source": [ "def load_checkpoint(checkpoint_path, model):\n", " if os.path.exists(checkpoint_path):\n", " checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)\n", " model.load_state_dict(checkpoint[\"model_state_dict\"])\n", " total_steps = checkpoint[\"total_steps\"]\n", " epoch = checkpoint[\"epoch\"]\n", " total_training_time = checkpoint.get(\"total_training_time\", 0)\n", " print(f\"Checkpoint loaded. \"\n", " f\"Epoch: {epoch}, Step: {total_steps}, Training Time: {total_training_time}\")\n", " return epoch\n", " else:\n", " print(\"No checkpoint found. Starting from scratch.\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "a1ffc82f", "metadata": {}, "outputs": [], "source": [ "# ==================================================================\n", "# M O D E L - I N I T I L I Z A T I O N\n", "# ==================================================================\n", "dataset_config = Config.dataset_params\n", "autoencoder_config = Config.autoencoder_params\n", "train_config = Config.train_params\n", "\n", "model = VQVAE(im_channels=dataset_config.im_channels, \n", " model_config=autoencoder_config).to(device=device)" ] }, { "cell_type": "code", "execution_count": 12, "id": "a3a12633", "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "PytorchStreamReader failed reading file byteorder: invalid header or archive is corrupted", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m checkpoint_path = \u001b[33mr\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/VaaniLDM_Acc/vqvaq_ckpt.pth\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m epoch = \u001b[43mload_checkpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcheckpoint_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 3\u001b[39m, in \u001b[36mload_checkpoint\u001b[39m\u001b[34m(checkpoint_path, model)\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mload_checkpoint\u001b[39m(checkpoint_path, model):\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m os.path.exists(checkpoint_path):\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m checkpoint = \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcheckpoint_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweights_only\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m model.load_state_dict(checkpoint[\u001b[33m\"\u001b[39m\u001b[33mmodel_state_dict\u001b[39m\u001b[33m\"\u001b[39m])\n\u001b[32m 5\u001b[39m total_steps = checkpoint[\u001b[33m\"\u001b[39m\u001b[33mtotal_steps\u001b[39m\u001b[33m\"\u001b[39m]\n", "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/flash2/lib/python3.12/site-packages/torch/serialization.py:1351\u001b[39m, in \u001b[36mload\u001b[39m\u001b[34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[39m\n\u001b[32m 1349\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m weights_only:\n\u001b[32m 1350\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1351\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_load\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1352\u001b[39m \u001b[43m \u001b[49m\u001b[43mopened_zipfile\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1353\u001b[39m \u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1354\u001b[39m \u001b[43m \u001b[49m\u001b[43m_weights_only_unpickler\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1355\u001b[39m \u001b[43m \u001b[49m\u001b[43moverall_storage\u001b[49m\u001b[43m=\u001b[49m\u001b[43moverall_storage\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1356\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mpickle_load_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1357\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1358\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m pickle.UnpicklingError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 1359\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m pickle.UnpicklingError(_get_wo_message(\u001b[38;5;28mstr\u001b[39m(e))) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/flash2/lib/python3.12/site-packages/torch/serialization.py:1731\u001b[39m, in \u001b[36m_load\u001b[39m\u001b[34m(zip_file, map_location, pickle_module, pickle_file, overall_storage, **pickle_load_args)\u001b[39m\n\u001b[32m 1729\u001b[39m byteorderdata = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1730\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m zip_file.has_record(byteordername):\n\u001b[32m-> \u001b[39m\u001b[32m1731\u001b[39m byteorderdata = \u001b[43mzip_file\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_record\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbyteordername\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1732\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m byteorderdata \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m [\u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mlittle\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33mb\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mbig\u001b[39m\u001b[33m\"\u001b[39m]:\n\u001b[32m 1733\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mUnknown endianness type: \u001b[39m\u001b[33m\"\u001b[39m + byteorderdata.decode())\n", "\u001b[31mRuntimeError\u001b[39m: PytorchStreamReader failed reading file byteorder: invalid header or archive is corrupted" ] } ], "source": [ "checkpoint_path = r\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/VaaniLDM_Acc/vqvaq_ckpt.pth\"\n", "epoch = load_checkpoint(checkpoint_path, model)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }