update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- .gitignore +6 -0
- app.py +166 -0
- assets/app_examples/0.png +3 -0
- assets/app_examples/1.png +3 -0
- assets/app_examples/2.png +3 -0
- assets/app_examples/3.png +3 -0
- assets/app_examples/4.png +3 -0
- assets/comparison_of_generation.png +3 -0
- assets/overview.png +3 -0
- evaluations/character_error_rate.py +27 -0
- evaluations/evaluate_images.py +130 -0
- evaluations/ocr.py +44 -0
- evaluations/word_error_rate.py +30 -0
- examples/get_result.py +94 -0
- examples/run.sh +85 -0
- examples/submit.sh +10 -0
- main.py +99 -0
- requirements.txt +21 -0
- src/__init__.py +0 -0
- src/data_loader.py +61 -0
- src/data_processing.py +89 -0
- src/model_processing.py +409 -0
- src/utils.py +47 -0
- src/vaes/gpt_image/gpt_image.py +48 -0
- src/vaes/stable_diffusion/vae.py +23 -0
- src/vqvaes/__init__.py +0 -0
- src/vqvaes/anole/anole.py +706 -0
- src/vqvaes/bsqvit/attention_mask.py +42 -0
- src/vqvaes/bsqvit/bsqvit.py +150 -0
- src/vqvaes/bsqvit/quantizer/bsq.py +223 -0
- src/vqvaes/bsqvit/quantizer/vq.py +152 -0
- src/vqvaes/bsqvit/stylegan_utils/custom_ops.py +126 -0
- src/vqvaes/bsqvit/stylegan_utils/misc.py +40 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cpp +99 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cu +176 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.h +38 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.py +226 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_gradfix.py +170 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_resample.py +155 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cpp +103 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cu +353 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.h +59 -0
- src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.py +382 -0
- src/vqvaes/bsqvit/transformer.py +416 -0
- src/vqvaes/flowmo/flowmo.py +945 -0
- src/vqvaes/flowmo/lookup_free_quantize.py +396 -0
- src/vqvaes/infinity/conv.py +107 -0
- src/vqvaes/infinity/dynamic_resolution.py +147 -0
- src/vqvaes/infinity/flux_vqgan.py +771 -0
.gitattributes
CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/app_examples/0.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/app_examples/1.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/overview.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/app_examples/2.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/app_examples/4.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/comparison_of_generation.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.ckpt
|
3 |
+
checkpoints/
|
4 |
+
results/
|
5 |
+
VTBench_models/
|
6 |
+
README.md
|
app.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import spaces
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
|
6 |
+
# REQUIREMENTS_FILE = "requirements.txt"
|
7 |
+
# if os.path.exists(REQUIREMENTS_FILE):
|
8 |
+
# try:
|
9 |
+
# print("Installing dependencies from requirements.txt...")
|
10 |
+
# subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_FILE])
|
11 |
+
# print("Dependencies installed successfully.")
|
12 |
+
# except subprocess.CalledProcessError as e:
|
13 |
+
# print(f"Failed to install dependencies: {e}")
|
14 |
+
# else:
|
15 |
+
# print("requirements.txt not found.")
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
from src.data_processing import pil_to_tensor, tensor_to_pil
|
19 |
+
from PIL import Image
|
20 |
+
from src.model_processing import get_model
|
21 |
+
from huggingface_hub import snapshot_download
|
22 |
+
import torch
|
23 |
+
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
print(f"Running on: {device}")
|
26 |
+
|
27 |
+
MODEL_DIR = "./VTBench_models"
|
28 |
+
if not os.path.exists(MODEL_DIR):
|
29 |
+
print("Downloading VTBench_models from Hugging Face...")
|
30 |
+
snapshot_download(
|
31 |
+
repo_id="huaweilin/VTBench_models",
|
32 |
+
local_dir=MODEL_DIR,
|
33 |
+
local_dir_use_symlinks=False
|
34 |
+
)
|
35 |
+
print("Download complete.")
|
36 |
+
|
37 |
+
example_image_paths = [f"assets/app_examples/{i}.png" for i in range(0, 5)]
|
38 |
+
|
39 |
+
model_name_mapping = {
|
40 |
+
"SD3.5L": "SD3.5L",
|
41 |
+
"chameleon": "Chameleon",
|
42 |
+
# "flowmo_lo": "FlowMo Lo",
|
43 |
+
# "flowmo_hi": "FlowMo Hi",
|
44 |
+
# "gpt4o": "GPT-4o",
|
45 |
+
"janus_pro_1b": "Janus Pro 1B/7B",
|
46 |
+
# "llamagen-ds8": "LlamaGen ds8",
|
47 |
+
# "llamagen-ds16": "LlamaGen ds16",
|
48 |
+
# "llamagen-ds16-t2i": "LlamaGen ds16 T2I",
|
49 |
+
# "maskbit_16bit": "MaskBiT 16bit",
|
50 |
+
# "maskbit_18bit": "MaskBiT 18bit",
|
51 |
+
# "open_magvit2": "OpenMagViT",
|
52 |
+
# "titok_b64": "Titok-b64",
|
53 |
+
# "titok_bl64": "Titok-bl64",
|
54 |
+
# "titok_s128": "Titok-s128",
|
55 |
+
# "titok_bl128": "Titok-bl128",
|
56 |
+
# "titok_l32": "Titok-l32",
|
57 |
+
# "titok_sl256": "Titok-sl256",
|
58 |
+
# "var_256": "VAR-256",
|
59 |
+
# "var_512": "VAR-512",
|
60 |
+
# "FLUX.1-dev": "FLUX.1-dev",
|
61 |
+
# "infinity_d32": "Infinity-d32",
|
62 |
+
# "infinity_d64": "Infinity-d64",
|
63 |
+
# "bsqvit": "BSQ-VIT",
|
64 |
+
}
|
65 |
+
|
66 |
+
def load_model(model_name):
|
67 |
+
model, data_params = get_model(MODEL_DIR, model_name)
|
68 |
+
model = model.to(device)
|
69 |
+
model.eval()
|
70 |
+
return model, data_params
|
71 |
+
|
72 |
+
model_dict = {
|
73 |
+
model_name: load_model(model_name)
|
74 |
+
for model_name in model_name_mapping
|
75 |
+
}
|
76 |
+
|
77 |
+
placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0))
|
78 |
+
|
79 |
+
@spaces.GPU
|
80 |
+
def process_selected_models(uploaded_image, selected_models):
|
81 |
+
results = []
|
82 |
+
for model_name in model_name_mapping:
|
83 |
+
if uploaded_image is None:
|
84 |
+
results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (No input)"))
|
85 |
+
elif model_name in selected_models:
|
86 |
+
try:
|
87 |
+
model, data_params = model_dict[model_name]
|
88 |
+
pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
|
89 |
+
output = model(pixel_values)[0]
|
90 |
+
reconstructed_image = tensor_to_pil(output[0].cpu(), **data_params)
|
91 |
+
results.append(gr.update(value=reconstructed_image, label=model_name_mapping[model_name]))
|
92 |
+
except Exception as e:
|
93 |
+
print(f"Error in model {model_name}: {e}")
|
94 |
+
results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Error)"))
|
95 |
+
else:
|
96 |
+
results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Not selected)"))
|
97 |
+
return results
|
98 |
+
|
99 |
+
with gr.Blocks() as demo:
|
100 |
+
gr.Markdown("## VTBench")
|
101 |
+
|
102 |
+
gr.Markdown("---")
|
103 |
+
|
104 |
+
image_input = gr.Image(
|
105 |
+
type="pil",
|
106 |
+
label="Upload an image",
|
107 |
+
width=512,
|
108 |
+
height=512,
|
109 |
+
)
|
110 |
+
|
111 |
+
gr.Markdown("### Click on an example image to use it as input:")
|
112 |
+
example_rows = [example_image_paths[i:i+5] for i in range(0, len(example_image_paths), 5)]
|
113 |
+
for row in example_rows:
|
114 |
+
with gr.Row():
|
115 |
+
for path in row:
|
116 |
+
ex_img = gr.Image(
|
117 |
+
value=path,
|
118 |
+
show_label=False,
|
119 |
+
interactive=True,
|
120 |
+
width=256,
|
121 |
+
height=256,
|
122 |
+
)
|
123 |
+
|
124 |
+
def make_loader(p=path):
|
125 |
+
def load_img():
|
126 |
+
return Image.open(p)
|
127 |
+
return load_img
|
128 |
+
|
129 |
+
ex_img.select(fn=make_loader(), outputs=image_input)
|
130 |
+
|
131 |
+
gr.Markdown("---")
|
132 |
+
|
133 |
+
gr.Markdown("⚠️ **The more models you select, the longer the processing time will be.**")
|
134 |
+
model_selector = gr.CheckboxGroup(
|
135 |
+
choices=list(model_name_mapping.keys()),
|
136 |
+
label="Select models to run",
|
137 |
+
value=["SD3.5L", "chameleon", "janus_pro_1b"],
|
138 |
+
interactive=True,
|
139 |
+
)
|
140 |
+
run_button = gr.Button("Start Processing")
|
141 |
+
|
142 |
+
image_outputs = []
|
143 |
+
model_items = list(model_name_mapping.items())
|
144 |
+
|
145 |
+
n_columns = 5
|
146 |
+
output_rows = [model_items[i:i+n_columns] for i in range(0, len(model_items), n_columns)]
|
147 |
+
|
148 |
+
with gr.Column():
|
149 |
+
for row in output_rows:
|
150 |
+
with gr.Row():
|
151 |
+
for model_name, display_name in row:
|
152 |
+
out_img = gr.Image(
|
153 |
+
label=f"{display_name} (Not run)",
|
154 |
+
value=placeholder_image,
|
155 |
+
width=512,
|
156 |
+
height=512,
|
157 |
+
)
|
158 |
+
image_outputs.append(out_img)
|
159 |
+
|
160 |
+
run_button.click(
|
161 |
+
fn=process_selected_models,
|
162 |
+
inputs=[image_input, model_selector],
|
163 |
+
outputs=image_outputs
|
164 |
+
)
|
165 |
+
|
166 |
+
demo.launch()
|
assets/app_examples/0.png
ADDED
![]() |
Git LFS Details
|
assets/app_examples/1.png
ADDED
![]() |
Git LFS Details
|
assets/app_examples/2.png
ADDED
![]() |
Git LFS Details
|
assets/app_examples/3.png
ADDED
![]() |
Git LFS Details
|
assets/app_examples/4.png
ADDED
![]() |
Git LFS Details
|
assets/comparison_of_generation.png
ADDED
![]() |
Git LFS Details
|
assets/overview.png
ADDED
![]() |
Git LFS Details
|
evaluations/character_error_rate.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchmetrics import Metric
|
3 |
+
from ocr import OCR
|
4 |
+
import Levenshtein
|
5 |
+
|
6 |
+
|
7 |
+
class CharacterErrorRate(Metric):
|
8 |
+
def __init__(self, ocr, dist_sync_on_step=False):
|
9 |
+
# super().__init__(dist_sync_on_step=dist_sync_on_step)
|
10 |
+
super().__init__()
|
11 |
+
self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
12 |
+
self.add_state("total_chars", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
13 |
+
self.ocr = ocr
|
14 |
+
|
15 |
+
def update(self, pred_images, target_images):
|
16 |
+
for pred_img, target_img in zip(pred_images, target_images):
|
17 |
+
pred_text = self.ocr.predict(pred_img)
|
18 |
+
target_text = self.ocr.predict(target_img)
|
19 |
+
|
20 |
+
dist = Levenshtein.distance(pred_text, target_text)
|
21 |
+
self.total_errors += dist
|
22 |
+
self.total_chars += len(target_text)
|
23 |
+
|
24 |
+
def compute(self):
|
25 |
+
if self.total_chars == 0:
|
26 |
+
return torch.tensor(0.0)
|
27 |
+
return self.total_errors / self.total_chars
|
evaluations/evaluate_images.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
from torchvision import transforms
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from ocr import OCR
|
10 |
+
from character_error_rate import CharacterErrorRate
|
11 |
+
from word_error_rate import WordErrorRate
|
12 |
+
from torchmetrics.image import (
|
13 |
+
PeakSignalNoiseRatio,
|
14 |
+
StructuralSimilarityIndexMeasure,
|
15 |
+
LearnedPerceptualImagePatchSimilarity,
|
16 |
+
FrechetInceptionDistance,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class ImageFolderPairDataset(Dataset):
|
21 |
+
def __init__(self, dir1, dir2, transform=None):
|
22 |
+
self.dir1 = dir1
|
23 |
+
self.dir2 = dir2
|
24 |
+
self.filenames = sorted(os.listdir(dir1))
|
25 |
+
self.transform = transform
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.filenames)
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
name = self.filenames[idx]
|
32 |
+
img1 = Image.open(os.path.join(self.dir1, name)).convert("RGB")
|
33 |
+
img2 = Image.open(os.path.join(self.dir2, name)).convert("RGB")
|
34 |
+
if self.transform:
|
35 |
+
img1 = self.transform(img1)
|
36 |
+
img2 = self.transform(img2)
|
37 |
+
return img1, img2
|
38 |
+
|
39 |
+
|
40 |
+
def evaluate(args):
|
41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
+
print(f"Using device: {device}")
|
43 |
+
|
44 |
+
transform = transforms.Compose(
|
45 |
+
[transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()]
|
46 |
+
)
|
47 |
+
|
48 |
+
dataset = ImageFolderPairDataset(
|
49 |
+
args.original_dir, args.reconstructed_dir, transform
|
50 |
+
)
|
51 |
+
loader = DataLoader(
|
52 |
+
dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
|
53 |
+
)
|
54 |
+
|
55 |
+
if "cer" in args.metrics or "wer" in args.metrics:
|
56 |
+
ocr = OCR(device)
|
57 |
+
|
58 |
+
# Metrics init
|
59 |
+
metrics = {}
|
60 |
+
|
61 |
+
if "psnr" in args.metrics:
|
62 |
+
metrics["psnr"] = PeakSignalNoiseRatio().to(device)
|
63 |
+
if "ssim" in args.metrics:
|
64 |
+
metrics["ssim"] = StructuralSimilarityIndexMeasure().to(device)
|
65 |
+
if "lpips" in args.metrics:
|
66 |
+
metrics["lpips"] = LearnedPerceptualImagePatchSimilarity().to(device)
|
67 |
+
if "fid" in args.metrics:
|
68 |
+
metrics["fid"] = FrechetInceptionDistance().to(device)
|
69 |
+
if "cer" in args.metrics:
|
70 |
+
metrics["cer"] = CharacterErrorRate(ocr)
|
71 |
+
if "wer" in args.metrics:
|
72 |
+
metrics["wer"] = WordErrorRate(ocr)
|
73 |
+
|
74 |
+
for batch in tqdm(loader, desc="Evaluating"):
|
75 |
+
# img1, img1_path, img2, img2_path = [b.to(device) for b in batch]
|
76 |
+
img1, img2 = [b.to(device) for b in batch]
|
77 |
+
|
78 |
+
if "psnr" in metrics:
|
79 |
+
metrics["psnr"].update(img2, img1)
|
80 |
+
if "ssim" in metrics:
|
81 |
+
metrics["ssim"].update(img2, img1)
|
82 |
+
if "lpips" in metrics:
|
83 |
+
metrics["lpips"].update(img2, img1)
|
84 |
+
if "cer" in metrics:
|
85 |
+
metrics["cer"].update(img2, img1)
|
86 |
+
if "wer" in metrics:
|
87 |
+
metrics["wer"].update(img2, img1)
|
88 |
+
if "fid" in metrics:
|
89 |
+
img1_uint8 = (img1 * 255).clamp(0, 255).to(torch.uint8)
|
90 |
+
img2_uint8 = (img2 * 255).clamp(0, 255).to(torch.uint8)
|
91 |
+
metrics["fid"].update(img1_uint8, real=True)
|
92 |
+
metrics["fid"].update(img2_uint8, real=False)
|
93 |
+
|
94 |
+
print("\nResults:")
|
95 |
+
for name, metric in metrics.items():
|
96 |
+
print(f"{name.upper()}", end="\t")
|
97 |
+
print()
|
98 |
+
for name, metric in metrics.items():
|
99 |
+
result = metric.compute().item()
|
100 |
+
print(f"{result:.4f}", end="\t")
|
101 |
+
print()
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
parser = argparse.ArgumentParser()
|
106 |
+
parser.add_argument(
|
107 |
+
"--original_dir", type=str, required=True, help="Path to original images"
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--reconstructed_dir",
|
111 |
+
type=str,
|
112 |
+
required=True,
|
113 |
+
help="Path to reconstructed images",
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--metrics",
|
117 |
+
nargs="+",
|
118 |
+
default=["psnr", "ssim", "lpips", "fid"],
|
119 |
+
help="Metrics to compute: psnr, ssim, lpips, fid",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--batch_size", type=int, default=8, help="Batch size for processing"
|
123 |
+
)
|
124 |
+
parser.add_argument("--image_size", type=int, default=256, help="Image resize size")
|
125 |
+
parser.add_argument(
|
126 |
+
"--num_workers", type=int, default=4, help="Number of workers for DataLoader"
|
127 |
+
)
|
128 |
+
args = parser.parse_args()
|
129 |
+
|
130 |
+
evaluate(args)
|
evaluations/ocr.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class OCR:
|
7 |
+
def __init__(self, device="cpu"):
|
8 |
+
self.device = torch.device(device)
|
9 |
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
10 |
+
"google/gemma-3-12b-it",
|
11 |
+
torch_dtype=torch.bfloat16,
|
12 |
+
).to(self.device)
|
13 |
+
self.processor = AutoProcessor.from_pretrained("google/gemma-3-12b-it")
|
14 |
+
|
15 |
+
self.messages = [
|
16 |
+
{
|
17 |
+
"role": "user",
|
18 |
+
"content": [
|
19 |
+
{"type": "image"},
|
20 |
+
{
|
21 |
+
"type": "text",
|
22 |
+
"text": "Extract and output only the text from the image in its original language. If there is no text, return nothing.",
|
23 |
+
},
|
24 |
+
],
|
25 |
+
},
|
26 |
+
]
|
27 |
+
|
28 |
+
def predict(self, image):
|
29 |
+
image = (
|
30 |
+
(image * 255).clamp(0, 255).to(torch.uint8).permute((1, 2, 0)).cpu().numpy()
|
31 |
+
)
|
32 |
+
image = Image.fromarray(image).convert("RGB").resize((1024, 1024))
|
33 |
+
prompt = self.processor.apply_chat_template(
|
34 |
+
self.messages, add_generation_prompt=True
|
35 |
+
)
|
36 |
+
inputs = self.processor(text=prompt, images=[image], return_tensors="pt").to(
|
37 |
+
self.device
|
38 |
+
)
|
39 |
+
with torch.no_grad():
|
40 |
+
generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
|
41 |
+
generated_text = self.processor.batch_decode(
|
42 |
+
generated_ids[:, inputs.input_ids.shape[-1] :], skip_special_tokens=True
|
43 |
+
)[0]
|
44 |
+
return generated_text
|
evaluations/word_error_rate.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchmetrics import Metric
|
3 |
+
import Levenshtein
|
4 |
+
|
5 |
+
|
6 |
+
class WordErrorRate(Metric):
|
7 |
+
def __init__(self, ocr, dist_sync_on_step=False):
|
8 |
+
# super().__init__(dist_sync_on_step=dist_sync_on_step)
|
9 |
+
super().__init__()
|
10 |
+
self.ocr = ocr
|
11 |
+
self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
12 |
+
self.add_state("total_words", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
13 |
+
|
14 |
+
def update(self, pred_images, target_images):
|
15 |
+
for pred_img, target_img in zip(pred_images, target_images):
|
16 |
+
pred_text = self.ocr.predict(pred_img)
|
17 |
+
target_text = self.ocr.predict(target_img)
|
18 |
+
|
19 |
+
pred_words = pred_text.strip().split()
|
20 |
+
target_words = target_text.strip().split()
|
21 |
+
|
22 |
+
dist = Levenshtein.distance(" ".join(pred_words), " ".join(target_words))
|
23 |
+
|
24 |
+
self.total_errors += dist
|
25 |
+
self.total_words += len(target_words)
|
26 |
+
|
27 |
+
def compute(self):
|
28 |
+
if self.total_words == 0:
|
29 |
+
return torch.tensor(0.0)
|
30 |
+
return self.total_errors / self.total_words
|
examples/get_result.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
root_dir = "./"
|
5 |
+
|
6 |
+
model_name_mapping = {
|
7 |
+
"flowmo_lo": "FlowMo Lo",
|
8 |
+
"flowmo_hi": "FlowMo Hi",
|
9 |
+
"gpt4o": "GPT-4o",
|
10 |
+
"janus_pro_1b": "Janus Pro 1B/7B",
|
11 |
+
"llamagen-ds8": "LlamaGen ds8",
|
12 |
+
"llamagen-ds16": "LlamaGen ds16",
|
13 |
+
"llamagen-ds16-t2i": "LlamaGen ds16 T2I",
|
14 |
+
"maskbit_16bit": "MaskBiT 16bit",
|
15 |
+
"maskbit_18bit": "MaskBiT 18bit",
|
16 |
+
"open_magvit2": "OpenMagViT",
|
17 |
+
"titok_b64": "Titok-b64",
|
18 |
+
"titok_bl64": "Titok-bl64",
|
19 |
+
"titok_s128": "Titok-s128",
|
20 |
+
"titok_bl128": "Titok-bl128",
|
21 |
+
"titok_l32": "Titok-l32",
|
22 |
+
"titok_sl256": "Titok-sl256",
|
23 |
+
"var_256": "VAR-256",
|
24 |
+
"var_512": "VAR-512",
|
25 |
+
"SD3.5L": "SD3.5L",
|
26 |
+
"FLUX.1-dev": "FLUX.1-dev",
|
27 |
+
"infinity_d32": "Infinity-d32",
|
28 |
+
"infinity_d64": "Infinity-d64",
|
29 |
+
"chameleon": "Chameleon",
|
30 |
+
"bsqvit": "BSQ-VIT",
|
31 |
+
}
|
32 |
+
|
33 |
+
output_order = [
|
34 |
+
"FlowMo Lo",
|
35 |
+
"FlowMo Hi",
|
36 |
+
"MaskBiT 16bit",
|
37 |
+
"MaskBiT 18bit",
|
38 |
+
"Titok-l32",
|
39 |
+
"Titok-b64",
|
40 |
+
"Titok-s128",
|
41 |
+
"Titok-bl64",
|
42 |
+
"Titok-bl128",
|
43 |
+
"Titok-sl256",
|
44 |
+
"OpenMagViT",
|
45 |
+
"LlamaGen ds8",
|
46 |
+
"BSQ-VIT",
|
47 |
+
"VAR-256",
|
48 |
+
"Janus Pro 1B/7B",
|
49 |
+
"Chameleon",
|
50 |
+
"LlamaGen ds16",
|
51 |
+
"LlamaGen ds16 T2I",
|
52 |
+
"VAR-512",
|
53 |
+
"Infinity-d32",
|
54 |
+
"Infinity-d64",
|
55 |
+
"SD3.5L",
|
56 |
+
"FLUX.1-dev",
|
57 |
+
"GPT-4o",
|
58 |
+
]
|
59 |
+
|
60 |
+
for dataset_name in os.listdir(root_dir):
|
61 |
+
dataset_path = os.path.join(root_dir, dataset_name)
|
62 |
+
if not os.path.isdir(dataset_path):
|
63 |
+
continue
|
64 |
+
|
65 |
+
results = {}
|
66 |
+
|
67 |
+
for model_dir in os.listdir(dataset_path):
|
68 |
+
model_path = os.path.join(dataset_path, model_dir)
|
69 |
+
result_file = os.path.join(model_path, "result.txt")
|
70 |
+
|
71 |
+
if os.path.isfile(result_file):
|
72 |
+
with open(result_file, "r", encoding="utf-8") as f:
|
73 |
+
lines = f.readlines()
|
74 |
+
|
75 |
+
if len(lines) >= 2:
|
76 |
+
metrics_line = lines[-2].strip()
|
77 |
+
values_line = lines[-1].strip()
|
78 |
+
|
79 |
+
metrics = metrics_line.split()
|
80 |
+
values = values_line.split()
|
81 |
+
|
82 |
+
mapped_name = model_name_mapping.get(model_dir, model_dir)
|
83 |
+
results[mapped_name] = values
|
84 |
+
|
85 |
+
if results:
|
86 |
+
header = "\t".join(metrics)
|
87 |
+
print(f"{dataset_name}\t{header}")
|
88 |
+
for model_name in output_order:
|
89 |
+
if model_name in results:
|
90 |
+
values = results[model_name]
|
91 |
+
print(f"{model_name}\t" + "\t".join(values))
|
92 |
+
else:
|
93 |
+
print(f"{model_name}\t" + "no result")
|
94 |
+
print()
|
examples/run.sh
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset_name_list=("task1-imagenet" "task1-high-resolution" "task1-varying-resolution" "task2-detail-preservation" "task3-movie-posters" "task3-arxiv-abstracts" "task3-multilingual_Chinese" "task3-multilingual_Hindi" "task3-multilingual_Japanese" "task3-multilingual_Korean")
|
4 |
+
model_name_list=("chameleon" "llamagen-ds16" "llamagen-ds8" "flowmo_lo" "flowmo_hi" "open_magvit2" "titok_l32" "titok_b64" "titok_s128" "titok_bl64" "titok_bl128" "titok_sl256" "janus_pro_1b" "maskbit_18bit" "maskbit_16bit" "var_256" "var_512" "SD3.5L" "gpt4o" "llamagen-ds16-t2i" "infinity_d32" "infinity_d64" "bsqvit" "FLUX.1-dev")
|
5 |
+
|
6 |
+
batch_size=1
|
7 |
+
|
8 |
+
if command -v sbatch >/dev/null 2>&1; then
|
9 |
+
has_slurm=true
|
10 |
+
else
|
11 |
+
has_slurm=false
|
12 |
+
fi
|
13 |
+
|
14 |
+
shell_dir=$(cd "$(dirname "$0")";pwd)
|
15 |
+
echo "shell_dir: ${shell_dir}"
|
16 |
+
base_path="${shell_dir}/../"
|
17 |
+
|
18 |
+
for dataset_name in "${dataset_name_list[@]}"
|
19 |
+
do
|
20 |
+
cd ${shell_dir}
|
21 |
+
folder_dir="${dataset_name}"
|
22 |
+
mkdir ${folder_dir}
|
23 |
+
|
24 |
+
metrics="fid ssim psnr lpips"
|
25 |
+
split_name="test"
|
26 |
+
n_take=-1
|
27 |
+
|
28 |
+
if [[ $dataset_name == task3-multilingual_* ]]; then
|
29 |
+
split_name="${dataset_name##*_}"
|
30 |
+
dataset_name="${dataset_name%_*}"
|
31 |
+
fi
|
32 |
+
if [ "$dataset_name" = "task1-imagenet" ]; then
|
33 |
+
split_name="val"
|
34 |
+
fi
|
35 |
+
|
36 |
+
if [ "$dataset_name" = "task1-varying-resolution" ]; then
|
37 |
+
batch_size=1
|
38 |
+
fi
|
39 |
+
if [ "$dataset_name" = "task3-movie-posters" ]; then
|
40 |
+
metrics="fid ssim psnr lpips cer wer"
|
41 |
+
fi
|
42 |
+
if [ "$dataset_name" = "task3-arxiv-abstracts" ]; then
|
43 |
+
metrics="fid ssim psnr lpips cer wer"
|
44 |
+
fi
|
45 |
+
if [ "$dataset_name" = "task3-multilingual" ]; then
|
46 |
+
metrics="fid ssim psnr lpips cer"
|
47 |
+
fi
|
48 |
+
|
49 |
+
for model_name in "${model_name_list[@]}"
|
50 |
+
do
|
51 |
+
if [ "$dataset_name" = "task1-imagenet" ] && [ "$model_name" = "gpt4o" ]; then
|
52 |
+
n_take=100
|
53 |
+
fi
|
54 |
+
cd ${shell_dir}
|
55 |
+
|
56 |
+
work_dir="${folder_dir}/${model_name}"
|
57 |
+
echo "model_name: ${model_name}, work_dir: ${work_dir}"
|
58 |
+
mkdir ${work_dir}
|
59 |
+
|
60 |
+
cp submit.sh ${work_dir}
|
61 |
+
|
62 |
+
cd ${work_dir}
|
63 |
+
sed -i "s|{model_name}|${model_name}|g" submit.sh
|
64 |
+
sed -i "s|{split_name}|${split_name}|g" submit.sh
|
65 |
+
sed -i "s|{dataset_name}|${dataset_name}|g" submit.sh
|
66 |
+
sed -i "s|{batch_size}|${batch_size}|g" submit.sh
|
67 |
+
sed -i "s|{base_path}|${base_path}|g" submit.sh
|
68 |
+
sed -i "s|{metrics}|${metrics}|g" submit.sh
|
69 |
+
sed -i "s|{n_take}|${n_take}|g" submit.sh
|
70 |
+
|
71 |
+
# if [ "$has_slurm" = true ]; then
|
72 |
+
# res=$(sbatch ./submit.sh)
|
73 |
+
# res=($res)
|
74 |
+
# task_id=${res[-1]}
|
75 |
+
# echo "task_id: ${task_id}"
|
76 |
+
# touch "task_id_${task_id}"
|
77 |
+
# else
|
78 |
+
# echo "Slurm not detected, running with bash..."
|
79 |
+
# bash ./submit.sh
|
80 |
+
# fi
|
81 |
+
|
82 |
+
bash ./submit.sh
|
83 |
+
|
84 |
+
done
|
85 |
+
done
|
examples/submit.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Put your slurm commands here
|
3 |
+
|
4 |
+
accelerate launch --num_processes=1 {base_path}/main.py --batch_size {batch_size} --model_name {model_name} --split_name {split_name} --dataset_name {dataset_name} --output_dir {model_name}_results --n_take {n_take}
|
5 |
+
python {base_path}/evaluations/evaluate_images.py \
|
6 |
+
--original_dir {model_name}_results/original_images \
|
7 |
+
--reconstructed_dir {model_name}_results/reconstructed_images/ \
|
8 |
+
--metrics {metrics} \
|
9 |
+
--batch_size 16 \
|
10 |
+
--num_workers 8 | tee result.txt
|
main.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import PIL
|
4 |
+
import pickle
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
from PIL import Image
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch
|
11 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText
|
12 |
+
from src.data_loader import DataCollatorForSupervisedDataset, get_dataset
|
13 |
+
from src.data_processing import tensor_to_pil
|
14 |
+
from src.model_processing import get_model
|
15 |
+
from PIL import Image
|
16 |
+
from accelerate import Accelerator
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from tqdm import tqdm
|
19 |
+
from concurrent.futures import ThreadPoolExecutor
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument("--model_name", type=str, default="chameleon")
|
23 |
+
parser.add_argument("--model_path", type=str, default=None)
|
24 |
+
parser.add_argument("--dataset_name", type=str, default="task3-movie-posters")
|
25 |
+
parser.add_argument("--split_name", type=str, default="test")
|
26 |
+
parser.add_argument("--batch_size", default=8, type=int)
|
27 |
+
parser.add_argument("--output_dir", type=str, default=None)
|
28 |
+
parser.add_argument("--begin_id", default=0, type=int)
|
29 |
+
parser.add_argument("--n_take", default=-1, type=int)
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
batch_size = args.batch_size
|
33 |
+
output_dir = args.output_dir
|
34 |
+
|
35 |
+
accelerator = Accelerator()
|
36 |
+
|
37 |
+
if accelerator.is_main_process and output_dir is not None:
|
38 |
+
os.makedirs(output_dir, exist_ok=True)
|
39 |
+
os.makedirs(f"{output_dir}/original_images", exist_ok=True)
|
40 |
+
os.makedirs(f"{output_dir}/reconstructed_images", exist_ok=True)
|
41 |
+
os.makedirs(f"{output_dir}/results", exist_ok=True)
|
42 |
+
|
43 |
+
model, data_params = get_model(args.model_path, args.model_name)
|
44 |
+
dataset = get_dataset(args.dataset_name, args.split_name, None if args.n_take <= 0 else args.n_take)
|
45 |
+
data_collator = DataCollatorForSupervisedDataset(args.dataset_name, **data_params)
|
46 |
+
dataloader = DataLoader(
|
47 |
+
dataset, batch_size=batch_size, num_workers=0, collate_fn=data_collator
|
48 |
+
)
|
49 |
+
|
50 |
+
model, dataloader = accelerator.prepare(model, dataloader)
|
51 |
+
print("Model prepared...")
|
52 |
+
|
53 |
+
|
54 |
+
def save_results(
|
55 |
+
pixel_values, reconstructed_image, idx, output_dir, data_params
|
56 |
+
):
|
57 |
+
if reconstructed_image is None:
|
58 |
+
return
|
59 |
+
|
60 |
+
ori_img = tensor_to_pil(pixel_values, **data_params)
|
61 |
+
rec_img = tensor_to_pil(reconstructed_image, **data_params)
|
62 |
+
|
63 |
+
ori_img.save(f"{output_dir}/original_images/{idx:08d}.png")
|
64 |
+
rec_img.save(f"{output_dir}/reconstructed_images/{idx:08d}.png")
|
65 |
+
|
66 |
+
result = {
|
67 |
+
"ori_img": ori_img,
|
68 |
+
"rec_img": rec_img,
|
69 |
+
}
|
70 |
+
|
71 |
+
with open(f"{output_dir}/results/{idx:08d}.pickle", "wb") as fw:
|
72 |
+
pickle.dump(result, fw)
|
73 |
+
|
74 |
+
|
75 |
+
executor = ThreadPoolExecutor(max_workers=16)
|
76 |
+
with torch.no_grad():
|
77 |
+
print("Begin data loading...")
|
78 |
+
for batch in tqdm(dataloader):
|
79 |
+
pixel_values = batch["image"]
|
80 |
+
reconstructed_images = model(pixel_values)
|
81 |
+
if isinstance(reconstructed_images, tuple):
|
82 |
+
reconstructed_images = reconstructed_images[0]
|
83 |
+
|
84 |
+
if output_dir is not None:
|
85 |
+
idx_list = batch["idx"]
|
86 |
+
original_images = pixel_values.detach().cpu()
|
87 |
+
if not isinstance(reconstructed_images, list):
|
88 |
+
reconstructed_images = reconstructed_images.detach().cpu()
|
89 |
+
for i in range(pixel_values.shape[0]):
|
90 |
+
executor.submit(
|
91 |
+
save_results,
|
92 |
+
original_images[i],
|
93 |
+
reconstructed_images[i],
|
94 |
+
idx_list[i],
|
95 |
+
output_dir,
|
96 |
+
data_params,
|
97 |
+
)
|
98 |
+
|
99 |
+
executor.shutdown(wait=True)
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
mup==1.0.0
|
3 |
+
einops
|
4 |
+
omegaconf
|
5 |
+
lightning==2.3.3
|
6 |
+
piq
|
7 |
+
python-Levenshtein
|
8 |
+
verovio
|
9 |
+
pytorch_fid
|
10 |
+
transformers
|
11 |
+
torch-fidelity
|
12 |
+
accelerate
|
13 |
+
datasets
|
14 |
+
git+https://github.com/deepseek-ai/Janus.git
|
15 |
+
diffusers
|
16 |
+
openai
|
17 |
+
imageio
|
18 |
+
huggingface_hub
|
19 |
+
gradio
|
20 |
+
torch
|
21 |
+
torchvision
|
src/__init__.py
ADDED
File without changes
|
src/data_loader.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
from PIL import Image
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from datasets import load_dataset
|
5 |
+
import torch
|
6 |
+
from .data_processing import pil_to_tensor
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class DataCollatorForSupervisedDataset(object):
|
11 |
+
"""Collate examples for supervised fine-tuning."""
|
12 |
+
|
13 |
+
def __init__(self, dataset_name, **kwargs):
|
14 |
+
override_params = {}
|
15 |
+
if dataset_name == "DIV2K":
|
16 |
+
override_params = {
|
17 |
+
"target_image_size": -1,
|
18 |
+
"lock_ratio": True,
|
19 |
+
"center_crop": False,
|
20 |
+
"padding": False,
|
21 |
+
}
|
22 |
+
if dataset_name == "imagenet":
|
23 |
+
override_params = {"center_crop": True, "padding": False}
|
24 |
+
if dataset_name == "movie_posters":
|
25 |
+
override_params = {"center_crop": True, "padding": False}
|
26 |
+
if dataset_name == "high_quality_1024":
|
27 |
+
override_params = {"target_image_size": (1024, 1024)}
|
28 |
+
|
29 |
+
self.data_params = {**kwargs, **override_params}
|
30 |
+
|
31 |
+
def __call__(self, instances):
|
32 |
+
images = torch.stack(
|
33 |
+
[
|
34 |
+
pil_to_tensor(instance["image"], **self.data_params)
|
35 |
+
for instance in instances
|
36 |
+
],
|
37 |
+
dim=0,
|
38 |
+
)
|
39 |
+
idx = [instance["idx"] for instance in instances]
|
40 |
+
return dict(image=images, idx=idx)
|
41 |
+
|
42 |
+
|
43 |
+
class ImagenetDataset(torch.utils.data.Dataset):
|
44 |
+
def __init__(self, dataset_name, split_name="test", n_take=None):
|
45 |
+
print(dataset_name, split_name)
|
46 |
+
ds = load_dataset("huaweilin/VTBench", name=dataset_name, split=split_name if n_take is None else f"{split_name}[:{n_take}]")
|
47 |
+
self.image_list = ds["image"]
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return len(self.image_list)
|
51 |
+
|
52 |
+
def __getitem__(self, idx):
|
53 |
+
return dict(
|
54 |
+
image=self.image_list[idx],
|
55 |
+
idx=idx,
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
def get_dataset(dataset_name, split_name, n_take):
|
60 |
+
dataset = ImagenetDataset(dataset_name, split_name, n_take)
|
61 |
+
return dataset
|
src/data_processing.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import PIL
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def pil_to_tensor(
|
8 |
+
img: Image.Image,
|
9 |
+
target_image_size=512,
|
10 |
+
lock_ratio=True,
|
11 |
+
center_crop=True,
|
12 |
+
padding=False,
|
13 |
+
standardize=True,
|
14 |
+
**kwarg
|
15 |
+
) -> torch.Tensor:
|
16 |
+
if img.mode != "RGB":
|
17 |
+
img = img.convert("RGB")
|
18 |
+
|
19 |
+
if isinstance(target_image_size, int):
|
20 |
+
target_size = (target_image_size, target_image_size)
|
21 |
+
if target_image_size < 0:
|
22 |
+
target_size = img.size
|
23 |
+
else:
|
24 |
+
target_size = target_image_size # (width, height)
|
25 |
+
|
26 |
+
if lock_ratio:
|
27 |
+
original_width, original_height = img.size
|
28 |
+
target_width, target_height = target_size
|
29 |
+
|
30 |
+
scale_w = target_width / original_width
|
31 |
+
scale_h = target_height / original_height
|
32 |
+
|
33 |
+
if center_crop:
|
34 |
+
scale = max(scale_w, scale_h)
|
35 |
+
elif padding:
|
36 |
+
scale = min(scale_w, scale_h)
|
37 |
+
else:
|
38 |
+
scale = 1.0 # fallback
|
39 |
+
|
40 |
+
new_size = (round(original_width * scale), round(original_height * scale))
|
41 |
+
img = img.resize(new_size, Image.LANCZOS)
|
42 |
+
|
43 |
+
if center_crop:
|
44 |
+
left = (img.width - target_width) // 2
|
45 |
+
top = (img.height - target_height) // 2
|
46 |
+
img = img.crop((left, top, left + target_width, top + target_height))
|
47 |
+
elif padding:
|
48 |
+
new_img = Image.new("RGB", target_size, (0, 0, 0))
|
49 |
+
left = (target_width - img.width) // 2
|
50 |
+
top = (target_height - img.height) // 2
|
51 |
+
new_img.paste(img, (left, top))
|
52 |
+
img = new_img
|
53 |
+
else:
|
54 |
+
img = img.resize(target_size, Image.LANCZOS)
|
55 |
+
|
56 |
+
np_img = np.array(img) / 255.0 # Normalize to [0, 1]
|
57 |
+
if standardize:
|
58 |
+
np_img = np_img * 2 - 1 # Scale to [-1, 1]
|
59 |
+
tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() # (C, H, W)
|
60 |
+
|
61 |
+
return tensor_img
|
62 |
+
|
63 |
+
|
64 |
+
def tensor_to_pil(chw_tensor: torch.Tensor, standardize=True, **kwarg) -> PIL.Image:
|
65 |
+
# Ensure detachment and move tensor to CPU.
|
66 |
+
detached_chw_tensor = chw_tensor.detach().cpu()
|
67 |
+
|
68 |
+
# Normalize tensor to [0, 1] range from [-1, 1] range.
|
69 |
+
if standardize:
|
70 |
+
normalized_chw_tensor = (
|
71 |
+
torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0
|
72 |
+
) / 2.0
|
73 |
+
else:
|
74 |
+
normalized_chw_tensor = torch.clamp(detached_chw_tensor, 0.0, 1.0)
|
75 |
+
|
76 |
+
# Permute CHW tensor to HWC format and convert to NumPy array.
|
77 |
+
hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
|
78 |
+
|
79 |
+
# Convert to an 8-bit unsigned integer format.
|
80 |
+
image_array_uint8 = (hwc_array * 255).astype(np.uint8)
|
81 |
+
|
82 |
+
# Convert NumPy array to PIL Image.
|
83 |
+
pil_image = Image.fromarray(image_array_uint8)
|
84 |
+
|
85 |
+
# Convert image to RGB if it is not already.
|
86 |
+
if pil_image.mode != "RGB":
|
87 |
+
pil_image = pil_image.convert("RGB")
|
88 |
+
|
89 |
+
return pil_image
|
src/model_processing.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
from .utils import get_ckpt, get_yaml_config
|
5 |
+
|
6 |
+
|
7 |
+
def download_ckpt_yaml(model_path, model_name, ckpt_path, yaml_url=None):
|
8 |
+
def download_file(url, save_path):
|
9 |
+
response = requests.get(url)
|
10 |
+
response.raise_for_status()
|
11 |
+
with open(save_path, 'wb') as f:
|
12 |
+
f.write(response.content)
|
13 |
+
|
14 |
+
# os.makedirs(model_path, exist_ok=True)
|
15 |
+
local_dir = os.path.join(model_path, model_name)
|
16 |
+
os.makedirs(local_dir, exist_ok=True)
|
17 |
+
|
18 |
+
ckpt_name = ckpt_path.split("/")[-1]
|
19 |
+
local_ckpt_path = os.path.join(local_dir, ckpt_name)
|
20 |
+
if not os.path.exists(local_ckpt_path):
|
21 |
+
print(f"Downloading CKPT to {local_ckpt_path}")
|
22 |
+
download_file(ckpt_path, local_ckpt_path)
|
23 |
+
|
24 |
+
if yaml_url:
|
25 |
+
yaml_name = yaml_url.split("/")[-1]
|
26 |
+
local_yaml_path = os.path.join(local_dir, yaml_name)
|
27 |
+
if not os.path.exists(local_yaml_path):
|
28 |
+
print(f"Downloading YAML to {local_yaml_path}")
|
29 |
+
download_file(yaml_url, local_yaml_path)
|
30 |
+
return local_ckpt_path, local_yaml_path
|
31 |
+
|
32 |
+
return local_ckpt_path, None
|
33 |
+
|
34 |
+
|
35 |
+
def get_model(model_path, model_name):
|
36 |
+
model = None
|
37 |
+
data_params = {
|
38 |
+
"target_image_size": (512, 512),
|
39 |
+
"lock_ratio": True,
|
40 |
+
"center_crop": True,
|
41 |
+
"padding": False,
|
42 |
+
}
|
43 |
+
|
44 |
+
if model_name.lower() == "anole":
|
45 |
+
from src.vqvaes.anole.anole import VQModel
|
46 |
+
yaml_url = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.yaml"
|
47 |
+
ckpt_path = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.ckpt"
|
48 |
+
|
49 |
+
if model_path is not None:
|
50 |
+
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "anole", ckpt_path, yaml_url)
|
51 |
+
config = get_yaml_config(yaml_url)
|
52 |
+
|
53 |
+
params = config["model"]["params"]
|
54 |
+
if "lossconfig" in params:
|
55 |
+
del params["lossconfig"]
|
56 |
+
params["ckpt_path"] = ckpt_path
|
57 |
+
model = VQModel(**params)
|
58 |
+
data_params = {
|
59 |
+
"target_image_size": (512, 512),
|
60 |
+
"lock_ratio": True,
|
61 |
+
"center_crop": True,
|
62 |
+
"padding": False,
|
63 |
+
}
|
64 |
+
|
65 |
+
elif model_name.lower() == "chameleon":
|
66 |
+
from src.vqvaes.anole.anole import VQModel
|
67 |
+
|
68 |
+
yaml_url = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.yaml"
|
69 |
+
ckpt_path = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.ckpt"
|
70 |
+
if model_path is not None:
|
71 |
+
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "chameleon", ckpt_path, yaml_url)
|
72 |
+
config = get_yaml_config(yaml_url)
|
73 |
+
|
74 |
+
params = config["model"]["params"]
|
75 |
+
if "lossconfig" in params:
|
76 |
+
del params["lossconfig"]
|
77 |
+
params["ckpt_path"] = ckpt_path
|
78 |
+
model = VQModel(**params)
|
79 |
+
data_params = {
|
80 |
+
"target_image_size": (512, 512),
|
81 |
+
"lock_ratio": True,
|
82 |
+
"center_crop": True,
|
83 |
+
"padding": False,
|
84 |
+
}
|
85 |
+
|
86 |
+
elif model_name.lower() == "llamagen-ds16":
|
87 |
+
from src.vqvaes.llamagen.llamagen import VQ_models
|
88 |
+
ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt"
|
89 |
+
if model_path is not None:
|
90 |
+
ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16", ckpt_path, None)
|
91 |
+
|
92 |
+
model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8)
|
93 |
+
model.load_state_dict(get_ckpt(ckpt_path, key="model"))
|
94 |
+
data_params = {
|
95 |
+
"target_image_size": (512, 512),
|
96 |
+
"lock_ratio": True,
|
97 |
+
"center_crop": True,
|
98 |
+
"padding": False,
|
99 |
+
}
|
100 |
+
|
101 |
+
elif model_name.lower() == "llamagen-ds16-t2i":
|
102 |
+
from src.vqvaes.llamagen.llamagen import VQ_models
|
103 |
+
ckpt_path = "https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt"
|
104 |
+
if model_path is not None:
|
105 |
+
ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16-t2i", ckpt_path, None)
|
106 |
+
|
107 |
+
model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8)
|
108 |
+
model.load_state_dict(get_ckpt(ckpt_path, key="model"))
|
109 |
+
data_params = {
|
110 |
+
"target_image_size": (512, 512),
|
111 |
+
"lock_ratio": True,
|
112 |
+
"center_crop": True,
|
113 |
+
"padding": False,
|
114 |
+
}
|
115 |
+
|
116 |
+
elif model_name.lower() == "llamagen-ds8":
|
117 |
+
from src.vqvaes.llamagen.llamagen import VQ_models
|
118 |
+
ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds8_c2i.pt"
|
119 |
+
if model_path is not None:
|
120 |
+
ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds8", ckpt_path, None)
|
121 |
+
|
122 |
+
model = VQ_models["VQ-8"](codebook_size=16384, codebook_embed_dim=8)
|
123 |
+
model.load_state_dict(get_ckpt(ckpt_path, key="model"))
|
124 |
+
data_params = {
|
125 |
+
"target_image_size": (256, 256),
|
126 |
+
"lock_ratio": True,
|
127 |
+
"center_crop": True,
|
128 |
+
"padding": False,
|
129 |
+
}
|
130 |
+
|
131 |
+
elif model_name.lower() == "flowmo_lo":
|
132 |
+
from src.vqvaes.flowmo.flowmo import build_model
|
133 |
+
yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml"
|
134 |
+
ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_lo.pth"
|
135 |
+
if model_path is not None:
|
136 |
+
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_lo", ckpt_path, yaml_url)
|
137 |
+
config = get_yaml_config(yaml_url)
|
138 |
+
|
139 |
+
config.model.context_dim = 18
|
140 |
+
model = build_model(config)
|
141 |
+
model.load_state_dict(
|
142 |
+
get_ckpt(ckpt_path, key="model_ema_state_dict")
|
143 |
+
)
|
144 |
+
data_params = {
|
145 |
+
"target_image_size": (256, 256),
|
146 |
+
"lock_ratio": True,
|
147 |
+
"center_crop": True,
|
148 |
+
"padding": False,
|
149 |
+
}
|
150 |
+
|
151 |
+
elif model_name.lower() == "flowmo_hi":
|
152 |
+
from src.vqvaes.flowmo.flowmo import build_model
|
153 |
+
|
154 |
+
yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml"
|
155 |
+
ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_hi.pth"
|
156 |
+
if model_path is not None:
|
157 |
+
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_hi", ckpt_path, yaml_url)
|
158 |
+
config = get_yaml_config(yaml_url)
|
159 |
+
|
160 |
+
config.model.context_dim = 56
|
161 |
+
config.model.codebook_size_for_entropy = 14
|
162 |
+
model = build_model(config)
|
163 |
+
model.load_state_dict(
|
164 |
+
get_ckpt(ckpt_path, key="model_ema_state_dict")
|
165 |
+
)
|
166 |
+
data_params = {
|
167 |
+
"target_image_size": (256, 256),
|
168 |
+
"lock_ratio": True,
|
169 |
+
"center_crop": True,
|
170 |
+
"padding": False,
|
171 |
+
}
|
172 |
+
|
173 |
+
elif model_name.lower() == "open_magvit2":
|
174 |
+
from src.vqvaes.open_magvit2.open_magvit2 import VQModel
|
175 |
+
|
176 |
+
yaml_url = "https://raw.githubusercontent.com/TencentARC/SEED-Voken/refs/heads/main/configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml"
|
177 |
+
ckpt_path = "https://huggingface.co/TencentARC/Open-MAGVIT2-Tokenizer-256-resolution/resolve/main/imagenet_256_L.ckpt"
|
178 |
+
if model_path is not None:
|
179 |
+
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "open_magvit2", ckpt_path, yaml_url)
|
180 |
+
config = get_yaml_config(yaml_url)
|
181 |
+
|
182 |
+
model = VQModel(**config.model.init_args)
|
183 |
+
model.load_state_dict(get_ckpt(ckpt_path, key="state_dict"))
|
184 |
+
data_params = {
|
185 |
+
"target_image_size": (256, 256),
|
186 |
+
"lock_ratio": True,
|
187 |
+
"center_crop": True,
|
188 |
+
"padding": False,
|
189 |
+
}
|
190 |
+
|
191 |
+
elif "maskbit" in model_name.lower():
|
192 |
+
from src.vqvaes.maskbit.maskbit import ConvVQModel
|
193 |
+
|
194 |
+
if "16bit" in model_name.lower():
|
195 |
+
yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_16bit.yaml"
|
196 |
+
ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_16bit/resolve/main/maskbit_tokenizer_16bit.bin"
|
197 |
+
if model_path is not None:
|
198 |
+
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-16bit", ckpt_path, yaml_url)
|
199 |
+
elif "18bit" in model_name.lower():
|
200 |
+
yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_18bit.yaml"
|
201 |
+
ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_18bit/resolve/main/maskbit_tokenizer_18bit.bin"
|
202 |
+
if model_path is not None:
|
203 |
+
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-18bit", ckpt_path, yaml_url)
|
204 |
+
else:
|
205 |
+
raise Exception(f"Unsupported model: {model_name}")
|
206 |
+
|
207 |
+
config = get_yaml_config(yaml_url)
|
208 |
+
model = ConvVQModel(config.model.vq_model, legacy=False)
|
209 |
+
model.load_pretrained(get_ckpt(ckpt_path, key=None))
|
210 |
+
data_params = {
|
211 |
+
"target_image_size": (256, 256),
|
212 |
+
"lock_ratio": True,
|
213 |
+
"center_crop": True,
|
214 |
+
"padding": False,
|
215 |
+
"standardize": False,
|
216 |
+
}
|
217 |
+
|
218 |
+
elif "bsqvit" in model_name.lower():
|
219 |
+
from src.vqvaes.bsqvit.bsqvit import VITBSQModel
|
220 |
+
|
221 |
+
yaml_url = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/config.yaml"
|
222 |
+
ckpt_path = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/checkpoint.pt"
|
223 |
+
if model_path is not None:
|
224 |
+
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "bsqvit", ckpt_path, yaml_url)
|
225 |
+
|
226 |
+
config = get_yaml_config(yaml_url)
|
227 |
+
model = VITBSQModel(**config["model"]["params"])
|
228 |
+
model.init_from_ckpt(get_ckpt(ckpt_path, key="state_dict"))
|
229 |
+
data_params = {
|
230 |
+
"target_image_size": (256, 256),
|
231 |
+
"lock_ratio": True,
|
232 |
+
"center_crop": True,
|
233 |
+
"padding": False,
|
234 |
+
"standardize": False,
|
235 |
+
}
|
236 |
+
|
237 |
+
elif "titok" in model_name.lower():
|
238 |
+
from src.vqvaes.titok.titok import TiTok
|
239 |
+
|
240 |
+
ckpt_path = None
|
241 |
+
if "bl64" in model_name.lower():
|
242 |
+
ckpt_path = "yucornetto/tokenizer_titok_bl64_vq8k_imagenet"
|
243 |
+
elif "bl128" in model_name.lower():
|
244 |
+
ckpt_path = "yucornetto/tokenizer_titok_bl128_vq8k_imagenet"
|
245 |
+
elif "sl256" in model_name.lower():
|
246 |
+
ckpt_path = "yucornetto/tokenizer_titok_sl256_vq8k_imagenet"
|
247 |
+
elif "l32" in model_name.lower():
|
248 |
+
ckpt_path = "yucornetto/tokenizer_titok_l32_imagenet"
|
249 |
+
elif "b64" in model_name.lower():
|
250 |
+
ckpt_path = "yucornetto/tokenizer_titok_b64_imagenet"
|
251 |
+
elif "s128" in model_name.lower():
|
252 |
+
ckpt_path = "yucornetto/tokenizer_titok_s128_imagenet"
|
253 |
+
else:
|
254 |
+
raise Exception(f"Unsupported model: {model_name}")
|
255 |
+
|
256 |
+
model = TiTok.from_pretrained(ckpt_path)
|
257 |
+
data_params = {
|
258 |
+
"target_image_size": (256, 256),
|
259 |
+
"lock_ratio": True,
|
260 |
+
"center_crop": True,
|
261 |
+
"padding": False,
|
262 |
+
"standardize": False,
|
263 |
+
}
|
264 |
+
|
265 |
+
elif "janus_pro" in model_name.lower():
|
266 |
+
from janus.models import MultiModalityCausalLM
|
267 |
+
from src.vqvaes.janus_pro.janus_pro import forward
|
268 |
+
import types
|
269 |
+
|
270 |
+
model = MultiModalityCausalLM.from_pretrained(
|
271 |
+
"deepseek-ai/Janus-Pro-7B", trust_remote_code=True
|
272 |
+
).gen_vision_model
|
273 |
+
model.forward = types.MethodType(forward, model)
|
274 |
+
data_params = {
|
275 |
+
"target_image_size": (384, 384),
|
276 |
+
"lock_ratio": True,
|
277 |
+
"center_crop": False,
|
278 |
+
"padding": True,
|
279 |
+
}
|
280 |
+
|
281 |
+
elif "var" in model_name.lower():
|
282 |
+
from src.vqvaes.var.var_vq import VQVAE
|
283 |
+
|
284 |
+
ckpt_path = "https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth"
|
285 |
+
if model_path is not None:
|
286 |
+
ckpt_path, _ = download_ckpt_yaml(model_path, "var", ckpt_path, None)
|
287 |
+
|
288 |
+
v_patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
|
289 |
+
if "512" in model_name.lower():
|
290 |
+
v_patch_nums = (1, 2, 3, 4, 6, 9, 13, 18, 24, 32)
|
291 |
+
model = VQVAE(
|
292 |
+
vocab_size=4096,
|
293 |
+
z_channels=32,
|
294 |
+
ch=160,
|
295 |
+
test_mode=True,
|
296 |
+
share_quant_resi=4,
|
297 |
+
v_patch_nums=v_patch_nums,
|
298 |
+
)
|
299 |
+
model.load_state_dict(get_ckpt(ckpt_path, key=None))
|
300 |
+
data_params = {
|
301 |
+
"target_image_size": (
|
302 |
+
(512, 512) if "512" in model_name.lower() else (256, 256)
|
303 |
+
),
|
304 |
+
"lock_ratio": True,
|
305 |
+
"center_crop": False,
|
306 |
+
"padding": True,
|
307 |
+
"standardize": False,
|
308 |
+
}
|
309 |
+
|
310 |
+
elif (
|
311 |
+
"infinity" in model_name.lower()
|
312 |
+
): # "infinity_d32", "infinity_d64", "infinity_d56_f8_14_patchify"
|
313 |
+
from src.vqvaes.infinity.vae import vae_model
|
314 |
+
|
315 |
+
if "d32" in model_name:
|
316 |
+
ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d32.pth"
|
317 |
+
codebook_dim = 32
|
318 |
+
if model_path is not None:
|
319 |
+
ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d32", ckpt_path, None)
|
320 |
+
elif "d64" in model_name:
|
321 |
+
ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d64.pth"
|
322 |
+
codebook_dim = 64
|
323 |
+
if model_path is not None:
|
324 |
+
ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d64", ckpt_path, None)
|
325 |
+
|
326 |
+
schedule_mode = "dynamic"
|
327 |
+
codebook_size = 2**codebook_dim
|
328 |
+
patch_size = 16
|
329 |
+
encoder_ch_mult = [1, 2, 4, 4, 4]
|
330 |
+
decoder_ch_mult = [1, 2, 4, 4, 4]
|
331 |
+
|
332 |
+
ckpt = get_ckpt(ckpt_path, key=None)
|
333 |
+
model = vae_model(
|
334 |
+
ckpt,
|
335 |
+
schedule_mode,
|
336 |
+
codebook_dim,
|
337 |
+
codebook_size,
|
338 |
+
patch_size=patch_size,
|
339 |
+
encoder_ch_mult=encoder_ch_mult,
|
340 |
+
decoder_ch_mult=decoder_ch_mult,
|
341 |
+
test_mode=True,
|
342 |
+
)
|
343 |
+
|
344 |
+
data_params = {
|
345 |
+
"target_image_size": (1024, 1024),
|
346 |
+
"lock_ratio": True,
|
347 |
+
"center_crop": False,
|
348 |
+
"padding": True,
|
349 |
+
"standardize": False,
|
350 |
+
}
|
351 |
+
|
352 |
+
elif "sd3.5l" in model_name.lower(): # SD3.5L
|
353 |
+
from src.vaes.stable_diffusion.vae import forward
|
354 |
+
from diffusers import AutoencoderKL
|
355 |
+
import types
|
356 |
+
|
357 |
+
model = AutoencoderKL.from_pretrained(
|
358 |
+
"huaweilin/stable-diffusion-3.5-large-vae", subfolder="vae"
|
359 |
+
)
|
360 |
+
model.forward = types.MethodType(forward, model)
|
361 |
+
data_params = {
|
362 |
+
"target_image_size": (1024, 1024),
|
363 |
+
"lock_ratio": True,
|
364 |
+
"center_crop": False,
|
365 |
+
"padding": True,
|
366 |
+
"standardize": True,
|
367 |
+
}
|
368 |
+
|
369 |
+
elif "FLUX.1-dev".lower() in model_name.lower(): # SD3.5L
|
370 |
+
from src.vaes.stable_diffusion.vae import forward
|
371 |
+
from diffusers import AutoencoderKL
|
372 |
+
import types
|
373 |
+
|
374 |
+
model = AutoencoderKL.from_pretrained(
|
375 |
+
"black-forest-labs/FLUX.1-dev", subfolder="vae"
|
376 |
+
)
|
377 |
+
model.forward = types.MethodType(forward, model)
|
378 |
+
data_params = {
|
379 |
+
"target_image_size": (1024, 1024),
|
380 |
+
"lock_ratio": True,
|
381 |
+
"center_crop": False,
|
382 |
+
"padding": True,
|
383 |
+
"standardize": True,
|
384 |
+
}
|
385 |
+
|
386 |
+
elif "gpt4o" in model_name.lower():
|
387 |
+
from src.vaes.gpt_image.gpt_image import GPTImage
|
388 |
+
|
389 |
+
data_params = {
|
390 |
+
"target_image_size": (1024, 1024),
|
391 |
+
"lock_ratio": True,
|
392 |
+
"center_crop": False,
|
393 |
+
"padding": True,
|
394 |
+
"standardize": False,
|
395 |
+
}
|
396 |
+
model = GPTImage(data_params)
|
397 |
+
|
398 |
+
else:
|
399 |
+
raise Exception(f"Unsupported model: \"{model_name}\"")
|
400 |
+
|
401 |
+
try:
|
402 |
+
trainable_params = sum(p.numel() for p in model.parameters())
|
403 |
+
print("trainable_params:", trainable_params)
|
404 |
+
except Exception as e:
|
405 |
+
print(e)
|
406 |
+
pass
|
407 |
+
|
408 |
+
model.eval()
|
409 |
+
return model, data_params
|
src/utils.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import torch
|
4 |
+
import tempfile
|
5 |
+
from safetensors.torch import load_file
|
6 |
+
import requests
|
7 |
+
import yaml
|
8 |
+
|
9 |
+
def get_ckpt(path, key="state_dict"):
|
10 |
+
is_url = path.startswith("http://") or path.startswith("https://")
|
11 |
+
suffix = os.path.splitext(path)[-1]
|
12 |
+
|
13 |
+
if is_url:
|
14 |
+
print(f"Loading checkpoint from URL: {path}")
|
15 |
+
with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file:
|
16 |
+
response = requests.get(path)
|
17 |
+
response.raise_for_status()
|
18 |
+
tmp_file.write(response.content)
|
19 |
+
tmp_file.flush()
|
20 |
+
ckpt_path = tmp_file.name
|
21 |
+
|
22 |
+
if suffix == ".safetensors":
|
23 |
+
checkpoint = load_file(ckpt_path)
|
24 |
+
else:
|
25 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
26 |
+
else:
|
27 |
+
print(f"Loading checkpoint from local path: {path}")
|
28 |
+
if suffix == ".safetensors":
|
29 |
+
checkpoint = load_file(path)
|
30 |
+
else:
|
31 |
+
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
32 |
+
|
33 |
+
if key is not None and key in checkpoint:
|
34 |
+
checkpoint = checkpoint[key]
|
35 |
+
|
36 |
+
return checkpoint
|
37 |
+
|
38 |
+
|
39 |
+
def get_yaml_config(path):
|
40 |
+
if path.startswith("http://") or path.startswith("https://"):
|
41 |
+
response = requests.get(path)
|
42 |
+
response.raise_for_status()
|
43 |
+
config = OmegaConf.create(response.text)
|
44 |
+
else:
|
45 |
+
with open(path, 'r') as f:
|
46 |
+
config = OmegaConf.load(f)
|
47 |
+
return config
|
src/vaes/gpt_image/gpt_image.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from torchvision.transforms.functional import to_pil_image
|
3 |
+
from openai import OpenAI
|
4 |
+
import io
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from ...data_processing import tensor_to_pil, pil_to_tensor
|
9 |
+
|
10 |
+
|
11 |
+
class GPTImage:
|
12 |
+
def __init__(self, data_params):
|
13 |
+
self.client = OpenAI(organization="org-xZTnLOf1k9s04LEoKKjl4jOB")
|
14 |
+
self.prompt = "Please recreate the exact same image without any alterations. Please preserve the original resolution (1024*1024)."
|
15 |
+
self.data_params = data_params
|
16 |
+
|
17 |
+
def eval(self):
|
18 |
+
pass
|
19 |
+
|
20 |
+
def __call__(self, *args, **kwargs):
|
21 |
+
return self.forward(*args, **kwargs)
|
22 |
+
|
23 |
+
def forward(self, input):
|
24 |
+
results = []
|
25 |
+
for image in input:
|
26 |
+
image = tensor_to_pil(image, **self.data_params)
|
27 |
+
buffer = io.BytesIO()
|
28 |
+
image.save(buffer, format="PNG")
|
29 |
+
buffer.seek(0)
|
30 |
+
image_file = ("image.png", buffer, "image/png")
|
31 |
+
|
32 |
+
try:
|
33 |
+
result = self.client.images.edit(
|
34 |
+
model="gpt-image-1",
|
35 |
+
image=image_file,
|
36 |
+
prompt=self.prompt,
|
37 |
+
n=1,
|
38 |
+
size="1024x1024",
|
39 |
+
)
|
40 |
+
image_base64 = result.data[0].b64_json
|
41 |
+
image_bytes = base64.b64decode(image_base64)
|
42 |
+
image = Image.open(io.BytesIO(image_bytes))
|
43 |
+
results.append(pil_to_tensor(image, **self.data_params))
|
44 |
+
except Exception as e:
|
45 |
+
print("💥 Unexpected error occurred:", e)
|
46 |
+
results.append(None)
|
47 |
+
|
48 |
+
return results, None, None
|
src/vaes/stable_diffusion/vae.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def forward(
|
2 |
+
self,
|
3 |
+
sample,
|
4 |
+
sample_posterior=False,
|
5 |
+
return_dict=True,
|
6 |
+
generator=None,
|
7 |
+
):
|
8 |
+
r"""
|
9 |
+
Args:
|
10 |
+
sample (`torch.Tensor`): Input sample.
|
11 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
12 |
+
Whether to sample from the posterior.
|
13 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
14 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
15 |
+
"""
|
16 |
+
x = sample
|
17 |
+
posterior = self.encode(x).latent_dist
|
18 |
+
if sample_posterior:
|
19 |
+
z = posterior.sample(generator=generator)
|
20 |
+
else:
|
21 |
+
z = posterior.mode()
|
22 |
+
dec = self.decode(z).sample
|
23 |
+
return dec, None, None
|
src/vqvaes/__init__.py
ADDED
File without changes
|
src/vqvaes/anole/anole.py
ADDED
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
Contents of this file are taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/models/vqgan.py
|
8 |
+
[with minimal dependencies]
|
9 |
+
|
10 |
+
This implementation is inference-only -- training steps and optimizer components
|
11 |
+
introduce significant additional dependencies
|
12 |
+
"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from ...utils import get_ckpt
|
19 |
+
|
20 |
+
|
21 |
+
class VectorQuantizer2(nn.Module):
|
22 |
+
"""
|
23 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
24 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
25 |
+
"""
|
26 |
+
|
27 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
28 |
+
# backwards compatibility we use the buggy version by default, but you can
|
29 |
+
# specify legacy=False to fix it.
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
n_e,
|
33 |
+
e_dim,
|
34 |
+
beta,
|
35 |
+
remap=None,
|
36 |
+
unknown_index="random",
|
37 |
+
sane_index_shape=False,
|
38 |
+
legacy=True,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.n_e = n_e
|
42 |
+
self.e_dim = e_dim
|
43 |
+
self.beta = beta
|
44 |
+
self.legacy = legacy
|
45 |
+
|
46 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
47 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
48 |
+
|
49 |
+
self.remap = remap
|
50 |
+
if self.remap is not None:
|
51 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
52 |
+
self.re_embed = self.used.shape[0]
|
53 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
54 |
+
if self.unknown_index == "extra":
|
55 |
+
self.unknown_index = self.re_embed
|
56 |
+
self.re_embed = self.re_embed + 1
|
57 |
+
print(
|
58 |
+
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
59 |
+
f"Using {self.unknown_index} for unknown indices."
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
self.re_embed = n_e
|
63 |
+
|
64 |
+
self.sane_index_shape = sane_index_shape
|
65 |
+
|
66 |
+
def remap_to_used(self, inds):
|
67 |
+
ishape = inds.shape
|
68 |
+
assert len(ishape) > 1
|
69 |
+
inds = inds.reshape(ishape[0], -1)
|
70 |
+
used = self.used.to(inds)
|
71 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
72 |
+
new = match.argmax(-1)
|
73 |
+
unknown = match.sum(2) < 1
|
74 |
+
if self.unknown_index == "random":
|
75 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
76 |
+
device=new.device
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
new[unknown] = self.unknown_index
|
80 |
+
return new.reshape(ishape)
|
81 |
+
|
82 |
+
def unmap_to_all(self, inds):
|
83 |
+
ishape = inds.shape
|
84 |
+
assert len(ishape) > 1
|
85 |
+
inds = inds.reshape(ishape[0], -1)
|
86 |
+
used = self.used.to(inds)
|
87 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
88 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
89 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
90 |
+
return back.reshape(ishape)
|
91 |
+
|
92 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
93 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
94 |
+
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
95 |
+
assert return_logits is False, "Only for interface compatible with Gumbel"
|
96 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
97 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
98 |
+
z_flattened = z.view(-1, self.e_dim)
|
99 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
100 |
+
|
101 |
+
d = (
|
102 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
103 |
+
+ torch.sum(self.embedding.weight**2, dim=1)
|
104 |
+
- 2
|
105 |
+
* torch.einsum(
|
106 |
+
"bd,dn->bn", z_flattened, self.embedding.weight.transpose(0, 1)
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
111 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
112 |
+
perplexity = None
|
113 |
+
min_encodings = None
|
114 |
+
|
115 |
+
# compute loss for embedding
|
116 |
+
if not self.legacy:
|
117 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
118 |
+
(z_q - z.detach()) ** 2
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
122 |
+
(z_q - z.detach()) ** 2
|
123 |
+
)
|
124 |
+
|
125 |
+
# preserve gradients
|
126 |
+
z_q = z + (z_q - z).detach()
|
127 |
+
|
128 |
+
# reshape back to match original input shape
|
129 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
130 |
+
|
131 |
+
if self.remap is not None:
|
132 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
133 |
+
z.shape[0], -1
|
134 |
+
) # add batch axis
|
135 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
136 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
137 |
+
|
138 |
+
if self.sane_index_shape:
|
139 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
140 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
141 |
+
)
|
142 |
+
|
143 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
144 |
+
|
145 |
+
def get_codebook_entry(self, indices, shape):
|
146 |
+
# shape specifying (batch, height, width, channel)
|
147 |
+
if self.remap is not None:
|
148 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
149 |
+
indices = self.unmap_to_all(indices)
|
150 |
+
indices = indices.reshape(-1) # flatten again
|
151 |
+
|
152 |
+
# get quantized latent vectors
|
153 |
+
z_q = self.embedding(indices)
|
154 |
+
|
155 |
+
if shape is not None:
|
156 |
+
z_q = z_q.view(shape)
|
157 |
+
# reshape back to match original input shape
|
158 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
159 |
+
|
160 |
+
return z_q
|
161 |
+
|
162 |
+
|
163 |
+
# Alias
|
164 |
+
VectorQuantizer = VectorQuantizer2
|
165 |
+
|
166 |
+
|
167 |
+
def nonlinearity(x):
|
168 |
+
# swish
|
169 |
+
return x * torch.sigmoid(x)
|
170 |
+
|
171 |
+
|
172 |
+
def Normalize(in_channels, num_groups=32):
|
173 |
+
return torch.nn.GroupNorm(
|
174 |
+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
class Upsample(nn.Module):
|
179 |
+
def __init__(self, in_channels, with_conv):
|
180 |
+
super().__init__()
|
181 |
+
self.with_conv = with_conv
|
182 |
+
if self.with_conv:
|
183 |
+
self.conv = torch.nn.Conv2d(
|
184 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
185 |
+
)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
189 |
+
if self.with_conv:
|
190 |
+
x = self.conv(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class Downsample(nn.Module):
|
195 |
+
def __init__(self, in_channels, with_conv):
|
196 |
+
super().__init__()
|
197 |
+
self.with_conv = with_conv
|
198 |
+
if self.with_conv:
|
199 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
200 |
+
self.conv = torch.nn.Conv2d(
|
201 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
if self.with_conv:
|
206 |
+
pad = (0, 1, 0, 1)
|
207 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
208 |
+
x = self.conv(x)
|
209 |
+
else:
|
210 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
211 |
+
return x
|
212 |
+
|
213 |
+
|
214 |
+
class ResnetBlock(nn.Module):
|
215 |
+
def __init__(
|
216 |
+
self,
|
217 |
+
*,
|
218 |
+
in_channels,
|
219 |
+
out_channels=None,
|
220 |
+
conv_shortcut=False,
|
221 |
+
dropout,
|
222 |
+
temb_channels=512,
|
223 |
+
):
|
224 |
+
super().__init__()
|
225 |
+
self.in_channels = in_channels
|
226 |
+
out_channels = in_channels if out_channels is None else out_channels
|
227 |
+
self.out_channels = out_channels
|
228 |
+
self.use_conv_shortcut = conv_shortcut
|
229 |
+
|
230 |
+
self.norm1 = Normalize(in_channels)
|
231 |
+
self.conv1 = torch.nn.Conv2d(
|
232 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
233 |
+
)
|
234 |
+
if temb_channels > 0:
|
235 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
236 |
+
self.norm2 = Normalize(out_channels)
|
237 |
+
self.dropout = torch.nn.Dropout(dropout)
|
238 |
+
self.conv2 = torch.nn.Conv2d(
|
239 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
240 |
+
)
|
241 |
+
if self.in_channels != self.out_channels:
|
242 |
+
if self.use_conv_shortcut:
|
243 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
244 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
245 |
+
)
|
246 |
+
else:
|
247 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
248 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
249 |
+
)
|
250 |
+
|
251 |
+
def forward(self, x, temb):
|
252 |
+
h = x
|
253 |
+
h = self.norm1(h)
|
254 |
+
h = nonlinearity(h)
|
255 |
+
h = self.conv1(h)
|
256 |
+
|
257 |
+
if temb is not None:
|
258 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
259 |
+
|
260 |
+
h = self.norm2(h)
|
261 |
+
h = nonlinearity(h)
|
262 |
+
h = self.dropout(h)
|
263 |
+
h = self.conv2(h)
|
264 |
+
|
265 |
+
if self.in_channels != self.out_channels:
|
266 |
+
if self.use_conv_shortcut:
|
267 |
+
x = self.conv_shortcut(x)
|
268 |
+
else:
|
269 |
+
x = self.nin_shortcut(x)
|
270 |
+
|
271 |
+
return x + h
|
272 |
+
|
273 |
+
|
274 |
+
class AttnBlock(nn.Module):
|
275 |
+
def __init__(self, in_channels):
|
276 |
+
super().__init__()
|
277 |
+
self.in_channels = in_channels
|
278 |
+
|
279 |
+
self.norm = Normalize(in_channels)
|
280 |
+
self.q = torch.nn.Conv2d(
|
281 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
282 |
+
)
|
283 |
+
self.k = torch.nn.Conv2d(
|
284 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
285 |
+
)
|
286 |
+
self.v = torch.nn.Conv2d(
|
287 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
288 |
+
)
|
289 |
+
self.proj_out = torch.nn.Conv2d(
|
290 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
291 |
+
)
|
292 |
+
|
293 |
+
def forward(self, x):
|
294 |
+
h_ = x
|
295 |
+
h_ = self.norm(h_)
|
296 |
+
q = self.q(h_)
|
297 |
+
k = self.k(h_)
|
298 |
+
v = self.v(h_)
|
299 |
+
|
300 |
+
# compute attention
|
301 |
+
b, c, h, w = q.shape
|
302 |
+
q = q.reshape(b, c, h * w)
|
303 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
304 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
305 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
306 |
+
w_ = w_ * (int(c) ** (-0.5))
|
307 |
+
w_ = F.softmax(w_, dim=2)
|
308 |
+
|
309 |
+
# attend to values
|
310 |
+
v = v.reshape(b, c, h * w)
|
311 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
312 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
313 |
+
h_ = h_.reshape(b, c, h, w)
|
314 |
+
|
315 |
+
h_ = self.proj_out(h_)
|
316 |
+
|
317 |
+
return x + h_
|
318 |
+
|
319 |
+
|
320 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
321 |
+
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
|
322 |
+
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
323 |
+
if attn_type == "vanilla":
|
324 |
+
return AttnBlock(in_channels)
|
325 |
+
elif attn_type == "none":
|
326 |
+
return nn.Identity(in_channels)
|
327 |
+
else:
|
328 |
+
raise ValueError("Unexpected attention type")
|
329 |
+
|
330 |
+
|
331 |
+
class Encoder(nn.Module):
|
332 |
+
def __init__(
|
333 |
+
self,
|
334 |
+
*,
|
335 |
+
ch,
|
336 |
+
out_ch,
|
337 |
+
ch_mult=(1, 2, 4, 8),
|
338 |
+
num_res_blocks,
|
339 |
+
attn_resolutions,
|
340 |
+
dropout=0.0,
|
341 |
+
resamp_with_conv=True,
|
342 |
+
in_channels,
|
343 |
+
resolution,
|
344 |
+
z_channels,
|
345 |
+
double_z=True,
|
346 |
+
use_linear_attn=False,
|
347 |
+
attn_type="vanilla",
|
348 |
+
**ignore_kwargs,
|
349 |
+
):
|
350 |
+
super().__init__()
|
351 |
+
if use_linear_attn:
|
352 |
+
attn_type = "linear"
|
353 |
+
self.ch = ch
|
354 |
+
self.temb_ch = 0
|
355 |
+
self.num_resolutions = len(ch_mult)
|
356 |
+
self.num_res_blocks = num_res_blocks
|
357 |
+
self.resolution = resolution
|
358 |
+
self.in_channels = in_channels
|
359 |
+
|
360 |
+
# downsampling
|
361 |
+
self.conv_in = torch.nn.Conv2d(
|
362 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
363 |
+
)
|
364 |
+
|
365 |
+
curr_res = resolution
|
366 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
367 |
+
self.in_ch_mult = in_ch_mult
|
368 |
+
self.down = nn.ModuleList()
|
369 |
+
for i_level in range(self.num_resolutions):
|
370 |
+
block = nn.ModuleList()
|
371 |
+
attn = nn.ModuleList()
|
372 |
+
block_in = ch * in_ch_mult[i_level]
|
373 |
+
block_out = ch * ch_mult[i_level]
|
374 |
+
for i_block in range(self.num_res_blocks):
|
375 |
+
block.append(
|
376 |
+
ResnetBlock(
|
377 |
+
in_channels=block_in,
|
378 |
+
out_channels=block_out,
|
379 |
+
temb_channels=self.temb_ch,
|
380 |
+
dropout=dropout,
|
381 |
+
)
|
382 |
+
)
|
383 |
+
block_in = block_out
|
384 |
+
if curr_res in attn_resolutions:
|
385 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
386 |
+
down = nn.Module()
|
387 |
+
down.block = block
|
388 |
+
down.attn = attn
|
389 |
+
if i_level != self.num_resolutions - 1:
|
390 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
391 |
+
curr_res = curr_res // 2
|
392 |
+
self.down.append(down)
|
393 |
+
|
394 |
+
# middle
|
395 |
+
self.mid = nn.Module()
|
396 |
+
self.mid.block_1 = ResnetBlock(
|
397 |
+
in_channels=block_in,
|
398 |
+
out_channels=block_in,
|
399 |
+
temb_channels=self.temb_ch,
|
400 |
+
dropout=dropout,
|
401 |
+
)
|
402 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
403 |
+
self.mid.block_2 = ResnetBlock(
|
404 |
+
in_channels=block_in,
|
405 |
+
out_channels=block_in,
|
406 |
+
temb_channels=self.temb_ch,
|
407 |
+
dropout=dropout,
|
408 |
+
)
|
409 |
+
|
410 |
+
# end
|
411 |
+
self.norm_out = Normalize(block_in)
|
412 |
+
self.conv_out = torch.nn.Conv2d(
|
413 |
+
block_in,
|
414 |
+
2 * z_channels if double_z else z_channels,
|
415 |
+
kernel_size=3,
|
416 |
+
stride=1,
|
417 |
+
padding=1,
|
418 |
+
)
|
419 |
+
|
420 |
+
def forward(self, x):
|
421 |
+
# timestep embedding
|
422 |
+
temb = None
|
423 |
+
|
424 |
+
# downsampling
|
425 |
+
hs = [self.conv_in(x)]
|
426 |
+
for i_level in range(self.num_resolutions):
|
427 |
+
for i_block in range(self.num_res_blocks):
|
428 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
429 |
+
if len(self.down[i_level].attn) > 0:
|
430 |
+
h = self.down[i_level].attn[i_block](h)
|
431 |
+
hs.append(h)
|
432 |
+
if i_level != self.num_resolutions - 1:
|
433 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
434 |
+
|
435 |
+
# middle
|
436 |
+
h = hs[-1]
|
437 |
+
h = self.mid.block_1(h, temb)
|
438 |
+
h = self.mid.attn_1(h)
|
439 |
+
h = self.mid.block_2(h, temb)
|
440 |
+
|
441 |
+
# end
|
442 |
+
h = self.norm_out(h)
|
443 |
+
h = nonlinearity(h)
|
444 |
+
h = self.conv_out(h)
|
445 |
+
return h
|
446 |
+
|
447 |
+
|
448 |
+
class Decoder(nn.Module):
|
449 |
+
def __init__(
|
450 |
+
self,
|
451 |
+
*,
|
452 |
+
ch,
|
453 |
+
out_ch,
|
454 |
+
ch_mult=(1, 2, 4, 8),
|
455 |
+
num_res_blocks,
|
456 |
+
attn_resolutions,
|
457 |
+
dropout=0.0,
|
458 |
+
resamp_with_conv=True,
|
459 |
+
in_channels,
|
460 |
+
resolution,
|
461 |
+
z_channels,
|
462 |
+
give_pre_end=False,
|
463 |
+
tanh_out=False,
|
464 |
+
use_linear_attn=False,
|
465 |
+
attn_type="vanilla",
|
466 |
+
**ignorekwargs,
|
467 |
+
):
|
468 |
+
super().__init__()
|
469 |
+
if use_linear_attn:
|
470 |
+
attn_type = "linear"
|
471 |
+
self.ch = ch
|
472 |
+
self.temb_ch = 0
|
473 |
+
self.num_resolutions = len(ch_mult)
|
474 |
+
self.num_res_blocks = num_res_blocks
|
475 |
+
self.resolution = resolution
|
476 |
+
self.in_channels = in_channels
|
477 |
+
self.give_pre_end = give_pre_end
|
478 |
+
self.tanh_out = tanh_out
|
479 |
+
|
480 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
481 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
482 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
483 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
484 |
+
|
485 |
+
# z to block_in
|
486 |
+
self.conv_in = torch.nn.Conv2d(
|
487 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
488 |
+
)
|
489 |
+
|
490 |
+
# middle
|
491 |
+
self.mid = nn.Module()
|
492 |
+
self.mid.block_1 = ResnetBlock(
|
493 |
+
in_channels=block_in,
|
494 |
+
out_channels=block_in,
|
495 |
+
temb_channels=self.temb_ch,
|
496 |
+
dropout=dropout,
|
497 |
+
)
|
498 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
499 |
+
self.mid.block_2 = ResnetBlock(
|
500 |
+
in_channels=block_in,
|
501 |
+
out_channels=block_in,
|
502 |
+
temb_channels=self.temb_ch,
|
503 |
+
dropout=dropout,
|
504 |
+
)
|
505 |
+
|
506 |
+
# upsampling
|
507 |
+
self.up = nn.ModuleList()
|
508 |
+
for i_level in reversed(range(self.num_resolutions)):
|
509 |
+
block = nn.ModuleList()
|
510 |
+
attn = nn.ModuleList()
|
511 |
+
block_out = ch * ch_mult[i_level]
|
512 |
+
for i_block in range(self.num_res_blocks + 1):
|
513 |
+
block.append(
|
514 |
+
ResnetBlock(
|
515 |
+
in_channels=block_in,
|
516 |
+
out_channels=block_out,
|
517 |
+
temb_channels=self.temb_ch,
|
518 |
+
dropout=dropout,
|
519 |
+
)
|
520 |
+
)
|
521 |
+
block_in = block_out
|
522 |
+
if curr_res in attn_resolutions:
|
523 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
524 |
+
up = nn.Module()
|
525 |
+
up.block = block
|
526 |
+
up.attn = attn
|
527 |
+
if i_level != 0:
|
528 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
529 |
+
curr_res = curr_res * 2
|
530 |
+
self.up.insert(0, up) # prepend to get consistent order
|
531 |
+
|
532 |
+
# end
|
533 |
+
self.norm_out = Normalize(block_in)
|
534 |
+
self.conv_out = torch.nn.Conv2d(
|
535 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
536 |
+
)
|
537 |
+
|
538 |
+
def forward(self, z):
|
539 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
540 |
+
self.last_z_shape = z.shape
|
541 |
+
|
542 |
+
# timestep embedding
|
543 |
+
temb = None
|
544 |
+
|
545 |
+
# z to block_in
|
546 |
+
h = self.conv_in(z)
|
547 |
+
|
548 |
+
# middle
|
549 |
+
h = self.mid.block_1(h, temb)
|
550 |
+
h = self.mid.attn_1(h)
|
551 |
+
h = self.mid.block_2(h, temb)
|
552 |
+
|
553 |
+
# upsampling
|
554 |
+
for i_level in reversed(range(self.num_resolutions)):
|
555 |
+
for i_block in range(self.num_res_blocks + 1):
|
556 |
+
h = self.up[i_level].block[i_block](h, temb)
|
557 |
+
if len(self.up[i_level].attn) > 0:
|
558 |
+
h = self.up[i_level].attn[i_block](h)
|
559 |
+
if i_level != 0:
|
560 |
+
h = self.up[i_level].upsample(h)
|
561 |
+
|
562 |
+
# end
|
563 |
+
if self.give_pre_end:
|
564 |
+
return h
|
565 |
+
|
566 |
+
h = self.norm_out(h)
|
567 |
+
h = nonlinearity(h)
|
568 |
+
h = self.conv_out(h)
|
569 |
+
if self.tanh_out:
|
570 |
+
h = torch.tanh(h)
|
571 |
+
return h
|
572 |
+
|
573 |
+
|
574 |
+
class VQModel(nn.Module):
|
575 |
+
def __init__(
|
576 |
+
self,
|
577 |
+
ddconfig,
|
578 |
+
n_embed,
|
579 |
+
embed_dim,
|
580 |
+
ckpt_path=None,
|
581 |
+
ignore_keys=[],
|
582 |
+
image_key="image",
|
583 |
+
colorize_nlabels=None,
|
584 |
+
monitor=None,
|
585 |
+
scheduler_config=None,
|
586 |
+
lr_g_factor=1.0,
|
587 |
+
remap=None,
|
588 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
589 |
+
):
|
590 |
+
super().__init__()
|
591 |
+
self.image_key = image_key
|
592 |
+
self.encoder = Encoder(**ddconfig)
|
593 |
+
self.decoder = Decoder(**ddconfig)
|
594 |
+
self.quantize = VectorQuantizer(
|
595 |
+
n_embed,
|
596 |
+
embed_dim,
|
597 |
+
beta=0.25,
|
598 |
+
remap=remap,
|
599 |
+
sane_index_shape=sane_index_shape,
|
600 |
+
)
|
601 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
602 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
603 |
+
if ckpt_path is not None:
|
604 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
605 |
+
self.image_key = image_key
|
606 |
+
if colorize_nlabels is not None:
|
607 |
+
assert isinstance(colorize_nlabels, int)
|
608 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
609 |
+
if monitor is not None:
|
610 |
+
self.monitor = monitor
|
611 |
+
self.scheduler_config = scheduler_config
|
612 |
+
self.lr_g_factor = lr_g_factor
|
613 |
+
|
614 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
615 |
+
if path.startswith("http://") or path.startswith("https://"):
|
616 |
+
sd = get_ckpt(path)
|
617 |
+
else:
|
618 |
+
print(f"Loading checkpoint from local path: {path}")
|
619 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
620 |
+
|
621 |
+
keys = list(sd.keys())
|
622 |
+
for k in keys:
|
623 |
+
for ik in ignore_keys:
|
624 |
+
if k.startswith(ik):
|
625 |
+
print(f"Deleting key {k} from state_dict.")
|
626 |
+
del sd[k]
|
627 |
+
|
628 |
+
self.load_state_dict(sd, strict=False)
|
629 |
+
print(f"VQModel loaded from {path}")
|
630 |
+
|
631 |
+
def encode(self, x):
|
632 |
+
h = self.encoder(x)
|
633 |
+
h = self.quant_conv(h)
|
634 |
+
quant, emb_loss, info = self.quantize(h)
|
635 |
+
return quant, emb_loss, info
|
636 |
+
|
637 |
+
def decode(self, quant):
|
638 |
+
quant = self.post_quant_conv(quant)
|
639 |
+
dec = self.decoder(quant)
|
640 |
+
return dec
|
641 |
+
|
642 |
+
def decode_code(self, code_b):
|
643 |
+
quant_b = self.quantize.embed_code(code_b)
|
644 |
+
dec = self.decode(quant_b)
|
645 |
+
return dec
|
646 |
+
|
647 |
+
# def forward(self, input):
|
648 |
+
# quant, diff, _ = self.encode(input)
|
649 |
+
# dec = self.decode(quant)
|
650 |
+
# return dec, diff
|
651 |
+
|
652 |
+
def forward(self, input):
|
653 |
+
quant, diff, [_, _, img_toks] = self.encode(input)
|
654 |
+
|
655 |
+
batch_size, n_channel, height, width = (
|
656 |
+
input.shape[0],
|
657 |
+
quant.shape[-1],
|
658 |
+
quant.shape[-2],
|
659 |
+
quant.shape[-3],
|
660 |
+
)
|
661 |
+
codebook_entry = self.quantize.get_codebook_entry(
|
662 |
+
img_toks, (batch_size, n_channel, height, width)
|
663 |
+
)
|
664 |
+
pixels = self.decode(codebook_entry)
|
665 |
+
|
666 |
+
return pixels, img_toks, quant
|
667 |
+
|
668 |
+
def get_input(self, batch, k):
|
669 |
+
x = batch[k]
|
670 |
+
if len(x.shape) == 3:
|
671 |
+
x = x[..., None]
|
672 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
673 |
+
return x.float()
|
674 |
+
|
675 |
+
def get_last_layer(self):
|
676 |
+
return self.decoder.conv_out.weight
|
677 |
+
|
678 |
+
def log_images(self, batch, **kwargs):
|
679 |
+
log = dict()
|
680 |
+
x = self.get_input(batch, self.image_key)
|
681 |
+
x = x.to(self.device)
|
682 |
+
xrec, _ = self(x)
|
683 |
+
if x.shape[1] > 3:
|
684 |
+
# colorize with random projection
|
685 |
+
assert xrec.shape[1] > 3
|
686 |
+
x = self.to_rgb(x)
|
687 |
+
xrec = self.to_rgb(xrec)
|
688 |
+
log["inputs"] = x
|
689 |
+
log["reconstructions"] = xrec
|
690 |
+
return log
|
691 |
+
|
692 |
+
@property
|
693 |
+
def device(self):
|
694 |
+
return next(self.parameters()).device
|
695 |
+
|
696 |
+
@property
|
697 |
+
def dtype(self):
|
698 |
+
return next(self.parameters()).dtype
|
699 |
+
|
700 |
+
def to_rgb(self, x):
|
701 |
+
assert self.image_key == "segmentation"
|
702 |
+
if not hasattr(self, "colorize"):
|
703 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
704 |
+
x = F.conv2d(x, weight=self.colorize)
|
705 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
706 |
+
return x
|
src/vqvaes/bsqvit/attention_mask.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def get_attention_mask(sequence_length, device, mask_type="block-causal", **kwargs):
|
5 |
+
if mask_type.lower() == 'none' or mask_type is None:
|
6 |
+
return None
|
7 |
+
elif mask_type.lower() == 'block-causal':
|
8 |
+
return _block_caulsal_mask_impl(sequence_length, device, **kwargs)
|
9 |
+
elif mask_type.lower() == 'causal':
|
10 |
+
return _caulsal_mask_impl(sequence_length, device, **kwargs)
|
11 |
+
else:
|
12 |
+
raise NotImplementedError(f"Mask type {mask_type} not implemented")
|
13 |
+
|
14 |
+
|
15 |
+
def _block_caulsal_mask_impl(sequence_length, device, block_size=16, **kwargs):
|
16 |
+
"""
|
17 |
+
Create a block-causal mask
|
18 |
+
"""
|
19 |
+
assert sequence_length % block_size == 0, "for block causal masks sequence length must be divisible by block size"
|
20 |
+
blocks = torch.ones(sequence_length // block_size, block_size, block_size, device=device)
|
21 |
+
block_diag_enable_mask = torch.block_diag(*blocks)
|
22 |
+
causal_enable_mask = torch.ones(sequence_length, sequence_length, device=device).tril_(0)
|
23 |
+
disable_mask = ((block_diag_enable_mask + causal_enable_mask) < 0.5)
|
24 |
+
return disable_mask
|
25 |
+
|
26 |
+
|
27 |
+
def _caulsal_mask_impl(sequence_length, device, **kwargs):
|
28 |
+
"""
|
29 |
+
Create a causal mask
|
30 |
+
"""
|
31 |
+
causal_disable_mask = torch.triu(
|
32 |
+
torch.full((sequence_length, sequence_length), float('-inf'), dtype=torch.float32, device=device),
|
33 |
+
diagonal=1,
|
34 |
+
)
|
35 |
+
return causal_disable_mask
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
mask = get_attention_mask(9, "cuda", mask_type="block-causal", block_size=3)
|
40 |
+
print(mask)
|
41 |
+
mask = get_attention_mask(9, "cuda", mask_type="causal")
|
42 |
+
print(mask)
|
src/vqvaes/bsqvit/bsqvit.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .quantizer.bsq import BinarySphericalQuantizer
|
6 |
+
from .quantizer.vq import VectorQuantizer
|
7 |
+
from .transformer import TransformerDecoder, TransformerEncoder
|
8 |
+
|
9 |
+
|
10 |
+
class VITVQModel(nn.Module):
|
11 |
+
def __init__(self, vitconfig, n_embed, embed_dim,
|
12 |
+
l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[],
|
13 |
+
grad_checkpointing=False, selective_checkpointing=False,
|
14 |
+
clamp_range=(0, 1),
|
15 |
+
dvitconfig=None,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.encoder = TransformerEncoder(**vitconfig)
|
19 |
+
dvitconfig = vitconfig if dvitconfig is None else dvitconfig
|
20 |
+
self.decoder = TransformerDecoder(**dvitconfig, logit_laplace=logit_laplace)
|
21 |
+
if self.training and grad_checkpointing:
|
22 |
+
self.encoder.set_grad_checkpointing(True, selective=selective_checkpointing)
|
23 |
+
self.decoder.set_grad_checkpointing(True, selective=selective_checkpointing)
|
24 |
+
|
25 |
+
self.n_embed = n_embed
|
26 |
+
self.embed_dim = embed_dim
|
27 |
+
self.l2_norm = l2_norm
|
28 |
+
self.setup_quantizer()
|
29 |
+
|
30 |
+
self.quant_embed = nn.Linear(in_features=vitconfig['width'], out_features=embed_dim)
|
31 |
+
self.post_quant_embed = nn.Linear(in_features=embed_dim, out_features=dvitconfig['width'])
|
32 |
+
self.l2_norm = l2_norm
|
33 |
+
self.logit_laplace = logit_laplace
|
34 |
+
self.clamp_range = clamp_range
|
35 |
+
|
36 |
+
if ckpt_path is not None:
|
37 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
38 |
+
|
39 |
+
def setup_quantizer(self):
|
40 |
+
self.quantize = VectorQuantizer(self.n_embed, self.embed_dim, l2_norm=self.l2_norm, beta=0.25, input_format='blc')
|
41 |
+
|
42 |
+
# def init_from_ckpt(self, ckpt_path, ignore_keys=[]):
|
43 |
+
def init_from_ckpt(self, state_dict, ignore_keys=[]):
|
44 |
+
state_dict = {k[7:]: v for k, v in state_dict.items() if k.startswith('module.')}
|
45 |
+
filtered_state_dict = {k: v for k, v in state_dict.items() if all([not k.startswith(ig) for ig in ignore_keys])}
|
46 |
+
missing_keys, unexpected_keys = self.load_state_dict(filtered_state_dict, strict=False)
|
47 |
+
print(f"missing_keys: {missing_keys}")
|
48 |
+
print(f"unexpected_keys: {unexpected_keys}")
|
49 |
+
|
50 |
+
def encode(self, x, skip_quantize=False):
|
51 |
+
h = self.encoder(x)
|
52 |
+
h = self.quant_embed(h)
|
53 |
+
if skip_quantize:
|
54 |
+
assert not self.training, 'skip_quantize should be used in eval mode only.'
|
55 |
+
if self.l2_norm:
|
56 |
+
h = F.normalize(h, dim=-1)
|
57 |
+
return h, {}, {}
|
58 |
+
quant, loss, info = self.quantize(h)
|
59 |
+
return quant, loss, info
|
60 |
+
|
61 |
+
def decode(self, quant):
|
62 |
+
h = self.post_quant_embed(quant)
|
63 |
+
x = self.decoder(h)
|
64 |
+
return x
|
65 |
+
|
66 |
+
def clamp(self, x):
|
67 |
+
if self.logit_laplace:
|
68 |
+
dec, _ = x.chunk(2, dim=1)
|
69 |
+
x = self.logit_laplace_loss.unmap(F.sigmoid(dec))
|
70 |
+
else:
|
71 |
+
x = x.clamp_(self.clamp_range[0], self.clamp_range[1])
|
72 |
+
return x
|
73 |
+
|
74 |
+
def forward(self, input, skip_quantize=False):
|
75 |
+
if self.logit_laplace:
|
76 |
+
input = self.logit_laplace_loss.inmap(input)
|
77 |
+
quant, loss, info = self.encode(input, skip_quantize=skip_quantize)
|
78 |
+
dec = self.decode(quant)
|
79 |
+
if self.logit_laplace:
|
80 |
+
dec, lnb = dec.chunk(2, dim=1)
|
81 |
+
logit_laplace_loss = self.logit_laplace_loss(dec, lnb, input)
|
82 |
+
info.update({'logit_laplace_loss': logit_laplace_loss})
|
83 |
+
dec = self.logit_laplace_loss.unmap(F.sigmoid(dec))
|
84 |
+
else:
|
85 |
+
dec = dec.clamp_(self.clamp_range[0], self.clamp_range[1])
|
86 |
+
return dec, loss, info
|
87 |
+
|
88 |
+
def get_last_layer(self):
|
89 |
+
return self.decoder.conv_out.weight
|
90 |
+
|
91 |
+
|
92 |
+
class VITBSQModel(VITVQModel):
|
93 |
+
def __init__(self, vitconfig, embed_dim, embed_group_size=9,
|
94 |
+
l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[],
|
95 |
+
grad_checkpointing=False, selective_checkpointing=False,
|
96 |
+
clamp_range=(0, 1),
|
97 |
+
dvitconfig=None, beta=0., gamma0=1.0, gamma=1.0, zeta=1.0,
|
98 |
+
persample_entropy_compute='group',
|
99 |
+
cb_entropy_compute='group',
|
100 |
+
post_q_l2_norm=False,
|
101 |
+
inv_temperature=1.,
|
102 |
+
):
|
103 |
+
# set quantizer params
|
104 |
+
self.beta = beta # commit loss
|
105 |
+
self.gamma0 = gamma0 # entropy
|
106 |
+
self.gamma = gamma # entropy penalty
|
107 |
+
self.zeta = zeta # lpips
|
108 |
+
self.embed_group_size = embed_group_size
|
109 |
+
self.persample_entropy_compute = persample_entropy_compute
|
110 |
+
self.cb_entropy_compute = cb_entropy_compute
|
111 |
+
self.post_q_l2_norm = post_q_l2_norm
|
112 |
+
self.inv_temperature = inv_temperature
|
113 |
+
|
114 |
+
# call init
|
115 |
+
super().__init__(
|
116 |
+
vitconfig,
|
117 |
+
2 ** embed_dim,
|
118 |
+
embed_dim,
|
119 |
+
l2_norm=l2_norm,
|
120 |
+
logit_laplace=logit_laplace,
|
121 |
+
ckpt_path=ckpt_path,
|
122 |
+
ignore_keys=ignore_keys,
|
123 |
+
grad_checkpointing=grad_checkpointing,
|
124 |
+
selective_checkpointing=selective_checkpointing,
|
125 |
+
clamp_range=clamp_range,
|
126 |
+
dvitconfig=dvitconfig,
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
def setup_quantizer(self):
|
131 |
+
self.quantize = BinarySphericalQuantizer(
|
132 |
+
self.embed_dim, self.beta, self.gamma0, self.gamma, self.zeta,
|
133 |
+
group_size=self.embed_group_size,
|
134 |
+
persample_entropy_compute=self.persample_entropy_compute,
|
135 |
+
cb_entropy_compute=self.cb_entropy_compute,
|
136 |
+
input_format='blc',
|
137 |
+
l2_norm=self.post_q_l2_norm,
|
138 |
+
inv_temperature=self.inv_temperature,
|
139 |
+
)
|
140 |
+
|
141 |
+
def encode(self, x, skip_quantize=False):
|
142 |
+
h = self.encoder(x)
|
143 |
+
h = self.quant_embed(h)
|
144 |
+
if self.l2_norm:
|
145 |
+
h = F.normalize(h, dim=-1)
|
146 |
+
if skip_quantize:
|
147 |
+
assert not self.training, 'skip_quantize should be used in eval mode only.'
|
148 |
+
return h, {}, {}
|
149 |
+
quant, loss, info = self.quantize(h)
|
150 |
+
return quant, loss, info
|
src/vqvaes/bsqvit/quantizer/bsq.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops import rearrange, reduce
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.autograd import Function
|
5 |
+
|
6 |
+
|
7 |
+
class DifferentiableEntropyFunction(Function):
|
8 |
+
@staticmethod
|
9 |
+
def forward(ctx, zq, basis, K, eps):
|
10 |
+
zb = (zq + 1) / 2
|
11 |
+
zi = ((zb * basis).sum(-1)).to(torch.int64)
|
12 |
+
cnt = torch.scatter_reduce(torch.zeros(2**K, device=zq.device, dtype=zq.dtype),
|
13 |
+
0,
|
14 |
+
zi.flatten(),
|
15 |
+
torch.ones_like(zi.flatten()).to(zq.dtype),
|
16 |
+
'sum')
|
17 |
+
prob = (cnt + eps) / (cnt + eps).sum()
|
18 |
+
H = -(prob * torch.log(prob)).sum()
|
19 |
+
ctx.save_for_backward(zq, zi, prob)
|
20 |
+
ctx.K = K
|
21 |
+
return H
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def backward(ctx, grad_output):
|
25 |
+
zq, zi, prob= ctx.saved_tensors
|
26 |
+
grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
|
27 |
+
reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
|
28 |
+
grad_input = reord_grad.unsqueeze(-1) * zq
|
29 |
+
return grad_input, None, None, None, None
|
30 |
+
|
31 |
+
|
32 |
+
def codebook_entropy(zq, basis, K, eps=1e-4):
|
33 |
+
return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
|
34 |
+
|
35 |
+
|
36 |
+
class BinarySphericalQuantizer(nn.Module):
|
37 |
+
def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
|
38 |
+
input_format='bchw',
|
39 |
+
soft_entropy=True, group_size=9,
|
40 |
+
persample_entropy_compute='group',
|
41 |
+
cb_entropy_compute='group',
|
42 |
+
l2_norm=False,
|
43 |
+
inv_temperature=1):
|
44 |
+
super().__init__()
|
45 |
+
self.embed_dim = embed_dim
|
46 |
+
self.beta = beta # loss weight for commit loss
|
47 |
+
self.gamma0 = gamma0 # loss weight for entropy penalty
|
48 |
+
self.gamma = gamma # loss weight for entropy penalty
|
49 |
+
self.zeta = zeta # loss weight for entire entropy penalty
|
50 |
+
self.input_format = input_format
|
51 |
+
assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
|
52 |
+
self.num_groups = self.embed_dim // group_size
|
53 |
+
self.group_size = group_size
|
54 |
+
assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'"
|
55 |
+
assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'"
|
56 |
+
self.persample_entropy_compute = persample_entropy_compute
|
57 |
+
self.cb_entropy_compute = cb_entropy_compute
|
58 |
+
self.l2_norm = l2_norm
|
59 |
+
self.inv_temperature = inv_temperature
|
60 |
+
|
61 |
+
self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
|
62 |
+
self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))
|
63 |
+
|
64 |
+
self.num_dimensions = 2 ** embed_dim
|
65 |
+
self.bits_per_index = embed_dim
|
66 |
+
|
67 |
+
# we only need to keep the codebook portion up to the group size
|
68 |
+
# because we approximate the H loss with this subcode
|
69 |
+
group_codes = torch.arange(2 ** self.group_size)
|
70 |
+
group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
|
71 |
+
self.register_buffer('group_codebook', group_codebook, persistent=False)
|
72 |
+
|
73 |
+
self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
|
74 |
+
|
75 |
+
def quantize(self, z):
|
76 |
+
assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
|
77 |
+
|
78 |
+
zhat = torch.where(z > 0,
|
79 |
+
torch.tensor(1, dtype=z.dtype, device=z.device),
|
80 |
+
torch.tensor(-1, dtype=z.dtype, device=z.device))
|
81 |
+
return z + (zhat - z).detach()
|
82 |
+
|
83 |
+
def forward(self, z):
|
84 |
+
if self.input_format == 'bchw':
|
85 |
+
z = rearrange(z, 'b c h w -> b h w c')
|
86 |
+
zq = self.quantize(z)
|
87 |
+
|
88 |
+
indices = self.codes_to_indexes(zq.detach())
|
89 |
+
group_indices = self.codes_to_group_indexes(zq.detach())
|
90 |
+
if not self.training:
|
91 |
+
used_codes = torch.unique(indices, return_counts=False)
|
92 |
+
else:
|
93 |
+
used_codes = None
|
94 |
+
|
95 |
+
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
|
96 |
+
|
97 |
+
if self.soft_entropy:
|
98 |
+
persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
|
99 |
+
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
|
100 |
+
else:
|
101 |
+
zb_by_sample= ((zq + 1)/2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
|
102 |
+
persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
|
103 |
+
cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
|
104 |
+
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
|
105 |
+
|
106 |
+
zq = zq * q_scale
|
107 |
+
|
108 |
+
# commit loss
|
109 |
+
commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
|
110 |
+
|
111 |
+
if self.input_format == 'bchw':
|
112 |
+
zq = rearrange(zq, 'b h w c -> b c h w')
|
113 |
+
|
114 |
+
return (
|
115 |
+
zq,
|
116 |
+
commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
|
117 |
+
{"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices,
|
118 |
+
"avg_prob": avg_prob}
|
119 |
+
)
|
120 |
+
|
121 |
+
def soft_entropy_loss(self, z):
|
122 |
+
# if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
|
123 |
+
# the sub-code is the last group_size bits of the full code
|
124 |
+
group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
|
125 |
+
divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)
|
126 |
+
|
127 |
+
# we calculate the distance between the divided_z and the codebook for each subgroup
|
128 |
+
distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
|
129 |
+
prob = (-distance * self.inv_temperature).softmax(dim = -1)
|
130 |
+
if self.persample_entropy_compute == 'analytical':
|
131 |
+
if self.l2_norm:
|
132 |
+
p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
|
133 |
+
else:
|
134 |
+
p = torch.sigmoid(-4 * z * self.inv_temperature)
|
135 |
+
prob = torch.stack([p, 1-p], dim=-1)
|
136 |
+
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
|
137 |
+
else:
|
138 |
+
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
|
139 |
+
|
140 |
+
# macro average of the probability of each subgroup
|
141 |
+
avg_prob = reduce(prob, '... g d ->g d', 'mean')
|
142 |
+
codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
|
143 |
+
|
144 |
+
# the approximation of the entropy is the sum of the entropy of each subgroup
|
145 |
+
return per_sample_entropy, codebook_entropy.sum(), avg_prob
|
146 |
+
|
147 |
+
def get_hard_per_sample_entropy(self, zb_by_sample):
|
148 |
+
probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
|
149 |
+
persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
|
150 |
+
persample_entropy = persample_entropy.sum(-1)
|
151 |
+
return persample_entropy.mean()
|
152 |
+
|
153 |
+
def codes_to_indexes(self, zhat):
|
154 |
+
"""Converts a `code` to an index in the codebook.
|
155 |
+
Args:
|
156 |
+
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
|
157 |
+
"""
|
158 |
+
assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
|
159 |
+
return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
|
160 |
+
|
161 |
+
def codes_to_group_indexes(self, zhat):
|
162 |
+
"""Converts a `code` to a list of indexes (in groups) in the codebook.
|
163 |
+
Args:
|
164 |
+
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
|
165 |
+
"""
|
166 |
+
zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
|
167 |
+
return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
|
168 |
+
|
169 |
+
def indexes_to_codes(self, indices):
|
170 |
+
"""Inverse of `indexes_to_codes`."""
|
171 |
+
indices = indices.unsqueeze(-1)
|
172 |
+
codes_non_centered = torch.remainder(
|
173 |
+
torch.floor_divide(indices, self.basis), 2
|
174 |
+
)
|
175 |
+
return codes_non_centered * 2 - 1
|
176 |
+
|
177 |
+
def group_indexes_to_codes(self, group_indices):
|
178 |
+
"""Inverse of `group_indexes_to_codes`."""
|
179 |
+
group_indices = group_indices.unsqueeze(-1)
|
180 |
+
codes_non_centered = torch.remainder(
|
181 |
+
torch.floor_divide(group_indices, self.group_basis), 2
|
182 |
+
)
|
183 |
+
codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
|
184 |
+
return codes_non_centered * 2 - 1
|
185 |
+
|
186 |
+
def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
|
187 |
+
if normalize:
|
188 |
+
probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True)
|
189 |
+
else:
|
190 |
+
probs = count
|
191 |
+
H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
|
192 |
+
return H
|
193 |
+
|
194 |
+
def get_group_codebook_entry(self, group_indices):
|
195 |
+
z_q = self.group_indexes_to_codes(group_indices)
|
196 |
+
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
|
197 |
+
z_q = z_q * q_scale
|
198 |
+
if self.input_format == 'bchw':
|
199 |
+
h, w = int(z_q.shape[1] ** 0.5)
|
200 |
+
assert h * w == z_q.shape[1], 'Invalid sequence length'
|
201 |
+
z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
|
202 |
+
return z_q
|
203 |
+
|
204 |
+
def get_codebook_entry(self, indices):
|
205 |
+
z_q = self.indexes_to_codes(indices)
|
206 |
+
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
|
207 |
+
z_q = z_q * q_scale
|
208 |
+
if self.input_format == 'bchw':
|
209 |
+
h, w = int(z_q.shape[1] ** 0.5)
|
210 |
+
assert h * w == z_q.shape[1], 'Invalid sequence length'
|
211 |
+
z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
|
212 |
+
return z_q
|
213 |
+
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
K = 8
|
217 |
+
# zq = torch.randint(0, 2, (4, 32, K), dtype=torch.bfloat16, device='cuda') * 2 - 1
|
218 |
+
zq = torch.zeros((4, 32, K), dtype=torch.bfloat16, device='cuda') * 2 - 1
|
219 |
+
basis = (2 ** torch.arange(K - 1, -1, -1)).to(torch.bfloat16).cuda()
|
220 |
+
zq.requires_grad = True
|
221 |
+
h = codebook_entropy(zq, basis, K)
|
222 |
+
h.backward()
|
223 |
+
print(zq.grad, zq)
|
src/vqvaes/bsqvit/quantizer/vq.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops import rearrange
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class VectorQuantizer(nn.Module):
|
10 |
+
def __init__(self, n_embed, embed_dim, l2_norm, beta, input_format='bchw'):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.n_embed = n_embed
|
14 |
+
self.embed_dim = embed_dim
|
15 |
+
self.l2_norm = l2_norm
|
16 |
+
self.beta = beta
|
17 |
+
assert input_format in ['bchw', 'blc']
|
18 |
+
self.input_format = input_format
|
19 |
+
|
20 |
+
self.embedding = nn.Embedding(n_embed, embed_dim)
|
21 |
+
self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed)
|
22 |
+
self.bits_per_index = int(np.ceil(np.log2(n_embed)))
|
23 |
+
|
24 |
+
def forward(self, z):
|
25 |
+
batch = z.shape[0]
|
26 |
+
if self.input_format == 'bchw':
|
27 |
+
z = rearrange(z, 'b c h w -> b h w c')
|
28 |
+
|
29 |
+
if self.l2_norm:
|
30 |
+
z = F.normalize(z, dim=-1)
|
31 |
+
z_flatten = z.reshape(-1, self.embed_dim)
|
32 |
+
embedding_weight = F.normalize(self.embedding.weight, dim=-1)
|
33 |
+
d = -z_flatten @ embedding_weight.t()
|
34 |
+
else:
|
35 |
+
z_flatten = z.reshape(-1, self.embed_dim)
|
36 |
+
d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t()
|
37 |
+
|
38 |
+
min_encoding_indices = torch.argmin(d.detach(), dim=1)
|
39 |
+
if not self.training:
|
40 |
+
used_codes = torch.unique(min_encoding_indices, return_counts=False)
|
41 |
+
else:
|
42 |
+
used_codes = None
|
43 |
+
cb_usage = F.one_hot(min_encoding_indices, self.n_embed).sum(0)
|
44 |
+
cb_entropy = self.get_entropy(cb_usage)
|
45 |
+
|
46 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
47 |
+
if self.l2_norm:
|
48 |
+
z_q = F.normalize(z_q, dim=-1)
|
49 |
+
|
50 |
+
# fix the issue with loss scaling
|
51 |
+
# loss weight should not associate with the dimensionality of words
|
52 |
+
# loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
53 |
+
loss = self.beta * torch.mean(((z_q.detach() - z) ** 2).sum(dim=-1)) + torch.mean(((z_q - z.detach()) ** 2).sum(dim=-1))
|
54 |
+
|
55 |
+
z_q = z + (z_q - z).detach()
|
56 |
+
if self.input_format == 'bchw':
|
57 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
58 |
+
return z_q, loss, {"H":cb_entropy, "used_codes": used_codes, 'indices': min_encoding_indices.view(batch, -1)}
|
59 |
+
|
60 |
+
def get_entropy(self, count, eps=1e-4):
|
61 |
+
probs = (count + eps) / (count + eps).sum()
|
62 |
+
H = -(probs * torch.log(probs)).sum()
|
63 |
+
return H
|
64 |
+
|
65 |
+
|
66 |
+
def get_codebook_entry(self, indices):
|
67 |
+
z_q = self.embedding(indices)
|
68 |
+
if self.l2_norm:
|
69 |
+
z_q = F.normalize(z_q, dim=-1)
|
70 |
+
|
71 |
+
if self.input_format == 'bchw':
|
72 |
+
h = w = int(z_q.shape[1] ** 0.5)
|
73 |
+
assert h * w == z_q.shape[1], 'Invalid sequence length'
|
74 |
+
z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
|
75 |
+
return z_q
|
76 |
+
|
77 |
+
|
78 |
+
class EMAVectorQuantizer(nn.Module):
|
79 |
+
def __init__(self, n_embed, embed_dim, l2_norm, beta, decay=0.99, eps=1e-5, random_restart=True, restart_threshold=1.0, input_format='bchw'):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
self.n_embed = n_embed
|
83 |
+
self.embed_dim = embed_dim
|
84 |
+
self.l2_norm = l2_norm
|
85 |
+
self.beta = beta
|
86 |
+
self.decay = decay
|
87 |
+
self.eps = eps
|
88 |
+
self.random_restart = random_restart
|
89 |
+
self.restart_threshold = restart_threshold
|
90 |
+
self.input_format = input_format
|
91 |
+
|
92 |
+
self.embedding = nn.Embedding(n_embed, embed_dim)
|
93 |
+
self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed) # TODO (yzhao): test other initialization methods
|
94 |
+
self.register_buffer("ema_cluster_size", torch.zeros(self.n_embed))
|
95 |
+
self.embedding_avg = nn.Parameter(torch.Tensor(self.n_embed, self.embed_dim))
|
96 |
+
self.embedding_avg.data.copy_(self.embedding.weight.data)
|
97 |
+
|
98 |
+
def _tile(self, z):
|
99 |
+
n_z, embedding_dim = z.shape
|
100 |
+
if n_z < self.n_embed:
|
101 |
+
n_repeats = (self.n_embed + n_z - 1) // n_z
|
102 |
+
std = 0.01 / np.sqrt(embedding_dim)
|
103 |
+
z = z.repeat(n_repeats, 1)
|
104 |
+
z = z + torch.randn_like(z) * std
|
105 |
+
return z
|
106 |
+
|
107 |
+
def forward(self, z):
|
108 |
+
if self.input_format == 'bchw':
|
109 |
+
z = rearrange(z, 'b c h w -> b h w c')
|
110 |
+
z_flatten = z.reshape(-1, self.embed_dim)
|
111 |
+
|
112 |
+
d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t()
|
113 |
+
|
114 |
+
encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
115 |
+
encodings = torch.zeros(encoding_indices.size(0), self.n_embed, device=z.device)
|
116 |
+
encodings.scatter_(1, encoding_indices, 1)
|
117 |
+
|
118 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
119 |
+
if self.l2_norm:
|
120 |
+
z = F.normalize(z, dim=-1)
|
121 |
+
z_q = F.normalize(z_q, dim=-1)
|
122 |
+
|
123 |
+
if self.training:
|
124 |
+
# EMA update cluster size
|
125 |
+
encodings_sum = encodings.sum(0)
|
126 |
+
if dist.is_initialized(): dist.all_reduce(encodings_sum)
|
127 |
+
self.ema_cluster_size.data.mul_(self.decay).add_(encodings_sum, alpha=1-self.decay)
|
128 |
+
|
129 |
+
# EMA update of the embedding vectors
|
130 |
+
dw = encodings.t() @ z_flatten
|
131 |
+
if dist.is_initialized(): dist.all_reduce(dw)
|
132 |
+
self.embedding_avg.data.mul_(self.decay).add_(dw, alpha=1-self.decay)
|
133 |
+
|
134 |
+
# Laplace smoothing of the cluster size
|
135 |
+
n = torch.sum(self.ema_cluster_size)
|
136 |
+
weights = (self.ema_cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
137 |
+
self.embedding.weight.data = self.embedding_avg.data / weights.unsqueeze(1)
|
138 |
+
|
139 |
+
if self.random_restart:
|
140 |
+
zz = self._tile(z_flatten)
|
141 |
+
_k_rand = zz[torch.randperm(zz.size(0))][:self.n_embed]
|
142 |
+
if dist.is_initialized(): dist.broadcast(_k_rand, 0)
|
143 |
+
usage = (self.ema_cluster_size.view(-1, 1) > self.restart_threshold).float()
|
144 |
+
self.embedding.weight.data.mul_(usage).add_(_k_rand * (1 - usage))
|
145 |
+
|
146 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
|
147 |
+
|
148 |
+
z_q = z + (z_q - z).detach()
|
149 |
+
if self.input_format == 'bchw':
|
150 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
151 |
+
# TODO (yzhao): monitor utility of the dictionary
|
152 |
+
return z_q, loss, {}
|
src/vqvaes/bsqvit/stylegan_utils/custom_ops.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import torch
|
12 |
+
import torch.utils.cpp_extension
|
13 |
+
import importlib
|
14 |
+
import hashlib
|
15 |
+
import shutil
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
from torch.utils.file_baton import FileBaton
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
# Global options.
|
22 |
+
|
23 |
+
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
# Internal helper funcs.
|
27 |
+
|
28 |
+
def _find_compiler_bindir():
|
29 |
+
patterns = [
|
30 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
31 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
32 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
33 |
+
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
34 |
+
]
|
35 |
+
for pattern in patterns:
|
36 |
+
matches = sorted(glob.glob(pattern))
|
37 |
+
if len(matches):
|
38 |
+
return matches[-1]
|
39 |
+
return None
|
40 |
+
|
41 |
+
#----------------------------------------------------------------------------
|
42 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
43 |
+
|
44 |
+
_cached_plugins = dict()
|
45 |
+
|
46 |
+
def get_plugin(module_name, sources, **build_kwargs):
|
47 |
+
assert verbosity in ['none', 'brief', 'full']
|
48 |
+
|
49 |
+
# Already cached?
|
50 |
+
if module_name in _cached_plugins:
|
51 |
+
return _cached_plugins[module_name]
|
52 |
+
|
53 |
+
# Print status.
|
54 |
+
if verbosity == 'full':
|
55 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
56 |
+
elif verbosity == 'brief':
|
57 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
58 |
+
|
59 |
+
try: # pylint: disable=too-many-nested-blocks
|
60 |
+
# Make sure we can find the necessary compiler binaries.
|
61 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
62 |
+
compiler_bindir = _find_compiler_bindir()
|
63 |
+
if compiler_bindir is None:
|
64 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
65 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
66 |
+
|
67 |
+
# Compile and load.
|
68 |
+
verbose_build = (verbosity == 'full')
|
69 |
+
|
70 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
71 |
+
# into a cached build directory under a combined md5 digest of the input
|
72 |
+
# source files. Copying is done only if the combined digest has changed.
|
73 |
+
# This keeps input file timestamps and filenames the same as in previous
|
74 |
+
# extension builds, allowing for fast incremental rebuilds.
|
75 |
+
#
|
76 |
+
# This optimization is done only in case all the source files reside in
|
77 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
78 |
+
# environment variable is set (we take this as a signal that the user
|
79 |
+
# actually cares about this.)
|
80 |
+
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
81 |
+
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
82 |
+
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
83 |
+
|
84 |
+
# Compute a combined hash digest for all source files in the same
|
85 |
+
# custom op directory (usually .cu, .cpp, .py and .h files).
|
86 |
+
hash_md5 = hashlib.md5()
|
87 |
+
for src in all_source_files:
|
88 |
+
with open(src, 'rb') as f:
|
89 |
+
hash_md5.update(f.read())
|
90 |
+
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
91 |
+
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
92 |
+
|
93 |
+
if not os.path.isdir(digest_build_dir):
|
94 |
+
os.makedirs(digest_build_dir, exist_ok=True)
|
95 |
+
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
96 |
+
if baton.try_acquire():
|
97 |
+
try:
|
98 |
+
for src in all_source_files:
|
99 |
+
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
100 |
+
finally:
|
101 |
+
baton.release()
|
102 |
+
else:
|
103 |
+
# Someone else is copying source files under the digest dir,
|
104 |
+
# wait until done and continue.
|
105 |
+
baton.wait()
|
106 |
+
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
107 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
108 |
+
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
109 |
+
else:
|
110 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
111 |
+
module = importlib.import_module(module_name)
|
112 |
+
|
113 |
+
except:
|
114 |
+
if verbosity == 'brief':
|
115 |
+
print('Failed!')
|
116 |
+
raise
|
117 |
+
|
118 |
+
# Print status and add to cache.
|
119 |
+
if verbosity == 'full':
|
120 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
121 |
+
elif verbosity == 'brief':
|
122 |
+
print('Done.')
|
123 |
+
_cached_plugins[module_name] = module
|
124 |
+
return module
|
125 |
+
|
126 |
+
#----------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/misc.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
|
5 |
+
#----------------------------------------------------------------------------
|
6 |
+
# Symbolic assert.
|
7 |
+
|
8 |
+
try:
|
9 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
10 |
+
except AttributeError:
|
11 |
+
symbolic_assert = torch.Assert # 1.7.0
|
12 |
+
|
13 |
+
#----------------------------------------------------------------------------
|
14 |
+
# Context manager to suppress known warnings in torch.jit.trace().
|
15 |
+
|
16 |
+
class suppress_tracer_warnings(warnings.catch_warnings):
|
17 |
+
def __enter__(self):
|
18 |
+
super().__enter__()
|
19 |
+
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
|
20 |
+
return self
|
21 |
+
|
22 |
+
#----------------------------------------------------------------------------
|
23 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
24 |
+
# None indicates that the size of a dimension is allowed to vary.
|
25 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
26 |
+
|
27 |
+
def assert_shape(tensor, ref_shape):
|
28 |
+
if tensor.ndim != len(ref_shape):
|
29 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
30 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
31 |
+
if ref_size is None:
|
32 |
+
pass
|
33 |
+
elif isinstance(ref_size, torch.Tensor):
|
34 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
35 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
36 |
+
elif isinstance(size, torch.Tensor):
|
37 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
38 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
39 |
+
elif size != ref_size:
|
40 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cpp
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "bias_act.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
17 |
+
{
|
18 |
+
if (x.dim() != y.dim())
|
19 |
+
return false;
|
20 |
+
for (int64_t i = 0; i < x.dim(); i++)
|
21 |
+
{
|
22 |
+
if (x.size(i) != y.size(i))
|
23 |
+
return false;
|
24 |
+
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
25 |
+
return false;
|
26 |
+
}
|
27 |
+
return true;
|
28 |
+
}
|
29 |
+
|
30 |
+
//------------------------------------------------------------------------
|
31 |
+
|
32 |
+
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
33 |
+
{
|
34 |
+
// Validate arguments.
|
35 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
36 |
+
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
37 |
+
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
38 |
+
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
39 |
+
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
40 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
41 |
+
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
42 |
+
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
43 |
+
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
44 |
+
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
45 |
+
|
46 |
+
// Validate layout.
|
47 |
+
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
48 |
+
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
49 |
+
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
50 |
+
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
51 |
+
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
52 |
+
|
53 |
+
// Create output tensor.
|
54 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
55 |
+
torch::Tensor y = torch::empty_like(x);
|
56 |
+
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
57 |
+
|
58 |
+
// Initialize CUDA kernel parameters.
|
59 |
+
bias_act_kernel_params p;
|
60 |
+
p.x = x.data_ptr();
|
61 |
+
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
62 |
+
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
63 |
+
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
64 |
+
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
65 |
+
p.y = y.data_ptr();
|
66 |
+
p.grad = grad;
|
67 |
+
p.act = act;
|
68 |
+
p.alpha = alpha;
|
69 |
+
p.gain = gain;
|
70 |
+
p.clamp = clamp;
|
71 |
+
p.sizeX = (int)x.numel();
|
72 |
+
p.sizeB = (int)b.numel();
|
73 |
+
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
74 |
+
|
75 |
+
// Choose CUDA kernel.
|
76 |
+
void* kernel;
|
77 |
+
AT_DISPATCH_REDUCED_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_cuda", [&]
|
78 |
+
{
|
79 |
+
kernel = choose_bias_act_kernel<scalar_t>(p);
|
80 |
+
});
|
81 |
+
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
82 |
+
|
83 |
+
// Launch CUDA kernel.
|
84 |
+
p.loopX = 4;
|
85 |
+
int blockSize = 4 * 32;
|
86 |
+
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
87 |
+
void* args[] = {&p};
|
88 |
+
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
89 |
+
return y;
|
90 |
+
}
|
91 |
+
|
92 |
+
//------------------------------------------------------------------------
|
93 |
+
|
94 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
95 |
+
{
|
96 |
+
m.def("bias_act", &bias_act);
|
97 |
+
}
|
98 |
+
|
99 |
+
//------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cu
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <c10/util/Half.h>
|
11 |
+
#include "bias_act.h"
|
12 |
+
|
13 |
+
//------------------------------------------------------------------------
|
14 |
+
// Helpers.
|
15 |
+
|
16 |
+
template <class T> struct InternalType;
|
17 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
18 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
19 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
20 |
+
template <> struct InternalType<at::BFloat16> { typedef float scalar_t; };
|
21 |
+
|
22 |
+
//------------------------------------------------------------------------
|
23 |
+
// CUDA kernel.
|
24 |
+
|
25 |
+
template <class T, int A>
|
26 |
+
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
27 |
+
{
|
28 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
29 |
+
int G = p.grad;
|
30 |
+
scalar_t alpha = (scalar_t)p.alpha;
|
31 |
+
scalar_t gain = (scalar_t)p.gain;
|
32 |
+
scalar_t clamp = (scalar_t)p.clamp;
|
33 |
+
scalar_t one = (scalar_t)1;
|
34 |
+
scalar_t two = (scalar_t)2;
|
35 |
+
scalar_t expRange = (scalar_t)80;
|
36 |
+
scalar_t halfExpRange = (scalar_t)40;
|
37 |
+
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
38 |
+
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
39 |
+
|
40 |
+
// Loop over elements.
|
41 |
+
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
42 |
+
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
43 |
+
{
|
44 |
+
// Load.
|
45 |
+
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
46 |
+
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
47 |
+
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
48 |
+
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
49 |
+
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
50 |
+
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
51 |
+
scalar_t y = 0;
|
52 |
+
|
53 |
+
// Apply bias.
|
54 |
+
((G == 0) ? x : xref) += b;
|
55 |
+
|
56 |
+
// linear
|
57 |
+
if (A == 1)
|
58 |
+
{
|
59 |
+
if (G == 0) y = x;
|
60 |
+
if (G == 1) y = x;
|
61 |
+
}
|
62 |
+
|
63 |
+
// relu
|
64 |
+
if (A == 2)
|
65 |
+
{
|
66 |
+
if (G == 0) y = (x > 0) ? x : 0;
|
67 |
+
if (G == 1) y = (yy > 0) ? x : 0;
|
68 |
+
}
|
69 |
+
|
70 |
+
// lrelu
|
71 |
+
if (A == 3)
|
72 |
+
{
|
73 |
+
if (G == 0) y = (x > 0) ? x : x * alpha;
|
74 |
+
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
75 |
+
}
|
76 |
+
|
77 |
+
// tanh
|
78 |
+
if (A == 4)
|
79 |
+
{
|
80 |
+
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
81 |
+
if (G == 1) y = x * (one - yy * yy);
|
82 |
+
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
83 |
+
}
|
84 |
+
|
85 |
+
// sigmoid
|
86 |
+
if (A == 5)
|
87 |
+
{
|
88 |
+
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
89 |
+
if (G == 1) y = x * yy * (one - yy);
|
90 |
+
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
91 |
+
}
|
92 |
+
|
93 |
+
// elu
|
94 |
+
if (A == 6)
|
95 |
+
{
|
96 |
+
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
97 |
+
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
98 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
99 |
+
}
|
100 |
+
|
101 |
+
// selu
|
102 |
+
if (A == 7)
|
103 |
+
{
|
104 |
+
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
105 |
+
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
106 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
107 |
+
}
|
108 |
+
|
109 |
+
// softplus
|
110 |
+
if (A == 8)
|
111 |
+
{
|
112 |
+
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
113 |
+
if (G == 1) y = x * (one - exp(-yy));
|
114 |
+
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
115 |
+
}
|
116 |
+
|
117 |
+
// swish
|
118 |
+
if (A == 9)
|
119 |
+
{
|
120 |
+
if (G == 0)
|
121 |
+
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
122 |
+
else
|
123 |
+
{
|
124 |
+
scalar_t c = exp(xref);
|
125 |
+
scalar_t d = c + one;
|
126 |
+
if (G == 1)
|
127 |
+
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
128 |
+
else
|
129 |
+
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
130 |
+
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
// Apply gain.
|
135 |
+
y *= gain * dy;
|
136 |
+
|
137 |
+
// Clamp.
|
138 |
+
if (clamp >= 0)
|
139 |
+
{
|
140 |
+
if (G == 0)
|
141 |
+
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
142 |
+
else
|
143 |
+
y = (yref > -clamp & yref < clamp) ? y : 0;
|
144 |
+
}
|
145 |
+
|
146 |
+
// Store.
|
147 |
+
((T*)p.y)[xi] = (T)y;
|
148 |
+
}
|
149 |
+
}
|
150 |
+
|
151 |
+
//------------------------------------------------------------------------
|
152 |
+
// CUDA kernel selection.
|
153 |
+
|
154 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
155 |
+
{
|
156 |
+
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
157 |
+
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
158 |
+
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
159 |
+
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
160 |
+
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
161 |
+
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
162 |
+
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
163 |
+
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
164 |
+
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
165 |
+
return NULL;
|
166 |
+
}
|
167 |
+
|
168 |
+
//------------------------------------------------------------------------
|
169 |
+
// Template specializations.
|
170 |
+
|
171 |
+
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
172 |
+
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
173 |
+
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
174 |
+
template void* choose_bias_act_kernel<at::BFloat16> (const bias_act_kernel_params& p);
|
175 |
+
|
176 |
+
//------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
//------------------------------------------------------------------------
|
10 |
+
// CUDA kernel parameters.
|
11 |
+
|
12 |
+
struct bias_act_kernel_params
|
13 |
+
{
|
14 |
+
const void* x; // [sizeX]
|
15 |
+
const void* b; // [sizeB] or NULL
|
16 |
+
const void* xref; // [sizeX] or NULL
|
17 |
+
const void* yref; // [sizeX] or NULL
|
18 |
+
const void* dy; // [sizeX] or NULL
|
19 |
+
void* y; // [sizeX]
|
20 |
+
|
21 |
+
int grad;
|
22 |
+
int act;
|
23 |
+
float alpha;
|
24 |
+
float gain;
|
25 |
+
float clamp;
|
26 |
+
|
27 |
+
int sizeX;
|
28 |
+
int sizeB;
|
29 |
+
int stepB;
|
30 |
+
int loopX;
|
31 |
+
};
|
32 |
+
|
33 |
+
//------------------------------------------------------------------------
|
34 |
+
// CUDA kernel selection.
|
35 |
+
|
36 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
37 |
+
|
38 |
+
//------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient bias and activation."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import traceback
|
16 |
+
from typing import Any
|
17 |
+
|
18 |
+
from .. import custom_ops
|
19 |
+
|
20 |
+
|
21 |
+
class EasyDict(dict):
|
22 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
23 |
+
|
24 |
+
def __getattr__(self, name: str) -> Any:
|
25 |
+
try:
|
26 |
+
return self[name]
|
27 |
+
except KeyError:
|
28 |
+
raise AttributeError(name)
|
29 |
+
|
30 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
31 |
+
self[name] = value
|
32 |
+
|
33 |
+
def __delattr__(self, name: str) -> None:
|
34 |
+
del self[name]
|
35 |
+
|
36 |
+
#----------------------------------------------------------------------------
|
37 |
+
|
38 |
+
activation_funcs = {
|
39 |
+
'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
40 |
+
'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
41 |
+
'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
42 |
+
'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
43 |
+
'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
44 |
+
'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
45 |
+
'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
46 |
+
'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
47 |
+
'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
48 |
+
}
|
49 |
+
|
50 |
+
#----------------------------------------------------------------------------
|
51 |
+
|
52 |
+
_inited = False
|
53 |
+
_plugin = None
|
54 |
+
_null_tensor = torch.empty([0])
|
55 |
+
|
56 |
+
def _init():
|
57 |
+
global _inited, _plugin
|
58 |
+
if not _inited:
|
59 |
+
_inited = True
|
60 |
+
sources = ['bias_act.cpp', 'bias_act.cu']
|
61 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
62 |
+
try:
|
63 |
+
_plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
64 |
+
except:
|
65 |
+
warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
66 |
+
return _plugin is not None
|
67 |
+
|
68 |
+
#----------------------------------------------------------------------------
|
69 |
+
|
70 |
+
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
71 |
+
r"""Fused bias and activation function.
|
72 |
+
|
73 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
74 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
75 |
+
the fused op is considerably more efficient than performing the same calculation
|
76 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
77 |
+
but not third order gradients.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
x: Input activation tensor. Can be of any shape.
|
81 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
82 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
83 |
+
corresponding to `dim`.
|
84 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
85 |
+
The value of `dim` is ignored if `b` is not specified.
|
86 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
87 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
88 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
89 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
90 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
91 |
+
See `activation_funcs` for the default scaling of each activation function.
|
92 |
+
If unsure, consider specifying 1.
|
93 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
94 |
+
the clamping (default).
|
95 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Tensor of the same shape and datatype as `x`.
|
99 |
+
"""
|
100 |
+
assert isinstance(x, torch.Tensor)
|
101 |
+
assert impl in ['ref', 'cuda']
|
102 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
103 |
+
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
104 |
+
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
105 |
+
|
106 |
+
#----------------------------------------------------------------------------
|
107 |
+
|
108 |
+
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
109 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
110 |
+
"""
|
111 |
+
assert isinstance(x, torch.Tensor)
|
112 |
+
assert clamp is None or clamp >= 0
|
113 |
+
spec = activation_funcs[act]
|
114 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
115 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
116 |
+
clamp = float(clamp if clamp is not None else -1)
|
117 |
+
|
118 |
+
# Add bias.
|
119 |
+
if b is not None:
|
120 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
121 |
+
assert 0 <= dim < x.ndim
|
122 |
+
assert b.shape[0] == x.shape[dim]
|
123 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
124 |
+
|
125 |
+
# Evaluate activation function.
|
126 |
+
alpha = float(alpha)
|
127 |
+
x = spec.func(x, alpha=alpha)
|
128 |
+
|
129 |
+
# Scale by gain.
|
130 |
+
gain = float(gain)
|
131 |
+
if gain != 1:
|
132 |
+
x = x * gain
|
133 |
+
|
134 |
+
# Clamp.
|
135 |
+
if clamp >= 0:
|
136 |
+
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
137 |
+
return x
|
138 |
+
|
139 |
+
#----------------------------------------------------------------------------
|
140 |
+
|
141 |
+
_bias_act_cuda_cache = dict()
|
142 |
+
|
143 |
+
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
144 |
+
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
145 |
+
"""
|
146 |
+
# Parse arguments.
|
147 |
+
assert clamp is None or clamp >= 0
|
148 |
+
spec = activation_funcs[act]
|
149 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
150 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
151 |
+
clamp = float(clamp if clamp is not None else -1)
|
152 |
+
|
153 |
+
# Lookup from cache.
|
154 |
+
key = (dim, act, alpha, gain, clamp)
|
155 |
+
if key in _bias_act_cuda_cache:
|
156 |
+
return _bias_act_cuda_cache[key]
|
157 |
+
|
158 |
+
# Forward op.
|
159 |
+
class BiasActCuda(torch.autograd.Function):
|
160 |
+
@staticmethod
|
161 |
+
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
162 |
+
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
|
163 |
+
x = x.contiguous(memory_format=ctx.memory_format)
|
164 |
+
b = b.contiguous() if b is not None else _null_tensor
|
165 |
+
y = x
|
166 |
+
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
167 |
+
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
168 |
+
ctx.save_for_backward(
|
169 |
+
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
170 |
+
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
171 |
+
y if 'y' in spec.ref else _null_tensor)
|
172 |
+
return y
|
173 |
+
|
174 |
+
@staticmethod
|
175 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
176 |
+
dy = dy.contiguous(memory_format=ctx.memory_format)
|
177 |
+
x, b, y = ctx.saved_tensors
|
178 |
+
dx = None
|
179 |
+
db = None
|
180 |
+
|
181 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
182 |
+
dx = dy
|
183 |
+
if act != 'linear' or gain != 1 or clamp >= 0:
|
184 |
+
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
185 |
+
|
186 |
+
if ctx.needs_input_grad[1]:
|
187 |
+
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
188 |
+
|
189 |
+
return dx, db
|
190 |
+
|
191 |
+
# Backward op.
|
192 |
+
class BiasActCudaGrad(torch.autograd.Function):
|
193 |
+
@staticmethod
|
194 |
+
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
195 |
+
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
|
196 |
+
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
197 |
+
ctx.save_for_backward(
|
198 |
+
dy if spec.has_2nd_grad else _null_tensor,
|
199 |
+
x, b, y)
|
200 |
+
return dx
|
201 |
+
|
202 |
+
@staticmethod
|
203 |
+
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
204 |
+
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
205 |
+
dy, x, b, y = ctx.saved_tensors
|
206 |
+
d_dy = None
|
207 |
+
d_x = None
|
208 |
+
d_b = None
|
209 |
+
d_y = None
|
210 |
+
|
211 |
+
if ctx.needs_input_grad[0]:
|
212 |
+
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
213 |
+
|
214 |
+
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
215 |
+
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
216 |
+
|
217 |
+
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
218 |
+
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
219 |
+
|
220 |
+
return d_dy, d_x, d_b, d_y
|
221 |
+
|
222 |
+
# Add to cache.
|
223 |
+
_bias_act_cuda_cache[key] = BiasActCuda
|
224 |
+
return BiasActCuda
|
225 |
+
|
226 |
+
#----------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_gradfix.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
10 |
+
arbitrarily high order gradients with zero performance penalty."""
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
import contextlib
|
14 |
+
import torch
|
15 |
+
|
16 |
+
# pylint: disable=redefined-builtin
|
17 |
+
# pylint: disable=arguments-differ
|
18 |
+
# pylint: disable=protected-access
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
enabled = False # Enable the custom op by setting this to true.
|
23 |
+
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
24 |
+
|
25 |
+
@contextlib.contextmanager
|
26 |
+
def no_weight_gradients():
|
27 |
+
global weight_gradients_disabled
|
28 |
+
old = weight_gradients_disabled
|
29 |
+
weight_gradients_disabled = True
|
30 |
+
yield
|
31 |
+
weight_gradients_disabled = old
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
36 |
+
if _should_use_custom_op(input):
|
37 |
+
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
38 |
+
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
39 |
+
|
40 |
+
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
41 |
+
if _should_use_custom_op(input):
|
42 |
+
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
43 |
+
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
44 |
+
|
45 |
+
#----------------------------------------------------------------------------
|
46 |
+
|
47 |
+
def _should_use_custom_op(input):
|
48 |
+
assert isinstance(input, torch.Tensor)
|
49 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
50 |
+
return False
|
51 |
+
if input.device.type != 'cuda':
|
52 |
+
return False
|
53 |
+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
54 |
+
return True
|
55 |
+
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
|
56 |
+
return False
|
57 |
+
|
58 |
+
def _tuple_of_ints(xs, ndim):
|
59 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
60 |
+
assert len(xs) == ndim
|
61 |
+
assert all(isinstance(x, int) for x in xs)
|
62 |
+
return xs
|
63 |
+
|
64 |
+
#----------------------------------------------------------------------------
|
65 |
+
|
66 |
+
_conv2d_gradfix_cache = dict()
|
67 |
+
|
68 |
+
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
69 |
+
# Parse arguments.
|
70 |
+
ndim = 2
|
71 |
+
weight_shape = tuple(weight_shape)
|
72 |
+
stride = _tuple_of_ints(stride, ndim)
|
73 |
+
padding = _tuple_of_ints(padding, ndim)
|
74 |
+
output_padding = _tuple_of_ints(output_padding, ndim)
|
75 |
+
dilation = _tuple_of_ints(dilation, ndim)
|
76 |
+
|
77 |
+
# Lookup from cache.
|
78 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
79 |
+
if key in _conv2d_gradfix_cache:
|
80 |
+
return _conv2d_gradfix_cache[key]
|
81 |
+
|
82 |
+
# Validate arguments.
|
83 |
+
assert groups >= 1
|
84 |
+
assert len(weight_shape) == ndim + 2
|
85 |
+
assert all(stride[i] >= 1 for i in range(ndim))
|
86 |
+
assert all(padding[i] >= 0 for i in range(ndim))
|
87 |
+
assert all(dilation[i] >= 0 for i in range(ndim))
|
88 |
+
if not transpose:
|
89 |
+
assert all(output_padding[i] == 0 for i in range(ndim))
|
90 |
+
else: # transpose
|
91 |
+
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
92 |
+
|
93 |
+
# Helpers.
|
94 |
+
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
95 |
+
def calc_output_padding(input_shape, output_shape):
|
96 |
+
if transpose:
|
97 |
+
return [0, 0]
|
98 |
+
return [
|
99 |
+
input_shape[i + 2]
|
100 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
101 |
+
- (1 - 2 * padding[i])
|
102 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
103 |
+
for i in range(ndim)
|
104 |
+
]
|
105 |
+
|
106 |
+
# Forward & backward.
|
107 |
+
class Conv2d(torch.autograd.Function):
|
108 |
+
@staticmethod
|
109 |
+
def forward(ctx, input, weight, bias):
|
110 |
+
assert weight.shape == weight_shape
|
111 |
+
if not transpose:
|
112 |
+
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
113 |
+
else: # transpose
|
114 |
+
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
115 |
+
ctx.save_for_backward(input, weight)
|
116 |
+
return output
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def backward(ctx, grad_output):
|
120 |
+
input, weight = ctx.saved_tensors
|
121 |
+
grad_input = None
|
122 |
+
grad_weight = None
|
123 |
+
grad_bias = None
|
124 |
+
|
125 |
+
if ctx.needs_input_grad[0]:
|
126 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
127 |
+
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
|
128 |
+
assert grad_input.shape == input.shape
|
129 |
+
|
130 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
131 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
132 |
+
assert grad_weight.shape == weight_shape
|
133 |
+
|
134 |
+
if ctx.needs_input_grad[2]:
|
135 |
+
grad_bias = grad_output.sum([0, 2, 3])
|
136 |
+
|
137 |
+
return grad_input, grad_weight, grad_bias
|
138 |
+
|
139 |
+
# Gradient with respect to the weights.
|
140 |
+
class Conv2dGradWeight(torch.autograd.Function):
|
141 |
+
@staticmethod
|
142 |
+
def forward(ctx, grad_output, input):
|
143 |
+
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
|
144 |
+
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
145 |
+
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
146 |
+
assert grad_weight.shape == weight_shape
|
147 |
+
ctx.save_for_backward(grad_output, input)
|
148 |
+
return grad_weight
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def backward(ctx, grad2_grad_weight):
|
152 |
+
grad_output, input = ctx.saved_tensors
|
153 |
+
grad2_grad_output = None
|
154 |
+
grad2_input = None
|
155 |
+
|
156 |
+
if ctx.needs_input_grad[0]:
|
157 |
+
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
158 |
+
assert grad2_grad_output.shape == grad_output.shape
|
159 |
+
|
160 |
+
if ctx.needs_input_grad[1]:
|
161 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
162 |
+
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
|
163 |
+
assert grad2_input.shape == input.shape
|
164 |
+
|
165 |
+
return grad2_grad_output, grad2_input
|
166 |
+
|
167 |
+
_conv2d_gradfix_cache[key] = Conv2d
|
168 |
+
return Conv2d
|
169 |
+
|
170 |
+
#----------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_resample.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""2D convolution with optional up/downsampling."""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .. import misc
|
14 |
+
from . import conv2d_gradfix
|
15 |
+
from . import upfirdn2d
|
16 |
+
from .upfirdn2d import _parse_padding
|
17 |
+
from .upfirdn2d import _get_filter_size
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
def _get_weight_shape(w):
|
22 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
23 |
+
shape = [int(sz) for sz in w.shape]
|
24 |
+
misc.assert_shape(w, shape)
|
25 |
+
return shape
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
30 |
+
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
31 |
+
"""
|
32 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
33 |
+
|
34 |
+
# Flip weight if requested.
|
35 |
+
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
36 |
+
w = w.flip([2, 3])
|
37 |
+
|
38 |
+
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
39 |
+
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
40 |
+
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
|
41 |
+
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
42 |
+
if out_channels <= 4 and groups == 1:
|
43 |
+
in_shape = x.shape
|
44 |
+
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
|
45 |
+
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
46 |
+
else:
|
47 |
+
x = x.to(memory_format=torch.contiguous_format)
|
48 |
+
w = w.to(memory_format=torch.contiguous_format)
|
49 |
+
x = conv2d_gradfix.conv2d(x, w, groups=groups)
|
50 |
+
return x.to(memory_format=torch.channels_last)
|
51 |
+
|
52 |
+
# Otherwise => execute using conv2d_gradfix.
|
53 |
+
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
54 |
+
return op(x, w, stride=stride, padding=padding, groups=groups)
|
55 |
+
|
56 |
+
#----------------------------------------------------------------------------
|
57 |
+
|
58 |
+
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
59 |
+
r"""2D convolution with optional up/downsampling.
|
60 |
+
|
61 |
+
Padding is performed only once at the beginning, not between the operations.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
x: Input tensor of shape
|
65 |
+
`[batch_size, in_channels, in_height, in_width]`.
|
66 |
+
w: Weight tensor of shape
|
67 |
+
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
68 |
+
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
69 |
+
calling upfirdn2d.setup_filter(). None = identity (default).
|
70 |
+
up: Integer upsampling factor (default: 1).
|
71 |
+
down: Integer downsampling factor (default: 1).
|
72 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
73 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
74 |
+
(default: 0).
|
75 |
+
groups: Split input channels into N groups (default: 1).
|
76 |
+
flip_weight: False = convolution, True = correlation (default: True).
|
77 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
81 |
+
"""
|
82 |
+
# Validate arguments.
|
83 |
+
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
84 |
+
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
85 |
+
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
86 |
+
assert isinstance(up, int) and (up >= 1)
|
87 |
+
assert isinstance(down, int) and (down >= 1)
|
88 |
+
assert isinstance(groups, int) and (groups >= 1)
|
89 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
90 |
+
fw, fh = _get_filter_size(f)
|
91 |
+
px0, px1, py0, py1 = _parse_padding(padding)
|
92 |
+
|
93 |
+
# Adjust padding to account for up/downsampling.
|
94 |
+
if up > 1:
|
95 |
+
px0 += (fw + up - 1) // 2
|
96 |
+
px1 += (fw - up) // 2
|
97 |
+
py0 += (fh + up - 1) // 2
|
98 |
+
py1 += (fh - up) // 2
|
99 |
+
if down > 1:
|
100 |
+
px0 += (fw - down + 1) // 2
|
101 |
+
px1 += (fw - down) // 2
|
102 |
+
py0 += (fh - down + 1) // 2
|
103 |
+
py1 += (fh - down) // 2
|
104 |
+
|
105 |
+
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
106 |
+
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
107 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
108 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
109 |
+
return x
|
110 |
+
|
111 |
+
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
112 |
+
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
113 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
114 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
115 |
+
return x
|
116 |
+
|
117 |
+
# Fast path: downsampling only => use strided convolution.
|
118 |
+
if down > 1 and up == 1:
|
119 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
120 |
+
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
121 |
+
return x
|
122 |
+
|
123 |
+
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
124 |
+
if up > 1:
|
125 |
+
if groups == 1:
|
126 |
+
w = w.transpose(0, 1)
|
127 |
+
else:
|
128 |
+
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
129 |
+
w = w.transpose(1, 2)
|
130 |
+
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
131 |
+
px0 -= kw - 1
|
132 |
+
px1 -= kw - up
|
133 |
+
py0 -= kh - 1
|
134 |
+
py1 -= kh - up
|
135 |
+
pxt = max(min(-px0, -px1), 0)
|
136 |
+
pyt = max(min(-py0, -py1), 0)
|
137 |
+
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
138 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
139 |
+
if down > 1:
|
140 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
141 |
+
return x
|
142 |
+
|
143 |
+
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
144 |
+
if up == 1 and down == 1:
|
145 |
+
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
146 |
+
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
147 |
+
|
148 |
+
# Fallback: Generic reference implementation.
|
149 |
+
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
150 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
151 |
+
if down > 1:
|
152 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
153 |
+
return x
|
154 |
+
|
155 |
+
#----------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cpp
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "upfirdn2d.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
17 |
+
{
|
18 |
+
// Validate arguments.
|
19 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
20 |
+
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
21 |
+
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
22 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
23 |
+
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
24 |
+
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
25 |
+
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
26 |
+
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
27 |
+
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
28 |
+
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
29 |
+
|
30 |
+
// Create output tensor.
|
31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
32 |
+
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
33 |
+
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
34 |
+
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
35 |
+
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
36 |
+
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
37 |
+
|
38 |
+
// Initialize CUDA kernel parameters.
|
39 |
+
upfirdn2d_kernel_params p;
|
40 |
+
p.x = x.data_ptr();
|
41 |
+
p.f = f.data_ptr<float>();
|
42 |
+
p.y = y.data_ptr();
|
43 |
+
p.up = make_int2(upx, upy);
|
44 |
+
p.down = make_int2(downx, downy);
|
45 |
+
p.pad0 = make_int2(padx0, pady0);
|
46 |
+
p.flip = (flip) ? 1 : 0;
|
47 |
+
p.gain = gain;
|
48 |
+
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
49 |
+
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
50 |
+
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
51 |
+
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
52 |
+
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
53 |
+
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
54 |
+
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
55 |
+
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
56 |
+
|
57 |
+
// Choose CUDA kernel.
|
58 |
+
upfirdn2d_kernel_spec spec;
|
59 |
+
AT_DISPATCH_REDUCED_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_cuda", [&]
|
60 |
+
{
|
61 |
+
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
62 |
+
});
|
63 |
+
|
64 |
+
// Set looping options.
|
65 |
+
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
66 |
+
p.loopMinor = spec.loopMinor;
|
67 |
+
p.loopX = spec.loopX;
|
68 |
+
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
69 |
+
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
70 |
+
|
71 |
+
// Compute grid size.
|
72 |
+
dim3 blockSize, gridSize;
|
73 |
+
if (spec.tileOutW < 0) // large
|
74 |
+
{
|
75 |
+
blockSize = dim3(4, 32, 1);
|
76 |
+
gridSize = dim3(
|
77 |
+
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
78 |
+
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
79 |
+
p.launchMajor);
|
80 |
+
}
|
81 |
+
else // small
|
82 |
+
{
|
83 |
+
blockSize = dim3(256, 1, 1);
|
84 |
+
gridSize = dim3(
|
85 |
+
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
86 |
+
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
87 |
+
p.launchMajor);
|
88 |
+
}
|
89 |
+
|
90 |
+
// Launch CUDA kernel.
|
91 |
+
void* args[] = {&p};
|
92 |
+
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
93 |
+
return y;
|
94 |
+
}
|
95 |
+
|
96 |
+
//------------------------------------------------------------------------
|
97 |
+
|
98 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
99 |
+
{
|
100 |
+
m.def("upfirdn2d", &upfirdn2d);
|
101 |
+
}
|
102 |
+
|
103 |
+
//------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cu
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <c10/util/Half.h>
|
11 |
+
#include "upfirdn2d.h"
|
12 |
+
|
13 |
+
//------------------------------------------------------------------------
|
14 |
+
// Helpers.
|
15 |
+
|
16 |
+
template <class T> struct InternalType;
|
17 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
18 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
19 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
20 |
+
template <> struct InternalType<at::BFloat16> { typedef float scalar_t; };
|
21 |
+
|
22 |
+
static __device__ __forceinline__ int floor_div(int a, int b)
|
23 |
+
{
|
24 |
+
int t = 1 - a / b;
|
25 |
+
return (a + t * b) / b - t;
|
26 |
+
}
|
27 |
+
|
28 |
+
//------------------------------------------------------------------------
|
29 |
+
// Generic CUDA implementation for large filters.
|
30 |
+
|
31 |
+
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
32 |
+
{
|
33 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
34 |
+
|
35 |
+
// Calculate thread index.
|
36 |
+
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
37 |
+
int outY = minorBase / p.launchMinor;
|
38 |
+
minorBase -= outY * p.launchMinor;
|
39 |
+
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
40 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
41 |
+
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
42 |
+
return;
|
43 |
+
|
44 |
+
// Setup Y receptive field.
|
45 |
+
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
46 |
+
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
47 |
+
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
48 |
+
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
49 |
+
if (p.flip)
|
50 |
+
filterY = p.filterSize.y - 1 - filterY;
|
51 |
+
|
52 |
+
// Loop over major, minor, and X.
|
53 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
54 |
+
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
55 |
+
{
|
56 |
+
int nc = major * p.sizeMinor + minor;
|
57 |
+
int n = nc / p.inSize.z;
|
58 |
+
int c = nc - n * p.inSize.z;
|
59 |
+
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
60 |
+
{
|
61 |
+
// Setup X receptive field.
|
62 |
+
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
63 |
+
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
64 |
+
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
65 |
+
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
66 |
+
if (p.flip)
|
67 |
+
filterX = p.filterSize.x - 1 - filterX;
|
68 |
+
|
69 |
+
// Initialize pointers.
|
70 |
+
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
71 |
+
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
72 |
+
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
73 |
+
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
74 |
+
|
75 |
+
// Inner loop.
|
76 |
+
scalar_t v = 0;
|
77 |
+
for (int y = 0; y < h; y++)
|
78 |
+
{
|
79 |
+
for (int x = 0; x < w; x++)
|
80 |
+
{
|
81 |
+
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
82 |
+
xp += p.inStride.x;
|
83 |
+
fp += filterStepX;
|
84 |
+
}
|
85 |
+
xp += p.inStride.y - w * p.inStride.x;
|
86 |
+
fp += filterStepY - w * filterStepX;
|
87 |
+
}
|
88 |
+
|
89 |
+
// Store result.
|
90 |
+
v *= p.gain;
|
91 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
92 |
+
}
|
93 |
+
}
|
94 |
+
}
|
95 |
+
|
96 |
+
//------------------------------------------------------------------------
|
97 |
+
// Specialized CUDA implementation for small filters.
|
98 |
+
|
99 |
+
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
100 |
+
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
101 |
+
{
|
102 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
103 |
+
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
104 |
+
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
105 |
+
__shared__ volatile scalar_t sf[filterH][filterW];
|
106 |
+
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
107 |
+
|
108 |
+
// Calculate tile index.
|
109 |
+
int minorBase = blockIdx.x;
|
110 |
+
int tileOutY = minorBase / p.launchMinor;
|
111 |
+
minorBase -= tileOutY * p.launchMinor;
|
112 |
+
minorBase *= loopMinor;
|
113 |
+
tileOutY *= tileOutH;
|
114 |
+
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
115 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
116 |
+
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
117 |
+
return;
|
118 |
+
|
119 |
+
// Load filter (flipped).
|
120 |
+
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
121 |
+
{
|
122 |
+
int fy = tapIdx / filterW;
|
123 |
+
int fx = tapIdx - fy * filterW;
|
124 |
+
scalar_t v = 0;
|
125 |
+
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
126 |
+
{
|
127 |
+
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
128 |
+
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
129 |
+
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
130 |
+
}
|
131 |
+
sf[fy][fx] = v;
|
132 |
+
}
|
133 |
+
|
134 |
+
// Loop over major and X.
|
135 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
136 |
+
{
|
137 |
+
int baseNC = major * p.sizeMinor + minorBase;
|
138 |
+
int n = baseNC / p.inSize.z;
|
139 |
+
int baseC = baseNC - n * p.inSize.z;
|
140 |
+
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
141 |
+
{
|
142 |
+
// Load input pixels.
|
143 |
+
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
144 |
+
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
145 |
+
int tileInX = floor_div(tileMidX, upx);
|
146 |
+
int tileInY = floor_div(tileMidY, upy);
|
147 |
+
__syncthreads();
|
148 |
+
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
149 |
+
{
|
150 |
+
int relC = inIdx;
|
151 |
+
int relInX = relC / loopMinor;
|
152 |
+
int relInY = relInX / tileInW;
|
153 |
+
relC -= relInX * loopMinor;
|
154 |
+
relInX -= relInY * tileInW;
|
155 |
+
int c = baseC + relC;
|
156 |
+
int inX = tileInX + relInX;
|
157 |
+
int inY = tileInY + relInY;
|
158 |
+
scalar_t v = 0;
|
159 |
+
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
160 |
+
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
161 |
+
sx[relInY][relInX][relC] = v;
|
162 |
+
}
|
163 |
+
|
164 |
+
// Loop over output pixels.
|
165 |
+
__syncthreads();
|
166 |
+
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
167 |
+
{
|
168 |
+
int relC = outIdx;
|
169 |
+
int relOutX = relC / loopMinor;
|
170 |
+
int relOutY = relOutX / tileOutW;
|
171 |
+
relC -= relOutX * loopMinor;
|
172 |
+
relOutX -= relOutY * tileOutW;
|
173 |
+
int c = baseC + relC;
|
174 |
+
int outX = tileOutX + relOutX;
|
175 |
+
int outY = tileOutY + relOutY;
|
176 |
+
|
177 |
+
// Setup receptive field.
|
178 |
+
int midX = tileMidX + relOutX * downx;
|
179 |
+
int midY = tileMidY + relOutY * downy;
|
180 |
+
int inX = floor_div(midX, upx);
|
181 |
+
int inY = floor_div(midY, upy);
|
182 |
+
int relInX = inX - tileInX;
|
183 |
+
int relInY = inY - tileInY;
|
184 |
+
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
185 |
+
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
186 |
+
|
187 |
+
// Inner loop.
|
188 |
+
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
189 |
+
{
|
190 |
+
scalar_t v = 0;
|
191 |
+
#pragma unroll
|
192 |
+
for (int y = 0; y < filterH / upy; y++)
|
193 |
+
#pragma unroll
|
194 |
+
for (int x = 0; x < filterW / upx; x++)
|
195 |
+
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
196 |
+
v *= p.gain;
|
197 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
198 |
+
}
|
199 |
+
}
|
200 |
+
}
|
201 |
+
}
|
202 |
+
}
|
203 |
+
|
204 |
+
//------------------------------------------------------------------------
|
205 |
+
// CUDA kernel selection.
|
206 |
+
|
207 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
208 |
+
{
|
209 |
+
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
210 |
+
|
211 |
+
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
212 |
+
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
213 |
+
|
214 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
215 |
+
{
|
216 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
217 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
218 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
219 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
220 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
221 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
222 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
223 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
224 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
225 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
226 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
227 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
228 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
229 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
230 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
231 |
+
}
|
232 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
233 |
+
{
|
234 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
235 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
236 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
237 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
238 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
239 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
240 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
241 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
242 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
243 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
244 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
245 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
246 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
247 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
248 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
249 |
+
}
|
250 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
251 |
+
{
|
252 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
253 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
254 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
255 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
256 |
+
}
|
257 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
258 |
+
{
|
259 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
260 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
261 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
262 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
263 |
+
}
|
264 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
265 |
+
{
|
266 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
267 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
268 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
269 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
270 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
271 |
+
}
|
272 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
273 |
+
{
|
274 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
275 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
276 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
277 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
278 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
279 |
+
}
|
280 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
281 |
+
{
|
282 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
283 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
284 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
285 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
286 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
287 |
+
}
|
288 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
289 |
+
{
|
290 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
291 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
292 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
293 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
294 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
295 |
+
}
|
296 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
|
297 |
+
{
|
298 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
299 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
300 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
301 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
302 |
+
}
|
303 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
|
304 |
+
{
|
305 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
306 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
307 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
308 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
309 |
+
}
|
310 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
|
311 |
+
{
|
312 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
313 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
|
314 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
315 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
|
316 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
317 |
+
}
|
318 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
|
319 |
+
{
|
320 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
321 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
|
322 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
323 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
|
324 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
325 |
+
}
|
326 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
|
327 |
+
{
|
328 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
329 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
|
330 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
331 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
|
332 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
333 |
+
}
|
334 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
|
335 |
+
{
|
336 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
337 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
|
338 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
339 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
|
340 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
341 |
+
}
|
342 |
+
return spec;
|
343 |
+
}
|
344 |
+
|
345 |
+
//------------------------------------------------------------------------
|
346 |
+
// Template specializations.
|
347 |
+
|
348 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
349 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
350 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half> (const upfirdn2d_kernel_params& p);
|
351 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<at::BFloat16> (const upfirdn2d_kernel_params& p);
|
352 |
+
|
353 |
+
//------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.h
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
//------------------------------------------------------------------------
|
12 |
+
// CUDA kernel parameters.
|
13 |
+
|
14 |
+
struct upfirdn2d_kernel_params
|
15 |
+
{
|
16 |
+
const void* x;
|
17 |
+
const float* f;
|
18 |
+
void* y;
|
19 |
+
|
20 |
+
int2 up;
|
21 |
+
int2 down;
|
22 |
+
int2 pad0;
|
23 |
+
int flip;
|
24 |
+
float gain;
|
25 |
+
|
26 |
+
int4 inSize; // [width, height, channel, batch]
|
27 |
+
int4 inStride;
|
28 |
+
int2 filterSize; // [width, height]
|
29 |
+
int2 filterStride;
|
30 |
+
int4 outSize; // [width, height, channel, batch]
|
31 |
+
int4 outStride;
|
32 |
+
int sizeMinor;
|
33 |
+
int sizeMajor;
|
34 |
+
|
35 |
+
int loopMinor;
|
36 |
+
int loopMajor;
|
37 |
+
int loopX;
|
38 |
+
int launchMinor;
|
39 |
+
int launchMajor;
|
40 |
+
};
|
41 |
+
|
42 |
+
//------------------------------------------------------------------------
|
43 |
+
// CUDA kernel specialization.
|
44 |
+
|
45 |
+
struct upfirdn2d_kernel_spec
|
46 |
+
{
|
47 |
+
void* kernel;
|
48 |
+
int tileOutW;
|
49 |
+
int tileOutH;
|
50 |
+
int loopMinor;
|
51 |
+
int loopX;
|
52 |
+
};
|
53 |
+
|
54 |
+
//------------------------------------------------------------------------
|
55 |
+
// CUDA kernel selection.
|
56 |
+
|
57 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
58 |
+
|
59 |
+
//------------------------------------------------------------------------
|
src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import traceback
|
16 |
+
|
17 |
+
from .. import custom_ops, misc
|
18 |
+
from . import conv2d_gradfix
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
_inited = False
|
23 |
+
_plugin = None
|
24 |
+
|
25 |
+
def _init():
|
26 |
+
global _inited, _plugin
|
27 |
+
if not _inited:
|
28 |
+
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
|
29 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
30 |
+
try:
|
31 |
+
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
32 |
+
except:
|
33 |
+
warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
34 |
+
return _plugin is not None
|
35 |
+
|
36 |
+
def _parse_scaling(scaling):
|
37 |
+
if isinstance(scaling, int):
|
38 |
+
scaling = [scaling, scaling]
|
39 |
+
assert isinstance(scaling, (list, tuple))
|
40 |
+
assert all(isinstance(x, int) for x in scaling)
|
41 |
+
sx, sy = scaling
|
42 |
+
assert sx >= 1 and sy >= 1
|
43 |
+
return sx, sy
|
44 |
+
|
45 |
+
def _parse_padding(padding):
|
46 |
+
if isinstance(padding, int):
|
47 |
+
padding = [padding, padding]
|
48 |
+
assert isinstance(padding, (list, tuple))
|
49 |
+
assert all(isinstance(x, int) for x in padding)
|
50 |
+
if len(padding) == 2:
|
51 |
+
padx, pady = padding
|
52 |
+
padding = [padx, padx, pady, pady]
|
53 |
+
padx0, padx1, pady0, pady1 = padding
|
54 |
+
return padx0, padx1, pady0, pady1
|
55 |
+
|
56 |
+
def _get_filter_size(f):
|
57 |
+
if f is None:
|
58 |
+
return 1, 1
|
59 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
60 |
+
fw = f.shape[-1]
|
61 |
+
fh = f.shape[0]
|
62 |
+
with misc.suppress_tracer_warnings():
|
63 |
+
fw = int(fw)
|
64 |
+
fh = int(fh)
|
65 |
+
misc.assert_shape(f, [fh, fw][:f.ndim])
|
66 |
+
assert fw >= 1 and fh >= 1
|
67 |
+
return fw, fh
|
68 |
+
|
69 |
+
#----------------------------------------------------------------------------
|
70 |
+
|
71 |
+
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
72 |
+
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
f: Torch tensor, numpy array, or python list of the shape
|
76 |
+
`[filter_height, filter_width]` (non-separable),
|
77 |
+
`[filter_taps]` (separable),
|
78 |
+
`[]` (impulse), or
|
79 |
+
`None` (identity).
|
80 |
+
device: Result device (default: cpu).
|
81 |
+
normalize: Normalize the filter so that it retains the magnitude
|
82 |
+
for constant input signal (DC)? (default: True).
|
83 |
+
flip_filter: Flip the filter? (default: False).
|
84 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
85 |
+
separable: Return a separable filter? (default: select automatically).
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
Float32 tensor of the shape
|
89 |
+
`[filter_height, filter_width]` (non-separable) or
|
90 |
+
`[filter_taps]` (separable).
|
91 |
+
"""
|
92 |
+
# Validate.
|
93 |
+
if f is None:
|
94 |
+
f = 1
|
95 |
+
f = torch.as_tensor(f, dtype=torch.float32)
|
96 |
+
assert f.ndim in [0, 1, 2]
|
97 |
+
assert f.numel() > 0
|
98 |
+
if f.ndim == 0:
|
99 |
+
f = f[np.newaxis]
|
100 |
+
|
101 |
+
# Separable?
|
102 |
+
if separable is None:
|
103 |
+
separable = (f.ndim == 1 and f.numel() >= 8)
|
104 |
+
if f.ndim == 1 and not separable:
|
105 |
+
f = f.ger(f)
|
106 |
+
assert f.ndim == (1 if separable else 2)
|
107 |
+
|
108 |
+
# Apply normalize, flip, gain, and device.
|
109 |
+
if normalize:
|
110 |
+
f /= f.sum()
|
111 |
+
if flip_filter:
|
112 |
+
f = f.flip(list(range(f.ndim)))
|
113 |
+
f = f * (gain ** (f.ndim / 2))
|
114 |
+
f = f.to(device=device)
|
115 |
+
return f
|
116 |
+
|
117 |
+
#----------------------------------------------------------------------------
|
118 |
+
|
119 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
120 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
121 |
+
|
122 |
+
Performs the following sequence of operations for each channel:
|
123 |
+
|
124 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
125 |
+
|
126 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
127 |
+
Negative padding corresponds to cropping the image.
|
128 |
+
|
129 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
130 |
+
so that the footprint of all output pixels lies within the input image.
|
131 |
+
|
132 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
133 |
+
|
134 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
135 |
+
The fused op is considerably more efficient than performing the same calculation
|
136 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
x: Float32/float64/float16 input tensor of the shape
|
140 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
141 |
+
f: Float32 FIR filter of the shape
|
142 |
+
`[filter_height, filter_width]` (non-separable),
|
143 |
+
`[filter_taps]` (separable), or
|
144 |
+
`None` (identity).
|
145 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
146 |
+
`[x, y]` (default: 1).
|
147 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
148 |
+
`[x, y]` (default: 1).
|
149 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
150 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
151 |
+
(default: 0).
|
152 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
153 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
154 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
158 |
+
"""
|
159 |
+
assert isinstance(x, torch.Tensor)
|
160 |
+
assert impl in ['ref', 'cuda']
|
161 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
162 |
+
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
163 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
164 |
+
|
165 |
+
#----------------------------------------------------------------------------
|
166 |
+
|
167 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
168 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
169 |
+
"""
|
170 |
+
# Validate arguments.
|
171 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
172 |
+
if f is None:
|
173 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
174 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
175 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
176 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
177 |
+
upx, upy = _parse_scaling(up)
|
178 |
+
downx, downy = _parse_scaling(down)
|
179 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
180 |
+
|
181 |
+
# Upsample by inserting zeros.
|
182 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
183 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
184 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
185 |
+
|
186 |
+
# Pad or crop.
|
187 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
188 |
+
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
189 |
+
|
190 |
+
# Setup filter.
|
191 |
+
f = f * (gain ** (f.ndim / 2))
|
192 |
+
f = f.to(x.dtype)
|
193 |
+
if not flip_filter:
|
194 |
+
f = f.flip(list(range(f.ndim)))
|
195 |
+
|
196 |
+
# Convolve with the filter.
|
197 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
198 |
+
if f.ndim == 4:
|
199 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
200 |
+
else:
|
201 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
202 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
203 |
+
|
204 |
+
# Downsample by throwing away pixels.
|
205 |
+
x = x[:, :, ::downy, ::downx]
|
206 |
+
return x
|
207 |
+
|
208 |
+
#----------------------------------------------------------------------------
|
209 |
+
|
210 |
+
_upfirdn2d_cuda_cache = dict()
|
211 |
+
|
212 |
+
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
213 |
+
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
214 |
+
"""
|
215 |
+
# Parse arguments.
|
216 |
+
upx, upy = _parse_scaling(up)
|
217 |
+
downx, downy = _parse_scaling(down)
|
218 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
219 |
+
|
220 |
+
# Lookup from cache.
|
221 |
+
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
222 |
+
if key in _upfirdn2d_cuda_cache:
|
223 |
+
return _upfirdn2d_cuda_cache[key]
|
224 |
+
|
225 |
+
# Forward op.
|
226 |
+
class Upfirdn2dCuda(torch.autograd.Function):
|
227 |
+
@staticmethod
|
228 |
+
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
229 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
230 |
+
if f is None:
|
231 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
232 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
233 |
+
y = x
|
234 |
+
if f.ndim == 2:
|
235 |
+
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
236 |
+
else:
|
237 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
|
238 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
|
239 |
+
ctx.save_for_backward(f)
|
240 |
+
ctx.x_shape = x.shape
|
241 |
+
return y
|
242 |
+
|
243 |
+
@staticmethod
|
244 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
245 |
+
f, = ctx.saved_tensors
|
246 |
+
_, _, ih, iw = ctx.x_shape
|
247 |
+
_, _, oh, ow = dy.shape
|
248 |
+
fw, fh = _get_filter_size(f)
|
249 |
+
p = [
|
250 |
+
fw - padx0 - 1,
|
251 |
+
iw * upx - ow * downx + padx0 - upx + 1,
|
252 |
+
fh - pady0 - 1,
|
253 |
+
ih * upy - oh * downy + pady0 - upy + 1,
|
254 |
+
]
|
255 |
+
dx = None
|
256 |
+
df = None
|
257 |
+
|
258 |
+
if ctx.needs_input_grad[0]:
|
259 |
+
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
260 |
+
|
261 |
+
assert not ctx.needs_input_grad[1]
|
262 |
+
return dx, df
|
263 |
+
|
264 |
+
# Add to cache.
|
265 |
+
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
266 |
+
return Upfirdn2dCuda
|
267 |
+
|
268 |
+
#----------------------------------------------------------------------------
|
269 |
+
|
270 |
+
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
271 |
+
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
272 |
+
|
273 |
+
By default, the result is padded so that its shape matches the input.
|
274 |
+
User-specified padding is applied on top of that, with negative values
|
275 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
x: Float32/float64/float16 input tensor of the shape
|
279 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
280 |
+
f: Float32 FIR filter of the shape
|
281 |
+
`[filter_height, filter_width]` (non-separable),
|
282 |
+
`[filter_taps]` (separable), or
|
283 |
+
`None` (identity).
|
284 |
+
padding: Padding with respect to the output. Can be a single number or a
|
285 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
286 |
+
(default: 0).
|
287 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
288 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
289 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
293 |
+
"""
|
294 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
295 |
+
fw, fh = _get_filter_size(f)
|
296 |
+
p = [
|
297 |
+
padx0 + fw // 2,
|
298 |
+
padx1 + (fw - 1) // 2,
|
299 |
+
pady0 + fh // 2,
|
300 |
+
pady1 + (fh - 1) // 2,
|
301 |
+
]
|
302 |
+
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
303 |
+
|
304 |
+
#----------------------------------------------------------------------------
|
305 |
+
|
306 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
307 |
+
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
308 |
+
|
309 |
+
By default, the result is padded so that its shape is a multiple of the input.
|
310 |
+
User-specified padding is applied on top of that, with negative values
|
311 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
x: Float32/float64/float16 input tensor of the shape
|
315 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
316 |
+
f: Float32 FIR filter of the shape
|
317 |
+
`[filter_height, filter_width]` (non-separable),
|
318 |
+
`[filter_taps]` (separable), or
|
319 |
+
`None` (identity).
|
320 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
321 |
+
`[x, y]` (default: 1).
|
322 |
+
padding: Padding with respect to the output. Can be a single number or a
|
323 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
324 |
+
(default: 0).
|
325 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
326 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
327 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
331 |
+
"""
|
332 |
+
upx, upy = _parse_scaling(up)
|
333 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
334 |
+
fw, fh = _get_filter_size(f)
|
335 |
+
p = [
|
336 |
+
padx0 + (fw + upx - 1) // 2,
|
337 |
+
padx1 + (fw - upx) // 2,
|
338 |
+
pady0 + (fh + upy - 1) // 2,
|
339 |
+
pady1 + (fh - upy) // 2,
|
340 |
+
]
|
341 |
+
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
342 |
+
|
343 |
+
#----------------------------------------------------------------------------
|
344 |
+
|
345 |
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
346 |
+
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
347 |
+
|
348 |
+
By default, the result is padded so that its shape is a fraction of the input.
|
349 |
+
User-specified padding is applied on top of that, with negative values
|
350 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
x: Float32/float64/float16 input tensor of the shape
|
354 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
355 |
+
f: Float32 FIR filter of the shape
|
356 |
+
`[filter_height, filter_width]` (non-separable),
|
357 |
+
`[filter_taps]` (separable), or
|
358 |
+
`None` (identity).
|
359 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
360 |
+
`[x, y]` (default: 1).
|
361 |
+
padding: Padding with respect to the input. Can be a single number or a
|
362 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
363 |
+
(default: 0).
|
364 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
365 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
366 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
370 |
+
"""
|
371 |
+
downx, downy = _parse_scaling(down)
|
372 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
373 |
+
fw, fh = _get_filter_size(f)
|
374 |
+
p = [
|
375 |
+
padx0 + (fw - downx + 1) // 2,
|
376 |
+
padx1 + (fw - downx) // 2,
|
377 |
+
pady0 + (fh - downy + 1) // 2,
|
378 |
+
pady1 + (fh - downy) // 2,
|
379 |
+
]
|
380 |
+
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
381 |
+
|
382 |
+
#----------------------------------------------------------------------------
|
src/vqvaes/bsqvit/transformer.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Callable, Optional, Union
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.utils.checkpoint import checkpoint
|
7 |
+
from timm.models.layers import to_2tuple
|
8 |
+
from timm.models.layers import trunc_normal_
|
9 |
+
from timm.models.layers import DropPath
|
10 |
+
|
11 |
+
from .attention_mask import get_attention_mask
|
12 |
+
|
13 |
+
|
14 |
+
class LayerScale(nn.Module):
|
15 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
16 |
+
super().__init__()
|
17 |
+
self.inplace = inplace
|
18 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
22 |
+
|
23 |
+
|
24 |
+
class ResidualAttentionBlock(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
d_model: int,
|
28 |
+
n_head: int,
|
29 |
+
mlp_ratio: float = 4.0,
|
30 |
+
ls_init_value: float = None,
|
31 |
+
drop: float = 0.,
|
32 |
+
attn_drop: float = 0.,
|
33 |
+
drop_path: float = 0.,
|
34 |
+
act_layer: Callable = nn.GELU,
|
35 |
+
norm_layer: Callable = nn.LayerNorm,
|
36 |
+
use_preln: bool = True,
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.ln_1 = norm_layer(d_model)
|
41 |
+
self.attn = nn.MultiheadAttention(d_model, n_head, dropout=attn_drop)
|
42 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
43 |
+
|
44 |
+
self.ln_2 = norm_layer(d_model)
|
45 |
+
mlp_width = int(d_model * mlp_ratio)
|
46 |
+
self.mlp = nn.Sequential(OrderedDict([
|
47 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
48 |
+
("gelu", act_layer()),
|
49 |
+
# disable this following JAX implementation.
|
50 |
+
# Reference: https://github.com/google-research/magvit/blob/main/videogvt/models/simplified_bert.py#L112
|
51 |
+
# ("drop1", nn.Dropout(drop)),
|
52 |
+
("c_proj", nn.Linear(mlp_width, d_model)),
|
53 |
+
("drop2", nn.Dropout(drop)),
|
54 |
+
]))
|
55 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
56 |
+
|
57 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
58 |
+
|
59 |
+
self.use_preln = use_preln
|
60 |
+
|
61 |
+
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False):
|
62 |
+
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
|
63 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, is_causal=is_causal)[0]
|
64 |
+
|
65 |
+
def checkpoint_forward(self, x: torch.Tensor,
|
66 |
+
attn_mask: Optional[torch.Tensor] = None,
|
67 |
+
is_causal: bool = False):
|
68 |
+
state = x
|
69 |
+
if self.use_preln:
|
70 |
+
x = checkpoint(self.ln_1, x, use_reentrant=False)
|
71 |
+
x = self.attention(x, attn_mask, is_causal)
|
72 |
+
x = checkpoint(self.ls_1, x, use_reentrant=False)
|
73 |
+
state = state + self.drop_path(x)
|
74 |
+
x = checkpoint(self.ln_2, state, use_reentrant=False)
|
75 |
+
x = self.mlp(x)
|
76 |
+
x = checkpoint(self.ls_2, x, use_reentrant=False)
|
77 |
+
state = state + self.drop_path(x)
|
78 |
+
else:
|
79 |
+
x = self.attention(x, attn_mask, is_causal)
|
80 |
+
x = state + self.drop_path(x)
|
81 |
+
state = checkpoint(self.ln_1, x, use_reentrant=False)
|
82 |
+
x = self.mlp(state)
|
83 |
+
state = state + self.drop_path(x)
|
84 |
+
state = checkpoint(self.ln_2, state, use_reentrant=False)
|
85 |
+
return state
|
86 |
+
|
87 |
+
def forward(self, x: torch.Tensor,
|
88 |
+
attn_mask: Optional[torch.Tensor] = None, is_causal: bool =False,
|
89 |
+
selective_checkpointing: bool = False):
|
90 |
+
if selective_checkpointing:
|
91 |
+
return self.checkpoint_forward(x, attn_mask, is_causal=is_causal)
|
92 |
+
if self.use_preln:
|
93 |
+
x = x + self.drop_path(self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)))
|
94 |
+
x = x + self.drop_path(self.ls_2(self.mlp(self.ln_2(x))))
|
95 |
+
else:
|
96 |
+
x = x + self.drop_path(self.attention(x, attn_mask=attn_mask, is_causal=is_causal))
|
97 |
+
x = self.ln_1(x)
|
98 |
+
x = x + self.drop_path(self.mlp(x))
|
99 |
+
x = self.ln_2(x)
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class Transformer(nn.Module):
|
104 |
+
def __init__(self,
|
105 |
+
width: int,
|
106 |
+
layers: int,
|
107 |
+
heads: int,
|
108 |
+
mlp_ratio: float = 4.0,
|
109 |
+
ls_init_value: float = None,
|
110 |
+
drop: float = 0.,
|
111 |
+
attn_drop: float = 0.,
|
112 |
+
drop_path: float = 0.,
|
113 |
+
act_layer: nn.Module = nn.GELU,
|
114 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
115 |
+
use_preln: bool = True,
|
116 |
+
):
|
117 |
+
super().__init__()
|
118 |
+
self.width = width
|
119 |
+
self.layers = layers
|
120 |
+
self.grad_checkpointing = False
|
121 |
+
self.selective_checkpointing = False
|
122 |
+
self.grad_checkpointing_params = {'use_reentrant': False}
|
123 |
+
if attn_drop == 0 and drop_path == 0 and drop_path == 0:
|
124 |
+
self.grad_checkpointing_params.update({'preserve_rng_state': False})
|
125 |
+
else:
|
126 |
+
self.grad_checkpointing_params.update({'preserve_rng_state': True})
|
127 |
+
|
128 |
+
self.resblocks = nn.ModuleList([
|
129 |
+
ResidualAttentionBlock(
|
130 |
+
width, heads, mlp_ratio, ls_init_value=ls_init_value,
|
131 |
+
drop=drop, attn_drop=attn_drop, drop_path=drop_path,
|
132 |
+
act_layer=act_layer, norm_layer=norm_layer,
|
133 |
+
use_preln=use_preln)
|
134 |
+
for _ in range(layers)
|
135 |
+
])
|
136 |
+
|
137 |
+
def forward(self, x: torch.Tensor,
|
138 |
+
attn_mask: Optional[torch.Tensor] = None,
|
139 |
+
is_causal: bool =False):
|
140 |
+
for r in self.resblocks:
|
141 |
+
if self.training and self.grad_checkpointing and not torch.jit.is_scripting():
|
142 |
+
if not self.selective_checkpointing:
|
143 |
+
x = checkpoint(r, x, attn_mask, is_causal=is_causal, **self.grad_checkpointing_params)
|
144 |
+
else:
|
145 |
+
x = r(x, attn_mask=attn_mask, is_causal=is_causal, selective_checkpointing=True)
|
146 |
+
else:
|
147 |
+
x = r(x, attn_mask=attn_mask)
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class TransformerEncoder(nn.Module):
|
152 |
+
def __init__(self,
|
153 |
+
image_size: int,
|
154 |
+
patch_size: int,
|
155 |
+
width: int,
|
156 |
+
layers: int,
|
157 |
+
heads: int,
|
158 |
+
mlp_ratio: float,
|
159 |
+
num_frames: int = 1,
|
160 |
+
cross_frames: bool = True,
|
161 |
+
ls_init_value: float = None,
|
162 |
+
drop_rate: float = 0.,
|
163 |
+
attn_drop_rate: float = 0.,
|
164 |
+
drop_path_rate: float = 0.,
|
165 |
+
ln_pre: bool = True,
|
166 |
+
ln_post: bool = True,
|
167 |
+
act_layer: str = 'gelu',
|
168 |
+
norm_layer: str = 'layer_norm',
|
169 |
+
mask_type: Union[str, None] = 'none',
|
170 |
+
mask_block_size: int = -1
|
171 |
+
):
|
172 |
+
super().__init__()
|
173 |
+
self.image_size = to_2tuple(image_size)
|
174 |
+
self.patch_size = to_2tuple(patch_size)
|
175 |
+
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
|
176 |
+
self.patches_per_frame = self.grid_size[0] * self.grid_size[1]
|
177 |
+
self.mask_type = mask_type
|
178 |
+
self.mask_block_size = mask_block_size
|
179 |
+
|
180 |
+
if act_layer.lower() == 'gelu':
|
181 |
+
self.act_layer = nn.GELU
|
182 |
+
else:
|
183 |
+
raise ValueError(f"Unsupported activation function: {act_layer}")
|
184 |
+
if norm_layer.lower() == 'layer_norm':
|
185 |
+
self.norm_layer = nn.LayerNorm
|
186 |
+
else:
|
187 |
+
raise ValueError(f"Unsupported normalization: {norm_layer}")
|
188 |
+
|
189 |
+
self.conv1 = nn.Linear(
|
190 |
+
in_features=3 * self.patch_size[0] * self.patch_size[1],
|
191 |
+
out_features=width,
|
192 |
+
bias=not ln_pre
|
193 |
+
)
|
194 |
+
|
195 |
+
scale = width ** -0.5
|
196 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width))
|
197 |
+
assert num_frames >= 1
|
198 |
+
self.num_frames = num_frames
|
199 |
+
self.cross_frames = cross_frames
|
200 |
+
if num_frames > 1 and cross_frames:
|
201 |
+
self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width))
|
202 |
+
else:
|
203 |
+
self.temporal_positional_embedding = None
|
204 |
+
|
205 |
+
self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity()
|
206 |
+
|
207 |
+
self.transformer = Transformer(
|
208 |
+
width, layers, heads, mlp_ratio, ls_init_value=ls_init_value,
|
209 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate,
|
210 |
+
act_layer=self.act_layer, norm_layer=self.norm_layer,
|
211 |
+
)
|
212 |
+
|
213 |
+
self.ln_post = self.norm_layer(width)
|
214 |
+
|
215 |
+
self.init_parameters()
|
216 |
+
|
217 |
+
def init_parameters(self):
|
218 |
+
if self.positional_embedding is not None:
|
219 |
+
nn.init.normal_(self.positional_embedding, std=0.02)
|
220 |
+
trunc_normal_(self.conv1.weight, std=0.02)
|
221 |
+
for block in self.transformer.resblocks:
|
222 |
+
for n, p in block.named_parameters():
|
223 |
+
if 'weight' in n:
|
224 |
+
if 'ln' not in n:
|
225 |
+
trunc_normal_(p, std=0.02)
|
226 |
+
elif 'bias' in n:
|
227 |
+
nn.init.zeros_(p)
|
228 |
+
else:
|
229 |
+
raise NotImplementedError(f'Unknown parameters named {n}')
|
230 |
+
|
231 |
+
@torch.jit.ignore
|
232 |
+
def set_grad_checkpointing(self, enable=True, selective=False):
|
233 |
+
self.transformer.grad_checkpointing = enable
|
234 |
+
self.transformer.selective_checkpointing = selective
|
235 |
+
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
if self.num_frames == 1:
|
239 |
+
x = rearrange(
|
240 |
+
x, "b c (hh sh) (ww sw) -> b (hh ww) (c sh sw)",
|
241 |
+
sh=self.patch_size[0], sw=self.patch_size[1]
|
242 |
+
)
|
243 |
+
x = self.conv1(x)
|
244 |
+
x = x + self.positional_embedding.to(x.dtype)
|
245 |
+
elif self.cross_frames:
|
246 |
+
num_frames = x.shape[2]
|
247 |
+
assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting'
|
248 |
+
x = rearrange(
|
249 |
+
x, "b c t (hh sh) (ww sw) -> b (t hh ww) (c sh sw)",
|
250 |
+
sh=self.patch_size[0], sw=self.patch_size[1]
|
251 |
+
)
|
252 |
+
x = self.conv1(x)
|
253 |
+
tile_pos_embed = self.positional_embedding.repeat(num_frames, 1)
|
254 |
+
tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0)
|
255 |
+
total_pos_embed = tile_pos_embed + tile_tem_embed
|
256 |
+
x = x + total_pos_embed.to(x.dtype).squeeze(0)
|
257 |
+
else:
|
258 |
+
x = rearrange(
|
259 |
+
x, "b c t (hh sh) (ww sw) -> (b t) (hh ww) (c sh sw)",
|
260 |
+
sh=self.patch_size[0], sw=self.patch_size[1]
|
261 |
+
)
|
262 |
+
x = self.conv1(x)
|
263 |
+
x = x + self.positional_embedding.to(x.dtype)
|
264 |
+
|
265 |
+
x = self.ln_pre(x)
|
266 |
+
x = x.permute(1, 0, 2)
|
267 |
+
block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size
|
268 |
+
attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size)
|
269 |
+
x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal')
|
270 |
+
x = x.permute(1, 0, 2)
|
271 |
+
x = self.ln_post(x)
|
272 |
+
|
273 |
+
return x
|
274 |
+
|
275 |
+
|
276 |
+
class TransformerDecoder(nn.Module):
|
277 |
+
def __init__(self,
|
278 |
+
image_size: int,
|
279 |
+
patch_size: int,
|
280 |
+
width: int,
|
281 |
+
layers: int,
|
282 |
+
heads: int,
|
283 |
+
mlp_ratio: float,
|
284 |
+
num_frames: int = 1,
|
285 |
+
cross_frames: bool = True,
|
286 |
+
ls_init_value: float = None,
|
287 |
+
drop_rate: float = 0.,
|
288 |
+
attn_drop_rate: float = 0.,
|
289 |
+
drop_path_rate: float = 0.,
|
290 |
+
ln_pre: bool = True,
|
291 |
+
ln_post: bool = True,
|
292 |
+
act_layer: str = 'gelu',
|
293 |
+
norm_layer: str = 'layer_norm',
|
294 |
+
use_ffn_output: bool = True,
|
295 |
+
dim_ffn_output: int = 3072,
|
296 |
+
logit_laplace: bool = False,
|
297 |
+
mask_type: Union[str, None] = 'none',
|
298 |
+
mask_block_size: int = -1
|
299 |
+
):
|
300 |
+
super().__init__()
|
301 |
+
self.image_size = to_2tuple(image_size)
|
302 |
+
self.patch_size = to_2tuple(patch_size)
|
303 |
+
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
|
304 |
+
self.patches_per_frame = self.grid_size[0] * self.grid_size[1]
|
305 |
+
self.mask_type = mask_type
|
306 |
+
self.mask_block_size = mask_block_size
|
307 |
+
|
308 |
+
if act_layer.lower() == 'gelu':
|
309 |
+
self.act_layer = nn.GELU
|
310 |
+
else:
|
311 |
+
raise ValueError(f"Unsupported activation function: {act_layer}")
|
312 |
+
if norm_layer.lower() == 'layer_norm':
|
313 |
+
self.norm_layer = nn.LayerNorm
|
314 |
+
else:
|
315 |
+
raise ValueError(f"Unsupported normalization: {norm_layer}")
|
316 |
+
|
317 |
+
self.use_ffn_output = use_ffn_output
|
318 |
+
if use_ffn_output:
|
319 |
+
self.ffn = nn.Sequential(
|
320 |
+
nn.Linear(width, dim_ffn_output),
|
321 |
+
nn.Tanh(),
|
322 |
+
)
|
323 |
+
self.conv_out = nn.Linear(
|
324 |
+
in_features=dim_ffn_output,
|
325 |
+
out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace)
|
326 |
+
)
|
327 |
+
else:
|
328 |
+
self.ffn = nn.Identity()
|
329 |
+
self.conv_out = nn.Linear(
|
330 |
+
in_features=width,
|
331 |
+
out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace)
|
332 |
+
)
|
333 |
+
|
334 |
+
scale = width ** -0.5
|
335 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width))
|
336 |
+
assert num_frames >= 1
|
337 |
+
self.num_frames = num_frames
|
338 |
+
self.cross_frames = cross_frames
|
339 |
+
if num_frames > 1 and cross_frames:
|
340 |
+
self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width))
|
341 |
+
else:
|
342 |
+
self.temporal_positional_embedding = None
|
343 |
+
|
344 |
+
self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity()
|
345 |
+
|
346 |
+
self.transformer = Transformer(
|
347 |
+
width, layers, heads, mlp_ratio, ls_init_value=ls_init_value,
|
348 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate,
|
349 |
+
act_layer=self.act_layer, norm_layer=self.norm_layer,
|
350 |
+
)
|
351 |
+
|
352 |
+
self.ln_post = self.norm_layer(width) if ln_post else nn.Identity()
|
353 |
+
|
354 |
+
self.init_parameters()
|
355 |
+
|
356 |
+
def init_parameters(self):
|
357 |
+
if self.positional_embedding is not None:
|
358 |
+
nn.init.normal_(self.positional_embedding, std=0.02)
|
359 |
+
|
360 |
+
for block in self.transformer.resblocks:
|
361 |
+
for n, p in block.named_parameters():
|
362 |
+
if 'weight' in n:
|
363 |
+
if 'ln' not in n:
|
364 |
+
trunc_normal_(p, std=0.02)
|
365 |
+
elif 'bias' in n:
|
366 |
+
nn.init.zeros_(p)
|
367 |
+
else:
|
368 |
+
raise NotImplementedError(f'Unknown parameters named {n}')
|
369 |
+
if self.use_ffn_output:
|
370 |
+
trunc_normal_(self.ffn[0].weight, std=0.02)
|
371 |
+
trunc_normal_(self.conv_out.weight, std=0.02)
|
372 |
+
|
373 |
+
@torch.jit.ignore
|
374 |
+
def set_grad_checkpointing(self, enable=True, selective=False):
|
375 |
+
self.transformer.grad_checkpointing = enable
|
376 |
+
self.transformer.selective_checkpointing = selective
|
377 |
+
|
378 |
+
def forward(self, x):
|
379 |
+
if self.num_frames == 1 or not self.cross_frames:
|
380 |
+
x = x + self.positional_embedding.to(x.dtype)
|
381 |
+
else:
|
382 |
+
num_frames = x.shape[1] // self.patches_per_frame
|
383 |
+
assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting'
|
384 |
+
tile_pos_embed = self.positional_embedding.repeat(num_frames, 1)
|
385 |
+
tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0)
|
386 |
+
total_pos_embed = tile_pos_embed + tile_tem_embed
|
387 |
+
x = x + total_pos_embed.to(x.dtype).squeeze(0)
|
388 |
+
x = self.ln_pre(x)
|
389 |
+
x = x.permute(1, 0, 2)
|
390 |
+
block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size
|
391 |
+
attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size)
|
392 |
+
x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal')
|
393 |
+
x = x.permute(1, 0, 2)
|
394 |
+
x = self.ln_post(x)
|
395 |
+
x = self.ffn(x)
|
396 |
+
x = self.conv_out(x)
|
397 |
+
if self.num_frames == 1:
|
398 |
+
x = rearrange(
|
399 |
+
x, "b (hh ww) (c sh sw) -> b c (hh sh) (ww sw)",
|
400 |
+
hh = self.grid_size[0], ww=self.grid_size[1],
|
401 |
+
sh=self.patch_size[0], sw=self.patch_size[1]
|
402 |
+
)
|
403 |
+
elif self.cross_frames:
|
404 |
+
x = rearrange(
|
405 |
+
x, "b (t hh ww) (c sh sw) -> b c t (hh sh) (ww sw)",
|
406 |
+
t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1],
|
407 |
+
sh=self.patch_size[0], sw=self.patch_size[1]
|
408 |
+
)
|
409 |
+
else:
|
410 |
+
x = rearrange(
|
411 |
+
x, "(b t) (hh ww) (c sh sw) -> b c t (hh sh) (ww sw)",
|
412 |
+
t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1],
|
413 |
+
sh=self.patch_size[0], sw=self.patch_size[1]
|
414 |
+
)
|
415 |
+
|
416 |
+
return x
|
src/vqvaes/flowmo/flowmo.py
ADDED
@@ -0,0 +1,945 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Model code for FlowMo.
|
2 |
+
|
3 |
+
Sources: https://github.com/feizc/FluxMusic/blob/main/train.py
|
4 |
+
https://github.com/black-forest-labs/flux/tree/main/src/flux
|
5 |
+
"""
|
6 |
+
|
7 |
+
import ast
|
8 |
+
import itertools
|
9 |
+
import math
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from typing import List, Tuple
|
12 |
+
|
13 |
+
import einops
|
14 |
+
import torch
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
from mup import MuReadout
|
17 |
+
from torch import Tensor, nn
|
18 |
+
import argparse
|
19 |
+
import contextlib
|
20 |
+
import copy
|
21 |
+
import glob
|
22 |
+
import os
|
23 |
+
import subprocess
|
24 |
+
import tempfile
|
25 |
+
import time
|
26 |
+
|
27 |
+
import fsspec
|
28 |
+
import psutil
|
29 |
+
import torch
|
30 |
+
import torch.distributed as dist
|
31 |
+
from mup import MuReadout, set_base_shapes
|
32 |
+
from omegaconf import OmegaConf
|
33 |
+
from torch.utils.data import DataLoader
|
34 |
+
|
35 |
+
from .lookup_free_quantize import LFQ
|
36 |
+
|
37 |
+
MUP_ENABLED = True
|
38 |
+
|
39 |
+
|
40 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
41 |
+
b, h, l, d = q.shape
|
42 |
+
q, k = apply_rope(q, k, pe)
|
43 |
+
|
44 |
+
if torch.__version__ == "2.0.1+cu117": # tmp workaround
|
45 |
+
if d != 64:
|
46 |
+
print("MUP is broken in this setting! Be careful!")
|
47 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
48 |
+
q,
|
49 |
+
k,
|
50 |
+
v,
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
54 |
+
q,
|
55 |
+
k,
|
56 |
+
v,
|
57 |
+
scale=8.0 / d if MUP_ENABLED else None,
|
58 |
+
)
|
59 |
+
assert x.shape == q.shape
|
60 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
65 |
+
assert dim % 2 == 0
|
66 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
67 |
+
omega = 1.0 / (theta**scale)
|
68 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
69 |
+
out = torch.stack(
|
70 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)],
|
71 |
+
dim=-1,
|
72 |
+
)
|
73 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
74 |
+
return out.float()
|
75 |
+
|
76 |
+
|
77 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
78 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
79 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
80 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
81 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
82 |
+
|
83 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
84 |
+
|
85 |
+
|
86 |
+
def _get_diagonal_gaussian(parameters):
|
87 |
+
mean, logvar = torch.chunk(parameters, 2, dim=1)
|
88 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
89 |
+
return mean, logvar
|
90 |
+
|
91 |
+
|
92 |
+
def _sample_diagonal_gaussian(mean, logvar):
|
93 |
+
std = torch.exp(0.5 * logvar)
|
94 |
+
x = mean + std * torch.randn(mean.shape, device=mean.device)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
def _kl_diagonal_gaussian(mean, logvar):
|
99 |
+
var = torch.exp(logvar)
|
100 |
+
return 0.5 * torch.sum(torch.pow(mean, 2) + var - 1.0 - logvar, dim=1).mean()
|
101 |
+
|
102 |
+
|
103 |
+
class EmbedND(nn.Module):
|
104 |
+
def __init__(self, dim: int, theta: int, axes_dim):
|
105 |
+
super().__init__()
|
106 |
+
self.dim = dim
|
107 |
+
self.theta = theta
|
108 |
+
self.axes_dim = axes_dim
|
109 |
+
|
110 |
+
def forward(self, ids: Tensor) -> Tensor:
|
111 |
+
n_axes = ids.shape[-1]
|
112 |
+
emb = torch.cat(
|
113 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
114 |
+
dim=-3,
|
115 |
+
)
|
116 |
+
|
117 |
+
return emb.unsqueeze(1)
|
118 |
+
|
119 |
+
|
120 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
121 |
+
"""
|
122 |
+
Create sinusoidal timestep embeddings.
|
123 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
124 |
+
These may be fractional.
|
125 |
+
:param dim: the dimension of the output.
|
126 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
127 |
+
:return: an (N, D) Tensor of positional embeddings.
|
128 |
+
"""
|
129 |
+
t = time_factor * t
|
130 |
+
half = dim // 2
|
131 |
+
freqs = torch.exp(
|
132 |
+
-math.log(max_period)
|
133 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
134 |
+
/ half
|
135 |
+
).to(t.device)
|
136 |
+
|
137 |
+
args = t[:, None].float() * freqs[None]
|
138 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
139 |
+
if dim % 2:
|
140 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
141 |
+
if torch.is_floating_point(t):
|
142 |
+
embedding = embedding.to(t)
|
143 |
+
return embedding
|
144 |
+
|
145 |
+
|
146 |
+
class MLPEmbedder(nn.Module):
|
147 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
148 |
+
super().__init__()
|
149 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
150 |
+
self.silu = nn.SiLU()
|
151 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
152 |
+
|
153 |
+
def forward(self, x: Tensor) -> Tensor:
|
154 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
155 |
+
|
156 |
+
|
157 |
+
class RMSNorm(torch.nn.Module):
|
158 |
+
def __init__(self, dim: int):
|
159 |
+
super().__init__()
|
160 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
161 |
+
|
162 |
+
def forward(self, x: Tensor):
|
163 |
+
x_dtype = x.dtype
|
164 |
+
x = x.float()
|
165 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
166 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
167 |
+
|
168 |
+
|
169 |
+
class QKNorm(torch.nn.Module):
|
170 |
+
def __init__(self, dim: int):
|
171 |
+
super().__init__()
|
172 |
+
self.query_norm = RMSNorm(dim)
|
173 |
+
self.key_norm = RMSNorm(dim)
|
174 |
+
|
175 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor):
|
176 |
+
q = self.query_norm(q)
|
177 |
+
k = self.key_norm(k)
|
178 |
+
return q.to(v), k.to(v)
|
179 |
+
|
180 |
+
|
181 |
+
class SelfAttention(nn.Module):
|
182 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
183 |
+
super().__init__()
|
184 |
+
self.num_heads = num_heads
|
185 |
+
head_dim = dim // num_heads
|
186 |
+
|
187 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
188 |
+
self.norm = QKNorm(head_dim)
|
189 |
+
self.proj = nn.Linear(dim, dim)
|
190 |
+
|
191 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
192 |
+
qkv = self.qkv(x)
|
193 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
194 |
+
q, k = self.norm(q, k, v)
|
195 |
+
x = attention(q, k, v, pe=pe)
|
196 |
+
x = self.proj(x)
|
197 |
+
return x
|
198 |
+
|
199 |
+
|
200 |
+
@dataclass
|
201 |
+
class ModulationOut:
|
202 |
+
shift: Tensor
|
203 |
+
scale: Tensor
|
204 |
+
gate: Tensor
|
205 |
+
|
206 |
+
|
207 |
+
class Modulation(nn.Module):
|
208 |
+
def __init__(self, dim: int, double: bool):
|
209 |
+
super().__init__()
|
210 |
+
self.is_double = double
|
211 |
+
self.multiplier = 6 if double else 3
|
212 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
213 |
+
|
214 |
+
self.lin.weight[dim * 2 : dim * 3].data[:] = 0.0
|
215 |
+
self.lin.bias[dim * 2 : dim * 3].data[:] = 0.0
|
216 |
+
self.lin.weight[dim * 5 : dim * 6].data[:] = 0.0
|
217 |
+
self.lin.bias[dim * 5 : dim * 6].data[:] = 0.0
|
218 |
+
|
219 |
+
def forward(self, vec: Tensor) -> Tuple[ModulationOut, ModulationOut]:
|
220 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
|
221 |
+
self.multiplier, dim=-1
|
222 |
+
)
|
223 |
+
return (
|
224 |
+
ModulationOut(*out[:3]),
|
225 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
226 |
+
)
|
227 |
+
|
228 |
+
|
229 |
+
class DoubleStreamBlock(nn.Module):
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
hidden_size: int,
|
233 |
+
num_heads: int,
|
234 |
+
mlp_ratio: float,
|
235 |
+
qkv_bias: bool = False,
|
236 |
+
):
|
237 |
+
super().__init__()
|
238 |
+
|
239 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
240 |
+
self.num_heads = num_heads
|
241 |
+
self.hidden_size = hidden_size
|
242 |
+
|
243 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
244 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
245 |
+
self.img_attn = SelfAttention(
|
246 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
247 |
+
)
|
248 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
249 |
+
self.img_mlp = nn.Sequential(
|
250 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
251 |
+
nn.GELU(approximate="tanh"),
|
252 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
253 |
+
)
|
254 |
+
|
255 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
256 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
257 |
+
self.txt_attn = SelfAttention(
|
258 |
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
259 |
+
)
|
260 |
+
|
261 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
262 |
+
self.txt_mlp = nn.Sequential(
|
263 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
264 |
+
nn.GELU(approximate="tanh"),
|
265 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
266 |
+
)
|
267 |
+
|
268 |
+
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
269 |
+
pe_single, pe_double = pe
|
270 |
+
p = 1
|
271 |
+
if vec is None:
|
272 |
+
img_mod1, img_mod2 = ModulationOut(0, 1 - p, 1), ModulationOut(0, 1 - p, 1)
|
273 |
+
txt_mod1, txt_mod2 = ModulationOut(0, 1 - p, 1), ModulationOut(0, 1 - p, 1)
|
274 |
+
else:
|
275 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
276 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
277 |
+
|
278 |
+
# prepare image for attention
|
279 |
+
img_modulated = self.img_norm1(img)
|
280 |
+
img_modulated = (p + img_mod1.scale) * img_modulated + img_mod1.shift
|
281 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
282 |
+
img_q, img_k, img_v = rearrange(
|
283 |
+
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
284 |
+
)
|
285 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
286 |
+
|
287 |
+
# prepare txt for attention
|
288 |
+
txt_modulated = self.txt_norm1(txt)
|
289 |
+
txt_modulated = (p + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
290 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
291 |
+
txt_q, txt_k, txt_v = rearrange(
|
292 |
+
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
293 |
+
)
|
294 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
295 |
+
|
296 |
+
# run actual attention
|
297 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
298 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
299 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
300 |
+
|
301 |
+
attn = attention(q, k, v, pe=pe_double)
|
302 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
303 |
+
|
304 |
+
# calculate the img bloks
|
305 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
306 |
+
img = img + img_mod2.gate * self.img_mlp(
|
307 |
+
(p + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
308 |
+
)
|
309 |
+
|
310 |
+
# calculate the txt bloks
|
311 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
312 |
+
txt = txt + txt_mod2.gate * self.txt_mlp(
|
313 |
+
(p + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
314 |
+
)
|
315 |
+
return img, txt
|
316 |
+
|
317 |
+
|
318 |
+
class LastLayer(nn.Module):
|
319 |
+
def __init__(
|
320 |
+
self,
|
321 |
+
hidden_size: int,
|
322 |
+
patch_size: int,
|
323 |
+
out_channels: int,
|
324 |
+
readout_zero_init=False,
|
325 |
+
):
|
326 |
+
super().__init__()
|
327 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
328 |
+
|
329 |
+
if MUP_ENABLED:
|
330 |
+
self.linear = MuReadout(
|
331 |
+
hidden_size,
|
332 |
+
patch_size * patch_size * out_channels,
|
333 |
+
bias=True,
|
334 |
+
readout_zero_init=readout_zero_init,
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
self.linear = nn.Linear(
|
338 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
339 |
+
)
|
340 |
+
|
341 |
+
self.adaLN_modulation = nn.Sequential(
|
342 |
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
343 |
+
)
|
344 |
+
|
345 |
+
def forward(self, x: Tensor, vec) -> Tensor:
|
346 |
+
if vec is None:
|
347 |
+
pass
|
348 |
+
else:
|
349 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
350 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
351 |
+
x = self.norm_final(x)
|
352 |
+
x = self.linear(x)
|
353 |
+
return x
|
354 |
+
|
355 |
+
|
356 |
+
@dataclass
|
357 |
+
class FluxParams:
|
358 |
+
in_channels: int
|
359 |
+
patch_size: int
|
360 |
+
context_dim: int
|
361 |
+
hidden_size: int
|
362 |
+
mlp_ratio: float
|
363 |
+
num_heads: int
|
364 |
+
depth: int
|
365 |
+
axes_dim: List[int]
|
366 |
+
theta: int
|
367 |
+
qkv_bias: bool
|
368 |
+
|
369 |
+
|
370 |
+
DIT_ZOO = dict(
|
371 |
+
dit_xl_4=dict(
|
372 |
+
hidden_size=1152,
|
373 |
+
mlp_ratio=4.0,
|
374 |
+
num_heads=16,
|
375 |
+
axes_dim=[8, 28, 28],
|
376 |
+
theta=10_000,
|
377 |
+
qkv_bias=True,
|
378 |
+
),
|
379 |
+
dit_l_4=dict(
|
380 |
+
hidden_size=1024,
|
381 |
+
mlp_ratio=4.0,
|
382 |
+
num_heads=16,
|
383 |
+
axes_dim=[8, 28, 28],
|
384 |
+
theta=10_000,
|
385 |
+
qkv_bias=True,
|
386 |
+
),
|
387 |
+
dit_b_4=dict(
|
388 |
+
hidden_size=768,
|
389 |
+
mlp_ratio=4.0,
|
390 |
+
num_heads=12,
|
391 |
+
axes_dim=[8, 28, 28],
|
392 |
+
theta=10_000,
|
393 |
+
qkv_bias=True,
|
394 |
+
),
|
395 |
+
dit_s_4=dict(
|
396 |
+
hidden_size=384,
|
397 |
+
mlp_ratio=4.0,
|
398 |
+
num_heads=6,
|
399 |
+
axes_dim=[8, 28, 28],
|
400 |
+
theta=10_000,
|
401 |
+
qkv_bias=True,
|
402 |
+
),
|
403 |
+
dit_mup_test=dict(
|
404 |
+
hidden_size=768,
|
405 |
+
mlp_ratio=4.0,
|
406 |
+
num_heads=12,
|
407 |
+
axes_dim=[8, 28, 28],
|
408 |
+
theta=10_000,
|
409 |
+
qkv_bias=True,
|
410 |
+
),
|
411 |
+
)
|
412 |
+
|
413 |
+
|
414 |
+
def prepare_idxs(img, code_length, patch_size):
|
415 |
+
bs, c, h, w = img.shape
|
416 |
+
|
417 |
+
img_ids = torch.zeros(h // patch_size, w // patch_size, 3, device=img.device)
|
418 |
+
img_ids[..., 1] = (
|
419 |
+
img_ids[..., 1] + torch.arange(h // patch_size, device=img.device)[:, None]
|
420 |
+
)
|
421 |
+
img_ids[..., 2] = (
|
422 |
+
img_ids[..., 2] + torch.arange(w // patch_size, device=img.device)[None, :]
|
423 |
+
)
|
424 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
425 |
+
|
426 |
+
txt_ids = (
|
427 |
+
torch.zeros((bs, code_length, 3), device=img.device)
|
428 |
+
+ torch.arange(code_length, device=img.device)[None, :, None]
|
429 |
+
)
|
430 |
+
return img_ids, txt_ids
|
431 |
+
|
432 |
+
|
433 |
+
class Flux(nn.Module):
|
434 |
+
"""
|
435 |
+
Transformer model for flow matching on sequences.
|
436 |
+
"""
|
437 |
+
|
438 |
+
def __init__(self, params: FluxParams, name="", lsg=False):
|
439 |
+
super().__init__()
|
440 |
+
|
441 |
+
self.name = name
|
442 |
+
self.lsg = lsg
|
443 |
+
self.params = params
|
444 |
+
self.in_channels = params.in_channels
|
445 |
+
self.patch_size = params.patch_size
|
446 |
+
self.out_channels = self.in_channels
|
447 |
+
if params.hidden_size % params.num_heads != 0:
|
448 |
+
raise ValueError(
|
449 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
450 |
+
)
|
451 |
+
pe_dim = params.hidden_size // params.num_heads
|
452 |
+
if sum(params.axes_dim) != pe_dim:
|
453 |
+
raise ValueError(
|
454 |
+
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
455 |
+
)
|
456 |
+
self.hidden_size = params.hidden_size
|
457 |
+
self.num_heads = params.num_heads
|
458 |
+
self.pe_embedder = EmbedND(
|
459 |
+
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
460 |
+
)
|
461 |
+
|
462 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
463 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
464 |
+
self.txt_in = nn.Linear(params.context_dim, self.hidden_size)
|
465 |
+
|
466 |
+
self.double_blocks = nn.ModuleList(
|
467 |
+
[
|
468 |
+
DoubleStreamBlock(
|
469 |
+
self.hidden_size,
|
470 |
+
self.num_heads,
|
471 |
+
mlp_ratio=params.mlp_ratio,
|
472 |
+
qkv_bias=params.qkv_bias,
|
473 |
+
)
|
474 |
+
for idx in range(params.depth)
|
475 |
+
]
|
476 |
+
)
|
477 |
+
|
478 |
+
self.final_layer_img = LastLayer(
|
479 |
+
self.hidden_size, 1, self.out_channels, readout_zero_init=False
|
480 |
+
)
|
481 |
+
self.final_layer_txt = LastLayer(
|
482 |
+
self.hidden_size, 1, params.context_dim, readout_zero_init=False
|
483 |
+
)
|
484 |
+
|
485 |
+
def forward(
|
486 |
+
self,
|
487 |
+
img: Tensor,
|
488 |
+
img_ids: Tensor,
|
489 |
+
txt: Tensor,
|
490 |
+
txt_ids: Tensor,
|
491 |
+
timesteps: Tensor,
|
492 |
+
) -> Tensor:
|
493 |
+
b, c, h, w = img.shape
|
494 |
+
|
495 |
+
img = rearrange(
|
496 |
+
img,
|
497 |
+
"b c (gh ph) (gw pw) -> b (gh gw) (ph pw c)",
|
498 |
+
ph=self.patch_size,
|
499 |
+
pw=self.patch_size,
|
500 |
+
)
|
501 |
+
if img.ndim != 3 or txt.ndim != 3:
|
502 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
503 |
+
img = self.img_in(img)
|
504 |
+
|
505 |
+
if timesteps is None:
|
506 |
+
vec = None
|
507 |
+
else:
|
508 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
509 |
+
|
510 |
+
txt = self.txt_in(txt)
|
511 |
+
pe_single = self.pe_embedder(torch.cat((txt_ids,), dim=1))
|
512 |
+
pe_double = self.pe_embedder(torch.cat((txt_ids, img_ids), dim=1))
|
513 |
+
|
514 |
+
for block in self.double_blocks:
|
515 |
+
img, txt = block(img=img, txt=txt, pe=(pe_single, pe_double), vec=vec)
|
516 |
+
|
517 |
+
img = self.final_layer_img(img, vec=vec)
|
518 |
+
img = rearrange(
|
519 |
+
img,
|
520 |
+
"b (gh gw) (ph pw c) -> b c (gh ph) (gw pw)",
|
521 |
+
ph=self.patch_size,
|
522 |
+
pw=self.patch_size,
|
523 |
+
gh=h // self.patch_size,
|
524 |
+
gw=w // self.patch_size,
|
525 |
+
)
|
526 |
+
|
527 |
+
txt = self.final_layer_txt(txt, vec=vec)
|
528 |
+
return img, txt, {"final_txt": txt}
|
529 |
+
|
530 |
+
|
531 |
+
def get_weights_to_fix(model):
|
532 |
+
with torch.no_grad():
|
533 |
+
for name, module in itertools.chain(model.named_modules()):
|
534 |
+
if "double_blocks" in name and isinstance(module, torch.nn.Linear):
|
535 |
+
yield name, module.weight
|
536 |
+
|
537 |
+
|
538 |
+
class FlowMo(nn.Module):
|
539 |
+
def __init__(self, width, config):
|
540 |
+
super().__init__()
|
541 |
+
code_length = config.model.code_length
|
542 |
+
context_dim = config.model.context_dim
|
543 |
+
enc_depth = config.model.enc_depth
|
544 |
+
dec_depth = config.model.dec_depth
|
545 |
+
|
546 |
+
patch_size = config.model.patch_size
|
547 |
+
self.config = config
|
548 |
+
|
549 |
+
self.image_size = config.data.image_size
|
550 |
+
self.patch_size = config.model.patch_size
|
551 |
+
self.code_length = code_length
|
552 |
+
self.dit_mode = "dit_b_4"
|
553 |
+
self.context_dim = context_dim
|
554 |
+
self.encoder_context_dim = context_dim * (
|
555 |
+
1 + (self.config.model.quantization_type == "kl")
|
556 |
+
)
|
557 |
+
|
558 |
+
if config.model.quantization_type == "lfq":
|
559 |
+
self.quantizer = LFQ(
|
560 |
+
codebook_size=2**self.config.model.codebook_size_for_entropy,
|
561 |
+
dim=self.config.model.codebook_size_for_entropy,
|
562 |
+
num_codebooks=1,
|
563 |
+
token_factorization=False,
|
564 |
+
)
|
565 |
+
|
566 |
+
if self.config.model.enc_mup_width is not None:
|
567 |
+
enc_width = self.config.model.enc_mup_width
|
568 |
+
else:
|
569 |
+
enc_width = width
|
570 |
+
|
571 |
+
encoder_params = FluxParams(
|
572 |
+
in_channels=3 * patch_size**2,
|
573 |
+
context_dim=self.encoder_context_dim,
|
574 |
+
patch_size=patch_size,
|
575 |
+
depth=enc_depth,
|
576 |
+
**DIT_ZOO[self.dit_mode],
|
577 |
+
)
|
578 |
+
decoder_params = FluxParams(
|
579 |
+
in_channels=3 * patch_size**2,
|
580 |
+
context_dim=context_dim + 1,
|
581 |
+
patch_size=patch_size,
|
582 |
+
depth=dec_depth,
|
583 |
+
**DIT_ZOO[self.dit_mode],
|
584 |
+
)
|
585 |
+
|
586 |
+
# width=4, dit_b_4 is the usual model
|
587 |
+
encoder_params.hidden_size = enc_width * (encoder_params.hidden_size // 4)
|
588 |
+
decoder_params.hidden_size = width * (decoder_params.hidden_size // 4)
|
589 |
+
encoder_params.axes_dim = [
|
590 |
+
(d // 4) * enc_width for d in encoder_params.axes_dim
|
591 |
+
]
|
592 |
+
decoder_params.axes_dim = [(d // 4) * width for d in decoder_params.axes_dim]
|
593 |
+
|
594 |
+
self.encoder = Flux(encoder_params, name="encoder")
|
595 |
+
self.decoder = Flux(decoder_params, name="decoder")
|
596 |
+
|
597 |
+
@torch.compile
|
598 |
+
def encode(self, img):
|
599 |
+
b, c, h, w = img.shape
|
600 |
+
|
601 |
+
img_idxs, txt_idxs = prepare_idxs(img, self.code_length, self.patch_size)
|
602 |
+
txt = torch.zeros(
|
603 |
+
(b, self.code_length, self.encoder_context_dim), device=img.device
|
604 |
+
)
|
605 |
+
|
606 |
+
_, code, aux = self.encoder(img, img_idxs, txt, txt_idxs, timesteps=None)
|
607 |
+
|
608 |
+
return code, aux
|
609 |
+
|
610 |
+
def _decode(self, img, code, timesteps):
|
611 |
+
b, c, h, w = img.shape
|
612 |
+
|
613 |
+
img_idxs, txt_idxs = prepare_idxs(
|
614 |
+
img,
|
615 |
+
self.code_length,
|
616 |
+
self.patch_size,
|
617 |
+
)
|
618 |
+
pred, _, decode_aux = self.decoder(
|
619 |
+
img, img_idxs, code, txt_idxs, timesteps=timesteps
|
620 |
+
)
|
621 |
+
return pred, decode_aux
|
622 |
+
|
623 |
+
@torch.compile
|
624 |
+
def decode(self, *args, **kwargs):
|
625 |
+
return self._decode(*args, **kwargs)
|
626 |
+
|
627 |
+
@torch.compile
|
628 |
+
def decode_checkpointed(self, *args, **kwargs):
|
629 |
+
# Need to compile(checkpoint), not checkpoint(compile)
|
630 |
+
assert not kwargs, kwargs
|
631 |
+
return torch.utils.checkpoint.checkpoint(
|
632 |
+
self._decode,
|
633 |
+
*args,
|
634 |
+
# WARNING: Do not use_reentrant=True with compile, it will silently
|
635 |
+
# produce incorrect gradients!
|
636 |
+
use_reentrant=False,
|
637 |
+
)
|
638 |
+
|
639 |
+
@torch.compile
|
640 |
+
def _quantize(self, code):
|
641 |
+
"""
|
642 |
+
Args:
|
643 |
+
code: [b codelength context dim]
|
644 |
+
|
645 |
+
Returns:
|
646 |
+
quantized code of the same shape
|
647 |
+
"""
|
648 |
+
b, t, f = code.shape
|
649 |
+
indices = None
|
650 |
+
if self.config.model.quantization_type == "noop":
|
651 |
+
quantized = code
|
652 |
+
quantizer_loss = torch.tensor(0.0).to(code.device)
|
653 |
+
elif self.config.model.quantization_type == "kl":
|
654 |
+
# colocating features of same token before split is maybe slightly
|
655 |
+
# better?
|
656 |
+
mean, logvar = _get_diagonal_gaussian(
|
657 |
+
einops.rearrange(code, "b t f -> b (f t)")
|
658 |
+
)
|
659 |
+
code = einops.rearrange(
|
660 |
+
_sample_diagonal_gaussian(mean, logvar),
|
661 |
+
"b (f t) -> b t f",
|
662 |
+
f=f // 2,
|
663 |
+
t=t,
|
664 |
+
)
|
665 |
+
quantizer_loss = _kl_diagonal_gaussian(mean, logvar)
|
666 |
+
elif self.config.model.quantization_type == "lfq":
|
667 |
+
assert f % self.config.model.codebook_size_for_entropy == 0, f
|
668 |
+
code = einops.rearrange(
|
669 |
+
code,
|
670 |
+
"b t (fg fh) -> b fg (t fh)",
|
671 |
+
fg=self.config.model.codebook_size_for_entropy,
|
672 |
+
)
|
673 |
+
|
674 |
+
(quantized, entropy_aux_loss, indices), breakdown = self.quantizer(
|
675 |
+
code, return_loss_breakdown=True
|
676 |
+
)
|
677 |
+
assert quantized.shape == code.shape
|
678 |
+
quantized = einops.rearrange(quantized, "b fg (t fh) -> b t (fg fh)", t=t)
|
679 |
+
|
680 |
+
quantizer_loss = (
|
681 |
+
entropy_aux_loss * self.config.model.entropy_loss_weight
|
682 |
+
+ breakdown.commitment * self.config.model.commit_loss_weight
|
683 |
+
)
|
684 |
+
code = quantized
|
685 |
+
else:
|
686 |
+
raise NotImplementedError
|
687 |
+
return code, indices, quantizer_loss
|
688 |
+
|
689 |
+
# def forward(
|
690 |
+
# self,
|
691 |
+
# img,
|
692 |
+
# noised_img,
|
693 |
+
# timesteps,
|
694 |
+
# enable_cfg=True,
|
695 |
+
# ):
|
696 |
+
# aux = {}
|
697 |
+
#
|
698 |
+
# code, encode_aux = self.encode(img)
|
699 |
+
#
|
700 |
+
# aux["original_code"] = code
|
701 |
+
#
|
702 |
+
# b, t, f = code.shape
|
703 |
+
#
|
704 |
+
# code, _, aux["quantizer_loss"] = self._quantize(code)
|
705 |
+
#
|
706 |
+
# mask = torch.ones_like(code[..., :1])
|
707 |
+
# code = torch.concatenate([code, mask], axis=-1)
|
708 |
+
# code_pre_cfg = code
|
709 |
+
#
|
710 |
+
# if self.config.model.enable_cfg and enable_cfg:
|
711 |
+
# cfg_mask = (torch.rand((b,), device=code.device) > 0.1)[:, None, None]
|
712 |
+
# code = code * cfg_mask
|
713 |
+
#
|
714 |
+
# v_est, decode_aux = self.decode(noised_img, code, timesteps)
|
715 |
+
# aux.update(decode_aux)
|
716 |
+
#
|
717 |
+
# if self.config.model.posttrain_sample:
|
718 |
+
# aux["posttrain_sample"] = self.reconstruct_checkpoint(code_pre_cfg)
|
719 |
+
#
|
720 |
+
# return v_est, aux
|
721 |
+
|
722 |
+
def forward(self, img):
|
723 |
+
return self.reconstruct(img)
|
724 |
+
|
725 |
+
def reconstruct_checkpoint(self, code):
|
726 |
+
with torch.autocast(
|
727 |
+
"cuda",
|
728 |
+
dtype=torch.bfloat16,
|
729 |
+
):
|
730 |
+
bs, *_ = code.shape
|
731 |
+
|
732 |
+
z = torch.randn((bs, 3, self.image_size, self.image_size)).cuda()
|
733 |
+
ts = (
|
734 |
+
torch.rand((bs, self.config.model.posttrain_sample_k + 1))
|
735 |
+
.cumsum(dim=1)
|
736 |
+
.cuda()
|
737 |
+
)
|
738 |
+
ts = ts - ts[:, :1]
|
739 |
+
ts = (ts / ts[:, -1:]).flip(dims=(1,))
|
740 |
+
dts = ts[:, :-1] - ts[:, 1:]
|
741 |
+
|
742 |
+
for i, (t, dt) in enumerate((zip(ts.T, dts.T))):
|
743 |
+
if self.config.model.posttrain_sample_enable_cfg:
|
744 |
+
mask = (torch.rand((bs,), device=code.device) > 0.1)[
|
745 |
+
:, None, None
|
746 |
+
].to(code.dtype)
|
747 |
+
code_t = code * mask
|
748 |
+
else:
|
749 |
+
code_t = code
|
750 |
+
|
751 |
+
vc, _ = self.decode_checkpointed(z, code_t, t)
|
752 |
+
|
753 |
+
z = z - dt[:, None, None, None] * vc
|
754 |
+
return z
|
755 |
+
|
756 |
+
@torch.no_grad()
|
757 |
+
def reconstruct(self, images, dtype=torch.bfloat16, code=None):
|
758 |
+
"""
|
759 |
+
Args:
|
760 |
+
images in [bchw] [-1, 1]
|
761 |
+
|
762 |
+
Returns:
|
763 |
+
images in [bchw] [-1, 1]
|
764 |
+
"""
|
765 |
+
model = self
|
766 |
+
config = self.config.eval.sampling
|
767 |
+
|
768 |
+
with torch.autocast(
|
769 |
+
"cuda",
|
770 |
+
dtype=dtype,
|
771 |
+
):
|
772 |
+
bs, c, h, w = images.shape
|
773 |
+
if code is None:
|
774 |
+
x = images.cuda()
|
775 |
+
prequantized_code = model.encode(x)[0].cuda()
|
776 |
+
code, indices, _ = model._quantize(prequantized_code)
|
777 |
+
|
778 |
+
z = torch.randn((bs, 3, h, w)).cuda()
|
779 |
+
|
780 |
+
mask = torch.ones_like(code[..., :1])
|
781 |
+
code = torch.concatenate([code * mask, mask], axis=-1)
|
782 |
+
|
783 |
+
cfg_mask = 0.0
|
784 |
+
null_code = code * cfg_mask if config.cfg != 1.0 else None
|
785 |
+
|
786 |
+
samples = rf_sample(
|
787 |
+
model,
|
788 |
+
z,
|
789 |
+
code,
|
790 |
+
null_code=null_code,
|
791 |
+
sample_steps=config.sample_steps,
|
792 |
+
cfg=config.cfg,
|
793 |
+
schedule=config.schedule,
|
794 |
+
)[-1].clip(-1, 1)
|
795 |
+
return samples.to(torch.float32), code, prequantized_code
|
796 |
+
|
797 |
+
|
798 |
+
def rf_loss(config, model, batch, aux_state):
|
799 |
+
x = batch["image"]
|
800 |
+
b = x.size(0)
|
801 |
+
|
802 |
+
if config.opt.schedule == "lognormal":
|
803 |
+
nt = torch.randn((b,)).to(x.device)
|
804 |
+
t = torch.sigmoid(nt)
|
805 |
+
elif config.opt.schedule == "fat_lognormal":
|
806 |
+
nt = torch.randn((b,)).to(x.device)
|
807 |
+
t = torch.sigmoid(nt)
|
808 |
+
t = torch.where(torch.rand_like(t) <= 0.9, t, torch.rand_like(t))
|
809 |
+
elif config.opt.schedule == "uniform":
|
810 |
+
t = torch.rand((b,), device=x.device)
|
811 |
+
elif config.opt.schedule.startswith("debug"):
|
812 |
+
p = float(config.opt.schedule.split("_")[1])
|
813 |
+
t = torch.ones((b,), device=x.device) * p
|
814 |
+
else:
|
815 |
+
raise NotImplementedError
|
816 |
+
|
817 |
+
t = t.view([b, *([1] * len(x.shape[1:]))])
|
818 |
+
z1 = torch.randn_like(x)
|
819 |
+
zt = (1 - t) * x + t * z1
|
820 |
+
|
821 |
+
zt, t = zt.to(x.dtype), t.to(x.dtype)
|
822 |
+
|
823 |
+
vtheta, aux = model(
|
824 |
+
img=x,
|
825 |
+
noised_img=zt,
|
826 |
+
timesteps=t.reshape((b,)),
|
827 |
+
)
|
828 |
+
|
829 |
+
diff = z1 - vtheta - x
|
830 |
+
x_pred = zt - vtheta * t
|
831 |
+
|
832 |
+
loss = ((diff) ** 2).mean(dim=list(range(1, len(x.shape))))
|
833 |
+
loss = loss.mean()
|
834 |
+
|
835 |
+
aux["loss_dict"] = {}
|
836 |
+
aux["loss_dict"]["diffusion_loss"] = loss
|
837 |
+
aux["loss_dict"]["quantizer_loss"] = aux["quantizer_loss"]
|
838 |
+
|
839 |
+
if config.opt.lpips_weight != 0.0:
|
840 |
+
aux_loss = 0.0
|
841 |
+
if config.model.posttrain_sample:
|
842 |
+
x_pred = aux["posttrain_sample"]
|
843 |
+
|
844 |
+
lpips_dist = aux_state["lpips_model"](x, x_pred)
|
845 |
+
lpips_dist = (config.opt.lpips_weight * lpips_dist).mean() + aux_loss
|
846 |
+
aux["loss_dict"]["lpips_loss"] = lpips_dist
|
847 |
+
else:
|
848 |
+
lpips_dist = 0.0
|
849 |
+
|
850 |
+
loss = loss + aux["quantizer_loss"] + lpips_dist
|
851 |
+
aux["loss_dict"]["total_loss"] = loss
|
852 |
+
return loss, aux
|
853 |
+
|
854 |
+
|
855 |
+
def _edm_to_flow_convention(noise_level):
|
856 |
+
# z = x + \sigma z'
|
857 |
+
return noise_level / (1 + noise_level)
|
858 |
+
|
859 |
+
|
860 |
+
def rf_sample(
|
861 |
+
model,
|
862 |
+
z,
|
863 |
+
code,
|
864 |
+
null_code=None,
|
865 |
+
sample_steps=25,
|
866 |
+
cfg=2.0,
|
867 |
+
schedule="linear",
|
868 |
+
):
|
869 |
+
b = z.size(0)
|
870 |
+
if schedule == "linear":
|
871 |
+
ts = torch.arange(1, sample_steps + 1).flip(0) / sample_steps
|
872 |
+
dts = torch.ones_like(ts) * (1.0 / sample_steps)
|
873 |
+
elif schedule.startswith("pow"):
|
874 |
+
p = float(schedule.split("_")[1])
|
875 |
+
ts = torch.arange(0, sample_steps + 1).flip(0) ** (1 / p) / sample_steps ** (
|
876 |
+
1 / p
|
877 |
+
)
|
878 |
+
dts = ts[:-1] - ts[1:]
|
879 |
+
else:
|
880 |
+
raise NotImplementedError
|
881 |
+
|
882 |
+
if model.config.eval.sampling.cfg_interval is None:
|
883 |
+
interval = None
|
884 |
+
else:
|
885 |
+
cfg_lo, cfg_hi = ast.literal_eval(model.config.eval.sampling.cfg_interval)
|
886 |
+
interval = _edm_to_flow_convention(cfg_lo), _edm_to_flow_convention(cfg_hi)
|
887 |
+
|
888 |
+
images = []
|
889 |
+
for i, (t, dt) in enumerate((zip(ts, dts))):
|
890 |
+
timesteps = torch.tensor([t] * b).to(z.device)
|
891 |
+
vc, decode_aux = model.decode(img=z, timesteps=timesteps, code=code)
|
892 |
+
|
893 |
+
if null_code is not None and (
|
894 |
+
interval is None
|
895 |
+
or ((t.item() >= interval[0]) and (t.item() <= interval[1]))
|
896 |
+
):
|
897 |
+
vu, _ = model.decode(img=z, timesteps=timesteps, code=null_code)
|
898 |
+
vc = vu + cfg * (vc - vu)
|
899 |
+
|
900 |
+
z = z - dt * vc
|
901 |
+
images.append(z)
|
902 |
+
return images
|
903 |
+
|
904 |
+
|
905 |
+
def build_model(config):
|
906 |
+
with tempfile.TemporaryDirectory() as log_dir:
|
907 |
+
MUP_ENABLED = config.model.enable_mup
|
908 |
+
model_partial = FlowMo
|
909 |
+
|
910 |
+
shared_kwargs = dict(config=config)
|
911 |
+
model = model_partial(
|
912 |
+
**shared_kwargs,
|
913 |
+
width=config.model.mup_width,
|
914 |
+
).cuda()
|
915 |
+
|
916 |
+
if config.model.enable_mup:
|
917 |
+
print("Mup enabled!")
|
918 |
+
with torch.device("cpu"):
|
919 |
+
base_model = model_partial(
|
920 |
+
**shared_kwargs, width=config.model.mup_width
|
921 |
+
)
|
922 |
+
delta_model = model_partial(
|
923 |
+
**shared_kwargs,
|
924 |
+
width=(
|
925 |
+
config.model.mup_width * 4 if config.model.mup_width == 1 else 1
|
926 |
+
),
|
927 |
+
)
|
928 |
+
true_model = model_partial(
|
929 |
+
**shared_kwargs, width=config.model.mup_width
|
930 |
+
)
|
931 |
+
|
932 |
+
if torch.distributed.is_initialized():
|
933 |
+
bsh_path = os.path.join(log_dir, f"{dist.get_rank()}.bsh")
|
934 |
+
else:
|
935 |
+
bsh_path = os.path.join(log_dir, "0.bsh")
|
936 |
+
set_base_shapes(
|
937 |
+
true_model, base_model, delta=delta_model, savefile=bsh_path
|
938 |
+
)
|
939 |
+
|
940 |
+
model = set_base_shapes(model, base=bsh_path)
|
941 |
+
|
942 |
+
for module in model.modules():
|
943 |
+
if isinstance(module, MuReadout):
|
944 |
+
module.width_mult = lambda: module.weight.infshape.width_mult()
|
945 |
+
return model
|
src/vqvaes/flowmo/lookup_free_quantize.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code is from https://github.com/TencentARC/SEED-Voken. Thanks!
|
3 |
+
|
4 |
+
Lookup Free Quantization
|
5 |
+
Proposed in https://arxiv.org/abs/2310.05737
|
6 |
+
|
7 |
+
In the simplest setup, each dimension is quantized into {-1, 1}.
|
8 |
+
An entropy penalty is used to encourage utilization.
|
9 |
+
|
10 |
+
Refer to
|
11 |
+
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py
|
12 |
+
https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py
|
13 |
+
"""
|
14 |
+
|
15 |
+
from collections import namedtuple
|
16 |
+
from math import ceil, log2
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from einops import pack, rearrange, reduce, unpack
|
21 |
+
from torch import einsum
|
22 |
+
from torch.nn import Module
|
23 |
+
|
24 |
+
# constants
|
25 |
+
|
26 |
+
LossBreakdown = namedtuple(
|
27 |
+
"LossBreakdown",
|
28 |
+
["per_sample_entropy", "codebook_entropy", "commitment", "avg_probs"],
|
29 |
+
)
|
30 |
+
|
31 |
+
# helper functions
|
32 |
+
|
33 |
+
|
34 |
+
def exists(v):
|
35 |
+
return v is not None
|
36 |
+
|
37 |
+
|
38 |
+
def default(*args):
|
39 |
+
for arg in args:
|
40 |
+
if exists(arg):
|
41 |
+
return arg() if callable(arg) else arg
|
42 |
+
return None
|
43 |
+
|
44 |
+
|
45 |
+
def pack_one(t, pattern):
|
46 |
+
return pack([t], pattern)
|
47 |
+
|
48 |
+
|
49 |
+
def unpack_one(t, ps, pattern):
|
50 |
+
return unpack(t, ps, pattern)[0]
|
51 |
+
|
52 |
+
|
53 |
+
# entropy
|
54 |
+
|
55 |
+
# def log(t, eps = 1e-5):
|
56 |
+
# return t.clamp(min = eps).log()
|
57 |
+
|
58 |
+
|
59 |
+
def entropy(prob):
|
60 |
+
return (-prob * torch.log(prob + 1e-5)).sum(dim=-1)
|
61 |
+
|
62 |
+
|
63 |
+
# class
|
64 |
+
|
65 |
+
|
66 |
+
def mult_along_first_dims(x, y):
|
67 |
+
"""
|
68 |
+
returns x * y elementwise along the leading dimensions of y
|
69 |
+
"""
|
70 |
+
ndim_to_expand = x.ndim - y.ndim
|
71 |
+
for _ in range(ndim_to_expand):
|
72 |
+
y = y.unsqueeze(-1)
|
73 |
+
return x * y
|
74 |
+
|
75 |
+
|
76 |
+
def masked_mean(x, m):
|
77 |
+
"""
|
78 |
+
takes the mean of the elements of x that are not masked
|
79 |
+
the mean is taken along the shared leading dims of m
|
80 |
+
equivalent to: x[m].mean(tuple(range(m.ndim)))
|
81 |
+
|
82 |
+
The benefit of using masked_mean rather than using
|
83 |
+
tensor indexing is that masked_mean is much faster
|
84 |
+
for torch-compile on batches.
|
85 |
+
|
86 |
+
The drawback is larger floating point errors
|
87 |
+
"""
|
88 |
+
x = mult_along_first_dims(x, m)
|
89 |
+
x = x / m.sum()
|
90 |
+
return x.sum(tuple(range(m.ndim)))
|
91 |
+
|
92 |
+
|
93 |
+
def entropy_loss(
|
94 |
+
logits,
|
95 |
+
mask=None,
|
96 |
+
# temperature=0.01,
|
97 |
+
sample_minimization_weight=1.0,
|
98 |
+
batch_maximization_weight=1.0,
|
99 |
+
eps=1e-5,
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
Entropy loss of unnormalized logits
|
103 |
+
|
104 |
+
logits: Affinities are over the last dimension
|
105 |
+
|
106 |
+
https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279
|
107 |
+
LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024)
|
108 |
+
"""
|
109 |
+
# import pdb
|
110 |
+
# pdb.set_trace()
|
111 |
+
# print(logits.shape)
|
112 |
+
# raise
|
113 |
+
|
114 |
+
temperature = 0.1
|
115 |
+
probs = F.softmax(logits / temperature, -1)
|
116 |
+
log_probs = F.log_softmax(logits / temperature + eps, -1)
|
117 |
+
|
118 |
+
if mask is not None:
|
119 |
+
# avg_probs = probs[mask].mean(tuple(range(probs.ndim - 1)))
|
120 |
+
# avg_probs = einx.mean("... D -> D", probs[mask])
|
121 |
+
|
122 |
+
avg_probs = masked_mean(probs, mask)
|
123 |
+
# avg_probs = einx.mean("... D -> D", avg_probs)
|
124 |
+
else:
|
125 |
+
avg_probs = reduce(probs, "... D -> D", "mean")
|
126 |
+
|
127 |
+
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps))
|
128 |
+
|
129 |
+
sample_entropy = -torch.sum(probs * log_probs, -1)
|
130 |
+
if mask is not None:
|
131 |
+
# sample_entropy = sample_entropy[mask].mean()
|
132 |
+
sample_entropy = masked_mean(sample_entropy, mask).mean()
|
133 |
+
else:
|
134 |
+
sample_entropy = torch.mean(sample_entropy)
|
135 |
+
|
136 |
+
loss = (sample_minimization_weight * sample_entropy) - (
|
137 |
+
batch_maximization_weight * avg_entropy
|
138 |
+
)
|
139 |
+
|
140 |
+
return sample_entropy, avg_entropy, loss
|
141 |
+
|
142 |
+
|
143 |
+
class LFQ(Module):
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
*,
|
147 |
+
dim=None,
|
148 |
+
codebook_size=None,
|
149 |
+
num_codebooks=1,
|
150 |
+
sample_minimization_weight=1.0,
|
151 |
+
batch_maximization_weight=1.0,
|
152 |
+
token_factorization=False,
|
153 |
+
factorized_bits=[9, 9],
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
# some assert validations
|
158 |
+
|
159 |
+
assert exists(dim) or exists(
|
160 |
+
codebook_size
|
161 |
+
), "either dim or codebook_size must be specified for LFQ"
|
162 |
+
assert (
|
163 |
+
not exists(codebook_size) or log2(codebook_size).is_integer()
|
164 |
+
), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
|
165 |
+
|
166 |
+
self.codebook_size = default(codebook_size, lambda: 2**dim)
|
167 |
+
self.codebook_dim = int(log2(codebook_size))
|
168 |
+
|
169 |
+
codebook_dims = self.codebook_dim * num_codebooks
|
170 |
+
dim = default(dim, codebook_dims)
|
171 |
+
|
172 |
+
has_projections = dim != codebook_dims
|
173 |
+
self.has_projections = has_projections
|
174 |
+
|
175 |
+
self.dim = dim
|
176 |
+
self.codebook_dim = self.codebook_dim
|
177 |
+
self.num_codebooks = num_codebooks
|
178 |
+
|
179 |
+
# for entropy loss
|
180 |
+
self.sample_minimization_weight = sample_minimization_weight
|
181 |
+
self.batch_maximization_weight = batch_maximization_weight
|
182 |
+
|
183 |
+
# for no auxiliary loss, during inference
|
184 |
+
self.token_factorization = token_factorization
|
185 |
+
if not self.token_factorization: # for first stage model
|
186 |
+
self.register_buffer(
|
187 |
+
"mask", 2 ** torch.arange(self.codebook_dim), persistent=False
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
self.factorized_bits = factorized_bits
|
191 |
+
self.register_buffer(
|
192 |
+
"pre_mask", 2 ** torch.arange(factorized_bits[0]), persistent=False
|
193 |
+
)
|
194 |
+
self.register_buffer(
|
195 |
+
"post_mask", 2 ** torch.arange(factorized_bits[1]), persistent=False
|
196 |
+
)
|
197 |
+
|
198 |
+
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
|
199 |
+
|
200 |
+
# codes
|
201 |
+
|
202 |
+
all_codes = torch.arange(codebook_size)
|
203 |
+
bits = self.indices_to_bits(all_codes)
|
204 |
+
codebook = bits * 2.0 - 1.0
|
205 |
+
|
206 |
+
self.register_buffer("codebook", codebook, persistent=False)
|
207 |
+
|
208 |
+
@property
|
209 |
+
def dtype(self):
|
210 |
+
return self.codebook.dtype
|
211 |
+
|
212 |
+
def indices_to_bits(self, x):
|
213 |
+
"""
|
214 |
+
x: long tensor of indices
|
215 |
+
|
216 |
+
returns big endian bits
|
217 |
+
"""
|
218 |
+
mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long)
|
219 |
+
# x is now big endian bits, the last dimension being the bits
|
220 |
+
x = (x.unsqueeze(-1) & mask) != 0
|
221 |
+
return x
|
222 |
+
|
223 |
+
def get_codebook_entry(self, x, bhwc, order): # 0610
|
224 |
+
if self.token_factorization:
|
225 |
+
if order == "pre":
|
226 |
+
mask = 2 ** torch.arange(
|
227 |
+
self.factorized_bits[0], device=x.device, dtype=torch.long
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
mask = 2 ** torch.arange(
|
231 |
+
self.factorized_bits[1], device=x.device, dtype=torch.long
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
mask = 2 ** torch.arange(
|
235 |
+
self.codebook_dim, device=x.device, dtype=torch.long
|
236 |
+
)
|
237 |
+
|
238 |
+
x = (x.unsqueeze(-1) & mask) != 0
|
239 |
+
x = x * 2.0 - 1.0 # back to the float
|
240 |
+
## scale back to the
|
241 |
+
b, h, w, c = bhwc
|
242 |
+
x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c)
|
243 |
+
x = rearrange(x, "b h w c -> b c h w")
|
244 |
+
return x
|
245 |
+
|
246 |
+
def bits_to_indices(self, bits):
|
247 |
+
"""
|
248 |
+
bits: bool tensor of big endian bits, where the last dimension is the bit dimension
|
249 |
+
|
250 |
+
returns indices, which are long integers from 0 to self.codebook_size
|
251 |
+
"""
|
252 |
+
assert bits.shape[-1] == self.codebook_dim
|
253 |
+
indices = 2 ** torch.arange(
|
254 |
+
0,
|
255 |
+
self.codebook_dim,
|
256 |
+
1,
|
257 |
+
dtype=torch.long,
|
258 |
+
device=bits.device,
|
259 |
+
)
|
260 |
+
return (bits * indices).sum(-1)
|
261 |
+
|
262 |
+
def decode(self, x):
|
263 |
+
"""
|
264 |
+
x: ... NH
|
265 |
+
where NH is number of codebook heads
|
266 |
+
A longtensor of codebook indices, containing values from
|
267 |
+
0 to self.codebook_size
|
268 |
+
"""
|
269 |
+
x = self.indices_to_bits(x)
|
270 |
+
# to some sort of float
|
271 |
+
x = x.to(self.dtype)
|
272 |
+
# -1 or 1
|
273 |
+
x = x * 2 - 1
|
274 |
+
x = rearrange(x, "... NC Z-> ... (NC Z)")
|
275 |
+
return x
|
276 |
+
|
277 |
+
def forward(
|
278 |
+
self,
|
279 |
+
x,
|
280 |
+
inv_temperature=100.0,
|
281 |
+
return_loss_breakdown=False,
|
282 |
+
mask=None,
|
283 |
+
return_loss=True,
|
284 |
+
):
|
285 |
+
"""
|
286 |
+
einstein notation
|
287 |
+
b - batch
|
288 |
+
n - sequence (or flattened spatial dimensions)
|
289 |
+
d - feature dimension, which is also log2(codebook size)
|
290 |
+
c - number of codebook dim
|
291 |
+
"""
|
292 |
+
# x = x.tanh() * 1.5
|
293 |
+
|
294 |
+
x = rearrange(x, "b d ... -> b ... d")
|
295 |
+
x, ps = pack_one(x, "b * d")
|
296 |
+
# split out number of codebooks
|
297 |
+
|
298 |
+
x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks)
|
299 |
+
|
300 |
+
codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype)
|
301 |
+
quantized = torch.where(
|
302 |
+
x > 0, codebook_value, -codebook_value
|
303 |
+
) # higher than 0 filled
|
304 |
+
|
305 |
+
# calculate indices
|
306 |
+
if self.token_factorization:
|
307 |
+
indices_pre = reduce(
|
308 |
+
(quantized[..., : self.factorized_bits[0]] > 0).int()
|
309 |
+
* self.pre_mask.int(),
|
310 |
+
"b n c d -> b n c",
|
311 |
+
"sum",
|
312 |
+
)
|
313 |
+
indices_post = reduce(
|
314 |
+
(quantized[..., self.factorized_bits[0] :] > 0).int()
|
315 |
+
* self.post_mask.int(),
|
316 |
+
"b n c d -> b n c",
|
317 |
+
"sum",
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
# print(quantized.shape)
|
321 |
+
indices = reduce(
|
322 |
+
(quantized > 0).int() * self.mask.int(), "b n c d -> b n c", "sum"
|
323 |
+
)
|
324 |
+
# print(indices.shape)
|
325 |
+
|
326 |
+
# entropy aux loss
|
327 |
+
|
328 |
+
if self.training and return_loss:
|
329 |
+
logits = 2 * einsum("... i d, j d -> ... i j", x, self.codebook)
|
330 |
+
# the same as euclidean distance up to a constant
|
331 |
+
# import pdb
|
332 |
+
# pdb.set_trace()
|
333 |
+
|
334 |
+
per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss(
|
335 |
+
logits=logits,
|
336 |
+
sample_minimization_weight=self.sample_minimization_weight,
|
337 |
+
batch_maximization_weight=self.batch_maximization_weight,
|
338 |
+
)
|
339 |
+
|
340 |
+
avg_probs = self.zero
|
341 |
+
else:
|
342 |
+
# logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook)
|
343 |
+
# probs = F.softmax(logits / 0.01, -1)
|
344 |
+
# avg_probs = reduce(probs, "b n c d -> b d", "mean")
|
345 |
+
# avg_probs = torch.sum(avg_probs, 0) #batch dimension
|
346 |
+
# if not training, just return dummy 0
|
347 |
+
per_sample_entropy = codebook_entropy = self.zero
|
348 |
+
## calculate the codebook_entropy needed for one batch evaluation
|
349 |
+
entropy_aux_loss = self.zero
|
350 |
+
avg_probs = self.zero
|
351 |
+
|
352 |
+
# commit loss
|
353 |
+
|
354 |
+
if self.training:
|
355 |
+
commit_loss = F.mse_loss(x, quantized.detach(), reduction="none")
|
356 |
+
|
357 |
+
if exists(mask):
|
358 |
+
commit_loss = commit_loss[mask]
|
359 |
+
|
360 |
+
commit_loss = commit_loss.mean()
|
361 |
+
else:
|
362 |
+
commit_loss = self.zero
|
363 |
+
|
364 |
+
# use straight-through gradients (optionally with custom activation fn) if training
|
365 |
+
|
366 |
+
quantized = x + (quantized - x).detach() # transfer to quantized
|
367 |
+
|
368 |
+
# merge back codebook dim
|
369 |
+
|
370 |
+
quantized = rearrange(quantized, "b n c d -> b n (c d)")
|
371 |
+
|
372 |
+
# reconstitute image or video dimensions
|
373 |
+
|
374 |
+
quantized = unpack_one(quantized, ps, "b * d")
|
375 |
+
quantized = rearrange(quantized, "b ... d -> b d ...")
|
376 |
+
|
377 |
+
if self.token_factorization:
|
378 |
+
indices_pre = unpack_one(indices_pre, ps, "b * c")
|
379 |
+
indices_post = unpack_one(indices_post, ps, "b * c")
|
380 |
+
indices_pre = indices_pre.flatten()
|
381 |
+
indices_post = indices_post.flatten()
|
382 |
+
indices = (indices_pre, indices_post)
|
383 |
+
else:
|
384 |
+
# print(indices.shape, ps)
|
385 |
+
indices = unpack_one(indices, ps, "b * c")
|
386 |
+
# print(indices.shape)
|
387 |
+
indices = indices.flatten()
|
388 |
+
|
389 |
+
ret = (quantized, entropy_aux_loss, indices)
|
390 |
+
|
391 |
+
if not return_loss_breakdown:
|
392 |
+
return ret
|
393 |
+
|
394 |
+
return ret, LossBreakdown(
|
395 |
+
per_sample_entropy, codebook_entropy, commit_loss, avg_probs
|
396 |
+
)
|
src/vqvaes/infinity/conv.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops import rearrange
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class Conv(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
in_channels,
|
11 |
+
out_channels,
|
12 |
+
kernel_size,
|
13 |
+
stride=1,
|
14 |
+
padding=0,
|
15 |
+
cnn_type="2d",
|
16 |
+
causal_offset=0,
|
17 |
+
temporal_down=False,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.cnn_type = cnn_type
|
21 |
+
self.slice_seq_len = 17
|
22 |
+
|
23 |
+
if cnn_type == "2d":
|
24 |
+
self.conv = nn.Conv2d(
|
25 |
+
in_channels, out_channels, kernel_size, stride=stride, padding=padding
|
26 |
+
)
|
27 |
+
if cnn_type == "3d":
|
28 |
+
if temporal_down == False:
|
29 |
+
stride = (1, stride, stride)
|
30 |
+
else:
|
31 |
+
stride = (stride, stride, stride)
|
32 |
+
self.conv = nn.Conv3d(
|
33 |
+
in_channels, out_channels, kernel_size, stride=stride, padding=0
|
34 |
+
)
|
35 |
+
if isinstance(kernel_size, int):
|
36 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
37 |
+
self.padding = (
|
38 |
+
kernel_size[0] - 1 + causal_offset, # Temporal causal padding
|
39 |
+
padding, # Height padding
|
40 |
+
padding, # Width padding
|
41 |
+
)
|
42 |
+
self.causal_offset = causal_offset
|
43 |
+
self.stride = stride
|
44 |
+
self.kernel_size = kernel_size
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
if self.cnn_type == "2d":
|
48 |
+
if x.ndim == 5:
|
49 |
+
B, C, T, H, W = x.shape
|
50 |
+
x = rearrange(x, "B C T H W -> (B T) C H W")
|
51 |
+
x = self.conv(x)
|
52 |
+
x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
|
53 |
+
return x
|
54 |
+
else:
|
55 |
+
return self.conv(x)
|
56 |
+
if self.cnn_type == "3d":
|
57 |
+
assert (
|
58 |
+
self.stride[0] == 1 or self.stride[0] == 2
|
59 |
+
), f"only temporal stride = 1 or 2 are supported"
|
60 |
+
xs = []
|
61 |
+
for i in range(0, x.shape[2], self.slice_seq_len + self.stride[0] - 1):
|
62 |
+
st = i
|
63 |
+
en = min(i + self.slice_seq_len, x.shape[2])
|
64 |
+
_x = x[:, :, st:en, :, :]
|
65 |
+
if i == 0:
|
66 |
+
_x = F.pad(
|
67 |
+
_x,
|
68 |
+
(
|
69 |
+
self.padding[2],
|
70 |
+
self.padding[2], # Width
|
71 |
+
self.padding[1],
|
72 |
+
self.padding[1], # Height
|
73 |
+
self.padding[0],
|
74 |
+
0,
|
75 |
+
),
|
76 |
+
) # Temporal
|
77 |
+
else:
|
78 |
+
padding_0 = self.kernel_size[0] - 1
|
79 |
+
_x = F.pad(
|
80 |
+
_x,
|
81 |
+
(
|
82 |
+
self.padding[2],
|
83 |
+
self.padding[2], # Width
|
84 |
+
self.padding[1],
|
85 |
+
self.padding[1], # Height
|
86 |
+
padding_0,
|
87 |
+
0,
|
88 |
+
),
|
89 |
+
) # Temporal
|
90 |
+
_x[
|
91 |
+
:,
|
92 |
+
:,
|
93 |
+
:padding_0,
|
94 |
+
self.padding[1] : _x.shape[-2] - self.padding[1],
|
95 |
+
self.padding[2] : _x.shape[-1] - self.padding[2],
|
96 |
+
] += x[:, :, i - padding_0 : i, :, :]
|
97 |
+
_x = self.conv(_x)
|
98 |
+
xs.append(_x)
|
99 |
+
try:
|
100 |
+
x = torch.cat(xs, dim=2)
|
101 |
+
except:
|
102 |
+
device = x.device
|
103 |
+
del x
|
104 |
+
xs = [_x.cpu().pin_memory() for _x in xs]
|
105 |
+
torch.cuda.empty_cache()
|
106 |
+
x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device)
|
107 |
+
return x
|
src/vqvaes/infinity/dynamic_resolution.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import tqdm
|
4 |
+
|
5 |
+
vae_stride = 16
|
6 |
+
ratio2hws = {
|
7 |
+
1.000: [
|
8 |
+
(1, 1),
|
9 |
+
(2, 2),
|
10 |
+
(4, 4),
|
11 |
+
(6, 6),
|
12 |
+
(8, 8),
|
13 |
+
(12, 12),
|
14 |
+
(16, 16),
|
15 |
+
(20, 20),
|
16 |
+
(24, 24),
|
17 |
+
(32, 32),
|
18 |
+
(40, 40),
|
19 |
+
(48, 48),
|
20 |
+
(64, 64),
|
21 |
+
],
|
22 |
+
1.250: [
|
23 |
+
(1, 1),
|
24 |
+
(2, 2),
|
25 |
+
(3, 3),
|
26 |
+
(5, 4),
|
27 |
+
(10, 8),
|
28 |
+
(15, 12),
|
29 |
+
(20, 16),
|
30 |
+
(25, 20),
|
31 |
+
(30, 24),
|
32 |
+
(35, 28),
|
33 |
+
(45, 36),
|
34 |
+
(55, 44),
|
35 |
+
(70, 56),
|
36 |
+
],
|
37 |
+
1.333: [
|
38 |
+
(1, 1),
|
39 |
+
(2, 2),
|
40 |
+
(4, 3),
|
41 |
+
(8, 6),
|
42 |
+
(12, 9),
|
43 |
+
(16, 12),
|
44 |
+
(20, 15),
|
45 |
+
(24, 18),
|
46 |
+
(28, 21),
|
47 |
+
(36, 27),
|
48 |
+
(48, 36),
|
49 |
+
(60, 45),
|
50 |
+
(72, 54),
|
51 |
+
],
|
52 |
+
1.500: [
|
53 |
+
(1, 1),
|
54 |
+
(2, 2),
|
55 |
+
(3, 2),
|
56 |
+
(6, 4),
|
57 |
+
(9, 6),
|
58 |
+
(15, 10),
|
59 |
+
(21, 14),
|
60 |
+
(27, 18),
|
61 |
+
(33, 22),
|
62 |
+
(39, 26),
|
63 |
+
(48, 32),
|
64 |
+
(63, 42),
|
65 |
+
(78, 52),
|
66 |
+
],
|
67 |
+
1.750: [
|
68 |
+
(1, 1),
|
69 |
+
(2, 2),
|
70 |
+
(3, 3),
|
71 |
+
(7, 4),
|
72 |
+
(11, 6),
|
73 |
+
(14, 8),
|
74 |
+
(21, 12),
|
75 |
+
(28, 16),
|
76 |
+
(35, 20),
|
77 |
+
(42, 24),
|
78 |
+
(56, 32),
|
79 |
+
(70, 40),
|
80 |
+
(84, 48),
|
81 |
+
],
|
82 |
+
2.000: [
|
83 |
+
(1, 1),
|
84 |
+
(2, 2),
|
85 |
+
(4, 2),
|
86 |
+
(6, 3),
|
87 |
+
(10, 5),
|
88 |
+
(16, 8),
|
89 |
+
(22, 11),
|
90 |
+
(30, 15),
|
91 |
+
(38, 19),
|
92 |
+
(46, 23),
|
93 |
+
(60, 30),
|
94 |
+
(74, 37),
|
95 |
+
(90, 45),
|
96 |
+
],
|
97 |
+
2.500: [
|
98 |
+
(1, 1),
|
99 |
+
(2, 2),
|
100 |
+
(5, 2),
|
101 |
+
(10, 4),
|
102 |
+
(15, 6),
|
103 |
+
(20, 8),
|
104 |
+
(25, 10),
|
105 |
+
(30, 12),
|
106 |
+
(40, 16),
|
107 |
+
(50, 20),
|
108 |
+
(65, 26),
|
109 |
+
(80, 32),
|
110 |
+
(100, 40),
|
111 |
+
],
|
112 |
+
3.000: [
|
113 |
+
(1, 1),
|
114 |
+
(2, 2),
|
115 |
+
(6, 2),
|
116 |
+
(9, 3),
|
117 |
+
(15, 5),
|
118 |
+
(21, 7),
|
119 |
+
(27, 9),
|
120 |
+
(36, 12),
|
121 |
+
(45, 15),
|
122 |
+
(54, 18),
|
123 |
+
(72, 24),
|
124 |
+
(90, 30),
|
125 |
+
(111, 37),
|
126 |
+
],
|
127 |
+
}
|
128 |
+
full_ratio2hws = {}
|
129 |
+
for ratio, hws in ratio2hws.items():
|
130 |
+
full_ratio2hws[ratio] = hws
|
131 |
+
full_ratio2hws[int(1 / ratio * 1000) / 1000] = [(item[1], item[0]) for item in hws]
|
132 |
+
|
133 |
+
dynamic_resolution_h_w = {}
|
134 |
+
predefined_HW_Scales_dynamic = {}
|
135 |
+
for ratio in full_ratio2hws:
|
136 |
+
dynamic_resolution_h_w[ratio] = {}
|
137 |
+
for ind, leng in enumerate([7, 10, 13]):
|
138 |
+
h, w = (
|
139 |
+
full_ratio2hws[ratio][leng - 1][0],
|
140 |
+
full_ratio2hws[ratio][leng - 1][1],
|
141 |
+
) # feature map size
|
142 |
+
pixel = (h * vae_stride, w * vae_stride) # The original image (H, W)
|
143 |
+
dynamic_resolution_h_w[ratio][pixel[1]] = {
|
144 |
+
"pixel": pixel,
|
145 |
+
"scales": full_ratio2hws[ratio][:leng],
|
146 |
+
} # W as key
|
147 |
+
predefined_HW_Scales_dynamic[(h, w)] = full_ratio2hws[ratio][:leng]
|
src/vqvaes/infinity/flux_vqgan.py
ADDED
@@ -0,0 +1,771 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import imageio
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from einops import rearrange
|
7 |
+
from torch import Tensor, nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
from torchvision import transforms
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
import torch.utils.checkpoint as checkpoint
|
13 |
+
|
14 |
+
from .conv import Conv
|
15 |
+
from .multiscale_bsq import MultiScaleBSQ
|
16 |
+
|
17 |
+
ptdtype = {None: torch.float32, "fp32": torch.float32, "bf16": torch.bfloat16}
|
18 |
+
|
19 |
+
|
20 |
+
class Normalize(nn.Module):
|
21 |
+
def __init__(self, in_channels, norm_type, norm_axis="spatial"):
|
22 |
+
super().__init__()
|
23 |
+
self.norm_axis = norm_axis
|
24 |
+
assert norm_type in ["group", "batch", "no"]
|
25 |
+
if norm_type == "group":
|
26 |
+
if in_channels % 32 == 0:
|
27 |
+
self.norm = nn.GroupNorm(
|
28 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
29 |
+
)
|
30 |
+
elif in_channels % 24 == 0:
|
31 |
+
self.norm = nn.GroupNorm(
|
32 |
+
num_groups=24, num_channels=in_channels, eps=1e-6, affine=True
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
raise NotImplementedError
|
36 |
+
elif norm_type == "batch":
|
37 |
+
self.norm = nn.SyncBatchNorm(
|
38 |
+
in_channels, track_running_stats=False
|
39 |
+
) # Runtime Error: grad inplace if set track_running_stats to True
|
40 |
+
elif norm_type == "no":
|
41 |
+
self.norm = nn.Identity()
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
if self.norm_axis == "spatial":
|
45 |
+
if x.ndim == 4:
|
46 |
+
x = self.norm(x)
|
47 |
+
else:
|
48 |
+
B, C, T, H, W = x.shape
|
49 |
+
x = rearrange(x, "B C T H W -> (B T) C H W")
|
50 |
+
x = self.norm(x)
|
51 |
+
x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
|
52 |
+
elif self.norm_axis == "spatial-temporal":
|
53 |
+
x = self.norm(x)
|
54 |
+
else:
|
55 |
+
raise NotImplementedError
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def swish(x: Tensor) -> Tensor:
|
60 |
+
try:
|
61 |
+
return x * torch.sigmoid(x)
|
62 |
+
except:
|
63 |
+
device = x.device
|
64 |
+
x = x.cpu().pin_memory()
|
65 |
+
return (x * torch.sigmoid(x)).to(device=device)
|
66 |
+
|
67 |
+
|
68 |
+
class AttnBlock(nn.Module):
|
69 |
+
def __init__(self, in_channels, norm_type="group", cnn_param=None):
|
70 |
+
super().__init__()
|
71 |
+
self.in_channels = in_channels
|
72 |
+
|
73 |
+
self.norm = Normalize(
|
74 |
+
in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
|
75 |
+
)
|
76 |
+
|
77 |
+
self.q = Conv(in_channels, in_channels, kernel_size=1)
|
78 |
+
self.k = Conv(in_channels, in_channels, kernel_size=1)
|
79 |
+
self.v = Conv(in_channels, in_channels, kernel_size=1)
|
80 |
+
self.proj_out = Conv(in_channels, in_channels, kernel_size=1)
|
81 |
+
|
82 |
+
def attention(self, h_: Tensor) -> Tensor:
|
83 |
+
B, _, T, _, _ = h_.shape
|
84 |
+
h_ = self.norm(h_)
|
85 |
+
h_ = rearrange(h_, "B C T H W -> (B T) C H W") # spatial attention only
|
86 |
+
q = self.q(h_)
|
87 |
+
k = self.k(h_)
|
88 |
+
v = self.v(h_)
|
89 |
+
|
90 |
+
b, c, h, w = q.shape
|
91 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
92 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
93 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
94 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
95 |
+
|
96 |
+
return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T)
|
97 |
+
|
98 |
+
def forward(self, x: Tensor) -> Tensor:
|
99 |
+
return x + self.proj_out(self.attention(x))
|
100 |
+
|
101 |
+
|
102 |
+
class ResnetBlock(nn.Module):
|
103 |
+
def __init__(
|
104 |
+
self, in_channels: int, out_channels: int, norm_type="group", cnn_param=None
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
self.in_channels = in_channels
|
108 |
+
out_channels = in_channels if out_channels is None else out_channels
|
109 |
+
self.out_channels = out_channels
|
110 |
+
|
111 |
+
self.norm1 = Normalize(
|
112 |
+
in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
|
113 |
+
)
|
114 |
+
if cnn_param["res_conv_2d"] in ["half", "full"]:
|
115 |
+
self.conv1 = Conv(
|
116 |
+
in_channels,
|
117 |
+
out_channels,
|
118 |
+
kernel_size=3,
|
119 |
+
stride=1,
|
120 |
+
padding=1,
|
121 |
+
cnn_type="2d",
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
self.conv1 = Conv(
|
125 |
+
in_channels,
|
126 |
+
out_channels,
|
127 |
+
kernel_size=3,
|
128 |
+
stride=1,
|
129 |
+
padding=1,
|
130 |
+
cnn_type=cnn_param["cnn_type"],
|
131 |
+
)
|
132 |
+
self.norm2 = Normalize(
|
133 |
+
out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
|
134 |
+
)
|
135 |
+
if cnn_param["res_conv_2d"] in ["full"]:
|
136 |
+
self.conv2 = Conv(
|
137 |
+
out_channels,
|
138 |
+
out_channels,
|
139 |
+
kernel_size=3,
|
140 |
+
stride=1,
|
141 |
+
padding=1,
|
142 |
+
cnn_type="2d",
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
self.conv2 = Conv(
|
146 |
+
out_channels,
|
147 |
+
out_channels,
|
148 |
+
kernel_size=3,
|
149 |
+
stride=1,
|
150 |
+
padding=1,
|
151 |
+
cnn_type=cnn_param["cnn_type"],
|
152 |
+
)
|
153 |
+
if self.in_channels != self.out_channels:
|
154 |
+
self.nin_shortcut = Conv(
|
155 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
156 |
+
)
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
h = x
|
160 |
+
h = self.norm1(h)
|
161 |
+
h = swish(h)
|
162 |
+
h = self.conv1(h)
|
163 |
+
|
164 |
+
h = self.norm2(h)
|
165 |
+
h = swish(h)
|
166 |
+
h = self.conv2(h)
|
167 |
+
|
168 |
+
if self.in_channels != self.out_channels:
|
169 |
+
x = self.nin_shortcut(x)
|
170 |
+
|
171 |
+
return x + h
|
172 |
+
|
173 |
+
|
174 |
+
class Downsample(nn.Module):
|
175 |
+
def __init__(
|
176 |
+
self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False
|
177 |
+
):
|
178 |
+
super().__init__()
|
179 |
+
assert spatial_down == True
|
180 |
+
if cnn_type == "2d":
|
181 |
+
self.pad = (0, 1, 0, 1)
|
182 |
+
if cnn_type == "3d":
|
183 |
+
self.pad = (
|
184 |
+
0,
|
185 |
+
1,
|
186 |
+
0,
|
187 |
+
1,
|
188 |
+
0,
|
189 |
+
0,
|
190 |
+
) # add padding to the right for h-axis and w-axis. No padding for t-axis
|
191 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
192 |
+
self.conv = Conv(
|
193 |
+
in_channels,
|
194 |
+
in_channels,
|
195 |
+
kernel_size=3,
|
196 |
+
stride=2,
|
197 |
+
padding=0,
|
198 |
+
cnn_type=cnn_type,
|
199 |
+
temporal_down=temporal_down,
|
200 |
+
)
|
201 |
+
|
202 |
+
def forward(self, x: Tensor):
|
203 |
+
x = nn.functional.pad(x, self.pad, mode="constant", value=0)
|
204 |
+
x = self.conv(x)
|
205 |
+
return x
|
206 |
+
|
207 |
+
|
208 |
+
class Upsample(nn.Module):
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
in_channels,
|
212 |
+
cnn_type="2d",
|
213 |
+
spatial_up=False,
|
214 |
+
temporal_up=False,
|
215 |
+
use_pxsl=False,
|
216 |
+
):
|
217 |
+
super().__init__()
|
218 |
+
if cnn_type == "2d":
|
219 |
+
self.scale_factor = 2
|
220 |
+
self.causal_offset = 0
|
221 |
+
else:
|
222 |
+
assert spatial_up == True
|
223 |
+
if temporal_up:
|
224 |
+
self.scale_factor = (2, 2, 2)
|
225 |
+
self.causal_offset = -1
|
226 |
+
else:
|
227 |
+
self.scale_factor = (1, 2, 2)
|
228 |
+
self.causal_offset = 0
|
229 |
+
self.use_pxsl = use_pxsl
|
230 |
+
if self.use_pxsl:
|
231 |
+
self.conv = Conv(
|
232 |
+
in_channels,
|
233 |
+
in_channels * 4,
|
234 |
+
kernel_size=3,
|
235 |
+
stride=1,
|
236 |
+
padding=1,
|
237 |
+
cnn_type=cnn_type,
|
238 |
+
causal_offset=self.causal_offset,
|
239 |
+
)
|
240 |
+
self.pxsl = nn.PixelShuffle(2)
|
241 |
+
else:
|
242 |
+
self.conv = Conv(
|
243 |
+
in_channels,
|
244 |
+
in_channels,
|
245 |
+
kernel_size=3,
|
246 |
+
stride=1,
|
247 |
+
padding=1,
|
248 |
+
cnn_type=cnn_type,
|
249 |
+
causal_offset=self.causal_offset,
|
250 |
+
)
|
251 |
+
|
252 |
+
def forward(self, x: Tensor):
|
253 |
+
if self.use_pxsl:
|
254 |
+
x = self.conv(x)
|
255 |
+
x = self.pxsl(x)
|
256 |
+
else:
|
257 |
+
try:
|
258 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
|
259 |
+
except:
|
260 |
+
# shard across channel
|
261 |
+
_xs = []
|
262 |
+
for i in range(x.shape[1]):
|
263 |
+
_x = F.interpolate(
|
264 |
+
x[:, i : i + 1, ...],
|
265 |
+
scale_factor=self.scale_factor,
|
266 |
+
mode="nearest",
|
267 |
+
)
|
268 |
+
_xs.append(_x)
|
269 |
+
x = torch.cat(_xs, dim=1)
|
270 |
+
x = self.conv(x)
|
271 |
+
return x
|
272 |
+
|
273 |
+
|
274 |
+
class Encoder(nn.Module):
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
ch: int,
|
278 |
+
ch_mult: list[int],
|
279 |
+
num_res_blocks: int,
|
280 |
+
z_channels: int,
|
281 |
+
in_channels=3,
|
282 |
+
patch_size=8,
|
283 |
+
temporal_patch_size=4,
|
284 |
+
norm_type="group",
|
285 |
+
cnn_param=None,
|
286 |
+
use_checkpoint=False,
|
287 |
+
use_vae=True,
|
288 |
+
):
|
289 |
+
super().__init__()
|
290 |
+
self.max_down = np.log2(patch_size)
|
291 |
+
self.temporal_max_down = np.log2(temporal_patch_size)
|
292 |
+
self.temporal_down_offset = self.max_down - self.temporal_max_down
|
293 |
+
self.ch = ch
|
294 |
+
self.num_resolutions = len(ch_mult)
|
295 |
+
self.num_res_blocks = num_res_blocks
|
296 |
+
self.in_channels = in_channels
|
297 |
+
self.cnn_param = cnn_param
|
298 |
+
self.use_checkpoint = use_checkpoint
|
299 |
+
# downsampling
|
300 |
+
# self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
301 |
+
# cnn_param["cnn_type"] = "2d" for images, cnn_param["cnn_type"] = "3d" for videos
|
302 |
+
if cnn_param["conv_in_out_2d"] == "yes": # "yes" for video
|
303 |
+
self.conv_in = Conv(
|
304 |
+
in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d"
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
self.conv_in = Conv(
|
308 |
+
in_channels,
|
309 |
+
ch,
|
310 |
+
kernel_size=3,
|
311 |
+
stride=1,
|
312 |
+
padding=1,
|
313 |
+
cnn_type=cnn_param["cnn_type"],
|
314 |
+
)
|
315 |
+
|
316 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
317 |
+
self.in_ch_mult = in_ch_mult
|
318 |
+
self.down = nn.ModuleList()
|
319 |
+
block_in = self.ch
|
320 |
+
for i_level in range(self.num_resolutions):
|
321 |
+
block = nn.ModuleList()
|
322 |
+
attn = nn.ModuleList()
|
323 |
+
block_in = ch * in_ch_mult[i_level]
|
324 |
+
block_out = ch * ch_mult[i_level]
|
325 |
+
for _ in range(self.num_res_blocks):
|
326 |
+
block.append(
|
327 |
+
ResnetBlock(
|
328 |
+
in_channels=block_in,
|
329 |
+
out_channels=block_out,
|
330 |
+
norm_type=norm_type,
|
331 |
+
cnn_param=cnn_param,
|
332 |
+
)
|
333 |
+
)
|
334 |
+
block_in = block_out
|
335 |
+
down = nn.Module()
|
336 |
+
down.block = block
|
337 |
+
down.attn = attn
|
338 |
+
# downsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE
|
339 |
+
spatial_down = True if i_level < self.max_down else False
|
340 |
+
temporal_down = (
|
341 |
+
True
|
342 |
+
if i_level < self.max_down and i_level >= self.temporal_down_offset
|
343 |
+
else False
|
344 |
+
)
|
345 |
+
if spatial_down or temporal_down:
|
346 |
+
down.downsample = Downsample(
|
347 |
+
block_in,
|
348 |
+
cnn_type=cnn_param["cnn_type"],
|
349 |
+
spatial_down=spatial_down,
|
350 |
+
temporal_down=temporal_down,
|
351 |
+
)
|
352 |
+
self.down.append(down)
|
353 |
+
|
354 |
+
# middle
|
355 |
+
self.mid = nn.Module()
|
356 |
+
self.mid.block_1 = ResnetBlock(
|
357 |
+
in_channels=block_in,
|
358 |
+
out_channels=block_in,
|
359 |
+
norm_type=norm_type,
|
360 |
+
cnn_param=cnn_param,
|
361 |
+
)
|
362 |
+
if cnn_param["cnn_attention"] == "yes":
|
363 |
+
self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param)
|
364 |
+
self.mid.block_2 = ResnetBlock(
|
365 |
+
in_channels=block_in,
|
366 |
+
out_channels=block_in,
|
367 |
+
norm_type=norm_type,
|
368 |
+
cnn_param=cnn_param,
|
369 |
+
)
|
370 |
+
|
371 |
+
# end
|
372 |
+
self.norm_out = Normalize(
|
373 |
+
block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
|
374 |
+
)
|
375 |
+
if cnn_param["conv_inner_2d"] == "yes":
|
376 |
+
self.conv_out = Conv(
|
377 |
+
block_in,
|
378 |
+
(int(use_vae) + 1) * z_channels,
|
379 |
+
kernel_size=3,
|
380 |
+
stride=1,
|
381 |
+
padding=1,
|
382 |
+
cnn_type="2d",
|
383 |
+
)
|
384 |
+
else:
|
385 |
+
self.conv_out = Conv(
|
386 |
+
block_in,
|
387 |
+
(int(use_vae) + 1) * z_channels,
|
388 |
+
kernel_size=3,
|
389 |
+
stride=1,
|
390 |
+
padding=1,
|
391 |
+
cnn_type=cnn_param["cnn_type"],
|
392 |
+
)
|
393 |
+
|
394 |
+
def forward(self, x, return_hidden=False):
|
395 |
+
if not self.use_checkpoint:
|
396 |
+
return self._forward(x, return_hidden=return_hidden)
|
397 |
+
else:
|
398 |
+
return checkpoint.checkpoint(
|
399 |
+
self._forward, x, return_hidden, use_reentrant=False
|
400 |
+
)
|
401 |
+
|
402 |
+
def _forward(self, x: Tensor, return_hidden=False) -> Tensor:
|
403 |
+
# downsampling
|
404 |
+
h0 = self.conv_in(x)
|
405 |
+
hs = [h0]
|
406 |
+
for i_level in range(self.num_resolutions):
|
407 |
+
for i_block in range(self.num_res_blocks):
|
408 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
409 |
+
if len(self.down[i_level].attn) > 0:
|
410 |
+
h = self.down[i_level].attn[i_block](h)
|
411 |
+
hs.append(h)
|
412 |
+
if hasattr(self.down[i_level], "downsample"):
|
413 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
414 |
+
|
415 |
+
# middle
|
416 |
+
h = hs[-1]
|
417 |
+
hs_mid = [h]
|
418 |
+
h = self.mid.block_1(h)
|
419 |
+
if self.cnn_param["cnn_attention"] == "yes":
|
420 |
+
h = self.mid.attn_1(h)
|
421 |
+
h = self.mid.block_2(h)
|
422 |
+
hs_mid.append(h)
|
423 |
+
# end
|
424 |
+
h = self.norm_out(h)
|
425 |
+
h = swish(h)
|
426 |
+
h = self.conv_out(h)
|
427 |
+
if return_hidden:
|
428 |
+
return h, hs, hs_mid
|
429 |
+
else:
|
430 |
+
return h
|
431 |
+
|
432 |
+
|
433 |
+
class Decoder(nn.Module):
|
434 |
+
def __init__(
|
435 |
+
self,
|
436 |
+
ch: int,
|
437 |
+
ch_mult: list[int],
|
438 |
+
num_res_blocks: int,
|
439 |
+
z_channels: int,
|
440 |
+
out_ch=3,
|
441 |
+
patch_size=8,
|
442 |
+
temporal_patch_size=4,
|
443 |
+
norm_type="group",
|
444 |
+
cnn_param=None,
|
445 |
+
use_checkpoint=False,
|
446 |
+
use_freq_dec=False, # use frequency features for decoder
|
447 |
+
use_pxsf=False,
|
448 |
+
):
|
449 |
+
super().__init__()
|
450 |
+
self.max_up = np.log2(patch_size)
|
451 |
+
self.temporal_max_up = np.log2(temporal_patch_size)
|
452 |
+
self.temporal_up_offset = self.max_up - self.temporal_max_up
|
453 |
+
self.ch = ch
|
454 |
+
self.num_resolutions = len(ch_mult)
|
455 |
+
self.num_res_blocks = num_res_blocks
|
456 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
457 |
+
self.cnn_param = cnn_param
|
458 |
+
self.use_checkpoint = use_checkpoint
|
459 |
+
self.use_freq_dec = use_freq_dec
|
460 |
+
self.use_pxsf = use_pxsf
|
461 |
+
|
462 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
463 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
464 |
+
|
465 |
+
# z to block_in
|
466 |
+
if cnn_param["conv_inner_2d"] == "yes":
|
467 |
+
self.conv_in = Conv(
|
468 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d"
|
469 |
+
)
|
470 |
+
else:
|
471 |
+
self.conv_in = Conv(
|
472 |
+
z_channels,
|
473 |
+
block_in,
|
474 |
+
kernel_size=3,
|
475 |
+
stride=1,
|
476 |
+
padding=1,
|
477 |
+
cnn_type=cnn_param["cnn_type"],
|
478 |
+
)
|
479 |
+
|
480 |
+
# middle
|
481 |
+
self.mid = nn.Module()
|
482 |
+
self.mid.block_1 = ResnetBlock(
|
483 |
+
in_channels=block_in,
|
484 |
+
out_channels=block_in,
|
485 |
+
norm_type=norm_type,
|
486 |
+
cnn_param=cnn_param,
|
487 |
+
)
|
488 |
+
if cnn_param["cnn_attention"] == "yes":
|
489 |
+
self.mid.attn_1 = AttnBlock(
|
490 |
+
block_in, norm_type=norm_type, cnn_param=cnn_param
|
491 |
+
)
|
492 |
+
self.mid.block_2 = ResnetBlock(
|
493 |
+
in_channels=block_in,
|
494 |
+
out_channels=block_in,
|
495 |
+
norm_type=norm_type,
|
496 |
+
cnn_param=cnn_param,
|
497 |
+
)
|
498 |
+
|
499 |
+
# upsampling
|
500 |
+
self.up = nn.ModuleList()
|
501 |
+
for i_level in reversed(range(self.num_resolutions)):
|
502 |
+
block = nn.ModuleList()
|
503 |
+
attn = nn.ModuleList()
|
504 |
+
block_out = ch * ch_mult[i_level]
|
505 |
+
for _ in range(self.num_res_blocks + 1):
|
506 |
+
block.append(
|
507 |
+
ResnetBlock(
|
508 |
+
in_channels=block_in,
|
509 |
+
out_channels=block_out,
|
510 |
+
norm_type=norm_type,
|
511 |
+
cnn_param=cnn_param,
|
512 |
+
)
|
513 |
+
)
|
514 |
+
block_in = block_out
|
515 |
+
up = nn.Module()
|
516 |
+
up.block = block
|
517 |
+
up.attn = attn
|
518 |
+
# upsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE, offset 1 compared with encoder
|
519 |
+
# https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228
|
520 |
+
spatial_up = True if 1 <= i_level <= self.max_up else False
|
521 |
+
temporal_up = (
|
522 |
+
True
|
523 |
+
if 1 <= i_level <= self.max_up
|
524 |
+
and i_level >= self.temporal_up_offset + 1
|
525 |
+
else False
|
526 |
+
)
|
527 |
+
if spatial_up or temporal_up:
|
528 |
+
up.upsample = Upsample(
|
529 |
+
block_in,
|
530 |
+
cnn_type=cnn_param["cnn_type"],
|
531 |
+
spatial_up=spatial_up,
|
532 |
+
temporal_up=temporal_up,
|
533 |
+
use_pxsl=self.use_pxsf,
|
534 |
+
)
|
535 |
+
self.up.insert(0, up) # prepend to get consistent order
|
536 |
+
|
537 |
+
# end
|
538 |
+
self.norm_out = Normalize(
|
539 |
+
block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
|
540 |
+
)
|
541 |
+
if cnn_param["conv_in_out_2d"] == "yes":
|
542 |
+
self.conv_out = Conv(
|
543 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d"
|
544 |
+
)
|
545 |
+
else:
|
546 |
+
self.conv_out = Conv(
|
547 |
+
block_in,
|
548 |
+
out_ch,
|
549 |
+
kernel_size=3,
|
550 |
+
stride=1,
|
551 |
+
padding=1,
|
552 |
+
cnn_type=cnn_param["cnn_type"],
|
553 |
+
)
|
554 |
+
|
555 |
+
def forward(self, z):
|
556 |
+
if not self.use_checkpoint:
|
557 |
+
return self._forward(z)
|
558 |
+
else:
|
559 |
+
return checkpoint.checkpoint(self._forward, z, use_reentrant=False)
|
560 |
+
|
561 |
+
def _forward(self, z: Tensor) -> Tensor:
|
562 |
+
# z to block_in
|
563 |
+
h = self.conv_in(z)
|
564 |
+
|
565 |
+
# middle
|
566 |
+
h = self.mid.block_1(h)
|
567 |
+
if self.cnn_param["cnn_attention"] == "yes":
|
568 |
+
h = self.mid.attn_1(h)
|
569 |
+
h = self.mid.block_2(h)
|
570 |
+
|
571 |
+
# upsampling
|
572 |
+
for i_level in reversed(range(self.num_resolutions)):
|
573 |
+
for i_block in range(self.num_res_blocks + 1):
|
574 |
+
h = self.up[i_level].block[i_block](h)
|
575 |
+
if len(self.up[i_level].attn) > 0:
|
576 |
+
h = self.up[i_level].attn[i_block](h)
|
577 |
+
if hasattr(self.up[i_level], "upsample"):
|
578 |
+
h = self.up[i_level].upsample(h)
|
579 |
+
|
580 |
+
# end
|
581 |
+
h = self.norm_out(h)
|
582 |
+
h = swish(h)
|
583 |
+
h = self.conv_out(h)
|
584 |
+
return h
|
585 |
+
|
586 |
+
|
587 |
+
class AutoEncoder(nn.Module):
|
588 |
+
def __init__(self, args):
|
589 |
+
super().__init__()
|
590 |
+
self.args = args
|
591 |
+
cnn_param = dict(
|
592 |
+
cnn_type=args.cnn_type,
|
593 |
+
conv_in_out_2d=args.conv_in_out_2d,
|
594 |
+
res_conv_2d=args.res_conv_2d,
|
595 |
+
cnn_attention=args.cnn_attention,
|
596 |
+
cnn_norm_axis=args.cnn_norm_axis,
|
597 |
+
conv_inner_2d=args.conv_inner_2d,
|
598 |
+
)
|
599 |
+
self.encoder = Encoder(
|
600 |
+
ch=args.base_ch,
|
601 |
+
ch_mult=args.encoder_ch_mult,
|
602 |
+
num_res_blocks=args.num_res_blocks,
|
603 |
+
z_channels=args.codebook_dim,
|
604 |
+
patch_size=args.patch_size,
|
605 |
+
temporal_patch_size=args.temporal_patch_size,
|
606 |
+
cnn_param=cnn_param,
|
607 |
+
use_checkpoint=args.use_checkpoint,
|
608 |
+
use_vae=args.use_vae,
|
609 |
+
)
|
610 |
+
self.decoder = Decoder(
|
611 |
+
ch=args.base_ch,
|
612 |
+
ch_mult=args.decoder_ch_mult,
|
613 |
+
num_res_blocks=args.num_res_blocks,
|
614 |
+
z_channels=args.codebook_dim,
|
615 |
+
patch_size=args.patch_size,
|
616 |
+
temporal_patch_size=args.temporal_patch_size,
|
617 |
+
cnn_param=cnn_param,
|
618 |
+
use_checkpoint=args.use_checkpoint,
|
619 |
+
use_freq_dec=args.use_freq_dec,
|
620 |
+
use_pxsf=args.use_pxsf, # pixelshuffle for upsampling
|
621 |
+
)
|
622 |
+
self.z_drop = nn.Dropout(args.z_drop)
|
623 |
+
self.scale_factor = 0.3611
|
624 |
+
self.shift_factor = 0.1159
|
625 |
+
self.codebook_dim = self.embed_dim = args.codebook_dim
|
626 |
+
|
627 |
+
self.gan_feat_weight = args.gan_feat_weight
|
628 |
+
self.video_perceptual_weight = args.video_perceptual_weight
|
629 |
+
self.recon_loss_type = args.recon_loss_type
|
630 |
+
self.l1_weight = args.l1_weight
|
631 |
+
self.use_vae = args.use_vae
|
632 |
+
self.kl_weight = args.kl_weight
|
633 |
+
self.lfq_weight = args.lfq_weight
|
634 |
+
self.image_gan_weight = args.image_gan_weight # image GAN loss weight
|
635 |
+
self.video_gan_weight = args.video_gan_weight # video GAN loss weight
|
636 |
+
self.perceptual_weight = args.perceptual_weight
|
637 |
+
self.flux_weight = args.flux_weight
|
638 |
+
self.cycle_weight = args.cycle_weight
|
639 |
+
self.cycle_feat_weight = args.cycle_feat_weight
|
640 |
+
self.cycle_gan_weight = args.cycle_gan_weight
|
641 |
+
|
642 |
+
self.flux_image_encoder = None
|
643 |
+
|
644 |
+
if not args.use_vae:
|
645 |
+
if args.quantizer_type == "MultiScaleBSQ":
|
646 |
+
self.quantizer = MultiScaleBSQ(
|
647 |
+
dim=args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
|
648 |
+
codebook_size=args.codebook_size, # codebook size, must be a power of 2
|
649 |
+
entropy_loss_weight=args.entropy_loss_weight, # how much weight to place on entropy loss
|
650 |
+
diversity_gamma=args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
|
651 |
+
preserve_norm=args.preserve_norm, # preserve norm of the input for BSQ
|
652 |
+
ln_before_quant=args.ln_before_quant, # use layer norm before quantization
|
653 |
+
ln_init_by_sqrt=args.ln_init_by_sqrt, # layer norm init value 1/sqrt(d)
|
654 |
+
commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss
|
655 |
+
new_quant=args.new_quant,
|
656 |
+
use_decay_factor=args.use_decay_factor,
|
657 |
+
mask_out=args.mask_out,
|
658 |
+
use_stochastic_depth=args.use_stochastic_depth,
|
659 |
+
drop_rate=args.drop_rate,
|
660 |
+
schedule_mode=args.schedule_mode,
|
661 |
+
keep_first_quant=args.keep_first_quant,
|
662 |
+
keep_last_quant=args.keep_last_quant,
|
663 |
+
remove_residual_detach=args.remove_residual_detach,
|
664 |
+
use_out_phi=args.use_out_phi,
|
665 |
+
use_out_phi_res=args.use_out_phi_res,
|
666 |
+
random_flip=args.random_flip,
|
667 |
+
flip_prob=args.flip_prob,
|
668 |
+
flip_mode=args.flip_mode,
|
669 |
+
max_flip_lvl=args.max_flip_lvl,
|
670 |
+
random_flip_1lvl=args.random_flip_1lvl,
|
671 |
+
flip_lvl_idx=args.flip_lvl_idx,
|
672 |
+
drop_when_test=args.drop_when_test,
|
673 |
+
drop_lvl_idx=args.drop_lvl_idx,
|
674 |
+
drop_lvl_num=args.drop_lvl_num,
|
675 |
+
)
|
676 |
+
self.quantize = self.quantizer
|
677 |
+
self.vocab_size = args.codebook_size
|
678 |
+
else:
|
679 |
+
raise NotImplementedError(f"{args.quantizer_type} not supported")
|
680 |
+
|
681 |
+
def forward(self, x):
|
682 |
+
is_image = x.ndim == 4
|
683 |
+
if not is_image:
|
684 |
+
B, C, T, H, W = x.shape
|
685 |
+
else:
|
686 |
+
B, C, H, W = x.shape
|
687 |
+
T = 1
|
688 |
+
enc_dtype = ptdtype[self.args.encoder_dtype]
|
689 |
+
|
690 |
+
with torch.amp.autocast("cuda", dtype=enc_dtype):
|
691 |
+
h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
|
692 |
+
hs = [_h.detach() for _h in hs]
|
693 |
+
hs_mid = [_h.detach() for _h in hs_mid]
|
694 |
+
h = h.to(dtype=torch.float32)
|
695 |
+
# print(z.shape)
|
696 |
+
# Multiscale LFQ
|
697 |
+
z, all_indices, _, _, all_loss, _ = self.quantizer(h)
|
698 |
+
x_recon = self.decoder(z)
|
699 |
+
vq_output = {
|
700 |
+
"commitment_loss": torch.mean(all_loss)
|
701 |
+
* self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty
|
702 |
+
"encodings": all_indices,
|
703 |
+
}
|
704 |
+
# return x_recon, vq_output
|
705 |
+
return x_recon, None, z
|
706 |
+
|
707 |
+
def encode_for_raw_features(
|
708 |
+
self, x, scale_schedule, return_residual_norm_per_scale=False
|
709 |
+
):
|
710 |
+
is_image = x.ndim == 4
|
711 |
+
if not is_image:
|
712 |
+
B, C, T, H, W = x.shape
|
713 |
+
else:
|
714 |
+
B, C, H, W = x.shape
|
715 |
+
T = 1
|
716 |
+
|
717 |
+
enc_dtype = ptdtype[self.args.encoder_dtype]
|
718 |
+
with torch.amp.autocast("cuda", dtype=enc_dtype):
|
719 |
+
h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
|
720 |
+
|
721 |
+
hs = [_h.detach() for _h in hs]
|
722 |
+
hs_mid = [_h.detach() for _h in hs_mid]
|
723 |
+
h = h.to(dtype=torch.float32)
|
724 |
+
return h, hs, hs_mid
|
725 |
+
|
726 |
+
def encode(self, x, scale_schedule, return_residual_norm_per_scale=False):
|
727 |
+
h, hs, hs_mid = self.encode_for_raw_features(
|
728 |
+
x, scale_schedule, return_residual_norm_per_scale
|
729 |
+
)
|
730 |
+
# Multiscale LFQ
|
731 |
+
(
|
732 |
+
z,
|
733 |
+
all_indices,
|
734 |
+
all_bit_indices,
|
735 |
+
residual_norm_per_scale,
|
736 |
+
all_loss,
|
737 |
+
var_input,
|
738 |
+
) = self.quantizer(
|
739 |
+
h,
|
740 |
+
scale_schedule=scale_schedule,
|
741 |
+
return_residual_norm_per_scale=return_residual_norm_per_scale,
|
742 |
+
)
|
743 |
+
return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input
|
744 |
+
|
745 |
+
def decode(self, z):
|
746 |
+
x_recon = self.decoder(z)
|
747 |
+
x_recon = torch.clamp(x_recon, min=-1, max=1)
|
748 |
+
return x_recon
|
749 |
+
|
750 |
+
def decode_from_indices(self, all_indices, scale_schedule, label_type):
|
751 |
+
summed_codes = 0
|
752 |
+
for idx_Bl in all_indices:
|
753 |
+
codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type)
|
754 |
+
summed_codes += F.interpolate(
|
755 |
+
codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up
|
756 |
+
)
|
757 |
+
assert summed_codes.shape[-3] == 1
|
758 |
+
x_recon = self.decoder(summed_codes.squeeze(-3))
|
759 |
+
x_recon = torch.clamp(x_recon, min=-1, max=1)
|
760 |
+
return summed_codes, x_recon
|
761 |
+
|
762 |
+
@staticmethod
|
763 |
+
def add_model_specific_args(parent_parser):
|
764 |
+
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
|
765 |
+
parser.add_argument("--flux_weight", type=float, default=0)
|
766 |
+
parser.add_argument("--cycle_weight", type=float, default=0)
|
767 |
+
parser.add_argument("--cycle_feat_weight", type=float, default=0)
|
768 |
+
parser.add_argument("--cycle_gan_weight", type=float, default=0)
|
769 |
+
parser.add_argument("--cycle_loop", type=int, default=0)
|
770 |
+
parser.add_argument("--z_drop", type=float, default=0.0)
|
771 |
+
return parser
|