RishabA commited on
Commit
e470bd9
·
1 Parent(s): 107afe5

Upload 7 files

Browse files
celebhq/.DS_Store ADDED
Binary file (6.15 kB). View file
 
celebhq/vqvae_autoencoder_ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56ef462bb47bf43acbdefd47fb4563d38cd526a6a4f405e2fd69ab6158e695c7
3
+ size 272212900
config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_params": {
3
+ "image_path": "data/CelebAMask-HQ",
4
+ "image_channels": 3,
5
+ "image_size": 256,
6
+ "name": "celebhq"
7
+ },
8
+ "diffusion_params": {
9
+ "num_timesteps": 1000,
10
+ "beta_start": 0.00085,
11
+ "beta_end": 0.012
12
+ },
13
+ "ldm_params": {
14
+ "down_channels": [256, 384, 512, 768],
15
+ "mid_channels": [768, 512],
16
+ "down_sample": [true, true, true],
17
+ "attn_down": [true, true, true],
18
+ "time_emb_dim": 512,
19
+ "norm_channels": 32,
20
+ "num_heads": 16,
21
+ "conv_out_channels": 128,
22
+ "num_down_layers": 2,
23
+ "num_mid_layers": 2,
24
+ "num_up_layers": 2,
25
+ "condition_config": {
26
+ "condition_types": ["text", "image"],
27
+ "text_condition_config": {
28
+ "text_embed_model": "clip",
29
+ "train_text_embed_model": false,
30
+ "text_embed_dim": 512,
31
+ "cond_drop_prob": 0.1
32
+ },
33
+ "image_condition_config": {
34
+ "image_condition_input_channels": 18,
35
+ "image_condition_output_channels": 3,
36
+ "image_condition_h": 512,
37
+ "image_condition_w": 512,
38
+ "cond_drop_prob": 0.1
39
+ }
40
+ }
41
+ },
42
+ "autoencoder_params": {
43
+ "z_channels": 4,
44
+ "codebook_size": 8192,
45
+ "down_channels": [64, 128, 256, 256],
46
+ "mid_channels": [256, 256],
47
+ "down_sample": [true, true, true],
48
+ "attn_down": [false, false, false],
49
+ "norm_channels": 32,
50
+ "num_heads": 4,
51
+ "num_down_layers": 2,
52
+ "num_mid_layers": 2,
53
+ "num_up_layers": 2
54
+ },
55
+ "train_params": {
56
+ "task_name": "celebhq",
57
+ "num_samples": 1,
58
+ "num_grid_rows": 1,
59
+ "cf_guidance_scale": 1.0,
60
+ "ldm_ckpt_name": "ddpm_ckpt_class_cond.pth",
61
+ "vqvae_autoencoder_ckpt_name": "vqvae_autoencoder_ckpt.pth",
62
+ "vqvae_latent_dir_name": "vqvae_latents"
63
+ }
64
+ }
inference.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import os
3
+ import json
4
+ import torch
5
+ import argparse
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ from torchvision.utils import make_grid
9
+ import torch.nn.functional as F
10
+
11
+ # Import your model definitions and helper functions
12
+ from model import (
13
+ UNet,
14
+ VQVAE,
15
+ LinearNoiseScheduler,
16
+ get_tokenizer_and_model,
17
+ get_text_representation,
18
+ get_time_embedding,
19
+ )
20
+
21
+
22
+ def load_config(config_path="config.json"):
23
+ with open(config_path, "r") as f:
24
+ config = json.load(f)
25
+ return config
26
+
27
+
28
+ def sample_ddpm_inference(
29
+ text_prompt, mask_image_path=None, guidance_scale=1.0, device=torch.device("cpu")
30
+ ):
31
+ config = load_config()
32
+
33
+ diffusion_params = config["diffusion_params"]
34
+ ldm_params = config["ldm_params"]
35
+ autoencoder_params = config["autoencoder_params"]
36
+ train_params = config["train_params"]
37
+ dataset_params = config["dataset_params"]
38
+
39
+ # Create the noise scheduler
40
+ scheduler = LinearNoiseScheduler(
41
+ num_timesteps=diffusion_params["num_timesteps"],
42
+ beta_start=diffusion_params["beta_start"],
43
+ beta_end=diffusion_params["beta_end"],
44
+ )
45
+
46
+ # Conditioning configuration
47
+ condition_config = ldm_params.get("condition_config", {})
48
+ condition_types = condition_config.get("condition_types", [])
49
+
50
+ # Text conditioning
51
+ text_model_type = condition_config["text_condition_config"]["text_embed_model"]
52
+ text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device)
53
+ empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device)
54
+ text_prompt_embed = get_text_representation(
55
+ [text_prompt], text_tokenizer, text_model, device
56
+ )
57
+
58
+ # Image conditioning
59
+ if "image" in condition_types:
60
+ if mask_image_path is not None:
61
+ mask_image = Image.open(mask_image_path).convert("RGB")
62
+ mask_transform = transforms.Compose(
63
+ [
64
+ transforms.Resize(
65
+ (
66
+ ldm_params["condition_config"]["image_condition_config"][
67
+ "image_condition_h"
68
+ ],
69
+ ldm_params["condition_config"]["image_condition_config"][
70
+ "image_condition_w"
71
+ ],
72
+ )
73
+ ),
74
+ transforms.ToTensor(),
75
+ ]
76
+ )
77
+ mask_tensor = mask_transform(mask_image).unsqueeze(0).to(device)
78
+ else:
79
+ ic = ldm_params["condition_config"]["image_condition_config"][
80
+ "image_condition_input_channels"
81
+ ]
82
+ H = ldm_params["condition_config"]["image_condition_config"][
83
+ "image_condition_h"
84
+ ]
85
+ W = ldm_params["condition_config"]["image_condition_config"][
86
+ "image_condition_w"
87
+ ]
88
+ mask_tensor = torch.zeros((1, ic, H, W), device=device)
89
+ else:
90
+ mask_tensor = None
91
+
92
+ # Build conditioning dictionaries
93
+ uncond_input = {}
94
+ cond_input = {}
95
+ if "text" in condition_types:
96
+ uncond_input["text"] = empty_text_embed
97
+ cond_input["text"] = text_prompt_embed
98
+ if "image" in condition_types:
99
+ uncond_input["image"] = torch.zeros_like(mask_tensor)
100
+ cond_input["image"] = mask_tensor
101
+
102
+ # Instantiate and load UNet model
103
+ unet = UNet(autoencoder_params["z_channels"], ldm_params).to(device)
104
+ ldm_ckpt_path = os.path.join(
105
+ train_params["task_name"], train_params["ldm_ckpt_name"]
106
+ )
107
+ if os.path.exists(ldm_ckpt_path):
108
+ ckpt = torch.load(ldm_ckpt_path, map_location=device)
109
+ unet.load_state_dict(ckpt["model_state_dict"])
110
+ unet.eval()
111
+
112
+ # Instantiate and load VQVAE autoencoder
113
+ vae = VQVAE(dataset_params["image_channels"], autoencoder_params).to(device)
114
+ vae_ckpt_path = os.path.join(
115
+ train_params["task_name"], train_params["vqvae_autoencoder_ckpt_name"]
116
+ )
117
+ if os.path.exists(vae_ckpt_path):
118
+ ckpt = torch.load(vae_ckpt_path, map_location=device)
119
+ vae.load_state_dict(ckpt["model_state_dict"])
120
+ vae.eval()
121
+
122
+ # Determine latent space size (simplified calculation)
123
+ latent_size = dataset_params["image_size"] // (
124
+ 2 ** sum(autoencoder_params["down_sample"])
125
+ )
126
+ batch = train_params["num_samples"]
127
+ z_channels = autoencoder_params["z_channels"]
128
+
129
+ # Sample initial latent noise
130
+ xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device)
131
+
132
+ T = diffusion_params["num_timesteps"]
133
+ for i in reversed(range(T)):
134
+ t = torch.full((batch,), i, dtype=torch.long, device=device)
135
+ noise_pred_cond = unet(xt, t, cond_input)
136
+ if guidance_scale > 1:
137
+ noise_pred_uncond = unet(xt, t, uncond_input)
138
+ noise_pred = noise_pred_uncond + guidance_scale * (
139
+ noise_pred_cond - noise_pred_uncond
140
+ )
141
+ else:
142
+ noise_pred = noise_pred_cond
143
+ xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t)
144
+
145
+ with torch.no_grad():
146
+ generated = vae.decode(xt)
147
+ generated = torch.clamp(generated, -1, 1)
148
+ generated = (generated + 1) / 2 # Scale to [0, 1]
149
+ grid = make_grid(generated, nrow=1)
150
+ pil_img = transforms.ToPILImage()(grid.cpu())
151
+ return pil_img
152
+
153
+
154
+ if __name__ == "__main__":
155
+ parser = argparse.ArgumentParser(description="Run model inference")
156
+ parser.add_argument(
157
+ "--text", type=str, required=True, help="Text prompt for conditioning"
158
+ )
159
+ parser.add_argument(
160
+ "--mask",
161
+ type=str,
162
+ default=None,
163
+ help="Path to mask image for conditioning (optional)",
164
+ )
165
+ args = parser.parse_args()
166
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
167
+ result_img = sample_ddpm_inference(args.text, args.mask, device=device)
168
+ result_img.save("generated.png")
169
+ print("Generated image saved as generated.png")
model.py ADDED
@@ -0,0 +1,1797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import glob
5
+ import pickle
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as transforms
10
+ from torch.optim import Adam
11
+ from torchvision.utils import make_grid
12
+ from PIL import Image
13
+ from transformers import (
14
+ DistilBertModel,
15
+ DistilBertTokenizer,
16
+ CLIPTokenizer,
17
+ CLIPTextModel,
18
+ )
19
+
20
+ dataset_params = {
21
+ "image_path": "data/CelebAMask-HQ",
22
+ "image_channels": 3,
23
+ "image_size": 256,
24
+ "name": "celebhq",
25
+ }
26
+
27
+ diffusion_params = {
28
+ "num_timesteps": 1000,
29
+ "beta_start": 0.00085,
30
+ "beta_end": 0.012,
31
+ }
32
+
33
+ ldm_params = {
34
+ "down_channels": [256, 384, 512, 768],
35
+ "mid_channels": [768, 512],
36
+ "down_sample": [True, True, True],
37
+ "attn_down": [True, True, True], # Attention in Down/Up blocks for diffusion
38
+ "time_emb_dim": 512,
39
+ "norm_channels": 32,
40
+ "num_heads": 16,
41
+ "conv_out_channels": 128,
42
+ "num_down_layers": 2,
43
+ "num_mid_layers": 2,
44
+ "num_up_layers": 2,
45
+ "condition_config": {
46
+ "condition_types": ["text", "image"],
47
+ "text_condition_config": {
48
+ "text_embed_model": "clip", # or "bert"
49
+ "train_text_embed_model": False,
50
+ "text_embed_dim": 512,
51
+ "cond_drop_prob": 0.1,
52
+ },
53
+ "image_condition_config": {
54
+ "image_condition_input_channels": 18,
55
+ "image_condition_output_channels": 3,
56
+ "image_condition_h": 512,
57
+ "image_condition_w": 512,
58
+ "cond_drop_prob": 0.1,
59
+ },
60
+ },
61
+ }
62
+
63
+ autoencoder_params = {
64
+ "z_channels": 4,
65
+ "codebook_size": 8192,
66
+ "down_channels": [64, 128, 256, 256],
67
+ "mid_channels": [256, 256],
68
+ "down_sample": [True, True, True],
69
+ "attn_down": [False, False, False],
70
+ "norm_channels": 32,
71
+ "num_heads": 4,
72
+ "num_down_layers": 2,
73
+ "num_mid_layers": 2,
74
+ "num_up_layers": 2,
75
+ }
76
+
77
+ train_params = {
78
+ "task_name": "celebhq", # Folder name in which model checkpoints are stored
79
+ "num_samples": 1,
80
+ "num_grid_rows": 1,
81
+ "cf_guidance_scale": 1.0,
82
+ "ldm_ckpt_name": "ddpm_ckpt_class_cond.pth",
83
+ "vqvae_autoencoder_ckpt_name": "vqvae_autoencoder_ckpt.pth",
84
+ "vqvae_latent_dir_name": "vqvae_latents",
85
+ }
86
+
87
+
88
+ def get_config_value(config, key, default_value):
89
+ return config[key] if key in config else default_value
90
+
91
+
92
+ def spatial_average(in_tens, keepdim=True):
93
+ return in_tens.mean([2, 3], keepdim=keepdim)
94
+
95
+
96
+ class LinearNoiseScheduler:
97
+ def __init__(self, num_timesteps, beta_start, beta_end):
98
+ self.num_timesteps = num_timesteps
99
+ self.beta_start = beta_start
100
+ self.beta_end = beta_end
101
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
102
+ self.alphas = 1.0 - self.betas
103
+ self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
104
+ self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
105
+ self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
106
+
107
+ def add_noise(self, original, noise, t):
108
+ # original: (batch_size, c, h, w), t: tensor of timesteps (batch_size,)
109
+ batch_size = original.shape[0]
110
+ sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].view(
111
+ batch_size, 1, 1, 1
112
+ )
113
+ sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(
114
+ original.device
115
+ )[t].view(batch_size, 1, 1, 1)
116
+ return sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise
117
+
118
+ def sample_prev_timestep(self, xt, noise_pred, t):
119
+ batch_size = xt.shape[0]
120
+ alpha_cum_prod_t = self.alpha_cum_prod.to(xt.device)[t].view(
121
+ batch_size, 1, 1, 1
122
+ )
123
+ sqrt_one_minus_alpha_cum_prod_t = self.sqrt_one_minus_alpha_cum_prod.to(
124
+ xt.device
125
+ )[t].view(batch_size, 1, 1, 1)
126
+ x0 = (xt - sqrt_one_minus_alpha_cum_prod_t * noise_pred) / torch.sqrt(
127
+ alpha_cum_prod_t
128
+ )
129
+ x0 = torch.clamp(x0, -1.0, 1.0)
130
+ betas_t = self.betas.to(xt.device)[t].view(batch_size, 1, 1, 1)
131
+ mean = (
132
+ xt - betas_t / sqrt_one_minus_alpha_cum_prod_t * noise_pred
133
+ ) / torch.sqrt(self.alphas.to(xt.device)[t].view(batch_size, 1, 1, 1))
134
+ if t[0] == 0:
135
+ return mean, x0
136
+ else:
137
+ prev_alpha_cum_prod = self.alpha_cum_prod.to(xt.device)[
138
+ (t - 1).clamp(min=0)
139
+ ].view(batch_size, 1, 1, 1)
140
+ variance = (1 - prev_alpha_cum_prod) / (1 - alpha_cum_prod_t) * betas_t
141
+ sigma = variance.sqrt()
142
+ z = torch.randn_like(xt)
143
+ return mean + sigma * z, x0
144
+
145
+
146
+ def get_tokenizer_and_model(model_type, device, eval_mode=True):
147
+ assert model_type in (
148
+ "bert",
149
+ "clip",
150
+ ), "Text model can only be one of 'clip' or 'bert'"
151
+ if model_type == "bert":
152
+ text_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
153
+ text_model = DistilBertModel.from_pretrained("distilbert-base-uncased").to(
154
+ device
155
+ )
156
+ else:
157
+ text_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16")
158
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch16").to(
159
+ device
160
+ )
161
+ if eval_mode:
162
+ text_model.eval()
163
+ return text_tokenizer, text_model
164
+
165
+
166
+ def get_text_representation(text, text_tokenizer, text_model, device, max_length=77):
167
+ token_output = text_tokenizer(
168
+ text,
169
+ truncation=True,
170
+ padding="max_length",
171
+ return_attention_mask=True,
172
+ max_length=max_length,
173
+ )
174
+ tokens_tensor = torch.tensor(token_output["input_ids"]).to(device)
175
+ mask_tensor = torch.tensor(token_output["attention_mask"]).to(device)
176
+ text_embed = text_model(tokens_tensor, attention_mask=mask_tensor).last_hidden_state
177
+ return text_embed
178
+
179
+
180
+ def get_time_embedding(time_steps, temb_dim):
181
+ """
182
+ Convert time steps tensor into an embedding using the sinusoidal time embedding formula
183
+ time_steps: 1D tensor of length batch size
184
+ temb_dim: Dimension of the embedding
185
+ """
186
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
187
+
188
+ # factor = 10000^(2i/d_model)
189
+ factor = 10000 ** (
190
+ (
191
+ torch.arange(
192
+ start=0,
193
+ end=temb_dim // 2,
194
+ dtype=torch.float32,
195
+ device=time_steps.device,
196
+ )
197
+ / (temb_dim // 2)
198
+ )
199
+ )
200
+
201
+ t_emb = time_steps.unsqueeze(dim=-1).repeat(1, temb_dim // 2) / factor
202
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
203
+
204
+ return t_emb # (batch_size, temb_dim)
205
+
206
+
207
+ class DownBlock(nn.Module):
208
+ """
209
+ Down conv block with attention.
210
+ 1. Resnet block with time embedding
211
+ 2. Attention block
212
+ 3. Downsample
213
+
214
+ in_channels: Number of channels in the input feature map.
215
+ out_channels: Number of channels produced by this block.
216
+ t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None.
217
+ down_sample: Whether to apply downsampling at the end.
218
+ num_heads: Number of attention heads (used if attention is enabled).
219
+ num_layers: How many sub-blocks to apply in sequence.
220
+ attn: Whether to apply self-attention
221
+ norm_channels: Number of groups for GroupNorm.
222
+ cross_attn: Whether to apply cross-attention.
223
+ context_dim: If performing cross-attention, provide a context_dim for extra conditioning context.
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ in_channels,
229
+ out_channels,
230
+ t_emb_dim,
231
+ down_sample,
232
+ num_heads,
233
+ num_layers,
234
+ attn,
235
+ norm_channels,
236
+ cross_attn=False,
237
+ context_dim=None,
238
+ ):
239
+ super().__init__()
240
+
241
+ self.num_layers = num_layers
242
+ self.down_sample = down_sample
243
+ self.attn = attn
244
+ self.context_dim = context_dim
245
+ self.cross_attn = cross_attn
246
+ self.t_emb_dim = t_emb_dim
247
+
248
+ self.resnet_conv_first = nn.ModuleList(
249
+ [
250
+ nn.Sequential(
251
+ nn.GroupNorm(
252
+ norm_channels, in_channels if i == 0 else out_channels
253
+ ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels
254
+ nn.SiLU(),
255
+ nn.Conv2d(
256
+ in_channels=(in_channels if i == 0 else out_channels),
257
+ out_channels=out_channels,
258
+ kernel_size=3,
259
+ stride=1,
260
+ padding=1,
261
+ ), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w)
262
+ )
263
+ for i in range(num_layers)
264
+ ]
265
+ )
266
+
267
+ # Only add the time embedding for diffusion and not AutoEncoder
268
+ if self.t_emb_dim is not None:
269
+ self.t_emb_layers = nn.ModuleList(
270
+ [
271
+ nn.Sequential(
272
+ nn.SiLU(),
273
+ nn.Linear(
274
+ in_features=self.t_emb_dim, out_features=out_channels
275
+ ), # (batch_size, t_emb_dim) -> (batch_size, out_channels)
276
+ )
277
+ for i in range(num_layers)
278
+ ]
279
+ )
280
+
281
+ self.resnet_conv_second = nn.ModuleList(
282
+ [
283
+ nn.Sequential(
284
+ nn.GroupNorm(norm_channels, out_channels),
285
+ nn.SiLU(),
286
+ nn.Conv2d(
287
+ in_channels=out_channels,
288
+ out_channels=out_channels,
289
+ kernel_size=3,
290
+ stride=1,
291
+ padding=1,
292
+ ), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
293
+ )
294
+ for i in range(num_layers)
295
+ ]
296
+ )
297
+
298
+ self.residual_input_conv = nn.ModuleList(
299
+ [
300
+ nn.Conv2d(
301
+ in_channels=(in_channels if i == 0 else out_channels),
302
+ out_channels=out_channels,
303
+ kernel_size=1,
304
+ stride=1,
305
+ padding=0,
306
+ ) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w)
307
+ for i in range(num_layers)
308
+ ]
309
+ )
310
+
311
+ if self.attn:
312
+ self.attention_norms = nn.ModuleList(
313
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
314
+ )
315
+
316
+ self.attentions = nn.ModuleList(
317
+ [
318
+ nn.MultiheadAttention(
319
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
320
+ )
321
+ for i in range(num_layers)
322
+ ]
323
+ )
324
+
325
+ # Cross attention for text conditioning
326
+ if self.cross_attn:
327
+ assert (
328
+ context_dim is not None
329
+ ), "Context Dimension must be passed for cross attention"
330
+
331
+ self.cross_attention_norms = nn.ModuleList(
332
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
333
+ )
334
+
335
+ self.cross_attentions = nn.ModuleList(
336
+ [
337
+ nn.MultiheadAttention(
338
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
339
+ )
340
+ for i in range(num_layers)
341
+ ]
342
+ )
343
+
344
+ self.context_proj = nn.ModuleList(
345
+ [
346
+ nn.Linear(in_features=context_dim, out_features=out_channels)
347
+ for i in range(num_layers)
348
+ ]
349
+ )
350
+
351
+ # Down sample by a factor of 2
352
+ self.down_sample_conv = (
353
+ nn.Conv2d(
354
+ in_channels=out_channels,
355
+ out_channels=out_channels,
356
+ kernel_size=4,
357
+ stride=2,
358
+ padding=1,
359
+ )
360
+ if self.down_sample
361
+ else nn.Identity()
362
+ ) # (batch_size, out_channels, h / 2, w / 2)
363
+
364
+ def forward(self, x, t_emb=None, context=None):
365
+ out = x
366
+ for i in range(self.num_layers):
367
+ # Resnet block of UNET
368
+ resnet_input = out # (batch_size, c, h, w)
369
+
370
+ out = self.resnet_conv_first[i](out) # (batch_size, out_channels, h, w)
371
+
372
+ # Only add the time embedding for diffusion and not AutoEncoder
373
+ if self.t_emb_dim is not None:
374
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
375
+ out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze(
376
+ dim=-1
377
+ ) # (batch_size, out_channels, h, w)
378
+
379
+ out = self.resnet_conv_second[i](
380
+ out
381
+ ) # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
382
+
383
+ # Residual Connection
384
+ out = out + self.residual_input_conv[i](
385
+ resnet_input
386
+ ) # (batch_size, out_channels, h, w)
387
+
388
+ # Only do for Diffusion and not for AutoEncoder
389
+ if self.attn:
390
+ # Attention block of UNET
391
+ batch_size, channels, h, w = (
392
+ out.shape
393
+ ) # (batch_size, out_channels, h, w)
394
+
395
+ in_attn = out.reshape(
396
+ batch_size, channels, h * w
397
+ ) # (batch_size, out_channels, h * w)
398
+ in_attn = self.attention_norms[i](in_attn)
399
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels)
400
+
401
+ # Self-Attention
402
+ out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn)
403
+ out_attn = out_attn.transpose(1, 2).reshape(
404
+ batch_size, channels, h, w
405
+ ) # (batch_size, out_channels h, w)
406
+
407
+ # Skip connection
408
+ out = out + out_attn # (batch_size, out_channels h, w)
409
+
410
+ if self.cross_attn:
411
+ assert (
412
+ context is not None
413
+ ), "context cannot be None if cross attention layers are used"
414
+
415
+ batch_size, channels, h, w = (
416
+ out.shape
417
+ ) # (batch_size, out_channels, h, w)
418
+
419
+ in_attn = out.reshape(
420
+ batch_size, channels, h * w
421
+ ) # (batch_size, out_channels, h * w)
422
+ in_attn = self.cross_attention_norms[i](in_attn)
423
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels)
424
+
425
+ assert (
426
+ context.shape[0] == x.shape[0]
427
+ and context.shape[-1] == self.context_dim
428
+ ) # Make sure the batch_size and context_dim match with the model's parameters
429
+ context_proj = self.context_proj[i](
430
+ context
431
+ ) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, out_channels)
432
+
433
+ # Cross-Attention
434
+ out_attn, attn_weights = self.cross_attentions[i](
435
+ in_attn, context_proj, context_proj
436
+ ) # (batch_size, h * w, out_channels)
437
+ out_attn = out_attn.transpose(1, 2).reshape(
438
+ batch_size, channels, h, w
439
+ ) # (batch_size, out_channels, h, w)
440
+
441
+ # Skip Connection
442
+ out = out + out_attn # (batch_size, out_channels, h, w)
443
+
444
+ # Downsampling
445
+ out = self.down_sample_conv(out) # (batch_size, out_channels, h / 2, w / 2)
446
+ return out
447
+
448
+
449
+ class MidBlock(nn.Module):
450
+ """
451
+ Mid conv block with attention.
452
+ 1. Resnet block with time embedding
453
+ 2. Attention block
454
+ 3. Resnet block with time embedding
455
+
456
+ in_channels: Number of channels in the input feature map.
457
+ out_channels: Number of channels produced by this block.
458
+ t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None.
459
+ num_heads: Number of attention heads (used if attention is enabled).
460
+ num_layers: How many sub-blocks to apply in sequence.
461
+ norm_channels: Number of groups for GroupNorm.
462
+ cross_attn: Whether to apply cross-attention.
463
+ context_dim: If performing cross-attention, provide a context_dim for extra conditioning context.
464
+ """
465
+
466
+ def __init__(
467
+ self,
468
+ in_channels,
469
+ out_channels,
470
+ t_emb_dim,
471
+ num_heads,
472
+ num_layers,
473
+ norm_channels,
474
+ cross_attn=None,
475
+ context_dim=None,
476
+ ):
477
+ super().__init__()
478
+
479
+ self.num_layers = num_layers
480
+ self.t_emb_dim = t_emb_dim
481
+ self.context_dim = context_dim
482
+ self.cross_attn = cross_attn
483
+
484
+ self.resnet_conv_first = nn.ModuleList(
485
+ [
486
+ nn.Sequential(
487
+ nn.GroupNorm(
488
+ norm_channels, in_channels if i == 0 else out_channels
489
+ ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels
490
+ nn.SiLU(),
491
+ nn.Conv2d(
492
+ in_channels=(in_channels if i == 0 else out_channels),
493
+ out_channels=out_channels,
494
+ kernel_size=3,
495
+ stride=1,
496
+ padding=1,
497
+ ), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w)
498
+ )
499
+ for i in range(num_layers + 1)
500
+ ]
501
+ )
502
+
503
+ # Only add the time embedding for diffusion and not AutoEncoder
504
+ if self.t_emb_dim is not None:
505
+ self.t_emb_layers = nn.ModuleList(
506
+ [
507
+ nn.Sequential(
508
+ nn.SiLU(),
509
+ nn.Linear(
510
+ in_features=self.t_emb_dim, out_features=out_channels
511
+ ), # (batch_size, t_emb_dim) -> (batch_size, out_channels)
512
+ )
513
+ for i in range(num_layers + 1)
514
+ ]
515
+ )
516
+
517
+ self.resnet_conv_second = nn.ModuleList(
518
+ [
519
+ nn.Sequential(
520
+ nn.GroupNorm(norm_channels, out_channels),
521
+ nn.SiLU(),
522
+ nn.Conv2d(
523
+ in_channels=out_channels,
524
+ out_channels=out_channels,
525
+ kernel_size=3,
526
+ stride=1,
527
+ padding=1,
528
+ ), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
529
+ )
530
+ for i in range(num_layers + 1)
531
+ ]
532
+ )
533
+
534
+ self.residual_input_conv = nn.ModuleList(
535
+ [
536
+ nn.Conv2d(
537
+ in_channels=(in_channels if i == 0 else out_channels),
538
+ out_channels=out_channels,
539
+ kernel_size=1,
540
+ stride=1,
541
+ padding=0,
542
+ ) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w)
543
+ for i in range(num_layers + 1)
544
+ ]
545
+ )
546
+
547
+ self.attention_norms = nn.ModuleList(
548
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
549
+ )
550
+
551
+ self.attentions = nn.ModuleList(
552
+ [
553
+ nn.MultiheadAttention(
554
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
555
+ )
556
+ for i in range(num_layers)
557
+ ]
558
+ )
559
+
560
+ # Cross attention for text conditioning
561
+ if self.cross_attn:
562
+ assert (
563
+ context_dim is not None
564
+ ), "Context Dimension must be passed for cross attention"
565
+
566
+ self.cross_attention_norms = nn.ModuleList(
567
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
568
+ )
569
+
570
+ self.cross_attentions = nn.ModuleList(
571
+ [
572
+ nn.MultiheadAttention(
573
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
574
+ )
575
+ for i in range(num_layers)
576
+ ]
577
+ )
578
+
579
+ self.context_proj = nn.ModuleList(
580
+ [
581
+ nn.Linear(in_features=context_dim, out_features=out_channels)
582
+ for i in range(num_layers)
583
+ ]
584
+ )
585
+
586
+ def forward(self, x, t_emb=None, context=None):
587
+ out = x
588
+
589
+ # First ResNet block
590
+ resnet_input = out # (batch_size, c, h, w)
591
+ out = self.resnet_conv_first[0](out) # (batch_size, out_channels, h, w)
592
+
593
+ # Only add the time embedding for diffusion and not AutoEncoder
594
+ if self.t_emb_dim is not None:
595
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
596
+ out = out + self.t_emb_layers[0](t_emb).unsqueeze(dim=-1).unsqueeze(
597
+ dim=-1
598
+ ) # (batch_size, out_channels, h, w)
599
+
600
+ out = self.resnet_conv_second[0](
601
+ out
602
+ ) # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
603
+
604
+ # Residual Connection
605
+ out = out + self.residual_input_conv[0](
606
+ resnet_input
607
+ ) # (batch_size, out_channels, h, w)
608
+
609
+ for i in range(self.num_layers):
610
+ # Attention Block
611
+ batch_size, channels, h, w = out.shape # (batch_size, out_channels, h, w)
612
+
613
+ # Do for both Diffusion and AutoEncoder
614
+ in_attn = out.reshape(
615
+ batch_size, channels, h * w
616
+ ) # (batch_size, out_channels, h * w)
617
+ in_attn = self.attention_norms[i](in_attn)
618
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels)
619
+
620
+ # Self-Attention
621
+ out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn)
622
+ out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
623
+
624
+ # Skip connection
625
+ out = out + out_attn # (batch_size, out_channels h, w)
626
+
627
+ if self.cross_attn:
628
+ assert (
629
+ context is not None
630
+ ), "context cannot be None if cross attention layers are used"
631
+ batch_size, channels, h, w = out.shape
632
+
633
+ in_attn = out.reshape(
634
+ batch_size, channels, h * w
635
+ ) # (batch_size, out_channels, h * w)
636
+ in_attn = self.cross_attention_norms[i](in_attn)
637
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels)
638
+
639
+ assert (
640
+ context.shape[0] == x.shape[0]
641
+ and context.shape[-1] == self.context_dim
642
+ ) # Make sure the batch_size and context_dim match with the model's parameters
643
+ context_proj = self.context_proj[i](
644
+ context
645
+ ) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, context_dim)
646
+
647
+ # Cross-Attention
648
+ out_attn, attn_weights = self.cross_attentions[i](
649
+ in_attn, context_proj, context_proj
650
+ )
651
+ out_attn = out_attn.transpose(1, 2).reshape(
652
+ batch_size, channels, h, w
653
+ ) # (batch_size, out_channels, h, w)
654
+
655
+ # Skip Connection
656
+ out = out + out_attn # (batch_size, out_channels h, w)
657
+
658
+ # Resnet Block
659
+ resnet_input = out
660
+ out = self.resnet_conv_first[i + 1](
661
+ out
662
+ ) # (batch_size, out_channels h, w) -> (batch_size, out_channels h, w)
663
+
664
+ # Only add the time embedding for diffusion and not AutoEncoder
665
+ if self.t_emb_dim is not None:
666
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
667
+ out = out + self.t_emb_layers[i + 1](t_emb).unsqueeze(dim=-1).unsqueeze(
668
+ dim=-1
669
+ ) # (batch_size, out_channels h, w)
670
+
671
+ out = self.resnet_conv_second[i + 1](
672
+ out
673
+ ) # (batch_size, out_channels h, w) -> (batch_size, out_channels h, w)
674
+
675
+ # Residual Connection
676
+ out = out + self.residual_input_conv[i + 1](
677
+ resnet_input
678
+ ) # (batch_size, out_channels, h, w)
679
+
680
+ return out
681
+
682
+
683
+ class UpBlock(nn.Module):
684
+ """
685
+ Up conv block with attention.
686
+ 1. Upsample
687
+ 1. Concatenate Down block output
688
+ 2. Resnet block with time embedding
689
+ 3. Attention Block
690
+
691
+ in_channels: Number of channels in the input feature map.
692
+ out_channels: Number of channels produced by this block.
693
+ t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None.
694
+ up_sample: Whether to apply upsampling at the end.
695
+ num_heads: Number of attention heads (used if attention is enabled).
696
+ num_layers: How many sub-blocks to apply in sequence.
697
+ attn: Whether to apply self-attention
698
+ norm_channels: Number of groups for GroupNorm.
699
+ """
700
+
701
+ def __init__(
702
+ self,
703
+ in_channels,
704
+ out_channels,
705
+ t_emb_dim,
706
+ up_sample,
707
+ num_heads,
708
+ num_layers,
709
+ attn,
710
+ norm_channels,
711
+ ):
712
+ super().__init__()
713
+
714
+ self.num_layers = num_layers
715
+ self.up_sample = up_sample
716
+ self.t_emb_dim = t_emb_dim
717
+ self.attn = attn
718
+
719
+ # Upsample by a factor of 2
720
+ self.up_sample_conv = (
721
+ nn.ConvTranspose2d(
722
+ in_channels=in_channels,
723
+ out_channels=in_channels,
724
+ kernel_size=4,
725
+ stride=2,
726
+ padding=1,
727
+ )
728
+ if self.up_sample
729
+ else nn.Identity()
730
+ ) # (batch_size, c, h * 2, w * 2)
731
+
732
+ self.resnet_conv_first = nn.ModuleList(
733
+ [
734
+ nn.Sequential(
735
+ nn.GroupNorm(
736
+ norm_channels, in_channels if i == 0 else out_channels
737
+ ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels
738
+ nn.SiLU(),
739
+ nn.Conv2d(
740
+ in_channels=(in_channels if i == 0 else out_channels),
741
+ out_channels=out_channels,
742
+ kernel_size=3,
743
+ stride=1,
744
+ padding=1,
745
+ ), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w)
746
+ )
747
+ for i in range(num_layers)
748
+ ]
749
+ )
750
+
751
+ # Only add the time embedding for diffusion and not AutoEncoder
752
+ if self.t_emb_dim is not None:
753
+ self.t_emb_layers = nn.ModuleList(
754
+ [
755
+ nn.Sequential(
756
+ nn.SiLU(),
757
+ nn.Linear(
758
+ in_features=self.t_emb_dim, out_features=out_channels
759
+ ), # (batch_size, t_emb_dim) -> (batch_size, out_channels)
760
+ )
761
+ for i in range(num_layers)
762
+ ]
763
+ )
764
+
765
+ self.resnet_conv_second = nn.ModuleList(
766
+ [
767
+ nn.Sequential(
768
+ nn.GroupNorm(norm_channels, out_channels),
769
+ nn.SiLU(),
770
+ nn.Conv2d(
771
+ in_channels=out_channels,
772
+ out_channels=out_channels,
773
+ kernel_size=3,
774
+ stride=1,
775
+ padding=1,
776
+ ), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w)
777
+ )
778
+ for i in range(num_layers)
779
+ ]
780
+ )
781
+
782
+ self.residual_input_conv = nn.ModuleList(
783
+ [
784
+ nn.Conv2d(
785
+ in_channels=(in_channels if i == 0 else out_channels),
786
+ out_channels=out_channels,
787
+ kernel_size=1,
788
+ stride=1,
789
+ padding=0,
790
+ ) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w)
791
+ for i in range(num_layers)
792
+ ]
793
+ )
794
+
795
+ if self.attn:
796
+ self.attention_norms = nn.ModuleList(
797
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
798
+ )
799
+
800
+ self.attentions = nn.ModuleList(
801
+ [
802
+ nn.MultiheadAttention(
803
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
804
+ )
805
+ for i in range(num_layers)
806
+ ]
807
+ )
808
+
809
+ def forward(self, x, out_down=None, t_emb=None):
810
+ # x shape: (batch_size, c, h, w)
811
+
812
+ # Upsample
813
+ x = self.up_sample_conv(
814
+ x
815
+ ) # (batch_size, c, h, w) -> (batch_size, c, h * 2, w * 2)
816
+
817
+ # *Only do for diffusion
818
+ # Concatenate with the output of respective DownBlock
819
+ if out_down is not None:
820
+ x = torch.cat(
821
+ [x, out_down], dim=1
822
+ ) # (batch_size, c, h * 2, w * 2) -> (batch_size, c * 2, h * 2, w * 2)
823
+
824
+ out = x # (batch_size, c, h * 2, w * 2)
825
+
826
+ for i in range(self.num_layers):
827
+ # Resnet block
828
+ resnet_input = out
829
+ out = self.resnet_conv_first[i](
830
+ out
831
+ ) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
832
+
833
+ # Only add the time embedding for diffusion and not AutoEncoder
834
+ if self.t_emb_dim is not None:
835
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
836
+ out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze(
837
+ dim=-1
838
+ ) # (batch_size, out_channels, h * 2, w * 2)
839
+
840
+ out = self.resnet_conv_second[i](
841
+ out
842
+ ) # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
843
+
844
+ # Residual Connection
845
+ out = out + self.residual_input_conv[i](
846
+ resnet_input
847
+ ) # (batch_size, out_channels, h * 2, w * 2)
848
+
849
+ # Only do for Diffusion and not for AutoEncoder
850
+ if self.attn:
851
+ # Attention block of UNET
852
+ batch_size, channels, h, w = out.shape
853
+
854
+ in_attn = out.reshape(
855
+ batch_size, channels, h * w
856
+ ) # (batch_size, out_channels, h * w * 4)
857
+ in_attn = self.attention_norms[i](in_attn)
858
+ in_attn = in_attn.transpose(
859
+ 1, 2
860
+ ) # (batch_size, h * w * 4, out_channels)
861
+
862
+ # Self-Attention
863
+ out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn)
864
+ out_attn = out_attn.transpose(1, 2).reshape(
865
+ batch_size, channels, h, w
866
+ ) # (batch_size, out_channels h * 2, w * 2)
867
+
868
+ # Skip connection
869
+ out = out + out_attn # (batch_size, out_channels h * 2, w * 2)
870
+
871
+ return out # (batch_size, out_channels h * 2, w * 2)
872
+
873
+
874
+ class UpBlockUNet(nn.Module):
875
+ """
876
+ Up conv block with attention.
877
+ 1. Upsample
878
+ 1. Concatenate Down block output
879
+ 2. Resnet block with time embedding
880
+ 3. Attention Block
881
+
882
+ in_channels: Number of channels in the input feature map. (It is passed in multiplied by 2 for concatenation with DownBlock output)
883
+ out_channels: Number of channels produced by this block.
884
+ t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None.
885
+ up_sample: Whether to apply upsampling at the end.
886
+ num_heads: Number of attention heads (used if attention is enabled).
887
+ num_layers: How many sub-blocks to apply in sequence.
888
+ norm_channels: Number of groups for GroupNorm.
889
+ cross_attn: Whether to apply cross-attention.
890
+ context_dim: If performing cross-attention, provide a context_dim for extra conditioning context.
891
+ """
892
+
893
+ def __init__(
894
+ self,
895
+ in_channels,
896
+ out_channels,
897
+ t_emb_dim,
898
+ up_sample,
899
+ num_heads,
900
+ num_layers,
901
+ norm_channels,
902
+ cross_attn=False,
903
+ context_dim=None,
904
+ ):
905
+ super().__init__()
906
+
907
+ self.num_layers = num_layers
908
+ self.up_sample = up_sample
909
+ self.t_emb_dim = t_emb_dim
910
+ self.cross_attn = cross_attn
911
+ self.context_dim = context_dim
912
+
913
+ self.up_sample_conv = (
914
+ nn.ConvTranspose2d(
915
+ in_channels=(in_channels // 2),
916
+ out_channels=(in_channels // 2),
917
+ kernel_size=4,
918
+ stride=2,
919
+ padding=1,
920
+ )
921
+ if self.up_sample
922
+ else nn.Identity()
923
+ ) # (batch_size, in_channels // 2, h * 2, w * 2)
924
+
925
+ self.resnet_conv_first = nn.ModuleList(
926
+ [
927
+ nn.Sequential(
928
+ nn.GroupNorm(
929
+ norm_channels, in_channels if i == 0 else out_channels
930
+ ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels
931
+ nn.SiLU(),
932
+ nn.Conv2d(
933
+ in_channels=(in_channels if i == 0 else out_channels),
934
+ out_channels=out_channels,
935
+ kernel_size=3,
936
+ stride=1,
937
+ padding=1,
938
+ ), # (batch_size, in_channels, h * 2, w. * 2) -> (batch_size, out_channels, h * 2, w * 2) - Starts at in_channels and not in_channels // 2 because of concatenation
939
+ )
940
+ for i in range(num_layers)
941
+ ]
942
+ )
943
+
944
+ # Only add the time embedding if needed for UNET in diffusion
945
+ # Do not add the time embedding in the AutoEncoder
946
+ if self.t_emb_dim is not None:
947
+ self.t_emb_layers = nn.ModuleList(
948
+ [
949
+ nn.Sequential(
950
+ nn.SiLU(),
951
+ nn.Linear(
952
+ in_features=self.t_emb_dim, out_features=out_channels
953
+ ), # (batch_size, t_emb_dim) -> (batch_size, out_channels)
954
+ )
955
+ for i in range(num_layers)
956
+ ]
957
+ )
958
+
959
+ self.resnet_conv_second = nn.ModuleList(
960
+ [
961
+ nn.Sequential(
962
+ nn.GroupNorm(norm_channels, out_channels),
963
+ nn.SiLU(),
964
+ nn.Conv2d(
965
+ in_channels=out_channels,
966
+ out_channels=out_channels,
967
+ kernel_size=3,
968
+ stride=1,
969
+ padding=1,
970
+ ), # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
971
+ )
972
+ for i in range(num_layers)
973
+ ]
974
+ )
975
+
976
+ self.residual_input_conv = nn.ModuleList(
977
+ [
978
+ nn.Conv2d(
979
+ in_channels=(in_channels if i == 0 else out_channels),
980
+ out_channels=out_channels,
981
+ kernel_size=1,
982
+ stride=1,
983
+ padding=0,
984
+ )
985
+ for i in range(
986
+ num_layers
987
+ ) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
988
+ ]
989
+ )
990
+
991
+ self.attention_norms = nn.ModuleList(
992
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
993
+ )
994
+
995
+ self.attentions = nn.ModuleList(
996
+ [
997
+ nn.MultiheadAttention(
998
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
999
+ )
1000
+ for i in range(num_layers)
1001
+ ]
1002
+ )
1003
+
1004
+ # Cross attention for text conditioning
1005
+ if self.cross_attn:
1006
+ assert (
1007
+ context_dim is not None
1008
+ ), "Context Dimension must be passed for cross attention"
1009
+
1010
+ self.cross_attention_norms = nn.ModuleList(
1011
+ [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)]
1012
+ )
1013
+
1014
+ self.cross_attentions = nn.ModuleList(
1015
+ [
1016
+ nn.MultiheadAttention(
1017
+ embed_dim=out_channels, num_heads=num_heads, batch_first=True
1018
+ )
1019
+ for i in range(num_layers)
1020
+ ]
1021
+ )
1022
+
1023
+ self.context_proj = nn.ModuleList(
1024
+ [
1025
+ nn.Linear(in_features=context_dim, out_features=out_channels)
1026
+ for i in range(num_layers)
1027
+ ]
1028
+ )
1029
+
1030
+ def forward(self, x, out_down=None, t_emb=None, context=None):
1031
+ # x shape: (batch_size, in_channels // 2, h, w)
1032
+
1033
+ # Upsample
1034
+ x = self.up_sample_conv(
1035
+ x
1036
+ ) # (batch_size, in_channels // 2, h, w) -> (batch_size, in_channels // 2, h * 2, w * 2)
1037
+
1038
+ # Concatenate with the output of respective DownBlock
1039
+ if out_down is not None:
1040
+ x = torch.cat(
1041
+ [x, out_down], dim=1
1042
+ ) # (batch_size, in_channels // 2, h * 2, w * 2) -> (batch_size, in_channels, h * 2, w * 2)
1043
+
1044
+ out = x # (batch_size, in_channels, h * 2, w * 2)
1045
+ for i in range(self.num_layers):
1046
+ # Resnet block
1047
+ resnet_input = out
1048
+
1049
+ out = self.resnet_conv_first[i](
1050
+ out
1051
+ ) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
1052
+
1053
+ if self.t_emb_dim is not None:
1054
+ # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1)
1055
+ out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze(
1056
+ dim=-1
1057
+ ) # (batch_size, out_channels, h * 2, w * 2)
1058
+
1059
+ out = self.resnet_conv_second[i](
1060
+ out
1061
+ ) # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2)
1062
+
1063
+ # Residual Connection
1064
+ out = out + self.residual_input_conv[i](
1065
+ resnet_input
1066
+ ) # (batch_size, out_channels, h * 2, w * 2)
1067
+
1068
+ # Attention block of UNET
1069
+ batch_size, channels, h, w = (
1070
+ out.shape
1071
+ ) # (batch_size, out_channels, h * 2, w * 2)
1072
+
1073
+ in_attn = out.reshape(
1074
+ batch_size, channels, h * w
1075
+ ) # (batch_size, out_channels, h * w * 4)
1076
+ in_attn = self.attention_norms[i](in_attn)
1077
+ in_attn = in_attn.transpose(1, 2) # (batch_size, h * w * 4, out_channels)
1078
+
1079
+ # Self-Attention
1080
+ out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn)
1081
+ out_attn = out_attn.transpose(1, 2).reshape(
1082
+ batch_size, channels, h, w
1083
+ ) # (batch_size, out_channels h * 2, w * 2)
1084
+
1085
+ # Skip connection
1086
+ out = out + out_attn # (batch_size, out_channels h * 2, w * 2)
1087
+
1088
+ if self.cross_attn:
1089
+ assert (
1090
+ context is not None
1091
+ ), "context cannot be None if cross attention layers are used"
1092
+ batch_size, channels, h, w = out.shape
1093
+
1094
+ in_attn = out.reshape(
1095
+ batch_size, channels, h * w
1096
+ ) # (batch_size, out_channels, h * w * 4)
1097
+ in_attn = self.cross_attention_norms[i](in_attn)
1098
+ in_attn = in_attn.transpose(
1099
+ 1, 2
1100
+ ) # (batch_size, h * w * 4, out_channels)
1101
+
1102
+ assert (
1103
+ len(context.shape) == 3
1104
+ ), "Context shape does not match batch_size, _, context_dim"
1105
+
1106
+ assert (
1107
+ context.shape[0] == x.shape[0]
1108
+ and context.shape[-1] == self.context_dim
1109
+ ), "Context shape does not match batch_size, _, context_dim" # Make sure the batch_size and context_dim match with the model's parameters
1110
+ context_proj = self.context_proj[i](
1111
+ context
1112
+ ) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, context_dim)
1113
+
1114
+ # Cross-Attention
1115
+ out_attn, attn_weights = self.cross_attentions[i](
1116
+ in_attn, context_proj, context_proj
1117
+ )
1118
+ out_attn = out_attn.transpose(1, 2).reshape(
1119
+ batch_size, channels, h, w
1120
+ ) # (batch_size, out_channels, h * 2, w * 2)
1121
+
1122
+ # Skip Connection
1123
+ out = out + out_attn # (batch_size, out_channels h * 2, w * 2)
1124
+
1125
+ return out # (batch_size, out_channels h * 2, w * 2)
1126
+
1127
+
1128
+ class VQVAE(nn.Module):
1129
+ def __init__(self, image_channels, model_config):
1130
+ super().__init__()
1131
+
1132
+ self.down_channels = model_config["down_channels"]
1133
+ self.mid_channels = model_config["mid_channels"]
1134
+ self.down_sample = model_config["down_sample"]
1135
+ self.num_down_layers = model_config["num_down_layers"]
1136
+ self.num_mid_layers = model_config["num_mid_layers"]
1137
+ self.num_up_layers = model_config["num_up_layers"]
1138
+
1139
+ # To disable attention in Downblock of Encoder and Upblock of Decoder
1140
+ self.attns = model_config["attn_down"]
1141
+
1142
+ # Latent Dimension
1143
+ self.z_channels = model_config[
1144
+ "z_channels"
1145
+ ] # number of channels in the latent representation
1146
+ self.codebook_size = model_config[
1147
+ "codebook_size"
1148
+ ] # number of discrete code vectors available
1149
+ self.norm_channels = model_config["norm_channels"]
1150
+ self.num_heads = model_config["num_heads"]
1151
+
1152
+ assert self.mid_channels[0] == self.down_channels[-1]
1153
+ assert self.mid_channels[-1] == self.down_channels[-1]
1154
+ assert len(self.down_sample) == len(self.down_channels) - 1
1155
+ assert len(self.attns) == len(self.down_channels) - 1
1156
+
1157
+ # Wherever we downsample in the encoder, use upsampling in the decoder at the corresponding location
1158
+ self.up_sample = list(reversed(self.down_sample))
1159
+
1160
+ # Encoder
1161
+ self.encoder_conv_in = nn.Conv2d(
1162
+ in_channels=image_channels,
1163
+ out_channels=self.down_channels[0],
1164
+ kernel_size=3,
1165
+ stride=1,
1166
+ padding=1,
1167
+ ) # (batch_size, 3, h, w) -> (batch_size, c, h, w)
1168
+
1169
+ # Downblock + Midblock
1170
+ self.encoder_layers = nn.ModuleList([])
1171
+ for i in range(len(self.down_channels) - 1):
1172
+ self.encoder_layers.append(
1173
+ DownBlock(
1174
+ in_channels=self.down_channels[i],
1175
+ out_channels=self.down_channels[i + 1],
1176
+ t_emb_dim=None,
1177
+ down_sample=self.down_sample[i],
1178
+ num_heads=self.num_heads,
1179
+ num_layers=self.num_down_layers,
1180
+ attn=self.attns[i],
1181
+ norm_channels=self.norm_channels,
1182
+ )
1183
+ )
1184
+
1185
+ self.encoder_mids = nn.ModuleList([])
1186
+ for i in range(len(self.mid_channels) - 1):
1187
+ self.encoder_mids.append(
1188
+ MidBlock(
1189
+ in_channels=self.mid_channels[i],
1190
+ out_channels=self.mid_channels[i + 1],
1191
+ t_emb_dim=None,
1192
+ num_heads=self.num_heads,
1193
+ num_layers=self.num_mid_layers,
1194
+ norm_channels=self.norm_channels,
1195
+ )
1196
+ )
1197
+
1198
+ self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
1199
+
1200
+ self.encoder_conv_out = nn.Conv2d(
1201
+ in_channels=self.down_channels[-1],
1202
+ out_channels=self.z_channels,
1203
+ kernel_size=3,
1204
+ stride=1,
1205
+ padding=1,
1206
+ ) # (batch_size, z_channels, h', w')
1207
+
1208
+ # Pre Quantization Convolution
1209
+ self.pre_quant_conv = nn.Conv2d(
1210
+ in_channels=self.z_channels,
1211
+ out_channels=self.z_channels,
1212
+ kernel_size=1,
1213
+ stride=1,
1214
+ padding=0,
1215
+ ) # (batch_size, z_channels, h', w')
1216
+
1217
+ # Codebook Vectors
1218
+ self.embedding = nn.Embedding(
1219
+ self.codebook_size, self.z_channels
1220
+ ) # (codebook_size, z_channels)
1221
+
1222
+ # Decoder
1223
+
1224
+ # Post Quantization Convolution
1225
+ self.post_quant_conv = nn.Conv2d(
1226
+ in_channels=self.z_channels,
1227
+ out_channels=self.z_channels,
1228
+ kernel_size=1,
1229
+ stride=1,
1230
+ padding=0,
1231
+ ) # (batch_size, z_channels, h', w')
1232
+
1233
+ self.decoder_conv_in = nn.Conv2d(
1234
+ in_channels=self.z_channels,
1235
+ out_channels=self.mid_channels[-1],
1236
+ kernel_size=3,
1237
+ stride=1,
1238
+ padding=1,
1239
+ ) # (batch_size, c, h', w')
1240
+
1241
+ # Midblock + Upblock
1242
+ self.decoder_mids = nn.ModuleList([])
1243
+ for i in reversed(range(1, len(self.mid_channels))):
1244
+ self.decoder_mids.append(
1245
+ MidBlock(
1246
+ in_channels=self.mid_channels[i],
1247
+ out_channels=self.mid_channels[i - 1],
1248
+ t_emb_dim=None,
1249
+ num_heads=self.num_heads,
1250
+ num_layers=self.num_mid_layers,
1251
+ norm_channels=self.norm_channels,
1252
+ )
1253
+ )
1254
+
1255
+ self.decoder_layers = nn.ModuleList([])
1256
+ for i in reversed(range(1, len(self.down_channels))):
1257
+ self.decoder_layers.append(
1258
+ UpBlock(
1259
+ in_channels=self.down_channels[i],
1260
+ out_channels=self.down_channels[i - 1],
1261
+ t_emb_dim=None,
1262
+ up_sample=self.down_sample[i - 1],
1263
+ num_heads=self.num_heads,
1264
+ num_layers=self.num_up_layers,
1265
+ attn=self.attns[i - 1],
1266
+ norm_channels=self.norm_channels,
1267
+ )
1268
+ )
1269
+
1270
+ self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
1271
+
1272
+ self.decoder_conv_out = nn.Conv2d(
1273
+ in_channels=self.down_channels[0],
1274
+ out_channels=image_channels,
1275
+ kernel_size=3,
1276
+ stride=1,
1277
+ padding=1,
1278
+ ) # (batch_size, c, h, w)
1279
+
1280
+ def quantize(self, x):
1281
+ batch_size, c, h, w = x.shape # (batch_size, z_channels, h, w)
1282
+
1283
+ x = x.permute(
1284
+ 0, 2, 3, 1
1285
+ ) # (batch_size, z_channels, h, w) -> (batch_size, h, w, z_channels)
1286
+ x = x.reshape(
1287
+ batch_size, -1, c
1288
+ ) # (batch_size, h, w, z_channels) -> (batch_size, h * w, z_channels)
1289
+
1290
+ # Find the nearest codebook vector with distance between (batch_size, h * w, z_channels) and (batch_size, code_book_size, z_channels) -> (batch_size, h * w, code_book_size)
1291
+ dist = torch.cdist(
1292
+ x, self.embedding.weight.unsqueeze(dim=0).repeat((batch_size, 1, 1))
1293
+ ) # cdist calculates the batched p-norm distance
1294
+
1295
+ # (batch_size, h * w) Get the index of the closet codebook vector
1296
+ min_encoding_indices = torch.argmin(dist, dim=-1)
1297
+
1298
+ # Replace the encoder output with the nearest codebook
1299
+ quant_out = torch.index_select(
1300
+ self.embedding.weight, 0, min_encoding_indices.view(-1)
1301
+ ) # (batch_size, h * w, z_channels)
1302
+
1303
+ x = x.reshape((-1, c)) # (batch_size * h * w, z_channels)
1304
+
1305
+ # Commitment and Codebook Loss using mSE
1306
+ commitment_loss = torch.mean((quant_out.detach() - x) ** 2)
1307
+ codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
1308
+
1309
+ quantize_losses = {
1310
+ "codebook_loss": codebook_loss,
1311
+ "commitment_loss": commitment_loss,
1312
+ }
1313
+
1314
+ # Straight through estimation
1315
+ quant_out = x + (quant_out - x).detach()
1316
+
1317
+ quant_out = quant_out.reshape(batch_size, h, w, c).permute(
1318
+ 0, 3, 1, 2
1319
+ ) # (batch_size, z_channels, h, w)
1320
+ min_encoding_indices = min_encoding_indices.reshape(
1321
+ (-1, h, w)
1322
+ ) # (batch_size, h, w)
1323
+
1324
+ return quant_out, quantize_losses, min_encoding_indices
1325
+
1326
+ def encode(self, x):
1327
+ out = self.encoder_conv_in(x) # (batch_size, self.down_channels[0], h, w)
1328
+
1329
+ # (batch_size, self.down_channels[0], h, w) -> (batch_size, self.down_channels[-1], h', w')
1330
+ for idx, down in enumerate(self.encoder_layers):
1331
+ out = down(out)
1332
+
1333
+ # (batch_size, self.down_channels[-1], h', w') -> (batch_size, self.mid_channels[-1], h', w')
1334
+ for mid in self.encoder_mids:
1335
+ out = mid(out)
1336
+
1337
+ out = self.encoder_norm_out(out)
1338
+ out = F.silu(out)
1339
+
1340
+ out = self.encoder_conv_out(
1341
+ out
1342
+ ) # (batch_size, self.mid_channels[-1], h', w') -> (batch_size, self.z_channels, h', w')
1343
+ out = self.pre_quant_conv(
1344
+ out
1345
+ ) # (batch_size, self.z_channels, h', w') -> (batch_size, self.z_channels, h', w')
1346
+
1347
+ out, quant_losses, min_encoding_indices = self.quantize(
1348
+ out
1349
+ ) # (batch_size, self.z_channels, h', w'), (codebook_loss, commitment_loss), (batch_size, h, w)
1350
+ return out, quant_losses
1351
+
1352
+ def decode(self, z):
1353
+ out = z
1354
+ out = self.post_quant_conv(
1355
+ out
1356
+ ) # (batch_size, self.z_channels, h', w') -> (batch_size, self.z_channels, h', w')
1357
+ out = self.decoder_conv_in(
1358
+ out
1359
+ ) # (batch_size, self.z_channels, h', w') -> (batch_size, self.mid_channels[-1], h', w')
1360
+
1361
+ # (batch_size, self.mid_channels[-1], h', w') -> (batch_size, self.down_channels[-1], h', w')
1362
+ for mid in self.decoder_mids:
1363
+ out = mid(out)
1364
+
1365
+ # (batch_size, self.down_channels[-1], h', w') -> (batch_size, self.down_channels[0], h, w)
1366
+ for idx, up in enumerate(self.decoder_layers):
1367
+ out = up(out)
1368
+
1369
+ out = self.decoder_norm_out(out)
1370
+ out = F.silu(out)
1371
+
1372
+ out = self.decoder_conv_out(
1373
+ out
1374
+ ) # (batch_size, self.down_channels[0], h, w) -> (batch_size, c, h, w)
1375
+ return out
1376
+
1377
+ def forward(self, x):
1378
+ # x shape: (batch_size, c, h, w)
1379
+
1380
+ z, quant_losses = self.encode(
1381
+ x
1382
+ ) # (batch_size, self.z_channels, h', w'), (codebook_loss, commitment_loss)
1383
+ out = self.decode(z) # (batch_size, c, h, w)
1384
+
1385
+ return out, z, quant_losses
1386
+
1387
+
1388
+ def validate_image_conditional_input(cond_input, x):
1389
+ assert (
1390
+ "image" in cond_input
1391
+ ), "Model initialized with image conditioning but cond_input has no image information"
1392
+ assert (
1393
+ cond_input["image"].shape[0] == x.shape[0]
1394
+ ), "Batch size mismatch of image condition and input"
1395
+ assert (
1396
+ cond_input["image"].shape[2] % x.shape[2] == 0
1397
+ ), "Height/Width of image condition must be divisible by latent input"
1398
+
1399
+
1400
+ def validate_class_conditional_input(cond_input, x, num_classes):
1401
+ assert (
1402
+ "class" in cond_input
1403
+ ), "Model initialized with class conditioning but cond_input has no class information"
1404
+ assert cond_input["class"].shape == (
1405
+ x.shape[0],
1406
+ num_classes,
1407
+ ), "Shape of class condition input must match (Batch Size, )"
1408
+
1409
+
1410
+ def get_config_value(config, key, default_value):
1411
+ return config[key] if key in config else default_value
1412
+
1413
+
1414
+ class UNet(nn.Module):
1415
+ """
1416
+ Unet model comprising
1417
+ Down blocks, Midblocks and Uplocks
1418
+ """
1419
+
1420
+ def __init__(self, image_channels, model_config):
1421
+ super().__init__()
1422
+
1423
+ self.down_channels = model_config["down_channels"]
1424
+ self.mid_channels = model_config["mid_channels"]
1425
+ self.t_emb_dim = model_config["time_emb_dim"]
1426
+ self.down_sample = model_config["down_sample"]
1427
+ self.num_down_layers = model_config["num_down_layers"]
1428
+ self.num_mid_layers = model_config["num_mid_layers"]
1429
+ self.num_up_layers = model_config["num_up_layers"]
1430
+ self.attns = model_config["attn_down"]
1431
+ self.norm_channels = model_config["norm_channels"]
1432
+ self.num_heads = model_config["num_heads"]
1433
+ self.conv_out_channels = model_config["conv_out_channels"]
1434
+
1435
+ assert self.mid_channels[0] == self.down_channels[-1]
1436
+ assert self.mid_channels[-1] == self.down_channels[-2]
1437
+ assert len(self.down_sample) == len(self.down_channels) - 1
1438
+ assert len(self.attns) == len(self.down_channels) - 1
1439
+
1440
+ # Class, Mask, and Text Conditioning Config
1441
+ self.class_cond = False
1442
+ self.text_cond = False
1443
+ self.image_cond = False
1444
+ self.text_embed_dim = None
1445
+ self.condition_config = get_config_value(
1446
+ model_config, "condition_config", None
1447
+ ) # Get the dictionary containing conditional information
1448
+
1449
+ if self.condition_config is not None:
1450
+ assert (
1451
+ "condition_types" in self.condition_config
1452
+ ), "Condition Type not provided in model config"
1453
+ condition_types = self.condition_config["condition_types"]
1454
+
1455
+ # For class, text, and image, get necessary parameters
1456
+ if "class" in condition_types:
1457
+ self.class_cond = True
1458
+ self.num_classes = self.condition_config["class_condition_config"][
1459
+ "num_classes"
1460
+ ]
1461
+
1462
+ if "text" in condition_types:
1463
+ self.text_cond = True
1464
+ self.text_embed_dim = self.condition_config["text_condition_config"][
1465
+ "text_embed_dim"
1466
+ ]
1467
+
1468
+ if "image" in condition_types:
1469
+ self.image_cond = True
1470
+ self.image_cond_input_channels = self.condition_config[
1471
+ "image_condition_config"
1472
+ ]["image_condition_input_channels"]
1473
+ self.image_cond_output_channels = self.condition_config[
1474
+ "image_condition_config"
1475
+ ]["image_condition_output_channels"]
1476
+
1477
+ if self.class_cond:
1478
+ # For class conditioning, do not add the class embedding information for unconditional generation
1479
+ self.class_emb = nn.Embedding(
1480
+ self.num_classes, self.t_emb_dim
1481
+ ) # (num_classes, t_emb_dim)
1482
+
1483
+ if self.image_cond:
1484
+ # Map the mask image to a image_cond_output_channels channel image, and concat with input across the channel dimension
1485
+ self.cond_conv_in = nn.Conv2d(
1486
+ in_channels=self.image_cond_input_channels,
1487
+ out_channels=self.image_cond_output_channels,
1488
+ kernel_size=1,
1489
+ stride=1,
1490
+ padding=0,
1491
+ bias=False,
1492
+ )
1493
+
1494
+ self.conv_in_concat = nn.Conv2d(
1495
+ in_channels=(image_channels + self.image_cond_output_channels),
1496
+ out_channels=self.down_channels[0],
1497
+ kernel_size=3,
1498
+ stride=1,
1499
+ padding=1,
1500
+ )
1501
+ else:
1502
+ self.conv_in = nn.Conv2d(
1503
+ in_channels=image_channels,
1504
+ out_channels=self.down_channels[0],
1505
+ kernel_size=3,
1506
+ stride=1,
1507
+ padding=1,
1508
+ ) # (batch_size, image_channels, h, w) -> (batch_size, self.down_channels[0], h, w)
1509
+
1510
+ self.cond = self.text_cond or self.image_cond or self.class_cond
1511
+
1512
+ # Initial projection from sinusoidal time embedding
1513
+ self.t_proj = nn.Sequential(
1514
+ nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim),
1515
+ nn.SiLU(),
1516
+ nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim),
1517
+ ) # (batch_size, t_emb_dim)
1518
+
1519
+ self.up_sample = list(reversed(self.down_sample))
1520
+
1521
+ self.downs = nn.ModuleList([])
1522
+ for i in range(len(self.down_channels) - 1):
1523
+ # Cross attention and Context Dim are only used for text conditioning
1524
+ self.downs.append(
1525
+ DownBlock(
1526
+ in_channels=self.down_channels[i],
1527
+ out_channels=self.down_channels[i + 1],
1528
+ t_emb_dim=self.t_emb_dim,
1529
+ down_sample=self.down_sample[i],
1530
+ num_heads=self.num_heads,
1531
+ num_layers=self.num_down_layers,
1532
+ attn=self.attns[i],
1533
+ norm_channels=self.norm_channels,
1534
+ cross_attn=self.text_cond,
1535
+ context_dim=self.text_embed_dim,
1536
+ )
1537
+ )
1538
+
1539
+ self.mids = nn.ModuleList([])
1540
+ for i in range(len(self.mid_channels) - 1):
1541
+ # Cross attention and Context Dim are only used for text conditioning
1542
+ self.mids.append(
1543
+ MidBlock(
1544
+ in_channels=self.mid_channels[i],
1545
+ out_channels=self.mid_channels[i + 1],
1546
+ t_emb_dim=self.t_emb_dim,
1547
+ num_heads=self.num_heads,
1548
+ num_layers=self.num_mid_layers,
1549
+ norm_channels=self.norm_channels,
1550
+ cross_attn=self.text_cond,
1551
+ context_dim=self.text_embed_dim,
1552
+ )
1553
+ )
1554
+
1555
+ self.ups = nn.ModuleList([])
1556
+ for i in reversed(range(len(self.down_channels) - 1)):
1557
+ # Cross attention and Context Dim are only used for text conditioning
1558
+ self.ups.append(
1559
+ UpBlockUNet(
1560
+ in_channels=(self.down_channels[i] * 2),
1561
+ out_channels=(
1562
+ self.down_channels[i - 1] if i != 0 else self.conv_out_channels
1563
+ ),
1564
+ t_emb_dim=self.t_emb_dim,
1565
+ up_sample=self.down_sample[i],
1566
+ num_heads=self.num_heads,
1567
+ num_layers=self.num_up_layers,
1568
+ norm_channels=self.norm_channels,
1569
+ cross_attn=self.text_cond,
1570
+ context_dim=self.text_embed_dim,
1571
+ )
1572
+ )
1573
+
1574
+ self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
1575
+
1576
+ self.conv_out = nn.Conv2d(
1577
+ in_channels=self.conv_out_channels,
1578
+ out_channels=image_channels,
1579
+ kernel_size=3,
1580
+ stride=1,
1581
+ padding=1,
1582
+ ) # (batch_size, conv_out_channels, h, w) -> (batch_size, image_channels, h, w)
1583
+
1584
+ def forward(self, x, t, cond_input=None):
1585
+ # x shape: (batch_size, c, h, w)
1586
+ # cond_input is the conditioning vector
1587
+ # For class conditioning, it will be a one-hot vector of size # (batch_size, num_classes)
1588
+
1589
+ if self.cond:
1590
+ assert (
1591
+ cond_input is not None
1592
+ ), "Model initialized with conditioning so cond_input cannot be None"
1593
+
1594
+ if self.image_cond:
1595
+ # Mask Conditioning
1596
+ validate_image_conditional_input(cond_input, x)
1597
+ image_cond = cond_input["image"]
1598
+ image_cond = F.interpolate(image_cond, size=x.shape[-2:])
1599
+ image_cond = self.cond_conv_in(image_cond)
1600
+ assert image_cond.shape[-2:] == x.shape[-2:]
1601
+
1602
+ x = torch.cat(
1603
+ [x, image_cond], dim=1
1604
+ ) # (batch_size, image_channels + image_cond_output_channels, h, w)
1605
+ out = self.conv_in_concat(x) # (batch_size, down_channels[0], h, w)
1606
+ else:
1607
+ out = self.conv_in(x) # (batch_size, down_channels[0], h, w)
1608
+
1609
+ t_emb = get_time_embedding(
1610
+ torch.as_tensor(t).long(), self.t_emb_dim
1611
+ ) # (batch_size, t_emb_dim)
1612
+ t_emb = self.t_proj(t_emb) # (batch_size, t_emb_dim)
1613
+
1614
+ # Class Conditioning
1615
+ if self.class_cond:
1616
+ validate_class_conditional_input(cond_input, x, self.num_classes)
1617
+
1618
+ # Take the matrix for class embedding vectors and matrix multiply it with the embedding matrix to get the class embedding for all images in a batch
1619
+ class_embed = torch.matmul(
1620
+ cond_input["class"].float(), self.class_emb.weight
1621
+ ) # (batch_size, t_emb_dim)
1622
+ t_emb += class_embed # Add the class embedding to the time embedding
1623
+
1624
+ context_hidden_states = None
1625
+
1626
+ # Only use context hidden states in cross-attention for text conditioning
1627
+ if self.text_cond:
1628
+ assert (
1629
+ "text" in cond_input
1630
+ ), "Model initialized with text conditioning but cond_input has no text information"
1631
+ context_hidden_states = cond_input["text"]
1632
+
1633
+ down_outs = []
1634
+ for idx, down in enumerate(self.downs):
1635
+ down_outs.append(out)
1636
+ out = down(
1637
+ out, t_emb, context_hidden_states
1638
+ ) # Use context_hidden_states for cross-attention
1639
+ # out = (batch_size, c4, h / 4, w / 4)
1640
+
1641
+ for mid in self.mids:
1642
+ out = mid(out, t_emb, context_hidden_states)
1643
+ # out = (batch_size, c3, h / 4, w / 4)
1644
+
1645
+ for up in self.ups:
1646
+ down_out = down_outs.pop()
1647
+ out = up(out, down_out, t_emb, context_hidden_states)
1648
+ # out = (batch_size, self.conv_out_channels, h, w)
1649
+
1650
+ out = F.silu(self.norm_out(out))
1651
+ out = self.conv_out(
1652
+ out
1653
+ ) # (batch_size, self.conv_out_channels, h, w) -> (batch_size, image_channels, h, w)
1654
+
1655
+ return out # (batch_size, image_channels, h, w)
1656
+
1657
+
1658
+ def sample_ddpm_inference(
1659
+ text_prompt, mask_image_pil=None, guidance_scale=1.0, device=torch.device("cpu")
1660
+ ):
1661
+ """
1662
+ Given a text prompt and (optionally) an image condition (as a PIL image),
1663
+ sample from the diffusion model and return a generated image (PIL image).
1664
+ """
1665
+ # Create noise scheduler
1666
+ scheduler = LinearNoiseScheduler(
1667
+ num_timesteps=diffusion_params["num_timesteps"],
1668
+ beta_start=diffusion_params["beta_start"],
1669
+ beta_end=diffusion_params["beta_end"],
1670
+ )
1671
+ # Get conditioning config from ldm_params
1672
+ condition_config = ldm_params.get("condition_config", None)
1673
+ condition_types = (
1674
+ condition_config.get("condition_types", [])
1675
+ if condition_config is not None
1676
+ else []
1677
+ )
1678
+
1679
+ # Load text tokenizer/model for conditioning
1680
+ text_model_type = condition_config["text_condition_config"]["text_embed_model"]
1681
+ text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device)
1682
+
1683
+ # Get empty text representation for classifier-free guidance
1684
+ empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device)
1685
+
1686
+ # Get text representation of the input prompt
1687
+ text_prompt_embed = get_text_representation(
1688
+ [text_prompt], text_tokenizer, text_model, device
1689
+ )
1690
+
1691
+ # Prepare image conditioning:
1692
+ # If the user uploaded a mask image (should be a PIL image), convert it; otherwise, use zeros.
1693
+ if "image" in condition_types:
1694
+ if mask_image_pil is not None:
1695
+ mask_transform = transforms.Compose(
1696
+ [
1697
+ transforms.Resize(
1698
+ (
1699
+ ldm_params["condition_config"]["image_condition_config"][
1700
+ "image_condition_h"
1701
+ ],
1702
+ ldm_params["condition_config"]["image_condition_config"][
1703
+ "image_condition_w"
1704
+ ],
1705
+ )
1706
+ ),
1707
+ transforms.ToTensor(),
1708
+ ]
1709
+ )
1710
+ mask_tensor = (
1711
+ mask_transform(mask_image_pil).unsqueeze(0).to(device)
1712
+ ) # (1, channels, H, W)
1713
+ else:
1714
+ # Create a zero mask with the required number of channels (e.g. 18)
1715
+ ic = ldm_params["condition_config"]["image_condition_config"][
1716
+ "image_condition_input_channels"
1717
+ ]
1718
+ H = ldm_params["condition_config"]["image_condition_config"][
1719
+ "image_condition_h"
1720
+ ]
1721
+ W = ldm_params["condition_config"]["image_condition_config"][
1722
+ "image_condition_w"
1723
+ ]
1724
+ mask_tensor = torch.zeros((1, ic, H, W), device=device)
1725
+ else:
1726
+ mask_tensor = None
1727
+
1728
+ # Build conditioning dictionaries for classifier-free guidance:
1729
+ # For unconditional, we use empty text and zero mask.
1730
+ uncond_input = {}
1731
+ cond_input = {}
1732
+ if "text" in condition_types:
1733
+ uncond_input["text"] = empty_text_embed
1734
+ cond_input["text"] = text_prompt_embed
1735
+ if "image" in condition_types:
1736
+ # Use zeros for unconditioning, and the provided mask for conditioning.
1737
+ uncond_input["image"] = torch.zeros_like(mask_tensor)
1738
+ cond_input["image"] = mask_tensor
1739
+
1740
+ # Load the diffusion UNet (and assume it has been pretrained and saved)
1741
+ unet = UNet(
1742
+ image_channels=autoencoder_params["z_channels"], model_config=ldm_params
1743
+ ).to(device)
1744
+ ldm_checkpoint_path = os.path.join(
1745
+ train_params["task_name"], train_params["ldm_ckpt_name"]
1746
+ )
1747
+ if os.path.exists(ldm_checkpoint_path):
1748
+ checkpoint = torch.load(ldm_checkpoint_path, map_location=device)
1749
+ unet.load_state_dict(checkpoint["model_state_dict"])
1750
+ unet.eval()
1751
+
1752
+ # Load VQVAE (assume pretrained and saved)
1753
+ vae = VQVAE(
1754
+ image_channels=dataset_params["image_channels"], model_config=autoencoder_params
1755
+ ).to(device)
1756
+ vae_checkpoint_path = os.path.join(
1757
+ train_params["task_name"], train_params["vqvae_autoencoder_ckpt_name"]
1758
+ )
1759
+ if os.path.exists(vae_checkpoint_path):
1760
+ checkpoint = torch.load(vae_checkpoint_path, map_location=device)
1761
+ vae.load_state_dict(checkpoint["model_state_dict"])
1762
+ vae.eval()
1763
+
1764
+ # Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
1765
+ # For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
1766
+ latent_size = dataset_params["image_size"] // (
1767
+ 2 ** sum(autoencoder_params["down_sample"])
1768
+ )
1769
+ batch = train_params["num_samples"]
1770
+ z_channels = autoencoder_params["z_channels"]
1771
+
1772
+ # Sample initial latent noise
1773
+ xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device)
1774
+
1775
+ # Sampling loop (reverse diffusion)
1776
+ T = diffusion_params["num_timesteps"]
1777
+ for i in reversed(range(T)):
1778
+ t = torch.full((batch,), i, dtype=torch.long, device=device)
1779
+ # Get conditional noise prediction
1780
+ noise_pred_cond = unet(xt, t, cond_input)
1781
+ if guidance_scale > 1:
1782
+ noise_pred_uncond = unet(xt, t, uncond_input)
1783
+ noise_pred = noise_pred_uncond + guidance_scale * (
1784
+ noise_pred_cond - noise_pred_uncond
1785
+ )
1786
+ else:
1787
+ noise_pred = noise_pred_cond
1788
+ xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t)
1789
+
1790
+ # Finally, decode the final latent using the VQVAE decoder
1791
+ with torch.no_grad():
1792
+ generated = vae.decode(xt)
1793
+ generated = torch.clamp(generated, -1, 1)
1794
+ generated = (generated + 1) / 2 # scale to [0,1]
1795
+ grid = make_grid(generated, nrow=1)
1796
+ pil_img = transforms.ToPILImage()(grid.cpu())
1797
+ return pil_img
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ gradio
5
+ spacy
6
+ datasets
7
+ Pillow