|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from collections import namedtuple |
|
|
|
import pandas as pd |
|
import torchvision as tv |
|
from torchvision.transforms import v2 |
|
from tqdm import tqdm, trange |
|
import matplotlib.pyplot as plt |
|
|
|
import yaml |
|
import shutil |
|
import sys |
|
import random |
|
import datetime |
|
import torch.hub |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision.utils import make_grid |
|
|
|
from accelerate import Accelerator |
|
|
|
print("TIME:", datetime.datetime.now()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any |
|
from argparse import Namespace |
|
import typing |
|
|
|
|
|
class DotDict(Namespace): |
|
"""A simple class that builds upon `argparse.Namespace` |
|
in order to make chained attributes possible.""" |
|
|
|
def __init__(self, temp=False, key=None, parent=None) -> None: |
|
self._temp = temp |
|
self._key = key |
|
self._parent = parent |
|
|
|
def __eq__(self, other): |
|
if not isinstance(other, DotDict): |
|
return NotImplemented |
|
return vars(self) == vars(other) |
|
|
|
def __getattr__(self, __name: str) -> Any: |
|
if __name not in self.__dict__ and not self._temp: |
|
self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self) |
|
else: |
|
del self._parent.__dict__[self._key] |
|
raise AttributeError("No attribute '%s'" % __name) |
|
return self.__dict__[__name] |
|
|
|
def __repr__(self) -> str: |
|
item_keys = [k for k in self.__dict__ if not k.startswith("_")] |
|
|
|
if len(item_keys) == 0: |
|
return "DotDict()" |
|
elif len(item_keys) == 1: |
|
key = item_keys[0] |
|
val = self.__dict__[key] |
|
return "DotDict(%s=%s)" % (key, repr(val)) |
|
else: |
|
return "DotDict(%s)" % ", ".join( |
|
"%s=%s" % (key, repr(val)) for key, val in self.__dict__.items() |
|
) |
|
|
|
@classmethod |
|
def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict": |
|
"""Create a DotDict from a (possibly nested) dict `original`. |
|
Warning: this method should not be used on very deeply nested inputs, |
|
since it's recursively traversing the nested dictionary values. |
|
""" |
|
dd = DotDict() |
|
for key, value in original.items(): |
|
if isinstance(value, typing.Mapping): |
|
value = cls.from_dict(value) |
|
setattr(dd, key, value) |
|
return dd |
|
|
|
|
|
|
|
|
|
|
|
class vgg16(nn.Module): |
|
def __init__(self): |
|
super(vgg16, self).__init__() |
|
vgg_pretrained_features = tv.models.vgg16( |
|
weights=tv.models.VGG16_Weights.IMAGENET1K_V1 |
|
).features |
|
self.slice1 = torch.nn.Sequential() |
|
self.slice2 = torch.nn.Sequential() |
|
self.slice3 = torch.nn.Sequential() |
|
self.slice4 = torch.nn.Sequential() |
|
self.slice5 = torch.nn.Sequential() |
|
self.N_slices = 5 |
|
for x in range(4): |
|
self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(4, 9): |
|
self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(9, 16): |
|
self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(16, 23): |
|
self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(23, 30): |
|
self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
|
|
|
self.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, X): |
|
h1 = self.slice1(X) |
|
h2 = self.slice2(h1) |
|
h3 = self.slice3(h2) |
|
h4 = self.slice4(h3) |
|
h5 = self.slice5(h4) |
|
vgg_outputs = namedtuple("VggOutputs", ['h1', 'h2', 'h3', 'h4', 'h5']) |
|
out = vgg_outputs(h1, h2, h3, h4, h5) |
|
return out |
|
|
|
|
|
def _spatial_average(in_tens, keepdim=True): |
|
return in_tens.mean([2, 3], keepdim=keepdim) |
|
|
|
|
|
def _normalize_tensor(in_feat, eps= 1e-8): |
|
norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True)) |
|
return in_feat / norm_factor |
|
|
|
|
|
class ScalingLayer(nn.Module): |
|
def __init__(self): |
|
super(ScalingLayer, self).__init__() |
|
|
|
|
|
|
|
|
|
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) |
|
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) |
|
|
|
def forward(self, inp): |
|
return (inp - self.shift) / self.scale |
|
|
|
|
|
class NetLinLayer(nn.Module): |
|
''' A single linear layer which does a 1x1 conv ''' |
|
def __init__(self, chn_in, chn_out=1, use_dropout=False): |
|
super(NetLinLayer, self).__init__() |
|
layers = [nn.Dropout(), ] if (use_dropout) else [] |
|
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] |
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class LPIPS(nn.Module): |
|
def __init__(self, net='vgg', version='0.1', use_dropout=True): |
|
super(LPIPS, self).__init__() |
|
self.version = version |
|
self.scaling_layer = ScalingLayer() |
|
self.chns = [64, 128, 256, 512, 512] |
|
self.L = len(self.chns) |
|
self.net = vgg16() |
|
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) |
|
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) |
|
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) |
|
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) |
|
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) |
|
self.lins = nn.ModuleList([self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]) |
|
|
|
|
|
|
|
|
|
|
|
weights_url = f"https://github.com/akuresonite/PerceptualSimilarity-Forked/raw/master/lpips/weights/v{version}/{net}.pth" |
|
|
|
|
|
|
|
|
|
state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') |
|
self.load_state_dict(state_dict, strict=False) |
|
|
|
self.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, in0, in1, normalize=False): |
|
|
|
if normalize: |
|
in0 = 2 * in0 - 1 |
|
in1 = 2 * in1 - 1 |
|
|
|
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) |
|
|
|
|
|
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) |
|
|
|
diffs = {} |
|
for kk in range(self.L): |
|
feats0 = _normalize_tensor(outs0[kk]) |
|
feats1 = _normalize_tensor(outs1[kk]) |
|
diffs[kk] = (feats0 - feats1) ** 2 |
|
|
|
res = [_spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] |
|
val = sum(res) |
|
return val.reshape(-1) |
|
|
|
|
|
|
|
|
|
|
|
class Discriminator(nn.Module): |
|
r""" |
|
PatchGAN Discriminator. |
|
Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to |
|
1 scalar value , we instead predict grid of values. |
|
Where each grid is prediction of how likely |
|
the discriminator thinks that the image patch corresponding |
|
to the grid cell is real |
|
""" |
|
|
|
def __init__( |
|
self, |
|
im_channels=3, |
|
conv_channels=[64, 128, 256], |
|
kernels=[4, 4, 4, 4], |
|
strides=[2, 2, 2, 1], |
|
paddings=[1, 1, 1, 1], |
|
): |
|
super().__init__() |
|
self.im_channels = im_channels |
|
activation = nn.LeakyReLU(0.2) |
|
layers_dim = [self.im_channels] + conv_channels + [1] |
|
self.layers = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.Conv2d( |
|
layers_dim[i], |
|
layers_dim[i + 1], |
|
kernel_size=kernels[i], |
|
stride=strides[i], |
|
padding=paddings[i], |
|
bias=False if i != 0 else True, |
|
), |
|
( |
|
nn.BatchNorm2d(layers_dim[i + 1]) |
|
if i != len(layers_dim) - 2 and i != 0 |
|
else nn.Identity() |
|
), |
|
activation if i != len(layers_dim) - 2 else nn.Identity(), |
|
) |
|
for i in range(len(layers_dim) - 1) |
|
] |
|
) |
|
|
|
def forward(self, x): |
|
out = x |
|
for layer in self.layers: |
|
out = layer(out) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
class DownBlock(nn.Module): |
|
r""" |
|
Down conv block with attention. |
|
Sequence of following block |
|
1. Resnet block with time embedding |
|
2. Attention block |
|
3. Downsample |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
t_emb_dim, |
|
down_sample, |
|
num_heads, |
|
num_layers, |
|
attn, |
|
norm_channels, |
|
cross_attn=False, |
|
context_dim=None, |
|
): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
self.down_sample = down_sample |
|
self.attn = attn |
|
self.context_dim = context_dim |
|
self.cross_attn = cross_attn |
|
self.t_emb_dim = t_emb_dim |
|
self.resnet_conv_first = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels if i == 0 else out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
if self.t_emb_dim is not None: |
|
self.t_emb_layers = nn.ModuleList( |
|
[ |
|
nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels)) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
self.resnet_conv_second = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
if self.attn: |
|
self.attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
|
) |
|
|
|
self.attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
if self.cross_attn: |
|
assert context_dim is not None, "Context Dimension must be passed for cross attention" |
|
self.cross_attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
|
) |
|
self.cross_attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
self.context_proj = nn.ModuleList( |
|
[nn.Linear(context_dim, out_channels) for _ in range(num_layers)] |
|
) |
|
self.residual_input_conv = nn.ModuleList( |
|
[ |
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
|
for i in range(num_layers) |
|
] |
|
) |
|
self.down_sample_conv = ( |
|
nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity() |
|
) |
|
|
|
def forward(self, x, t_emb=None, context=None): |
|
out = x |
|
for i in range(self.num_layers): |
|
|
|
|
|
resnet_input = out |
|
out = self.resnet_conv_first[i](out) |
|
if self.t_emb_dim is not None: |
|
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
|
out = self.resnet_conv_second[i](out) |
|
out = out + self.residual_input_conv[i](resnet_input) |
|
|
|
if self.attn: |
|
|
|
|
|
batch_size, channels, h, w = out.shape |
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
in_attn = self.attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
out = out + out_attn |
|
if self.cross_attn: |
|
assert ( |
|
context is not None |
|
), "context cannot be None if cross attention layers are used" |
|
batch_size, channels, h, w = out.shape |
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
in_attn = self.cross_attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim |
|
context_proj = self.context_proj[i](context) |
|
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
out = out + out_attn |
|
|
|
|
|
out = self.down_sample_conv(out) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
class MidBlock(nn.Module): |
|
r""" |
|
Mid conv block with attention. |
|
Sequence of following blocks |
|
1. Resnet block with time embedding |
|
2. Attention block |
|
3. Resnet block with time embedding |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
t_emb_dim, |
|
num_heads, |
|
num_layers, |
|
norm_channels, |
|
cross_attn=None, |
|
context_dim=None, |
|
): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
self.t_emb_dim = t_emb_dim |
|
self.context_dim = context_dim |
|
self.cross_attn = cross_attn |
|
self.resnet_conv_first = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels if i == 0 else out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers + 1) |
|
] |
|
) |
|
|
|
if self.t_emb_dim is not None: |
|
self.t_emb_layers = nn.ModuleList( |
|
[ |
|
nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) |
|
for _ in range(num_layers + 1) |
|
] |
|
) |
|
self.resnet_conv_second = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
) |
|
for _ in range(num_layers + 1) |
|
] |
|
) |
|
|
|
self.attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
|
) |
|
|
|
self.attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
if self.cross_attn: |
|
assert context_dim is not None, "Context Dimension must be passed for cross attention" |
|
self.cross_attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
|
) |
|
self.cross_attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
self.context_proj = nn.ModuleList( |
|
[nn.Linear(context_dim, out_channels) for _ in range(num_layers)] |
|
) |
|
self.residual_input_conv = nn.ModuleList( |
|
[ |
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
|
for i in range(num_layers + 1) |
|
] |
|
) |
|
|
|
def forward(self, x, t_emb=None, context=None): |
|
out = x |
|
|
|
|
|
|
|
resnet_input = out |
|
out = self.resnet_conv_first[0](out) |
|
if self.t_emb_dim is not None: |
|
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] |
|
out = self.resnet_conv_second[0](out) |
|
out = out + self.residual_input_conv[0](resnet_input) |
|
|
|
for i in range(self.num_layers): |
|
|
|
|
|
batch_size, channels, h, w = out.shape |
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
in_attn = self.attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
out = out + out_attn |
|
|
|
if self.cross_attn: |
|
assert ( |
|
context is not None |
|
), "context cannot be None if cross attention layers are used" |
|
batch_size, channels, h, w = out.shape |
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
in_attn = self.cross_attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim |
|
context_proj = self.context_proj[i](context) |
|
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
out = out + out_attn |
|
|
|
|
|
resnet_input = out |
|
out = self.resnet_conv_first[i + 1](out) |
|
if self.t_emb_dim is not None: |
|
out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] |
|
out = self.resnet_conv_second[i + 1](out) |
|
out = out + self.residual_input_conv[i + 1](resnet_input) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class UpBlock(nn.Module): |
|
r""" |
|
Up conv block with attention. |
|
Sequence of following blocks |
|
1. Upsample |
|
1. Concatenate Down block output |
|
2. Resnet block with time embedding |
|
3. Attention Block |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
t_emb_dim, |
|
up_sample, |
|
num_heads, |
|
num_layers, |
|
attn, |
|
norm_channels, |
|
): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
self.up_sample = up_sample |
|
self.t_emb_dim = t_emb_dim |
|
self.attn = attn |
|
self.resnet_conv_first = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d( |
|
in_channels if i == 0 else out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
) |
|
for i in range(num_layers) |
|
] |
|
) |
|
|
|
if self.t_emb_dim is not None: |
|
self.t_emb_layers = nn.ModuleList( |
|
[ |
|
nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
self.resnet_conv_second = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.GroupNorm(norm_channels, out_channels), |
|
nn.SiLU(), |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
if self.attn: |
|
self.attention_norms = nn.ModuleList( |
|
[nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
|
) |
|
|
|
self.attentions = nn.ModuleList( |
|
[ |
|
nn.MultiheadAttention(embed_dim=out_channels, num_heads=num_heads, batch_first=True) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
self.residual_input_conv = nn.ModuleList( |
|
[ |
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
|
for i in range(num_layers) |
|
] |
|
) |
|
self.up_sample_conv = ( |
|
nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1) |
|
if self.up_sample |
|
else nn.Identity() |
|
) |
|
|
|
def forward(self, x, out_down=None, t_emb=None): |
|
|
|
|
|
x = self.up_sample_conv(x) |
|
|
|
|
|
|
|
if out_down is not None: |
|
x = torch.cat([x, out_down], dim=1) |
|
out = x |
|
for i in range(self.num_layers): |
|
|
|
|
|
resnet_input = out |
|
out = self.resnet_conv_first[i](out) |
|
if self.t_emb_dim is not None: |
|
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
|
out = self.resnet_conv_second[i](out) |
|
out = out + self.residual_input_conv[i](resnet_input) |
|
|
|
|
|
|
|
if self.attn: |
|
batch_size, channels, h, w = out.shape |
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
in_attn = self.attention_norms[i](in_attn) |
|
in_attn = in_attn.transpose(1, 2) |
|
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
out = out + out_attn |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class VQVAE(nn.Module): |
|
def __init__(self, im_channels, model_config): |
|
super().__init__() |
|
self.down_channels = model_config.down_channels |
|
self.mid_channels = model_config.mid_channels |
|
self.down_sample = model_config.down_sample |
|
self.num_down_layers = model_config.num_down_layers |
|
self.num_mid_layers = model_config.num_mid_layers |
|
self.num_up_layers = model_config.num_up_layers |
|
|
|
|
|
self.attns = model_config.attn_down |
|
|
|
|
|
self.z_channels = model_config.z_channels |
|
self.codebook_size = model_config.codebook_size |
|
self.norm_channels = model_config.norm_channels |
|
self.num_heads = model_config.num_heads |
|
|
|
|
|
assert self.mid_channels[0] == self.down_channels[-1] |
|
assert self.mid_channels[-1] == self.down_channels[-1] |
|
assert len(self.down_sample) == len(self.down_channels) - 1 |
|
assert len(self.attns) == len(self.down_channels) - 1 |
|
|
|
|
|
|
|
self.up_sample = list(reversed(self.down_sample)) |
|
|
|
|
|
self.encoder_conv_in = nn.Conv2d( |
|
im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1) |
|
) |
|
|
|
|
|
self.encoder_layers = nn.ModuleList([]) |
|
for i in range(len(self.down_channels) - 1): |
|
self.encoder_layers.append( |
|
DownBlock( |
|
self.down_channels[i], |
|
self.down_channels[i + 1], |
|
t_emb_dim=None, |
|
down_sample=self.down_sample[i], |
|
num_heads=self.num_heads, |
|
num_layers=self.num_down_layers, |
|
attn=self.attns[i], |
|
norm_channels=self.norm_channels, |
|
) |
|
) |
|
self.encoder_mids = nn.ModuleList([]) |
|
for i in range(len(self.mid_channels) - 1): |
|
self.encoder_mids.append( |
|
MidBlock( |
|
self.mid_channels[i], |
|
self.mid_channels[i + 1], |
|
t_emb_dim=None, |
|
num_heads=self.num_heads, |
|
num_layers=self.num_mid_layers, |
|
norm_channels=self.norm_channels, |
|
) |
|
) |
|
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) |
|
self.encoder_conv_out = nn.Conv2d( |
|
self.down_channels[-1], self.z_channels, kernel_size=3, padding=1 |
|
) |
|
|
|
|
|
self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) |
|
|
|
|
|
self.embedding = nn.Embedding(self.codebook_size, self.z_channels) |
|
|
|
|
|
|
|
|
|
|
|
self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) |
|
self.decoder_conv_in = nn.Conv2d( |
|
self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1) |
|
) |
|
|
|
|
|
self.decoder_mids = nn.ModuleList([]) |
|
for i in reversed(range(1, len(self.mid_channels))): |
|
self.decoder_mids.append( |
|
MidBlock( |
|
self.mid_channels[i], |
|
self.mid_channels[i - 1], |
|
t_emb_dim=None, |
|
num_heads=self.num_heads, |
|
num_layers=self.num_mid_layers, |
|
norm_channels=self.norm_channels, |
|
) |
|
) |
|
self.decoder_layers = nn.ModuleList([]) |
|
for i in reversed(range(1, len(self.down_channels))): |
|
self.decoder_layers.append( |
|
UpBlock( |
|
self.down_channels[i], |
|
self.down_channels[i - 1], |
|
t_emb_dim=None, |
|
up_sample=self.down_sample[i - 1], |
|
num_heads=self.num_heads, |
|
num_layers=self.num_up_layers, |
|
attn=self.attns[i - 1], |
|
norm_channels=self.norm_channels, |
|
) |
|
) |
|
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) |
|
self.decoder_conv_out = nn.Conv2d( |
|
self.down_channels[0], im_channels, kernel_size=3, padding=1 |
|
) |
|
|
|
def quantize(self, x): |
|
B, C, H, W = x.shape |
|
|
|
|
|
x = x.permute(0, 2, 3, 1) |
|
|
|
|
|
x = x.reshape(x.size(0), -1, x.size(-1)) |
|
|
|
|
|
|
|
dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) |
|
|
|
min_encoding_indices = torch.argmin(dist, dim=-1) |
|
|
|
|
|
|
|
quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) |
|
|
|
|
|
x = x.reshape((-1, x.size(-1))) |
|
commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) |
|
codebook_loss = torch.mean((quant_out - x.detach()) ** 2) |
|
quantize_losses = {"codebook_loss": codebook_loss, "commitment_loss": commmitment_loss} |
|
|
|
quant_out = x + (quant_out - x).detach() |
|
|
|
|
|
quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) |
|
min_encoding_indices = min_encoding_indices.reshape( |
|
(-1, quant_out.size(-2), quant_out.size(-1)) |
|
) |
|
return quant_out, quantize_losses, min_encoding_indices |
|
|
|
def encode(self, x): |
|
out = self.encoder_conv_in(x) |
|
for idx, down in enumerate(self.encoder_layers): |
|
out = down(out) |
|
for mid in self.encoder_mids: |
|
out = mid(out) |
|
out = self.encoder_norm_out(out) |
|
out = nn.SiLU()(out) |
|
out = self.encoder_conv_out(out) |
|
out = self.pre_quant_conv(out) |
|
out, quant_losses, _ = self.quantize(out) |
|
return out, quant_losses |
|
|
|
def decode(self, z): |
|
out = z |
|
out = self.post_quant_conv(out) |
|
out = self.decoder_conv_in(out) |
|
for mid in self.decoder_mids: |
|
out = mid(out) |
|
for idx, up in enumerate(self.decoder_layers): |
|
out = up(out) |
|
out = self.decoder_norm_out(out) |
|
out = nn.SiLU()(out) |
|
out = self.decoder_conv_out(out) |
|
return out |
|
|
|
def forward(self, x): |
|
'''out: [B, 3, 256, 256] |
|
z: [B, 3, 64, 64] |
|
quant_losses: { |
|
codebook_loss: 0.0681, |
|
commitment_loss: 0.0681 |
|
} |
|
''' |
|
z, quant_losses = self.encode(x) |
|
out = self.decode(z) |
|
return out, z, quant_losses |
|
|
|
|
|
|
|
|
|
|
|
import pprint |
|
|
|
|
|
config_path = sys.argv[1] |
|
with open(config_path, 'r') as file: |
|
Config = yaml.safe_load(file) |
|
pprint.pprint(Config, width=120) |
|
|
|
Config = DotDict.from_dict(Config) |
|
dataset_config = Config.dataset_params |
|
diffusion_config = Config.diffusion_params |
|
model_config = Config.model_params |
|
train_config = Config.train_params |
|
paths = Config.paths |
|
|
|
|
|
|
|
|
|
|
|
IMAGES_PATH = paths.images_dir |
|
|
|
def walkDIR(folder_path, include=None): |
|
file_list = [] |
|
for root, _, files in os.walk(folder_path): |
|
for file in files: |
|
if include is None or any(file.endswith(ext) for ext in include): |
|
file_list.append(os.path.join(root, file)) |
|
print("Files found:", len(file_list)) |
|
return file_list |
|
|
|
files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg']) |
|
df = pd.DataFrame(files, columns=['image_path']) |
|
|
|
class VaaniDataset(torch.utils.data.Dataset): |
|
def __init__(self, files_paths, im_size): |
|
self.files_paths = files_paths |
|
self.im_size = im_size |
|
|
|
def __len__(self): |
|
return len(self.files_paths) |
|
|
|
def __getitem__(self, idx): |
|
image = tv.io.decode_image(self.files_paths[idx], mode='RGB') |
|
image = v2.Resize((self.im_size,self.im_size))(image) |
|
image = v2.ToDtype(torch.float32, scale=True)(image) |
|
|
|
return image |
|
|
|
dataset = VaaniDataset(files_paths=files, im_size=dataset_config.im_size) |
|
|
|
|
|
print("Length of Train dataset:", len(dataset)) |
|
|
|
|
|
|
|
|
|
accelerator = Accelerator() |
|
device = accelerator.device |
|
print("DEVICE:", device) |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=train_config.autoencoder_batch_size, |
|
shuffle=True, |
|
|
|
|
|
drop_last=True, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_config = Config.dataset_params |
|
autoencoder_config = Config.autoencoder_params |
|
train_config = Config.train_params |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
|
|
def format_time(t1, t2): |
|
elapsed_time = t2 - t1 |
|
if elapsed_time < 60: |
|
return f"{elapsed_time:.2f} seconds" |
|
elif elapsed_time < 3600: |
|
minutes = elapsed_time // 60 |
|
seconds = elapsed_time % 60 |
|
return f"{minutes:.0f} minutes {seconds:.2f} seconds" |
|
elif elapsed_time < 86400: |
|
hours = elapsed_time // 3600 |
|
remainder = elapsed_time % 3600 |
|
minutes = remainder // 60 |
|
seconds = remainder % 60 |
|
return f"{hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" |
|
else: |
|
days = elapsed_time // 86400 |
|
remainder = elapsed_time % 86400 |
|
hours = remainder // 3600 |
|
remainder = remainder % 3600 |
|
minutes = remainder // 60 |
|
seconds = remainder % 60 |
|
return f"{days:.0f} days {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" |
|
|
|
def save_training_meta(training_meta_path, |
|
total_steps, epoch, metrics, logs, total_training_time |
|
): |
|
checkpoint = { |
|
"total_steps": total_steps, |
|
"epoch": epoch, |
|
"metrics": metrics, |
|
"logs": logs, |
|
"total_training_time": total_training_time |
|
} |
|
print(training_meta_path) |
|
tmp_path = training_meta_path + ".tmp" |
|
|
|
torch.save(checkpoint, tmp_path) |
|
print(os.listdir("./VaaniLDM_Acc")) |
|
|
|
os.replace(tmp_path, training_meta_path) |
|
|
|
print(f"Checkpoint saved after {total_steps} steps at epoch {epoch}") |
|
|
|
def load_training_meta(training_meta_path): |
|
if os.path.exists(training_meta_path): |
|
checkpoint = torch.load(training_meta_path) |
|
total_steps = checkpoint["total_steps"] |
|
epoch = checkpoint["epoch"] |
|
metrics = checkpoint["metrics"] |
|
logs = checkpoint.get("logs", []) |
|
total_training_time = checkpoint.get("total_training_time", 0) |
|
print(f"Checkpoint loaded. Resuming from epoch {epoch + 1}, step {total_steps}") |
|
return total_steps, epoch + 1, metrics, logs, total_training_time |
|
else: |
|
print("No checkpoint found. Starting from scratch.") |
|
return 0, 0, None, [], 0 |
|
|
|
def inference(model, dataset, save_path, epoch, device="cuda", sample_size=8): |
|
if not os.path.exists(save_path): |
|
os.makedirs(save_path) |
|
|
|
image_tensors = [] |
|
for i in range(sample_size): |
|
image_tensors.append(dataset[i].unsqueeze(0)) |
|
|
|
image_tensors = torch.cat(image_tensors, dim=0) |
|
with torch.no_grad(): |
|
outputs, _, _ = model(image_tensors) |
|
|
|
save_input = image_tensors.detach().cpu() |
|
save_output = outputs.detach().cpu() |
|
|
|
grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) |
|
|
|
combined_image = tv.transforms.ToPILImage()(grid) |
|
combined_image.save(os.path.join(save_path, f"reconstructed_images_EP-{epoch}_{sample_size}.png")) |
|
|
|
print(f"Reconstructed images saved at: {save_path}") |
|
|
|
|
|
def trainVAE(Config, dataloader): |
|
dataset_config = Config.dataset_params |
|
autoencoder_config = Config.autoencoder_params |
|
train_config = Config.train_params |
|
paths = Config.paths |
|
|
|
seed = train_config.seed |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
if device == "cuda": |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config) |
|
discriminator = Discriminator(im_channels=dataset_config.im_channels) |
|
|
|
optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) |
|
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) |
|
|
|
cwd = os.getcwd() |
|
checkpoint_path = os.path.join(cwd, train_config.task_name, "vqvae_ckpt") |
|
training_meta_path = os.path.join(cwd, train_config.task_name, "training_meta.pth") |
|
total_steps, start_epoch, metrics, logs, total_training_time = load_training_meta(training_meta_path) |
|
|
|
if not os.path.exists(train_config.task_name): |
|
os.mkdir(train_config.task_name) |
|
|
|
num_epochs = train_config.autoencoder_epochs |
|
recon_criterion = torch.nn.MSELoss() |
|
disc_criterion = torch.nn.MSELoss() |
|
lpips_model = LPIPS().eval() |
|
|
|
acc_steps = train_config.autoencoder_acc_steps |
|
disc_step_start = train_config.disc_start |
|
|
|
|
|
model, discriminator, lpips_model, optimizer_g, optimizer_d, dataloader = accelerator.prepare( |
|
model, discriminator, lpips_model, optimizer_g, optimizer_d, dataloader |
|
) |
|
|
|
accelerator.register_for_checkpointing(model) |
|
accelerator.register_for_checkpointing(discriminator) |
|
accelerator.register_for_checkpointing(lpips_model) |
|
accelerator.register_for_checkpointing(optimizer_g) |
|
accelerator.register_for_checkpointing(optimizer_d) |
|
|
|
if os.path.exists(checkpoint_path): |
|
|
|
accelerator.load_state(checkpoint_path) |
|
|
|
start_time_total = time.time() - total_training_time |
|
|
|
for epoch_idx in trange(start_epoch, num_epochs, ncols=70, colour='blue'): |
|
start_time_epoch = time.time() |
|
epoch_log = [] |
|
|
|
for images in tqdm(dataloader, ncols=70, colour='blue'): |
|
|
|
batch_start_time = time.time() |
|
total_steps += 1 |
|
|
|
images = images |
|
model_output = model(images) |
|
output, z, quantize_losses = model_output |
|
|
|
recon_loss = recon_criterion(output, images) / acc_steps |
|
|
|
g_loss = ( |
|
recon_loss |
|
+ (train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps) |
|
+ (train_config.commitment_beta * quantize_losses["commitment_loss"] / acc_steps) |
|
) |
|
|
|
if total_steps > disc_step_start: |
|
disc_fake_pred = discriminator(output[0]) |
|
disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) |
|
g_loss += train_config.disc_weight * disc_fake_loss / acc_steps |
|
|
|
lpips_loss = torch.mean(lpips_model(output, images)) / acc_steps |
|
g_loss += train_config.perceptual_weight * lpips_loss |
|
|
|
|
|
accelerator.backward(g_loss) |
|
|
|
if total_steps > disc_step_start: |
|
fake = output |
|
disc_fake_pred = discriminator(fake.detach()) |
|
disc_real_pred = discriminator(images) |
|
|
|
|
|
|
|
|
|
disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros(disc_fake_pred.shape)) |
|
disc_real_loss = disc_criterion(disc_real_pred, torch.ones(disc_real_pred.shape)) |
|
|
|
disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 / acc_steps |
|
|
|
|
|
accelerator.backward(disc_loss) |
|
|
|
if total_steps % acc_steps == 0: |
|
optimizer_d.step() |
|
optimizer_d.zero_grad() |
|
|
|
if total_steps % acc_steps == 0: |
|
optimizer_g.step() |
|
optimizer_g.zero_grad() |
|
|
|
batch_time = time.time() - batch_start_time |
|
epoch_log.append(format_time(0, batch_time)) |
|
|
|
optimizer_d.step() |
|
optimizer_d.zero_grad() |
|
optimizer_g.step() |
|
optimizer_g.zero_grad() |
|
|
|
epoch_time = time.time() - start_time_epoch |
|
logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log}) |
|
|
|
total_training_time = time.time() - start_time_total |
|
|
|
|
|
if accelerator.is_main_process: |
|
save_training_meta(training_meta_path, total_steps, epoch_idx + 1, metrics, logs, total_training_time) |
|
print('Training Metadata Saved in', training_meta_path) |
|
shutil.copy2(training_meta_path, os.path.join(cwd, train_config.task_name, "training_meta-1.pth")) |
|
print('Copied to', os.path.join(cwd, train_config.task_name, "training_meta-1.pth")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accelerator.save_state(checkpoint_path, safe_serialization=True) |
|
print('States Saved in', checkpoint_path) |
|
shutil.copytree(checkpoint_path, os.path.join(cwd, train_config.task_name, "vqvae_ckpt-1"), dirs_exist_ok=True) |
|
print('Copied to', os.path.join(cwd, train_config.task_name, "vqvae_ckpt-1.pth")) |
|
|
|
recon_save_path = os.path.join(cwd, train_config.task_name, 'vqvae_recon') |
|
|
|
|
|
accelerator.end_training() |
|
print("Training completed.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
trainVAE(Config, dataloader) |
|
|
|
|
|
|
|
|