Spaces:
Running
on
Zero
Running
on
Zero
刘虹雨
commited on
Commit
·
8ed2f16
1
Parent(s):
8f481d2
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +8 -32
- DiT_VAE/.DS_Store +0 -0
- DiT_VAE/__init__.py +0 -0
- DiT_VAE/diffusion/__init__.py +8 -0
- DiT_VAE/diffusion/configs/PixArt_xl2_4D_Triplane.py +64 -0
- DiT_VAE/diffusion/configs/PixArt_xl2_img256_4D_Triplane.py +41 -0
- DiT_VAE/diffusion/configs/__init__.py +0 -0
- DiT_VAE/diffusion/configs/vae_model.yaml +24 -0
- DiT_VAE/diffusion/data/__init__.py +2 -0
- DiT_VAE/diffusion/data/builder.py +67 -0
- DiT_VAE/diffusion/data/transforms.py +29 -0
- DiT_VAE/diffusion/dpm_solver.py +28 -0
- DiT_VAE/diffusion/iddpm.py +51 -0
- DiT_VAE/diffusion/lcm_scheduler.py +455 -0
- DiT_VAE/diffusion/model/__init__.py +2 -0
- DiT_VAE/diffusion/model/builder.py +14 -0
- DiT_VAE/diffusion/model/diffusion_utils.py +92 -0
- DiT_VAE/diffusion/model/dpm_solver.py +1337 -0
- DiT_VAE/diffusion/model/edm_sample.py +168 -0
- DiT_VAE/diffusion/model/gaussian_diffusion.py +1006 -0
- DiT_VAE/diffusion/model/hed.py +150 -0
- DiT_VAE/diffusion/model/image_embedding.py +15 -0
- DiT_VAE/diffusion/model/nets/PixArt_blocks.py +655 -0
- DiT_VAE/diffusion/model/nets/TriDitCLIPDINO.py +315 -0
- DiT_VAE/diffusion/model/nets/__init__.py +1 -0
- DiT_VAE/diffusion/model/respace.py +131 -0
- DiT_VAE/diffusion/model/sa_solver.py +1129 -0
- DiT_VAE/diffusion/model/timestep_sampler.py +150 -0
- DiT_VAE/diffusion/model/utils.py +510 -0
- DiT_VAE/diffusion/sa_sampler.py +66 -0
- DiT_VAE/diffusion/sa_solver_diffusers.py +840 -0
- DiT_VAE/diffusion/utils/__init__.py +1 -0
- DiT_VAE/diffusion/utils/checkpoint.py +80 -0
- DiT_VAE/diffusion/utils/data_sampler.py +138 -0
- DiT_VAE/diffusion/utils/dist_utils.py +303 -0
- DiT_VAE/diffusion/utils/logger.py +94 -0
- DiT_VAE/diffusion/utils/lr_scheduler.py +89 -0
- DiT_VAE/diffusion/utils/misc.py +366 -0
- DiT_VAE/diffusion/utils/optimizer.py +237 -0
- DiT_VAE/train_diffusion.py +5 -0
- DiT_VAE/train_vae.py +369 -0
- DiT_VAE/util.py +217 -0
- DiT_VAE/vae/__init__.py +0 -0
- DiT_VAE/vae/aemodules3d.py +368 -0
- DiT_VAE/vae/attention_vae.py +620 -0
- DiT_VAE/vae/data/__init__.py +0 -0
- DiT_VAE/vae/data/dataset_online_vae.py +108 -0
- DiT_VAE/vae/distributions.py +94 -0
- DiT_VAE/vae/losses/__init__.py +1 -0
.DS_Store
ADDED
Binary file (10.2 kB). View file
|
|
.gitattributes
CHANGED
@@ -1,36 +1,12 @@
|
|
1 |
-
*.
|
2 |
-
|
|
|
|
|
|
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
|
27 |
-
*.
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
1 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
2 |
+
data_process/lib/FaceVerse/v3/faceverse_v3_1.npy filter=lfs diff=lfs merge=lfs -text
|
3 |
+
data_process/lib/faceverse_process/BgMatting_models/rvm_resnet50_fp32.torchscript filter=lfs diff=lfs merge=lfs -text
|
4 |
+
data_process/lib/faceverse_process/metamodel/v3/faceverse_v3_1.npy filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
6 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
7 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.torchscript filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
*.png filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
DiT_VAE/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
DiT_VAE/__init__.py
ADDED
File without changes
|
DiT_VAE/diffusion/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from .iddpm import IDDPM
|
7 |
+
from .dpm_solver import DPMS
|
8 |
+
from .sa_sampler import SASolverSampler
|
DiT_VAE/diffusion/configs/PixArt_xl2_4D_Triplane.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_root = '/data/data'
|
2 |
+
data = dict(type='TriplaneData', data_base_dir='triplane', data_json_file='/nas8/liuhongyu/HeadGallery_Data/data.json', model_names='configs/gan_model.yaml' )
|
3 |
+
image_size = 256 # the generated image resolution
|
4 |
+
train_batch_size = 32
|
5 |
+
eval_batch_size = 16
|
6 |
+
use_fsdp=False # if use FSDP mode
|
7 |
+
valid_num=0 # take as valid aspect-ratio when sample number >= valid_num
|
8 |
+
triplane_size = (256*4, 256)
|
9 |
+
# model setting
|
10 |
+
image_encoder_path = "/home/liuhongyu/code/IP-Adapter/models/image_encoder"
|
11 |
+
vae_triplane_config_path = "vae_model.yaml"
|
12 |
+
std_dir = '/nas8/liuhongyu/HeadGallery_Data/final_std.pt'
|
13 |
+
mean_dir = '/nas8/liuhongyu/HeadGallery_Data/final_mean.pt'
|
14 |
+
conditioning_params_dir = '/nas8/liuhongyu/HeadGallery_Data/conditioning_params.pkl'
|
15 |
+
gan_model_base_dir = '/nas8/liuhongyu/HeadGallery_Data/gan_models'
|
16 |
+
model = 'PixArt_XL_2'
|
17 |
+
aspect_ratio_type = None # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
|
18 |
+
multi_scale = False # if use multiscale dataset model training
|
19 |
+
lewei_scale = 1.0 # lewei_scale for positional embedding interpolation
|
20 |
+
# training setting
|
21 |
+
num_workers=4
|
22 |
+
train_sampling_steps = 1000
|
23 |
+
eval_sampling_steps = 250
|
24 |
+
model_max_length = 8
|
25 |
+
lora_rank = 4
|
26 |
+
|
27 |
+
num_epochs = 80
|
28 |
+
gradient_accumulation_steps = 1
|
29 |
+
grad_checkpointing = False
|
30 |
+
gradient_clip = 1.0
|
31 |
+
gc_step = 1
|
32 |
+
auto_lr = dict(rule='sqrt')
|
33 |
+
|
34 |
+
# we use different weight decay with the official implementation since it results better result
|
35 |
+
optimizer = dict(type='AdamW', lr=1e-4, weight_decay=3e-2, eps=1e-10)
|
36 |
+
lr_schedule = 'constant'
|
37 |
+
lr_schedule_args = dict(num_warmup_steps=500)
|
38 |
+
|
39 |
+
save_image_epochs = 1
|
40 |
+
save_model_epochs = 1
|
41 |
+
save_model_steps=1000000
|
42 |
+
|
43 |
+
sample_posterior = True
|
44 |
+
mixed_precision = 'fp16'
|
45 |
+
scale_factor = 0.3994218
|
46 |
+
ema_rate = 0.9999
|
47 |
+
tensorboard_mox_interval = 50
|
48 |
+
log_interval = 50
|
49 |
+
cfg_scale = 4
|
50 |
+
mask_type='null'
|
51 |
+
num_group_tokens=0
|
52 |
+
mask_loss_coef=0.
|
53 |
+
load_mask_index=False # load prepared mask_type index
|
54 |
+
# load model settings
|
55 |
+
vae_pretrained = "/cache/pretrained_models/sd-vae-ft-ema"
|
56 |
+
load_from = None
|
57 |
+
resume_from = dict(checkpoint=None, load_ema=False, resume_optimizer=True, resume_lr_scheduler=True)
|
58 |
+
snr_loss=False
|
59 |
+
|
60 |
+
# work dir settings
|
61 |
+
work_dir = '/cache/exps/'
|
62 |
+
s3_work_dir = None
|
63 |
+
|
64 |
+
seed = 43
|
DiT_VAE/diffusion/configs/PixArt_xl2_img256_4D_Triplane.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['./PixArt_xl2_4D_Triplane.py']
|
2 |
+
data_root = 'data'
|
3 |
+
|
4 |
+
data = dict(type='TriplaneData', data_base_dir='/nas8/liuhongyu/HeadGallery_Data',
|
5 |
+
data_json_file='/nas8/liuhongyu/HeadGallery_Data/data_combine.json', model_names='configs/gan_model.yaml',
|
6 |
+
dino_path='/nas8/liuhongyu/model/dinov2-base')
|
7 |
+
image_size = 256
|
8 |
+
|
9 |
+
# model setting
|
10 |
+
gan_model_config = "./configs/gan_model.yaml"
|
11 |
+
image_encoder_path = "/home/liuhongyu/code/IP-Adapter/models/image_encoder"
|
12 |
+
vae_triplane_config_path = "./vae_model.yaml"
|
13 |
+
std_dir = '/nas8/liuhongyu/HeadGallery_Data/final_std.pt'
|
14 |
+
mean_dir = '/nas8/liuhongyu/HeadGallery_Data/final_mean.pt'
|
15 |
+
conditioning_params_dir = '/nas8/liuhongyu/HeadGallery_Data/conditioning_params.pkl'
|
16 |
+
gan_model_base_dir = '/nas8/liuhongyu/HeadGallery_Data/gan_models'
|
17 |
+
dino_pretrained = '/nas8/liuhongyu/HeadGallery_Data/dinov2-base'
|
18 |
+
window_block_indexes = []
|
19 |
+
window_size = 0
|
20 |
+
use_rel_pos = False
|
21 |
+
model = 'PixArt_XL_2'
|
22 |
+
fp32_attention = True
|
23 |
+
dino_norm = False
|
24 |
+
img_feature_self_attention = False
|
25 |
+
load_from = None
|
26 |
+
vae_pretrained = "/nas8/liuhongyu/all_training_results/VAE/checkpoint-140000"
|
27 |
+
# training setting
|
28 |
+
eval_sampling_steps = 200
|
29 |
+
save_model_steps = 10000
|
30 |
+
num_workers = 2
|
31 |
+
train_batch_size = 8 # 32 # max 96 for PixArt-L/4 when grad_checkpoint
|
32 |
+
num_epochs = 200 # 3
|
33 |
+
gradient_accumulation_steps = 1
|
34 |
+
grad_checkpointing = True
|
35 |
+
gradient_clip = 0.01
|
36 |
+
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
|
37 |
+
lr_schedule_args = dict(num_warmup_steps=1000)
|
38 |
+
|
39 |
+
log_interval = 20
|
40 |
+
save_model_epochs = 5
|
41 |
+
work_dir = 'output/debug'
|
DiT_VAE/diffusion/configs/__init__.py
ADDED
File without changes
|
DiT_VAE/diffusion/configs/vae_model.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
embed_dim: 8
|
2 |
+
ddconfig:
|
3 |
+
double_z: True
|
4 |
+
z_channels: 8
|
5 |
+
encoder:
|
6 |
+
target: DiT_VAE.vae.aemodules3d.Encoder
|
7 |
+
params:
|
8 |
+
n_hiddens: 128
|
9 |
+
downsample: [4, 8, 8]
|
10 |
+
image_channel: 32
|
11 |
+
norm_type: group
|
12 |
+
padding_type: replicate
|
13 |
+
double_z: True
|
14 |
+
z_channels: 8
|
15 |
+
|
16 |
+
decoder:
|
17 |
+
target: DiT_VAE.vae.aemodules3d.Decoder
|
18 |
+
params:
|
19 |
+
n_hiddens: 128
|
20 |
+
upsample: [4, 8, 8]
|
21 |
+
z_channels: 8
|
22 |
+
image_channel: 32
|
23 |
+
norm_type: group
|
24 |
+
|
DiT_VAE/diffusion/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .datasets import *
|
2 |
+
from .transforms import get_transform
|
DiT_VAE/diffusion/data/builder.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
from mmcv import Registry, build_from_cfg
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
from DiT_VAE.diffusion.data.transforms import get_transform
|
8 |
+
from DiT_VAE.diffusion.utils.logger import get_root_logger
|
9 |
+
|
10 |
+
DATASETS = Registry('datasets')
|
11 |
+
|
12 |
+
DATA_ROOT = '/cache/data'
|
13 |
+
|
14 |
+
|
15 |
+
def set_data_root(data_root):
|
16 |
+
global DATA_ROOT
|
17 |
+
DATA_ROOT = data_root
|
18 |
+
|
19 |
+
|
20 |
+
def get_data_path(data_dir):
|
21 |
+
if os.path.isabs(data_dir):
|
22 |
+
return data_dir
|
23 |
+
global DATA_ROOT
|
24 |
+
return os.path.join(DATA_ROOT, data_dir)
|
25 |
+
|
26 |
+
|
27 |
+
def build_dataset_triplane(cfg, resolution=256, **kwargs):
|
28 |
+
logger = get_root_logger()
|
29 |
+
|
30 |
+
dataset_type = cfg.get('type')
|
31 |
+
logger.info(f"Constructing dataset {dataset_type}...")
|
32 |
+
t = time.time()
|
33 |
+
|
34 |
+
dataset = build_from_cfg(cfg, DATASETS, default_args=dict( image_size=resolution, **kwargs))
|
35 |
+
logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length: {len(dataset)} ")
|
36 |
+
return dataset
|
37 |
+
def build_dataset(cfg, resolution=224, **kwargs):
|
38 |
+
logger = get_root_logger()
|
39 |
+
|
40 |
+
dataset_type = cfg.get('type')
|
41 |
+
logger.info(f"Constructing dataset {dataset_type}...")
|
42 |
+
t = time.time()
|
43 |
+
transform = cfg.pop('transform', 'default_train')
|
44 |
+
transform = get_transform(transform, resolution)
|
45 |
+
dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs))
|
46 |
+
logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}")
|
47 |
+
return dataset
|
48 |
+
|
49 |
+
|
50 |
+
def build_dataloader(dataset, batch_size=256, num_workers=2, shuffle=True, **kwargs):
|
51 |
+
return (
|
52 |
+
DataLoader(
|
53 |
+
dataset,
|
54 |
+
batch_sampler=kwargs['batch_sampler'],
|
55 |
+
num_workers=num_workers,
|
56 |
+
pin_memory=True,
|
57 |
+
)
|
58 |
+
if 'batch_sampler' in kwargs
|
59 |
+
else DataLoader(
|
60 |
+
dataset,
|
61 |
+
batch_size=batch_size,
|
62 |
+
shuffle=shuffle,
|
63 |
+
num_workers=num_workers,
|
64 |
+
pin_memory=True,
|
65 |
+
**kwargs
|
66 |
+
)
|
67 |
+
)
|
DiT_VAE/diffusion/data/transforms.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms as T
|
2 |
+
|
3 |
+
TRANSFORMS = {}
|
4 |
+
|
5 |
+
|
6 |
+
def register_transform(transform):
|
7 |
+
name = transform.__name__
|
8 |
+
if name in TRANSFORMS:
|
9 |
+
raise RuntimeError(f'Transform {name} has already registered.')
|
10 |
+
TRANSFORMS.update({name: transform})
|
11 |
+
|
12 |
+
|
13 |
+
def get_transform(type, resolution):
|
14 |
+
transform = TRANSFORMS[type](resolution)
|
15 |
+
transform = T.Compose(transform)
|
16 |
+
transform.image_size = resolution
|
17 |
+
return transform
|
18 |
+
|
19 |
+
|
20 |
+
@register_transform
|
21 |
+
def default_train(n_px):
|
22 |
+
return [
|
23 |
+
T.Lambda(lambda img: img.convert('RGB')),
|
24 |
+
T.Resize(n_px), # Image.BICUBIC
|
25 |
+
T.CenterCrop(n_px),
|
26 |
+
# T.RandomHorizontalFlip(),
|
27 |
+
T.ToTensor(),
|
28 |
+
T.Normalize([0.5], [0.5]),
|
29 |
+
]
|
DiT_VAE/diffusion/dpm_solver.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .model import gaussian_diffusion as gd
|
3 |
+
from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP
|
4 |
+
|
5 |
+
|
6 |
+
def DPMS(model, condition, uncondition, cfg_scale, model_type='noise', noise_schedule="linear", guidance_type='classifier-free', model_kwargs=None, diffusion_steps=1000):
|
7 |
+
if model_kwargs is None:
|
8 |
+
model_kwargs = {}
|
9 |
+
betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
|
10 |
+
|
11 |
+
## 1. Define the noise schedule.
|
12 |
+
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
|
13 |
+
|
14 |
+
## 2. Convert your discrete-time `model` to the continuous-time
|
15 |
+
## noise prediction model. Here is an example for a diffusion model
|
16 |
+
## `model` with the noise prediction type ("noise") .
|
17 |
+
model_fn = model_wrapper(
|
18 |
+
model,
|
19 |
+
noise_schedule,
|
20 |
+
model_type=model_type,
|
21 |
+
model_kwargs=model_kwargs,
|
22 |
+
guidance_type=guidance_type,
|
23 |
+
condition=condition,
|
24 |
+
unconditional_condition=uncondition,
|
25 |
+
guidance_scale=cfg_scale,
|
26 |
+
)
|
27 |
+
## 3. Define dpm-solver and sample by multistep DPM-Solver.
|
28 |
+
return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
DiT_VAE/diffusion/iddpm.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
from .model.respace import SpacedDiffusion, space_timesteps
|
6 |
+
from .model import gaussian_diffusion as gd
|
7 |
+
|
8 |
+
|
9 |
+
def IDDPM(
|
10 |
+
timestep_respacing,
|
11 |
+
noise_schedule="linear",
|
12 |
+
use_kl=False,
|
13 |
+
sigma_small=False,
|
14 |
+
predict_xstart=False,
|
15 |
+
learn_sigma=True,
|
16 |
+
pred_sigma=True,
|
17 |
+
rescale_learned_sigmas=False,
|
18 |
+
diffusion_steps=1000,
|
19 |
+
snr=False,
|
20 |
+
return_startx=False,
|
21 |
+
):
|
22 |
+
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
23 |
+
if use_kl:
|
24 |
+
loss_type = gd.LossType.RESCALED_KL
|
25 |
+
elif rescale_learned_sigmas:
|
26 |
+
loss_type = gd.LossType.RESCALED_MSE
|
27 |
+
else:
|
28 |
+
loss_type = gd.LossType.MSE
|
29 |
+
if timestep_respacing is None or timestep_respacing == "":
|
30 |
+
timestep_respacing = [diffusion_steps]
|
31 |
+
return SpacedDiffusion(
|
32 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
33 |
+
betas=betas,
|
34 |
+
model_mean_type=(
|
35 |
+
gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON
|
36 |
+
),
|
37 |
+
model_var_type=(
|
38 |
+
(gd.ModelVarType.LEARNED_RANGE if learn_sigma else (
|
39 |
+
gd.ModelVarType.FIXED_LARGE
|
40 |
+
if not sigma_small
|
41 |
+
else gd.ModelVarType.FIXED_SMALL
|
42 |
+
)
|
43 |
+
)
|
44 |
+
if pred_sigma
|
45 |
+
else None
|
46 |
+
),
|
47 |
+
loss_type=loss_type,
|
48 |
+
snr=snr,
|
49 |
+
return_startx=return_startx,
|
50 |
+
# rescale_timesteps=rescale_timesteps,
|
51 |
+
)
|
DiT_VAE/diffusion/lcm_scheduler.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
from diffusers import ConfigMixin, SchedulerMixin
|
26 |
+
from diffusers.configuration_utils import register_to_config
|
27 |
+
from diffusers.utils import BaseOutput
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
32 |
+
class LCMSchedulerOutput(BaseOutput):
|
33 |
+
"""
|
34 |
+
Output class for the scheduler's `step` function output.
|
35 |
+
Args:
|
36 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38 |
+
denoising loop.
|
39 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
40 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
41 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
42 |
+
"""
|
43 |
+
|
44 |
+
prev_sample: torch.FloatTensor
|
45 |
+
denoised: Optional[torch.FloatTensor] = None
|
46 |
+
|
47 |
+
|
48 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
49 |
+
def betas_for_alpha_bar(
|
50 |
+
num_diffusion_timesteps,
|
51 |
+
max_beta=0.999,
|
52 |
+
alpha_transform_type="cosine",
|
53 |
+
):
|
54 |
+
"""
|
55 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
56 |
+
(1-beta) over time from t = [0,1].
|
57 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
58 |
+
to that part of the diffusion process.
|
59 |
+
Args:
|
60 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
61 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
62 |
+
prevent singularities.
|
63 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
64 |
+
Choose from `cosine` or `exp`
|
65 |
+
Returns:
|
66 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
67 |
+
"""
|
68 |
+
if alpha_transform_type == "cosine":
|
69 |
+
|
70 |
+
def alpha_bar_fn(t):
|
71 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
72 |
+
|
73 |
+
elif alpha_transform_type == "exp":
|
74 |
+
|
75 |
+
def alpha_bar_fn(t):
|
76 |
+
return math.exp(t * -12.0)
|
77 |
+
|
78 |
+
else:
|
79 |
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
80 |
+
|
81 |
+
betas = []
|
82 |
+
for i in range(num_diffusion_timesteps):
|
83 |
+
t1 = i / num_diffusion_timesteps
|
84 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
85 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
86 |
+
return torch.tensor(betas, dtype=torch.float32)
|
87 |
+
|
88 |
+
|
89 |
+
def rescale_zero_terminal_snr(betas):
|
90 |
+
"""
|
91 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
92 |
+
Args:
|
93 |
+
betas (`torch.FloatTensor`):
|
94 |
+
the betas that the scheduler is being initialized with.
|
95 |
+
Returns:
|
96 |
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
97 |
+
"""
|
98 |
+
# Convert betas to alphas_bar_sqrt
|
99 |
+
alphas = 1.0 - betas
|
100 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
101 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
102 |
+
|
103 |
+
# Store old values.
|
104 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
105 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
106 |
+
|
107 |
+
# Shift so the last timestep is zero.
|
108 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
109 |
+
|
110 |
+
# Scale so the first timestep is back to the old value.
|
111 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
112 |
+
|
113 |
+
# Convert alphas_bar_sqrt to betas
|
114 |
+
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
|
115 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
116 |
+
alphas = torch.cat([alphas_bar[:1], alphas])
|
117 |
+
betas = 1 - alphas
|
118 |
+
|
119 |
+
return betas
|
120 |
+
|
121 |
+
|
122 |
+
class LCMScheduler(SchedulerMixin, ConfigMixin):
|
123 |
+
"""
|
124 |
+
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic vae (DDPMs) with
|
125 |
+
non-Markovian guidance.
|
126 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
127 |
+
methods the library implements for all schedulers such as loading and saving.
|
128 |
+
Args:
|
129 |
+
num_train_timesteps (`int`, defaults to 1000):
|
130 |
+
The number of diffusion steps to train the model.
|
131 |
+
beta_start (`float`, defaults to 0.0001):
|
132 |
+
The starting `beta` value of inference.
|
133 |
+
beta_end (`float`, defaults to 0.02):
|
134 |
+
The final `beta` value.
|
135 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
136 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
137 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
138 |
+
trained_betas (`np.ndarray`, *optional*):
|
139 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
140 |
+
clip_sample (`bool`, defaults to `True`):
|
141 |
+
Clip the predicted sample for numerical stability.
|
142 |
+
clip_sample_range (`float`, defaults to 1.0):
|
143 |
+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
144 |
+
set_alpha_to_one (`bool`, defaults to `True`):
|
145 |
+
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
146 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
147 |
+
otherwise it uses the alpha value at step 0.
|
148 |
+
steps_offset (`int`, defaults to 0):
|
149 |
+
An offset added to the inference steps. You can use a combination of `offset=1` and
|
150 |
+
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
151 |
+
Diffusion.
|
152 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
153 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
154 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
155 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
156 |
+
thresholding (`bool`, defaults to `False`):
|
157 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion vae such
|
158 |
+
as Stable Diffusion.
|
159 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
160 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
161 |
+
sample_max_value (`float`, defaults to 1.0):
|
162 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
163 |
+
timestep_spacing (`str`, defaults to `"leading"`):
|
164 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
165 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
166 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
167 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
168 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
169 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
170 |
+
"""
|
171 |
+
|
172 |
+
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
173 |
+
order = 1
|
174 |
+
|
175 |
+
@register_to_config
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
num_train_timesteps: int = 1000,
|
179 |
+
beta_start: float = 0.0001,
|
180 |
+
beta_end: float = 0.02,
|
181 |
+
beta_schedule: str = "linear",
|
182 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
183 |
+
clip_sample: bool = True,
|
184 |
+
set_alpha_to_one: bool = True,
|
185 |
+
steps_offset: int = 0,
|
186 |
+
prediction_type: str = "epsilon",
|
187 |
+
thresholding: bool = False,
|
188 |
+
dynamic_thresholding_ratio: float = 0.995,
|
189 |
+
clip_sample_range: float = 1.0,
|
190 |
+
sample_max_value: float = 1.0,
|
191 |
+
timestep_spacing: str = "leading",
|
192 |
+
rescale_betas_zero_snr: bool = False,
|
193 |
+
):
|
194 |
+
if trained_betas is not None:
|
195 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
196 |
+
elif beta_schedule == "linear":
|
197 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
198 |
+
elif beta_schedule == "scaled_linear":
|
199 |
+
# this schedule is very specific to the latent diffusion model.
|
200 |
+
self.betas = (
|
201 |
+
torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
202 |
+
)
|
203 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
204 |
+
# Glide cosine schedule
|
205 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
206 |
+
else:
|
207 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
208 |
+
|
209 |
+
# Rescale for zero SNR
|
210 |
+
if rescale_betas_zero_snr:
|
211 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
212 |
+
|
213 |
+
self.alphas = 1.0 - self.betas
|
214 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
215 |
+
|
216 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
217 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
218 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
219 |
+
# whether we use the final alpha of the "non-previous" one.
|
220 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
221 |
+
|
222 |
+
# standard deviation of the initial noise distribution
|
223 |
+
self.init_noise_sigma = 1.0
|
224 |
+
|
225 |
+
# setable values
|
226 |
+
self.num_inference_steps = None
|
227 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
228 |
+
|
229 |
+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
230 |
+
"""
|
231 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
232 |
+
current timestep.
|
233 |
+
Args:
|
234 |
+
sample (`torch.FloatTensor`):
|
235 |
+
The input sample.
|
236 |
+
timestep (`int`, *optional*):
|
237 |
+
The current timestep in the diffusion chain.
|
238 |
+
Returns:
|
239 |
+
`torch.FloatTensor`:
|
240 |
+
A scaled input sample.
|
241 |
+
"""
|
242 |
+
return sample
|
243 |
+
|
244 |
+
def _get_variance(self, timestep, prev_timestep):
|
245 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
246 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
247 |
+
beta_prod_t = 1 - alpha_prod_t
|
248 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
249 |
+
|
250 |
+
return (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
251 |
+
|
252 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
253 |
+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
254 |
+
"""
|
255 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
256 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
257 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
258 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
259 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
260 |
+
https://arxiv.org/abs/2205.11487
|
261 |
+
"""
|
262 |
+
dtype = sample.dtype
|
263 |
+
batch_size, channels, height, width = sample.shape
|
264 |
+
|
265 |
+
if dtype not in (torch.float32, torch.float64):
|
266 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
267 |
+
|
268 |
+
# Flatten sample for doing quantile calculation along each image
|
269 |
+
sample = sample.reshape(batch_size, channels * height * width)
|
270 |
+
|
271 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
272 |
+
|
273 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
274 |
+
s = torch.clamp(
|
275 |
+
s, min=1, max=self.config.sample_max_value
|
276 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
277 |
+
|
278 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
279 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
280 |
+
|
281 |
+
sample = sample.reshape(batch_size, channels, height, width)
|
282 |
+
sample = sample.to(dtype)
|
283 |
+
|
284 |
+
return sample
|
285 |
+
|
286 |
+
def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
|
287 |
+
"""
|
288 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
289 |
+
Args:
|
290 |
+
num_inference_steps (`int`):
|
291 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
292 |
+
"""
|
293 |
+
|
294 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
295 |
+
raise ValueError(
|
296 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
297 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
298 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
299 |
+
)
|
300 |
+
|
301 |
+
self.num_inference_steps = num_inference_steps
|
302 |
+
|
303 |
+
# LCM Timesteps Setting: # Linear Spacing
|
304 |
+
c = self.config.num_train_timesteps // lcm_origin_steps
|
305 |
+
lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
|
306 |
+
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
307 |
+
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
|
308 |
+
|
309 |
+
self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
|
310 |
+
|
311 |
+
def get_scalings_for_boundary_condition_discrete(self, t):
|
312 |
+
self.sigma_data = 0.5 # Default: 0.5
|
313 |
+
|
314 |
+
# By dividing 0.1: This is almost a delta function at t=0.
|
315 |
+
c_skip = self.sigma_data ** 2 / ((t / 0.1) ** 2 + self.sigma_data ** 2)
|
316 |
+
c_out = ((t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data ** 2) ** 0.5)
|
317 |
+
return c_skip, c_out
|
318 |
+
|
319 |
+
def step(
|
320 |
+
self,
|
321 |
+
model_output: torch.FloatTensor,
|
322 |
+
timeindex: int,
|
323 |
+
timestep: int,
|
324 |
+
sample: torch.FloatTensor,
|
325 |
+
eta: float = 0.0,
|
326 |
+
use_clipped_model_output: bool = False,
|
327 |
+
generator=None,
|
328 |
+
variance_noise: Optional[torch.FloatTensor] = None,
|
329 |
+
return_dict: bool = True,
|
330 |
+
) -> Union[LCMSchedulerOutput, Tuple]:
|
331 |
+
"""
|
332 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
333 |
+
process from the learned model outputs (most often the predicted noise).
|
334 |
+
Args:
|
335 |
+
model_output (`torch.FloatTensor`):
|
336 |
+
The direct output from learned diffusion model.
|
337 |
+
timestep (`float`):
|
338 |
+
The current discrete timestep in the diffusion chain.
|
339 |
+
sample (`torch.FloatTensor`):
|
340 |
+
A current instance of a sample created by the diffusion process.
|
341 |
+
eta (`float`):
|
342 |
+
The weight of noise for added noise in diffusion step.
|
343 |
+
use_clipped_model_output (`bool`, defaults to `False`):
|
344 |
+
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
345 |
+
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
346 |
+
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
347 |
+
`use_clipped_model_output` has no effect.
|
348 |
+
generator (`torch.Generator`, *optional*):
|
349 |
+
A random number generator.
|
350 |
+
variance_noise (`torch.FloatTensor`):
|
351 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
352 |
+
itself. Useful for methods such as [`CycleDiffusion`].
|
353 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
354 |
+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
355 |
+
Returns:
|
356 |
+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
357 |
+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
358 |
+
tuple is returned where the first element is the sample tensor.
|
359 |
+
"""
|
360 |
+
if self.num_inference_steps is None:
|
361 |
+
raise ValueError(
|
362 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
363 |
+
)
|
364 |
+
|
365 |
+
# 1. get previous step value
|
366 |
+
prev_timeindex = timeindex + 1
|
367 |
+
if prev_timeindex < len(self.timesteps):
|
368 |
+
prev_timestep = self.timesteps[prev_timeindex]
|
369 |
+
else:
|
370 |
+
prev_timestep = timestep
|
371 |
+
|
372 |
+
# 2. compute alphas, betas
|
373 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
374 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
375 |
+
|
376 |
+
beta_prod_t = 1 - alpha_prod_t
|
377 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
378 |
+
|
379 |
+
# 3. Get scalings for boundary conditions
|
380 |
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
381 |
+
|
382 |
+
# 4. Different Parameterization:
|
383 |
+
parameterization = self.config.prediction_type
|
384 |
+
|
385 |
+
if parameterization == "epsilon": # noise-prediction
|
386 |
+
pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
387 |
+
|
388 |
+
elif parameterization == "sample": # x-prediction
|
389 |
+
pred_x0 = model_output
|
390 |
+
|
391 |
+
elif parameterization == "v_prediction": # v-prediction
|
392 |
+
pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
393 |
+
|
394 |
+
# 4. Denoise model output using boundary conditions
|
395 |
+
denoised = c_out * pred_x0 + c_skip * sample
|
396 |
+
|
397 |
+
# 5. Sample z ~ N(0, I), For MultiStep Inference
|
398 |
+
# Noise is not used for one-step sampling.
|
399 |
+
if len(self.timesteps) > 1:
|
400 |
+
noise = torch.randn(model_output.shape).to(model_output.device)
|
401 |
+
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
402 |
+
else:
|
403 |
+
prev_sample = denoised
|
404 |
+
|
405 |
+
if not return_dict:
|
406 |
+
return (prev_sample, denoised)
|
407 |
+
|
408 |
+
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
|
409 |
+
|
410 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
411 |
+
def add_noise(
|
412 |
+
self,
|
413 |
+
original_samples: torch.FloatTensor,
|
414 |
+
noise: torch.FloatTensor,
|
415 |
+
timesteps: torch.IntTensor,
|
416 |
+
) -> torch.FloatTensor:
|
417 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
418 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
419 |
+
timesteps = timesteps.to(original_samples.device)
|
420 |
+
|
421 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
422 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
423 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
424 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
425 |
+
|
426 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
427 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
428 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
429 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
430 |
+
|
431 |
+
return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
432 |
+
|
433 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
434 |
+
def get_velocity(
|
435 |
+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
436 |
+
) -> torch.FloatTensor:
|
437 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
438 |
+
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
439 |
+
timesteps = timesteps.to(sample.device)
|
440 |
+
|
441 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
442 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
443 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
444 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
445 |
+
|
446 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
447 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
448 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
449 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
450 |
+
|
451 |
+
return sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
452 |
+
|
453 |
+
def __len__(self):
|
454 |
+
return self.config.num_train_timesteps
|
455 |
+
|
DiT_VAE/diffusion/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .nets import *
|
2 |
+
# import utils
|
DiT_VAE/diffusion/model/builder.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcv import Registry
|
2 |
+
|
3 |
+
from DiT_VAE.diffusion.model.utils import set_grad_checkpoint
|
4 |
+
|
5 |
+
MODELS = Registry('vae')
|
6 |
+
|
7 |
+
|
8 |
+
def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs):
|
9 |
+
if isinstance(cfg, str):
|
10 |
+
cfg = dict(type=cfg)
|
11 |
+
model = MODELS.build(cfg, default_args=kwargs)
|
12 |
+
if use_grad_checkpoint:
|
13 |
+
set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step)
|
14 |
+
return model
|
DiT_VAE/diffusion/model/diffusion_utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
|
10 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two gaussians.
|
13 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
+
scalars, among other use cases.
|
15 |
+
"""
|
16 |
+
tensor = next(
|
17 |
+
(
|
18 |
+
obj
|
19 |
+
for obj in (mean1, logvar1, mean2, logvar2)
|
20 |
+
if isinstance(obj, th.Tensor)
|
21 |
+
),
|
22 |
+
None,
|
23 |
+
)
|
24 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
25 |
+
|
26 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
27 |
+
# Tensors, but it does not work for th.exp().
|
28 |
+
logvar1, logvar2 = [
|
29 |
+
x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
|
30 |
+
for x in (logvar1, logvar2)
|
31 |
+
]
|
32 |
+
|
33 |
+
return 0.5 * (
|
34 |
+
-1.0
|
35 |
+
+ logvar2
|
36 |
+
- logvar1
|
37 |
+
+ th.exp(logvar1 - logvar2)
|
38 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def approx_standard_normal_cdf(x):
|
43 |
+
"""
|
44 |
+
A fast approximation of the cumulative distribution function of the
|
45 |
+
standard normal.
|
46 |
+
"""
|
47 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
48 |
+
|
49 |
+
|
50 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
51 |
+
"""
|
52 |
+
Compute the log-likelihood of a continuous Gaussian distribution.
|
53 |
+
:param x: the targets
|
54 |
+
:param means: the Gaussian mean Tensor.
|
55 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
56 |
+
:return: a tensor like x of log probabilities (in nats).
|
57 |
+
"""
|
58 |
+
centered_x = x - means
|
59 |
+
inv_stdv = th.exp(-log_scales)
|
60 |
+
normalized_x = centered_x * inv_stdv
|
61 |
+
return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
|
62 |
+
normalized_x
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
67 |
+
"""
|
68 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
69 |
+
given image.
|
70 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
71 |
+
rescaled to the range [-1, 1].
|
72 |
+
:param means: the Gaussian mean Tensor.
|
73 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
74 |
+
:return: a tensor like x of log probabilities (in nats).
|
75 |
+
"""
|
76 |
+
assert x.shape == means.shape == log_scales.shape
|
77 |
+
centered_x = x - means
|
78 |
+
inv_stdv = th.exp(-log_scales)
|
79 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
80 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
81 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
82 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
83 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
84 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
85 |
+
cdf_delta = cdf_plus - cdf_min
|
86 |
+
log_probs = th.where(
|
87 |
+
x < -0.999,
|
88 |
+
log_cdf_plus,
|
89 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
90 |
+
)
|
91 |
+
assert log_probs.shape == x.shape
|
92 |
+
return log_probs
|
DiT_VAE/diffusion/model/dpm_solver.py
ADDED
@@ -0,0 +1,1337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
|
5 |
+
class NoiseScheduleVP:
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
schedule='discrete',
|
9 |
+
betas=None,
|
10 |
+
alphas_cumprod=None,
|
11 |
+
continuous_beta_0=0.1,
|
12 |
+
continuous_beta_1=20.,
|
13 |
+
dtype=torch.float32,
|
14 |
+
):
|
15 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
16 |
+
|
17 |
+
***
|
18 |
+
Update: We support discrete-time diffusion vae by implementing a picewise linear interpolation for log_alpha_t.
|
19 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion vae, especially for high-resolution images.
|
20 |
+
***
|
21 |
+
|
22 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
23 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
24 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
25 |
+
|
26 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
27 |
+
sigma_t = self.marginal_std(t)
|
28 |
+
lambda_t = self.marginal_lambda(t)
|
29 |
+
|
30 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
31 |
+
|
32 |
+
t = self.inverse_lambda(lambda_t)
|
33 |
+
|
34 |
+
===============================================================
|
35 |
+
|
36 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
37 |
+
|
38 |
+
1. For discrete-time DPMs:
|
39 |
+
|
40 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
41 |
+
t_i = (i + 1) / N
|
42 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
43 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
47 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
48 |
+
|
49 |
+
Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
50 |
+
|
51 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
52 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
53 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
54 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
55 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
56 |
+
and
|
57 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
58 |
+
|
59 |
+
|
60 |
+
2. For continuous-time DPMs:
|
61 |
+
|
62 |
+
We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
|
63 |
+
schedule are the default settings in Yang Song's ScoreSDE:
|
64 |
+
|
65 |
+
Args:
|
66 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
67 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
68 |
+
T: A `float` number. The ending time of the forward process.
|
69 |
+
|
70 |
+
===============================================================
|
71 |
+
|
72 |
+
Args:
|
73 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
74 |
+
'linear' for continuous-time DPMs.
|
75 |
+
Returns:
|
76 |
+
A wrapper object of the forward SDE (VP type).
|
77 |
+
|
78 |
+
===============================================================
|
79 |
+
|
80 |
+
Example:
|
81 |
+
|
82 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
83 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
84 |
+
|
85 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
86 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
87 |
+
|
88 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
89 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
90 |
+
|
91 |
+
"""
|
92 |
+
|
93 |
+
if schedule not in ['discrete', 'linear']:
|
94 |
+
raise ValueError(
|
95 |
+
f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear'"
|
96 |
+
)
|
97 |
+
|
98 |
+
self.schedule = schedule
|
99 |
+
if schedule == 'discrete':
|
100 |
+
if betas is not None:
|
101 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
102 |
+
else:
|
103 |
+
assert alphas_cumprod is not None
|
104 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
105 |
+
self.T = 1.
|
106 |
+
self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
|
107 |
+
self.total_N = self.log_alpha_array.shape[1]
|
108 |
+
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
|
109 |
+
else:
|
110 |
+
self.T = 1.
|
111 |
+
self.total_N = 1000
|
112 |
+
self.beta_0 = continuous_beta_0
|
113 |
+
self.beta_1 = continuous_beta_1
|
114 |
+
|
115 |
+
def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
|
116 |
+
"""
|
117 |
+
For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
|
118 |
+
We clip the log-SNR near t=T within -5.1 to ensure the stability.
|
119 |
+
Such a trick is very useful for diffusion vae with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
|
120 |
+
"""
|
121 |
+
log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
|
122 |
+
lambs = log_alphas - log_sigmas
|
123 |
+
idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
|
124 |
+
if idx > 0:
|
125 |
+
log_alphas = log_alphas[:-idx]
|
126 |
+
return log_alphas
|
127 |
+
|
128 |
+
def marginal_log_mean_coeff(self, t):
|
129 |
+
"""
|
130 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
131 |
+
"""
|
132 |
+
if self.schedule == 'discrete':
|
133 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
|
134 |
+
self.log_alpha_array.to(t.device)).reshape((-1))
|
135 |
+
elif self.schedule == 'linear':
|
136 |
+
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
137 |
+
|
138 |
+
def marginal_alpha(self, t):
|
139 |
+
"""
|
140 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
141 |
+
"""
|
142 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
143 |
+
|
144 |
+
def marginal_std(self, t):
|
145 |
+
"""
|
146 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
147 |
+
"""
|
148 |
+
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
149 |
+
|
150 |
+
def marginal_lambda(self, t):
|
151 |
+
"""
|
152 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
153 |
+
"""
|
154 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
155 |
+
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
156 |
+
return log_mean_coeff - log_std
|
157 |
+
|
158 |
+
def inverse_lambda(self, lamb):
|
159 |
+
"""
|
160 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
161 |
+
"""
|
162 |
+
if self.schedule == 'linear':
|
163 |
+
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
164 |
+
Delta = self.beta_0 ** 2 + tmp
|
165 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
166 |
+
elif self.schedule == 'discrete':
|
167 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
168 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
169 |
+
torch.flip(self.t_array.to(lamb.device), [1]))
|
170 |
+
return t.reshape((-1,))
|
171 |
+
|
172 |
+
|
173 |
+
def model_wrapper(
|
174 |
+
model,
|
175 |
+
noise_schedule,
|
176 |
+
model_type="noise",
|
177 |
+
model_kwargs={},
|
178 |
+
guidance_type="uncond",
|
179 |
+
condition=None,
|
180 |
+
unconditional_condition=None,
|
181 |
+
guidance_scale=1.,
|
182 |
+
classifier_fn=None,
|
183 |
+
classifier_kwargs={},
|
184 |
+
):
|
185 |
+
"""Create a wrapper function for the noise prediction model.
|
186 |
+
|
187 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
188 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
189 |
+
|
190 |
+
We support four types of the diffusion model by setting `model_type`:
|
191 |
+
|
192 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
193 |
+
|
194 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
195 |
+
|
196 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
197 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
198 |
+
|
199 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion vae."
|
200 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
201 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
202 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
203 |
+
|
204 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
205 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
206 |
+
```
|
207 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
208 |
+
```
|
209 |
+
|
210 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
211 |
+
1. "uncond": unconditional sampling by DPMs.
|
212 |
+
The input `model` has the following format:
|
213 |
+
``
|
214 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
215 |
+
``
|
216 |
+
|
217 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
218 |
+
The input `model` has the following format:
|
219 |
+
``
|
220 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
221 |
+
``
|
222 |
+
|
223 |
+
The input `classifier_fn` has the following format:
|
224 |
+
``
|
225 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
226 |
+
``
|
227 |
+
|
228 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion vae beat GANs on image synthesis,"
|
229 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
230 |
+
|
231 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
232 |
+
The input `model` has the following format:
|
233 |
+
``
|
234 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
235 |
+
``
|
236 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
237 |
+
|
238 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
239 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
240 |
+
|
241 |
+
|
242 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
243 |
+
or continuous-time labels (i.e. epsilon to T).
|
244 |
+
|
245 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
246 |
+
``
|
247 |
+
def model_fn(x, t_continuous) -> noise:
|
248 |
+
t_input = get_model_input_time(t_continuous)
|
249 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
250 |
+
``
|
251 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
252 |
+
|
253 |
+
===============================================================
|
254 |
+
|
255 |
+
Args:
|
256 |
+
model: A diffusion model with the corresponding format described above.
|
257 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
258 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
259 |
+
"noise" or "x_start" or "v" or "score".
|
260 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
261 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
262 |
+
"uncond" or "classifier" or "classifier-free".
|
263 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
264 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
265 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
266 |
+
Only used for "classifier-free" guidance type.
|
267 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
268 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
269 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
270 |
+
Returns:
|
271 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
272 |
+
"""
|
273 |
+
|
274 |
+
def get_model_input_time(t_continuous):
|
275 |
+
"""
|
276 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
277 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
278 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
279 |
+
"""
|
280 |
+
if noise_schedule.schedule == 'discrete':
|
281 |
+
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
282 |
+
else:
|
283 |
+
return t_continuous
|
284 |
+
|
285 |
+
def noise_pred_fn(x, t_continuous, cond=None, cond_2=None):
|
286 |
+
t_input = get_model_input_time(t_continuous)
|
287 |
+
if cond is None:
|
288 |
+
output = model(x, t_input, **model_kwargs)
|
289 |
+
else:
|
290 |
+
output = model(x, t_input, y=cond, img_feature=cond_2, **model_kwargs)
|
291 |
+
if model_type == "noise":
|
292 |
+
return output
|
293 |
+
elif model_type == "x_start":
|
294 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
295 |
+
return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
|
296 |
+
elif model_type == "v":
|
297 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
298 |
+
return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
|
299 |
+
elif model_type == "score":
|
300 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
301 |
+
return -expand_dims(sigma_t, x.dim()) * output
|
302 |
+
|
303 |
+
def cond_grad_fn(x, t_input):
|
304 |
+
"""
|
305 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
306 |
+
"""
|
307 |
+
with torch.enable_grad():
|
308 |
+
x_in = x.detach().requires_grad_(True)
|
309 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
310 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
311 |
+
|
312 |
+
def model_fn(x, t_continuous):
|
313 |
+
"""
|
314 |
+
The noise predicition model function that is used for DPM-Solver.
|
315 |
+
"""
|
316 |
+
if guidance_type == "uncond":
|
317 |
+
return noise_pred_fn(x, t_continuous)
|
318 |
+
elif guidance_type == "classifier":
|
319 |
+
assert classifier_fn is not None
|
320 |
+
t_input = get_model_input_time(t_continuous)
|
321 |
+
cond_grad = cond_grad_fn(x, t_input)
|
322 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
323 |
+
noise = noise_pred_fn(x, t_continuous)
|
324 |
+
return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
|
325 |
+
elif guidance_type == "classifier-free":
|
326 |
+
if guidance_scale == 1. or unconditional_condition is None:
|
327 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
328 |
+
x_in = torch.cat([x] * 2)
|
329 |
+
t_in = torch.cat([t_continuous] * 2)
|
330 |
+
# c_in = torch.cat([unconditional_condition, condition])
|
331 |
+
c_in_y = torch.cat([unconditional_condition[0], condition[0]])
|
332 |
+
c_in_dino = torch.cat([unconditional_condition[1], condition[1]])
|
333 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in_y, cond_2=c_in_dino).chunk(2)
|
334 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
335 |
+
|
336 |
+
assert model_type in ["noise", "x_start", "v", "score"]
|
337 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
338 |
+
return model_fn
|
339 |
+
|
340 |
+
|
341 |
+
class DPM_Solver:
|
342 |
+
def __init__(
|
343 |
+
self,
|
344 |
+
model_fn,
|
345 |
+
noise_schedule,
|
346 |
+
algorithm_type="dpmsolver++",
|
347 |
+
correcting_x0_fn=None,
|
348 |
+
correcting_xt_fn=None,
|
349 |
+
thresholding_max_val=1.,
|
350 |
+
dynamic_thresholding_ratio=0.995,
|
351 |
+
):
|
352 |
+
"""Construct a DPM-Solver.
|
353 |
+
|
354 |
+
We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
|
355 |
+
|
356 |
+
We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion vae, you
|
357 |
+
can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
|
358 |
+
dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
|
359 |
+
DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
|
360 |
+
DPMs (such as stable-diffusion).
|
361 |
+
|
362 |
+
To support advanced algorithms in image-to-image applications, we also support corrector functions for
|
363 |
+
both x0 and xt.
|
364 |
+
|
365 |
+
Args:
|
366 |
+
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
|
367 |
+
``
|
368 |
+
def model_fn(x, t_continuous):
|
369 |
+
return noise
|
370 |
+
``
|
371 |
+
The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
|
372 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
373 |
+
algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
|
374 |
+
correcting_x0_fn: A `str` or a function with the following format:
|
375 |
+
```
|
376 |
+
def correcting_x0_fn(x0, t):
|
377 |
+
x0_new = ...
|
378 |
+
return x0_new
|
379 |
+
```
|
380 |
+
This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
|
381 |
+
```
|
382 |
+
x0_pred = data_pred_model(xt, t)
|
383 |
+
if correcting_x0_fn is not None:
|
384 |
+
x0_pred = correcting_x0_fn(x0_pred, t)
|
385 |
+
xt_1 = update(x0_pred, xt, t)
|
386 |
+
```
|
387 |
+
If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
|
388 |
+
correcting_xt_fn: A function with the following format:
|
389 |
+
```
|
390 |
+
def correcting_xt_fn(xt, t, step):
|
391 |
+
x_new = ...
|
392 |
+
return x_new
|
393 |
+
```
|
394 |
+
This function is to correct the intermediate samples xt at each sampling step. e.g.,
|
395 |
+
```
|
396 |
+
xt = ...
|
397 |
+
xt = correcting_xt_fn(xt, t, step)
|
398 |
+
```
|
399 |
+
thresholding_max_val: A `float`. The max value for thresholding.
|
400 |
+
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
|
401 |
+
dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
|
402 |
+
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
|
403 |
+
|
404 |
+
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
|
405 |
+
Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion vae
|
406 |
+
with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
|
407 |
+
"""
|
408 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
409 |
+
self.noise_schedule = noise_schedule
|
410 |
+
assert algorithm_type in ["dpmsolver", "dpmsolver++"]
|
411 |
+
self.algorithm_type = algorithm_type
|
412 |
+
if correcting_x0_fn == "dynamic_thresholding":
|
413 |
+
self.correcting_x0_fn = self.dynamic_thresholding_fn
|
414 |
+
else:
|
415 |
+
self.correcting_x0_fn = correcting_x0_fn
|
416 |
+
self.correcting_xt_fn = correcting_xt_fn
|
417 |
+
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
418 |
+
self.thresholding_max_val = thresholding_max_val
|
419 |
+
|
420 |
+
def dynamic_thresholding_fn(self, x0, t):
|
421 |
+
"""
|
422 |
+
The dynamic thresholding method.
|
423 |
+
"""
|
424 |
+
dims = x0.dim()
|
425 |
+
p = self.dynamic_thresholding_ratio
|
426 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
427 |
+
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
428 |
+
x0 = torch.clamp(x0, -s, s) / s
|
429 |
+
return x0
|
430 |
+
|
431 |
+
def noise_prediction_fn(self, x, t):
|
432 |
+
"""
|
433 |
+
Return the noise prediction model.
|
434 |
+
"""
|
435 |
+
return self.model(x, t)
|
436 |
+
|
437 |
+
def data_prediction_fn(self, x, t):
|
438 |
+
"""
|
439 |
+
Return the data prediction model (with corrector).
|
440 |
+
"""
|
441 |
+
noise = self.noise_prediction_fn(x, t)
|
442 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
443 |
+
x0 = (x - sigma_t * noise) / alpha_t
|
444 |
+
if self.correcting_x0_fn is not None:
|
445 |
+
x0 = self.correcting_x0_fn(x0, t)
|
446 |
+
return x0
|
447 |
+
|
448 |
+
def model_fn(self, x, t):
|
449 |
+
"""
|
450 |
+
Convert the model to the noise prediction model or the data prediction model.
|
451 |
+
"""
|
452 |
+
if self.algorithm_type == "dpmsolver++":
|
453 |
+
return self.data_prediction_fn(x, t)
|
454 |
+
else:
|
455 |
+
return self.noise_prediction_fn(x, t)
|
456 |
+
|
457 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
458 |
+
"""Compute the intermediate time steps for sampling.
|
459 |
+
|
460 |
+
Args:
|
461 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
462 |
+
- 'logSNR': uniform logSNR for the time steps.
|
463 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
464 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
465 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
466 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
467 |
+
N: A `int`. The total number of the spacing of the time steps.
|
468 |
+
device: A torch device.
|
469 |
+
Returns:
|
470 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
471 |
+
"""
|
472 |
+
if skip_type == 'logSNR':
|
473 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
474 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
475 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
476 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
477 |
+
elif skip_type == 'time_uniform':
|
478 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
479 |
+
elif skip_type == 'time_quadratic':
|
480 |
+
t_order = 2
|
481 |
+
return (
|
482 |
+
torch.linspace(
|
483 |
+
t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1
|
484 |
+
)
|
485 |
+
.pow(t_order)
|
486 |
+
.to(device)
|
487 |
+
)
|
488 |
+
else:
|
489 |
+
raise ValueError(
|
490 |
+
f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
|
491 |
+
)
|
492 |
+
|
493 |
+
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
494 |
+
"""
|
495 |
+
Get the order of each step for sampling by the singlestep DPM-Solver.
|
496 |
+
|
497 |
+
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
|
498 |
+
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
|
499 |
+
- If order == 1:
|
500 |
+
We take `steps` of DPM-Solver-1 (i.e. DDIM).
|
501 |
+
- If order == 2:
|
502 |
+
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
|
503 |
+
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
|
504 |
+
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
505 |
+
- If order == 3:
|
506 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
507 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
508 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
509 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
510 |
+
|
511 |
+
============================================
|
512 |
+
Args:
|
513 |
+
order: A `int`. The max order for the solver (2 or 3).
|
514 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
515 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
516 |
+
- 'logSNR': uniform logSNR for the time steps.
|
517 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
518 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
519 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
520 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
521 |
+
device: A torch device.
|
522 |
+
Returns:
|
523 |
+
orders: A list of the solver order of each step.
|
524 |
+
"""
|
525 |
+
if order == 3:
|
526 |
+
K = steps // 3 + 1
|
527 |
+
if steps % 3 == 0:
|
528 |
+
orders = [3, ] * (K - 2) + [2, 1]
|
529 |
+
elif steps % 3 == 1:
|
530 |
+
orders = [3, ] * (K - 1) + [1]
|
531 |
+
else:
|
532 |
+
orders = [3, ] * (K - 1) + [2]
|
533 |
+
elif order == 2:
|
534 |
+
if steps % 2 == 0:
|
535 |
+
K = steps // 2
|
536 |
+
orders = [2, ] * K
|
537 |
+
else:
|
538 |
+
K = steps // 2 + 1
|
539 |
+
orders = [2, ] * (K - 1) + [1]
|
540 |
+
elif order == 1:
|
541 |
+
K = 1
|
542 |
+
orders = [1, ] * steps
|
543 |
+
else:
|
544 |
+
raise ValueError("'order' must be '1' or '2' or '3'.")
|
545 |
+
if skip_type == 'logSNR':
|
546 |
+
# To reproduce the results in DPM-Solver paper
|
547 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
548 |
+
else:
|
549 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
|
550 |
+
torch.cumsum(torch.tensor([0, ] + orders), 0).to(device)]
|
551 |
+
return timesteps_outer, orders
|
552 |
+
|
553 |
+
def denoise_to_zero_fn(self, x, s):
|
554 |
+
"""
|
555 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
556 |
+
"""
|
557 |
+
return self.data_prediction_fn(x, s)
|
558 |
+
|
559 |
+
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
|
560 |
+
"""
|
561 |
+
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
|
562 |
+
|
563 |
+
Args:
|
564 |
+
x: A pytorch tensor. The initial value at time `s`.
|
565 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
566 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
567 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
568 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
569 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`.
|
570 |
+
Returns:
|
571 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
572 |
+
"""
|
573 |
+
ns = self.noise_schedule
|
574 |
+
dims = x.dim()
|
575 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
576 |
+
h = lambda_t - lambda_s
|
577 |
+
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
578 |
+
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
579 |
+
alpha_t = torch.exp(log_alpha_t)
|
580 |
+
|
581 |
+
if self.algorithm_type == "dpmsolver++":
|
582 |
+
phi_1 = torch.expm1(-h)
|
583 |
+
if model_s is None:
|
584 |
+
model_s = self.model_fn(x, s)
|
585 |
+
x_t = (
|
586 |
+
sigma_t / sigma_s * x
|
587 |
+
- alpha_t * phi_1 * model_s
|
588 |
+
)
|
589 |
+
else:
|
590 |
+
phi_1 = torch.expm1(h)
|
591 |
+
if model_s is None:
|
592 |
+
model_s = self.model_fn(x, s)
|
593 |
+
x_t = (
|
594 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
595 |
+
- (sigma_t * phi_1) * model_s
|
596 |
+
)
|
597 |
+
return (x_t, {'model_s': model_s}) if return_intermediate else x_t
|
598 |
+
|
599 |
+
def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
|
600 |
+
solver_type='dpmsolver'):
|
601 |
+
"""
|
602 |
+
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
|
603 |
+
|
604 |
+
Args:
|
605 |
+
x: A pytorch tensor. The initial value at time `s`.
|
606 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
607 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
608 |
+
r1: A `float`. The hyperparameter of the second-order solver.
|
609 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
610 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
611 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
|
612 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
613 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
614 |
+
Returns:
|
615 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
616 |
+
"""
|
617 |
+
if solver_type not in ['dpmsolver', 'taylor']:
|
618 |
+
raise ValueError(
|
619 |
+
f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}"
|
620 |
+
)
|
621 |
+
if r1 is None:
|
622 |
+
r1 = 0.5
|
623 |
+
ns = self.noise_schedule
|
624 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
625 |
+
h = lambda_t - lambda_s
|
626 |
+
lambda_s1 = lambda_s + r1 * h
|
627 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
628 |
+
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
|
629 |
+
s1), ns.marginal_log_mean_coeff(t)
|
630 |
+
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
|
631 |
+
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
632 |
+
|
633 |
+
if self.algorithm_type == "dpmsolver++":
|
634 |
+
phi_11 = torch.expm1(-r1 * h)
|
635 |
+
phi_1 = torch.expm1(-h)
|
636 |
+
|
637 |
+
if model_s is None:
|
638 |
+
model_s = self.model_fn(x, s)
|
639 |
+
x_s1 = (
|
640 |
+
(sigma_s1 / sigma_s) * x
|
641 |
+
- (alpha_s1 * phi_11) * model_s
|
642 |
+
)
|
643 |
+
model_s1 = self.model_fn(x_s1, s1)
|
644 |
+
if solver_type == 'dpmsolver':
|
645 |
+
x_t = (
|
646 |
+
(sigma_t / sigma_s) * x
|
647 |
+
- (alpha_t * phi_1) * model_s
|
648 |
+
- (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
|
649 |
+
)
|
650 |
+
elif solver_type == 'taylor':
|
651 |
+
x_t = (
|
652 |
+
(sigma_t / sigma_s) * x
|
653 |
+
- (alpha_t * phi_1) * model_s
|
654 |
+
+ (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
|
655 |
+
)
|
656 |
+
else:
|
657 |
+
phi_11 = torch.expm1(r1 * h)
|
658 |
+
phi_1 = torch.expm1(h)
|
659 |
+
|
660 |
+
if model_s is None:
|
661 |
+
model_s = self.model_fn(x, s)
|
662 |
+
x_s1 = (
|
663 |
+
torch.exp(log_alpha_s1 - log_alpha_s) * x
|
664 |
+
- (sigma_s1 * phi_11) * model_s
|
665 |
+
)
|
666 |
+
model_s1 = self.model_fn(x_s1, s1)
|
667 |
+
if solver_type == 'dpmsolver':
|
668 |
+
x_t = (
|
669 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
670 |
+
- (sigma_t * phi_1) * model_s
|
671 |
+
- (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
|
672 |
+
)
|
673 |
+
elif solver_type == 'taylor':
|
674 |
+
x_t = (
|
675 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
676 |
+
- (sigma_t * phi_1) * model_s
|
677 |
+
- (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
|
678 |
+
)
|
679 |
+
if return_intermediate:
|
680 |
+
return x_t, {'model_s': model_s, 'model_s1': model_s1}
|
681 |
+
else:
|
682 |
+
return x_t
|
683 |
+
|
684 |
+
def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
|
685 |
+
return_intermediate=False, solver_type='dpmsolver'):
|
686 |
+
"""
|
687 |
+
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
|
688 |
+
|
689 |
+
Args:
|
690 |
+
x: A pytorch tensor. The initial value at time `s`.
|
691 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
692 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
693 |
+
r1: A `float`. The hyperparameter of the third-order solver.
|
694 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
695 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
696 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
697 |
+
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
|
698 |
+
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
|
699 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
700 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
701 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
702 |
+
Returns:
|
703 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
704 |
+
"""
|
705 |
+
if solver_type not in ['dpmsolver', 'taylor']:
|
706 |
+
raise ValueError(
|
707 |
+
f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}"
|
708 |
+
)
|
709 |
+
if r1 is None:
|
710 |
+
r1 = 1. / 3.
|
711 |
+
if r2 is None:
|
712 |
+
r2 = 2. / 3.
|
713 |
+
ns = self.noise_schedule
|
714 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
715 |
+
h = lambda_t - lambda_s
|
716 |
+
lambda_s1 = lambda_s + r1 * h
|
717 |
+
lambda_s2 = lambda_s + r2 * h
|
718 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
719 |
+
s2 = ns.inverse_lambda(lambda_s2)
|
720 |
+
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
|
721 |
+
s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
|
722 |
+
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
|
723 |
+
s2), ns.marginal_std(t)
|
724 |
+
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
|
725 |
+
|
726 |
+
if self.algorithm_type == "dpmsolver++":
|
727 |
+
phi_11 = torch.expm1(-r1 * h)
|
728 |
+
phi_12 = torch.expm1(-r2 * h)
|
729 |
+
phi_1 = torch.expm1(-h)
|
730 |
+
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
|
731 |
+
phi_2 = phi_1 / h + 1.
|
732 |
+
phi_3 = phi_2 / h - 0.5
|
733 |
+
|
734 |
+
if model_s is None:
|
735 |
+
model_s = self.model_fn(x, s)
|
736 |
+
if model_s1 is None:
|
737 |
+
x_s1 = (
|
738 |
+
(sigma_s1 / sigma_s) * x
|
739 |
+
- (alpha_s1 * phi_11) * model_s
|
740 |
+
)
|
741 |
+
model_s1 = self.model_fn(x_s1, s1)
|
742 |
+
x_s2 = (
|
743 |
+
(sigma_s2 / sigma_s) * x
|
744 |
+
- (alpha_s2 * phi_12) * model_s
|
745 |
+
+ r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
|
746 |
+
)
|
747 |
+
model_s2 = self.model_fn(x_s2, s2)
|
748 |
+
if solver_type == 'dpmsolver':
|
749 |
+
x_t = (
|
750 |
+
(sigma_t / sigma_s) * x
|
751 |
+
- (alpha_t * phi_1) * model_s
|
752 |
+
+ (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
|
753 |
+
)
|
754 |
+
elif solver_type == 'taylor':
|
755 |
+
D1_0 = (1. / r1) * (model_s1 - model_s)
|
756 |
+
D1_1 = (1. / r2) * (model_s2 - model_s)
|
757 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
758 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
759 |
+
x_t = (
|
760 |
+
(sigma_t / sigma_s) * x
|
761 |
+
- (alpha_t * phi_1) * model_s
|
762 |
+
+ (alpha_t * phi_2) * D1
|
763 |
+
- (alpha_t * phi_3) * D2
|
764 |
+
)
|
765 |
+
else:
|
766 |
+
phi_11 = torch.expm1(r1 * h)
|
767 |
+
phi_12 = torch.expm1(r2 * h)
|
768 |
+
phi_1 = torch.expm1(h)
|
769 |
+
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
|
770 |
+
phi_2 = phi_1 / h - 1.
|
771 |
+
phi_3 = phi_2 / h - 0.5
|
772 |
+
|
773 |
+
if model_s is None:
|
774 |
+
model_s = self.model_fn(x, s)
|
775 |
+
if model_s1 is None:
|
776 |
+
x_s1 = (
|
777 |
+
(torch.exp(log_alpha_s1 - log_alpha_s)) * x
|
778 |
+
- (sigma_s1 * phi_11) * model_s
|
779 |
+
)
|
780 |
+
model_s1 = self.model_fn(x_s1, s1)
|
781 |
+
x_s2 = (
|
782 |
+
(torch.exp(log_alpha_s2 - log_alpha_s)) * x
|
783 |
+
- (sigma_s2 * phi_12) * model_s
|
784 |
+
- r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
|
785 |
+
)
|
786 |
+
model_s2 = self.model_fn(x_s2, s2)
|
787 |
+
if solver_type == 'dpmsolver':
|
788 |
+
x_t = (
|
789 |
+
(torch.exp(log_alpha_t - log_alpha_s)) * x
|
790 |
+
- (sigma_t * phi_1) * model_s
|
791 |
+
- (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
|
792 |
+
)
|
793 |
+
elif solver_type == 'taylor':
|
794 |
+
D1_0 = (1. / r1) * (model_s1 - model_s)
|
795 |
+
D1_1 = (1. / r2) * (model_s2 - model_s)
|
796 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
797 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
798 |
+
x_t = (
|
799 |
+
(torch.exp(log_alpha_t - log_alpha_s)) * x
|
800 |
+
- (sigma_t * phi_1) * model_s
|
801 |
+
- (sigma_t * phi_2) * D1
|
802 |
+
- (sigma_t * phi_3) * D2
|
803 |
+
)
|
804 |
+
|
805 |
+
if return_intermediate:
|
806 |
+
return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
|
807 |
+
else:
|
808 |
+
return x_t
|
809 |
+
|
810 |
+
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
|
811 |
+
"""
|
812 |
+
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
|
813 |
+
|
814 |
+
Args:
|
815 |
+
x: A pytorch tensor. The initial value at time `s`.
|
816 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
817 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
818 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
819 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
820 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
821 |
+
Returns:
|
822 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
823 |
+
"""
|
824 |
+
if solver_type not in ['dpmsolver', 'taylor']:
|
825 |
+
raise ValueError(
|
826 |
+
f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}"
|
827 |
+
)
|
828 |
+
ns = self.noise_schedule
|
829 |
+
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
|
830 |
+
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
|
831 |
+
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
|
832 |
+
t_prev_0), ns.marginal_lambda(t)
|
833 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
834 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
835 |
+
alpha_t = torch.exp(log_alpha_t)
|
836 |
+
|
837 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
838 |
+
h = lambda_t - lambda_prev_0
|
839 |
+
r0 = h_0 / h
|
840 |
+
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
|
841 |
+
if self.algorithm_type == "dpmsolver++":
|
842 |
+
phi_1 = torch.expm1(-h)
|
843 |
+
if solver_type == 'dpmsolver':
|
844 |
+
x_t = (
|
845 |
+
(sigma_t / sigma_prev_0) * x
|
846 |
+
- (alpha_t * phi_1) * model_prev_0
|
847 |
+
- 0.5 * (alpha_t * phi_1) * D1_0
|
848 |
+
)
|
849 |
+
elif solver_type == 'taylor':
|
850 |
+
x_t = (
|
851 |
+
(sigma_t / sigma_prev_0) * x
|
852 |
+
- (alpha_t * phi_1) * model_prev_0
|
853 |
+
+ (alpha_t * (phi_1 / h + 1.)) * D1_0
|
854 |
+
)
|
855 |
+
else:
|
856 |
+
phi_1 = torch.expm1(h)
|
857 |
+
if solver_type == 'dpmsolver':
|
858 |
+
x_t = (
|
859 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
860 |
+
- (sigma_t * phi_1) * model_prev_0
|
861 |
+
- 0.5 * (sigma_t * phi_1) * D1_0
|
862 |
+
)
|
863 |
+
elif solver_type == 'taylor':
|
864 |
+
x_t = (
|
865 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
866 |
+
- (sigma_t * phi_1) * model_prev_0
|
867 |
+
- (sigma_t * (phi_1 / h - 1.)) * D1_0
|
868 |
+
)
|
869 |
+
return x_t
|
870 |
+
|
871 |
+
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
|
872 |
+
"""
|
873 |
+
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
|
874 |
+
|
875 |
+
Args:
|
876 |
+
x: A pytorch tensor. The initial value at time `s`.
|
877 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
878 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
879 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
880 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
881 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
882 |
+
Returns:
|
883 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
884 |
+
"""
|
885 |
+
ns = self.noise_schedule
|
886 |
+
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
|
887 |
+
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
888 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
|
889 |
+
t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
890 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
891 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
892 |
+
alpha_t = torch.exp(log_alpha_t)
|
893 |
+
|
894 |
+
h_1 = lambda_prev_1 - lambda_prev_2
|
895 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
896 |
+
h = lambda_t - lambda_prev_0
|
897 |
+
r0, r1 = h_0 / h, h_1 / h
|
898 |
+
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
|
899 |
+
D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
|
900 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
901 |
+
D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
|
902 |
+
if self.algorithm_type == "dpmsolver++":
|
903 |
+
phi_1 = torch.expm1(-h)
|
904 |
+
phi_2 = phi_1 / h + 1.
|
905 |
+
phi_3 = phi_2 / h - 0.5
|
906 |
+
return (
|
907 |
+
(sigma_t / sigma_prev_0) * x
|
908 |
+
- (alpha_t * phi_1) * model_prev_0
|
909 |
+
+ (alpha_t * phi_2) * D1
|
910 |
+
- (alpha_t * phi_3) * D2
|
911 |
+
)
|
912 |
+
else:
|
913 |
+
phi_1 = torch.expm1(h)
|
914 |
+
phi_2 = phi_1 / h - 1.
|
915 |
+
phi_3 = phi_2 / h - 0.5
|
916 |
+
return (
|
917 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
918 |
+
- (sigma_t * phi_1) * model_prev_0
|
919 |
+
- (sigma_t * phi_2) * D1
|
920 |
+
- (sigma_t * phi_3) * D2
|
921 |
+
)
|
922 |
+
|
923 |
+
def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None,
|
924 |
+
r2=None):
|
925 |
+
"""
|
926 |
+
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
|
927 |
+
|
928 |
+
Args:
|
929 |
+
x: A pytorch tensor. The initial value at time `s`.
|
930 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
931 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
932 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
933 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
934 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
935 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
936 |
+
r1: A `float`. The hyperparameter of the second-order or third-order solver.
|
937 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
938 |
+
Returns:
|
939 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
940 |
+
"""
|
941 |
+
if order == 1:
|
942 |
+
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
|
943 |
+
elif order == 2:
|
944 |
+
return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
|
945 |
+
solver_type=solver_type, r1=r1)
|
946 |
+
elif order == 3:
|
947 |
+
return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
|
948 |
+
solver_type=solver_type, r1=r1, r2=r2)
|
949 |
+
else:
|
950 |
+
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
|
951 |
+
|
952 |
+
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
|
953 |
+
"""
|
954 |
+
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
|
955 |
+
|
956 |
+
Args:
|
957 |
+
x: A pytorch tensor. The initial value at time `s`.
|
958 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
959 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
960 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
961 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
962 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
963 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
964 |
+
Returns:
|
965 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
966 |
+
"""
|
967 |
+
if order == 1:
|
968 |
+
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
|
969 |
+
elif order == 2:
|
970 |
+
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
971 |
+
elif order == 3:
|
972 |
+
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
973 |
+
else:
|
974 |
+
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
|
975 |
+
|
976 |
+
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
|
977 |
+
solver_type='dpmsolver'):
|
978 |
+
"""
|
979 |
+
The adaptive step size solver based on singlestep DPM-Solver.
|
980 |
+
|
981 |
+
Args:
|
982 |
+
x: A pytorch tensor. The initial value at time `t_T`.
|
983 |
+
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
|
984 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
985 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
986 |
+
h_init: A `float`. The initial step size (for logSNR).
|
987 |
+
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
|
988 |
+
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
|
989 |
+
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
|
990 |
+
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
|
991 |
+
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
|
992 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
993 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
994 |
+
Returns:
|
995 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
996 |
+
|
997 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based vae," arXiv preprint arXiv:2105.14080, 2021.
|
998 |
+
"""
|
999 |
+
ns = self.noise_schedule
|
1000 |
+
s = t_T * torch.ones((1,)).to(x)
|
1001 |
+
lambda_s = ns.marginal_lambda(s)
|
1002 |
+
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
|
1003 |
+
h = h_init * torch.ones_like(s).to(x)
|
1004 |
+
x_prev = x
|
1005 |
+
nfe = 0
|
1006 |
+
if order == 2:
|
1007 |
+
r1 = 0.5
|
1008 |
+
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
|
1009 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
1010 |
+
solver_type=solver_type,
|
1011 |
+
**kwargs)
|
1012 |
+
elif order == 3:
|
1013 |
+
r1, r2 = 1. / 3., 2. / 3.
|
1014 |
+
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
1015 |
+
return_intermediate=True,
|
1016 |
+
solver_type=solver_type)
|
1017 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
|
1018 |
+
solver_type=solver_type,
|
1019 |
+
**kwargs)
|
1020 |
+
else:
|
1021 |
+
raise ValueError(
|
1022 |
+
f"For adaptive step size solver, order must be 2 or 3, got {order}"
|
1023 |
+
)
|
1024 |
+
while torch.abs((s - t_0)).mean() > t_err:
|
1025 |
+
t = ns.inverse_lambda(lambda_s + h)
|
1026 |
+
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
1027 |
+
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
1028 |
+
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
|
1029 |
+
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
|
1030 |
+
E = norm_fn((x_higher - x_lower) / delta).max()
|
1031 |
+
if torch.all(E <= 1.):
|
1032 |
+
x = x_higher
|
1033 |
+
s = t
|
1034 |
+
x_prev = x_lower
|
1035 |
+
lambda_s = ns.marginal_lambda(s)
|
1036 |
+
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
|
1037 |
+
nfe += order
|
1038 |
+
print('adaptive solver nfe', nfe)
|
1039 |
+
return x
|
1040 |
+
|
1041 |
+
def add_noise(self, x, t, noise=None):
|
1042 |
+
"""
|
1043 |
+
Compute the noised input xt = alpha_t * x + sigma_t * noise.
|
1044 |
+
|
1045 |
+
Args:
|
1046 |
+
x: A `torch.Tensor` with shape `(batch_size, *shape)`.
|
1047 |
+
t: A `torch.Tensor` with shape `(t_size,)`.
|
1048 |
+
Returns:
|
1049 |
+
xt with shape `(t_size, batch_size, *shape)`.
|
1050 |
+
"""
|
1051 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
1052 |
+
if noise is None:
|
1053 |
+
noise = torch.randn((t.shape[0], *x.shape), device=x.device)
|
1054 |
+
x = x.reshape((-1, *x.shape))
|
1055 |
+
xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
|
1056 |
+
return xt.squeeze(0) if t.shape[0] == 1 else xt
|
1057 |
+
|
1058 |
+
def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
|
1059 |
+
method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
|
1060 |
+
atol=0.0078, rtol=0.05, return_intermediate=False,
|
1061 |
+
):
|
1062 |
+
"""
|
1063 |
+
Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
|
1064 |
+
For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
|
1065 |
+
"""
|
1066 |
+
t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
|
1067 |
+
t_T = self.noise_schedule.T if t_end is None else t_end
|
1068 |
+
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
1069 |
+
return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
|
1070 |
+
method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero,
|
1071 |
+
solver_type=solver_type,
|
1072 |
+
atol=atol, rtol=rtol, return_intermediate=return_intermediate)
|
1073 |
+
|
1074 |
+
def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
|
1075 |
+
method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
|
1076 |
+
atol=0.0078, rtol=0.05, return_intermediate=False,
|
1077 |
+
):
|
1078 |
+
"""
|
1079 |
+
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
1080 |
+
|
1081 |
+
=====================================================
|
1082 |
+
|
1083 |
+
We support the following algorithms for both noise prediction model and data prediction model:
|
1084 |
+
- 'singlestep':
|
1085 |
+
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
|
1086 |
+
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
|
1087 |
+
The total number of function evaluations (NFE) == `steps`.
|
1088 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
1089 |
+
- If `order` == 1:
|
1090 |
+
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
|
1091 |
+
- If `order` == 2:
|
1092 |
+
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
|
1093 |
+
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
|
1094 |
+
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
1095 |
+
- If `order` == 3:
|
1096 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
1097 |
+
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
1098 |
+
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
|
1099 |
+
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
|
1100 |
+
- 'multistep':
|
1101 |
+
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
|
1102 |
+
We initialize the first `order` values by lower order multistep solvers.
|
1103 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
1104 |
+
Denote K = steps.
|
1105 |
+
- If `order` == 1:
|
1106 |
+
- We use K steps of DPM-Solver-1 (i.e. DDIM).
|
1107 |
+
- If `order` == 2:
|
1108 |
+
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
|
1109 |
+
- If `order` == 3:
|
1110 |
+
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
|
1111 |
+
- 'singlestep_fixed':
|
1112 |
+
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
|
1113 |
+
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
|
1114 |
+
- 'adaptive':
|
1115 |
+
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
|
1116 |
+
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
|
1117 |
+
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
|
1118 |
+
(NFE) and the sample quality.
|
1119 |
+
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
|
1120 |
+
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
|
1121 |
+
|
1122 |
+
=====================================================
|
1123 |
+
|
1124 |
+
Some advices for choosing the algorithm:
|
1125 |
+
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
|
1126 |
+
Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
|
1127 |
+
e.g., DPM-Solver:
|
1128 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
|
1129 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
1130 |
+
skip_type='time_uniform', method='singlestep')
|
1131 |
+
e.g., DPM-Solver++:
|
1132 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
1133 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
1134 |
+
skip_type='time_uniform', method='singlestep')
|
1135 |
+
- For **guided sampling with large guidance scale** by DPMs:
|
1136 |
+
Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
|
1137 |
+
e.g.
|
1138 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
1139 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
|
1140 |
+
skip_type='time_uniform', method='multistep')
|
1141 |
+
|
1142 |
+
We support three types of `skip_type`:
|
1143 |
+
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
|
1144 |
+
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
|
1145 |
+
- 'time_quadratic': quadratic time for the time steps.
|
1146 |
+
|
1147 |
+
=====================================================
|
1148 |
+
Args:
|
1149 |
+
x: A pytorch tensor. The initial value at time `t_start`
|
1150 |
+
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
|
1151 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
1152 |
+
t_start: A `float`. The starting time of the sampling.
|
1153 |
+
If `T` is None, we use self.noise_schedule.T (default is 1.0).
|
1154 |
+
t_end: A `float`. The ending time of the sampling.
|
1155 |
+
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
|
1156 |
+
e.g. if total_N == 1000, we have `t_end` == 1e-3.
|
1157 |
+
For discrete-time DPMs:
|
1158 |
+
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
|
1159 |
+
For continuous-time DPMs:
|
1160 |
+
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
|
1161 |
+
order: A `int`. The order of DPM-Solver.
|
1162 |
+
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
|
1163 |
+
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
|
1164 |
+
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
|
1165 |
+
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
|
1166 |
+
|
1167 |
+
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
|
1168 |
+
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
|
1169 |
+
for diffusion vae sampling by diffusion SDEs for low-resolutional images
|
1170 |
+
(such as CIFAR-10). However, we observed that such trick does not matter for
|
1171 |
+
high-resolutional images. As it needs an additional NFE, we do not recommend
|
1172 |
+
it for high-resolutional images.
|
1173 |
+
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
|
1174 |
+
Only valid for `method=multistep` and `steps < 15`. We empirically find that
|
1175 |
+
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
|
1176 |
+
(especially for steps <= 10). So we recommend to set it to be `True`.
|
1177 |
+
solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
|
1178 |
+
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
1179 |
+
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
1180 |
+
return_intermediate: A `bool`. Whether to save the xt at each step.
|
1181 |
+
When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
|
1182 |
+
Returns:
|
1183 |
+
x_end: A pytorch tensor. The approximated solution at time `t_end`.
|
1184 |
+
|
1185 |
+
"""
|
1186 |
+
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
1187 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
1188 |
+
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
1189 |
+
if return_intermediate:
|
1190 |
+
assert method in ['multistep', 'singlestep',
|
1191 |
+
'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
|
1192 |
+
if self.correcting_xt_fn is not None:
|
1193 |
+
assert method in ['multistep', 'singlestep',
|
1194 |
+
'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
|
1195 |
+
device = x.device
|
1196 |
+
intermediates = []
|
1197 |
+
with torch.no_grad():
|
1198 |
+
if method == 'adaptive':
|
1199 |
+
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
|
1200 |
+
solver_type=solver_type)
|
1201 |
+
elif method == 'multistep':
|
1202 |
+
assert steps >= order
|
1203 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
1204 |
+
assert timesteps.shape[0] - 1 == steps
|
1205 |
+
# Init the initial values.
|
1206 |
+
step = 0
|
1207 |
+
t = timesteps[step]
|
1208 |
+
t_prev_list = [t]
|
1209 |
+
model_prev_list = [self.model_fn(x, t)]
|
1210 |
+
if self.correcting_xt_fn is not None:
|
1211 |
+
x = self.correcting_xt_fn(x, t, step)
|
1212 |
+
if return_intermediate:
|
1213 |
+
intermediates.append(x)
|
1214 |
+
# Init the first `order` values by lower order multistep DPM-Solver.
|
1215 |
+
for step in range(1, order):
|
1216 |
+
t = timesteps[step]
|
1217 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step,
|
1218 |
+
solver_type=solver_type)
|
1219 |
+
if self.correcting_xt_fn is not None:
|
1220 |
+
x = self.correcting_xt_fn(x, t, step)
|
1221 |
+
if return_intermediate:
|
1222 |
+
intermediates.append(x)
|
1223 |
+
t_prev_list.append(t)
|
1224 |
+
model_prev_list.append(self.model_fn(x, t))
|
1225 |
+
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
1226 |
+
for step in tqdm(range(order, steps + 1)):
|
1227 |
+
t = timesteps[step]
|
1228 |
+
# We only use lower order for steps < 10
|
1229 |
+
if lower_order_final and steps < 10:
|
1230 |
+
step_order = min(order, steps + 1 - step)
|
1231 |
+
else:
|
1232 |
+
step_order = order
|
1233 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order,
|
1234 |
+
solver_type=solver_type)
|
1235 |
+
if self.correcting_xt_fn is not None:
|
1236 |
+
x = self.correcting_xt_fn(x, t, step)
|
1237 |
+
if return_intermediate:
|
1238 |
+
intermediates.append(x)
|
1239 |
+
for i in range(order - 1):
|
1240 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
1241 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
1242 |
+
t_prev_list[-1] = t
|
1243 |
+
# We do not need to evaluate the final model value.
|
1244 |
+
if step < steps:
|
1245 |
+
model_prev_list[-1] = self.model_fn(x, t)
|
1246 |
+
elif method in ['singlestep', 'singlestep_fixed']:
|
1247 |
+
if method == 'singlestep':
|
1248 |
+
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps,
|
1249 |
+
order=order,
|
1250 |
+
skip_type=skip_type,
|
1251 |
+
t_T=t_T, t_0=t_0,
|
1252 |
+
device=device)
|
1253 |
+
elif method == 'singlestep_fixed':
|
1254 |
+
K = steps // order
|
1255 |
+
orders = [order, ] * K
|
1256 |
+
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
|
1257 |
+
for step, order in enumerate(orders):
|
1258 |
+
s, t = timesteps_outer[step], timesteps_outer[step + 1]
|
1259 |
+
timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order,
|
1260 |
+
device=device)
|
1261 |
+
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
|
1262 |
+
h = lambda_inner[-1] - lambda_inner[0]
|
1263 |
+
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
1264 |
+
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
1265 |
+
x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
|
1266 |
+
if self.correcting_xt_fn is not None:
|
1267 |
+
x = self.correcting_xt_fn(x, t, step)
|
1268 |
+
if return_intermediate:
|
1269 |
+
intermediates.append(x)
|
1270 |
+
else:
|
1271 |
+
raise ValueError(f"Got wrong method {method}")
|
1272 |
+
if denoise_to_zero:
|
1273 |
+
t = torch.ones((1,)).to(device) * t_0
|
1274 |
+
x = self.denoise_to_zero_fn(x, t)
|
1275 |
+
if self.correcting_xt_fn is not None:
|
1276 |
+
x = self.correcting_xt_fn(x, t, step + 1)
|
1277 |
+
if return_intermediate:
|
1278 |
+
intermediates.append(x)
|
1279 |
+
return (x, intermediates) if return_intermediate else x
|
1280 |
+
|
1281 |
+
|
1282 |
+
#############################################################
|
1283 |
+
# other utility functions
|
1284 |
+
#############################################################
|
1285 |
+
|
1286 |
+
def interpolate_fn(x, xp, yp):
|
1287 |
+
"""
|
1288 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
1289 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
1290 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
1291 |
+
|
1292 |
+
Args:
|
1293 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
1294 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
1295 |
+
yp: PyTorch tensor with shape [C, K].
|
1296 |
+
Returns:
|
1297 |
+
The function values f(x), with shape [N, C].
|
1298 |
+
"""
|
1299 |
+
N, K = x.shape[0], xp.shape[1]
|
1300 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
1301 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
1302 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
1303 |
+
cand_start_idx = x_idx - 1
|
1304 |
+
start_idx = torch.where(
|
1305 |
+
torch.eq(x_idx, 0),
|
1306 |
+
torch.tensor(1, device=x.device),
|
1307 |
+
torch.where(
|
1308 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1309 |
+
),
|
1310 |
+
)
|
1311 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
1312 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
1313 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
1314 |
+
start_idx2 = torch.where(
|
1315 |
+
torch.eq(x_idx, 0),
|
1316 |
+
torch.tensor(0, device=x.device),
|
1317 |
+
torch.where(
|
1318 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1319 |
+
),
|
1320 |
+
)
|
1321 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
1322 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
1323 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
1324 |
+
return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
1325 |
+
|
1326 |
+
|
1327 |
+
def expand_dims(v, dims):
|
1328 |
+
"""
|
1329 |
+
Expand the tensor `v` to the dim `dims`.
|
1330 |
+
|
1331 |
+
Args:
|
1332 |
+
`v`: a PyTorch tensor with shape [N].
|
1333 |
+
`dim`: a `int`.
|
1334 |
+
Returns:
|
1335 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
1336 |
+
"""
|
1337 |
+
return v[(...,) + (None,) * (dims - 1)]
|
DiT_VAE/diffusion/model/edm_sample.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
|
5 |
+
# ----------------------------------------------------------------------------
|
6 |
+
# Proposed EDM sampler (Algorithm 2).
|
7 |
+
|
8 |
+
def edm_sampler(
|
9 |
+
net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like,
|
10 |
+
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
11 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs
|
12 |
+
):
|
13 |
+
# Adjust noise levels based on what's supported by the network.
|
14 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
15 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
16 |
+
|
17 |
+
# Time step discretization.
|
18 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
19 |
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
20 |
+
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
21 |
+
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
22 |
+
|
23 |
+
# Main sampling loop.
|
24 |
+
x_next = latents.to(torch.float64) * t_steps[0]
|
25 |
+
for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1
|
26 |
+
x_cur = x_next
|
27 |
+
|
28 |
+
# Increase noise temporarily.
|
29 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
30 |
+
t_hat = net.round_sigma(t_cur + gamma * t_cur)
|
31 |
+
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
32 |
+
|
33 |
+
# Euler step.
|
34 |
+
denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
|
35 |
+
d_cur = (x_hat - denoised) / t_hat
|
36 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
37 |
+
|
38 |
+
# Apply 2nd order correction.
|
39 |
+
if i < num_steps - 1:
|
40 |
+
denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
|
41 |
+
d_prime = (x_next - denoised) / t_next
|
42 |
+
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
43 |
+
|
44 |
+
return x_next
|
45 |
+
|
46 |
+
|
47 |
+
# ----------------------------------------------------------------------------
|
48 |
+
# Generalized ablation sampler, representing the superset of all sampling
|
49 |
+
# methods discussed in the paper.
|
50 |
+
|
51 |
+
def ablation_sampler(
|
52 |
+
net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
|
53 |
+
num_steps=18, sigma_min=None, sigma_max=None, rho=7,
|
54 |
+
solver='heun', discretization='edm', schedule='linear', scaling='none',
|
55 |
+
epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
|
56 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
57 |
+
):
|
58 |
+
assert solver in ['euler', 'heun']
|
59 |
+
assert discretization in ['vp', 've', 'iddpm', 'edm']
|
60 |
+
assert schedule in ['vp', 've', 'linear']
|
61 |
+
assert scaling in ['vp', 'none']
|
62 |
+
|
63 |
+
# Helper functions for VP & VE noise level schedules.
|
64 |
+
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
65 |
+
vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
|
66 |
+
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
|
67 |
+
sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
68 |
+
ve_sigma = lambda t: t.sqrt()
|
69 |
+
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
|
70 |
+
ve_sigma_inv = lambda sigma: sigma ** 2
|
71 |
+
|
72 |
+
# Select default noise level range based on the specified time step discretization.
|
73 |
+
if sigma_min is None:
|
74 |
+
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
|
75 |
+
sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
|
76 |
+
if sigma_max is None:
|
77 |
+
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
|
78 |
+
sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
|
79 |
+
|
80 |
+
# Adjust noise levels based on what's supported by the network.
|
81 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
82 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
83 |
+
|
84 |
+
# Compute corresponding betas for VP.
|
85 |
+
vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
|
86 |
+
vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
|
87 |
+
|
88 |
+
# Define time steps in terms of noise level.
|
89 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
90 |
+
if discretization == 'vp':
|
91 |
+
orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
|
92 |
+
sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
|
93 |
+
elif discretization == 've':
|
94 |
+
orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
|
95 |
+
sigma_steps = ve_sigma(orig_t_steps)
|
96 |
+
elif discretization == 'iddpm':
|
97 |
+
u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
|
98 |
+
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
|
99 |
+
for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
|
100 |
+
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
|
101 |
+
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
|
102 |
+
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
|
103 |
+
else:
|
104 |
+
assert discretization == 'edm'
|
105 |
+
sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
106 |
+
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
107 |
+
|
108 |
+
# Define noise level schedule.
|
109 |
+
if schedule == 'vp':
|
110 |
+
sigma = vp_sigma(vp_beta_d, vp_beta_min)
|
111 |
+
sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
|
112 |
+
sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
|
113 |
+
elif schedule == 've':
|
114 |
+
sigma = ve_sigma
|
115 |
+
sigma_deriv = ve_sigma_deriv
|
116 |
+
sigma_inv = ve_sigma_inv
|
117 |
+
else:
|
118 |
+
assert schedule == 'linear'
|
119 |
+
sigma = lambda t: t
|
120 |
+
sigma_deriv = lambda t: 1
|
121 |
+
sigma_inv = lambda sigma: sigma
|
122 |
+
|
123 |
+
# Define scaling schedule.
|
124 |
+
if scaling == 'vp':
|
125 |
+
s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
|
126 |
+
s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
|
127 |
+
else:
|
128 |
+
assert scaling == 'none'
|
129 |
+
s = lambda t: 1
|
130 |
+
s_deriv = lambda t: 0
|
131 |
+
|
132 |
+
# Compute final time steps based on the corresponding noise levels.
|
133 |
+
t_steps = sigma_inv(net.round_sigma(sigma_steps))
|
134 |
+
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
|
135 |
+
|
136 |
+
# Main sampling loop.
|
137 |
+
t_next = t_steps[0]
|
138 |
+
x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
|
139 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
140 |
+
x_cur = x_next
|
141 |
+
|
142 |
+
# Increase noise temporarily.
|
143 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
|
144 |
+
t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
|
145 |
+
x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
|
146 |
+
t_hat) * S_noise * randn_like(x_cur)
|
147 |
+
|
148 |
+
# Euler step.
|
149 |
+
h = t_next - t_hat
|
150 |
+
denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(
|
151 |
+
torch.float64)
|
152 |
+
d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
|
153 |
+
t_hat) / sigma(t_hat) * denoised
|
154 |
+
x_prime = x_hat + alpha * h * d_cur
|
155 |
+
t_prime = t_hat + alpha * h
|
156 |
+
|
157 |
+
# Apply 2nd order correction.
|
158 |
+
if solver == 'euler' or i == num_steps - 1:
|
159 |
+
x_next = x_hat + h * d_cur
|
160 |
+
else:
|
161 |
+
assert solver == 'heun'
|
162 |
+
denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(
|
163 |
+
torch.float64)
|
164 |
+
d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
|
165 |
+
t_prime) * s(t_prime) / sigma(t_prime) * denoised
|
166 |
+
x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
|
167 |
+
|
168 |
+
return x_next
|
DiT_VAE/diffusion/model/gaussian_diffusion.py
ADDED
@@ -0,0 +1,1006 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
|
7 |
+
import enum
|
8 |
+
import math
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch as th
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
15 |
+
|
16 |
+
|
17 |
+
def mean_flat(tensor):
|
18 |
+
"""
|
19 |
+
Take the mean over all non-batch dimensions.
|
20 |
+
"""
|
21 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
22 |
+
|
23 |
+
|
24 |
+
class ModelMeanType(enum.Enum):
|
25 |
+
"""
|
26 |
+
Which type of output the model predicts.
|
27 |
+
"""
|
28 |
+
|
29 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
30 |
+
START_X = enum.auto() # the model predicts x_0
|
31 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
32 |
+
|
33 |
+
|
34 |
+
class ModelVarType(enum.Enum):
|
35 |
+
"""
|
36 |
+
What is used as the model's output variance.
|
37 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
38 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
39 |
+
"""
|
40 |
+
|
41 |
+
LEARNED = enum.auto()
|
42 |
+
FIXED_SMALL = enum.auto()
|
43 |
+
FIXED_LARGE = enum.auto()
|
44 |
+
LEARNED_RANGE = enum.auto()
|
45 |
+
|
46 |
+
|
47 |
+
class LossType(enum.Enum):
|
48 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
49 |
+
RESCALED_MSE = (
|
50 |
+
enum.auto()
|
51 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
52 |
+
KL = enum.auto() # use the variational lower-bound
|
53 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
54 |
+
|
55 |
+
def is_vb(self):
|
56 |
+
return self in [LossType.KL, LossType.RESCALED_KL]
|
57 |
+
|
58 |
+
|
59 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
60 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
61 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
62 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
63 |
+
return betas
|
64 |
+
|
65 |
+
|
66 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
67 |
+
"""
|
68 |
+
This is the deprecated API for creating beta schedules.
|
69 |
+
See get_named_beta_schedule() for the new library of schedules.
|
70 |
+
"""
|
71 |
+
if beta_schedule == "quad":
|
72 |
+
betas = (
|
73 |
+
np.linspace(
|
74 |
+
beta_start ** 0.5,
|
75 |
+
beta_end ** 0.5,
|
76 |
+
num_diffusion_timesteps,
|
77 |
+
dtype=np.float64,
|
78 |
+
)
|
79 |
+
** 2
|
80 |
+
)
|
81 |
+
elif beta_schedule == "linear":
|
82 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
83 |
+
elif beta_schedule == "warmup10":
|
84 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
85 |
+
elif beta_schedule == "warmup50":
|
86 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
87 |
+
elif beta_schedule == "const":
|
88 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
89 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
90 |
+
betas = 1.0 / np.linspace(
|
91 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
raise NotImplementedError(beta_schedule)
|
95 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
96 |
+
return betas
|
97 |
+
|
98 |
+
|
99 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
100 |
+
"""
|
101 |
+
Get a pre-defined beta schedule for the given name.
|
102 |
+
The beta schedule library consists of beta schedules which remain similar
|
103 |
+
in the limit of num_diffusion_timesteps.
|
104 |
+
Beta schedules may be added, but should not be removed or changed once
|
105 |
+
they are committed to maintain backwards compatibility.
|
106 |
+
"""
|
107 |
+
if schedule_name == "linear":
|
108 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
109 |
+
# diffusion steps.
|
110 |
+
scale = 1000 / num_diffusion_timesteps
|
111 |
+
return get_beta_schedule(
|
112 |
+
"linear",
|
113 |
+
beta_start=scale * 0.0001,
|
114 |
+
beta_end=scale * 0.02,
|
115 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
116 |
+
)
|
117 |
+
elif schedule_name == "squaredcos_cap_v2":
|
118 |
+
return betas_for_alpha_bar(
|
119 |
+
num_diffusion_timesteps,
|
120 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
124 |
+
|
125 |
+
|
126 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
127 |
+
"""
|
128 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
129 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
130 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
131 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
132 |
+
produces the cumulative product of (1-beta) up to that
|
133 |
+
part of the diffusion process.
|
134 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
135 |
+
prevent singularities.
|
136 |
+
"""
|
137 |
+
betas = []
|
138 |
+
for i in range(num_diffusion_timesteps):
|
139 |
+
t1 = i / num_diffusion_timesteps
|
140 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
141 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
142 |
+
return np.array(betas)
|
143 |
+
|
144 |
+
|
145 |
+
class GaussianDiffusion:
|
146 |
+
"""
|
147 |
+
Utilities for training and sampling diffusion vae.
|
148 |
+
Original ported from this codebase:
|
149 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
150 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
151 |
+
starting at T and going to 1.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
*,
|
157 |
+
betas,
|
158 |
+
model_mean_type,
|
159 |
+
model_var_type,
|
160 |
+
loss_type,
|
161 |
+
snr=False,
|
162 |
+
return_startx=False,
|
163 |
+
):
|
164 |
+
|
165 |
+
self.model_mean_type = model_mean_type
|
166 |
+
self.model_var_type = model_var_type
|
167 |
+
self.loss_type = loss_type
|
168 |
+
self.snr = snr
|
169 |
+
self.return_startx = return_startx
|
170 |
+
|
171 |
+
# Use float64 for accuracy.
|
172 |
+
betas = np.array(betas, dtype=np.float64)
|
173 |
+
self.betas = betas
|
174 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
175 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
176 |
+
|
177 |
+
self.num_timesteps = int(betas.shape[0])
|
178 |
+
|
179 |
+
alphas = 1.0 - betas
|
180 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
181 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
182 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
183 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
184 |
+
|
185 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
186 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
187 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
188 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
189 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
190 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
191 |
+
|
192 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
193 |
+
self.posterior_variance = (
|
194 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
195 |
+
)
|
196 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
197 |
+
self.posterior_log_variance_clipped = np.log(
|
198 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
199 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
200 |
+
|
201 |
+
self.posterior_mean_coef1 = (
|
202 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
203 |
+
)
|
204 |
+
self.posterior_mean_coef2 = (
|
205 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
206 |
+
)
|
207 |
+
|
208 |
+
def q_mean_variance(self, x_start, t):
|
209 |
+
"""
|
210 |
+
Get the distribution q(x_t | x_0).
|
211 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
212 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
213 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
214 |
+
"""
|
215 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
216 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
217 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
218 |
+
return mean, variance, log_variance
|
219 |
+
|
220 |
+
def q_sample(self, x_start, t, noise=None):
|
221 |
+
"""
|
222 |
+
Diffuse the data for a given number of diffusion steps.
|
223 |
+
In other words, sample from q(x_t | x_0).
|
224 |
+
:param x_start: the initial data batch.
|
225 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
226 |
+
:param noise: if specified, the split-out normal noise.
|
227 |
+
:return: A noisy version of x_start.
|
228 |
+
"""
|
229 |
+
if noise is None:
|
230 |
+
noise = th.randn_like(x_start)
|
231 |
+
assert noise.shape == x_start.shape
|
232 |
+
return (
|
233 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
234 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
235 |
+
)
|
236 |
+
|
237 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
238 |
+
"""
|
239 |
+
Compute the mean and variance of the diffusion posterior:
|
240 |
+
q(x_{t-1} | x_t, x_0)
|
241 |
+
"""
|
242 |
+
assert x_start.shape == x_t.shape
|
243 |
+
posterior_mean = (
|
244 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
245 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
246 |
+
)
|
247 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
248 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
249 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
250 |
+
)
|
251 |
+
assert (
|
252 |
+
posterior_mean.shape[0]
|
253 |
+
== posterior_variance.shape[0]
|
254 |
+
== posterior_log_variance_clipped.shape[0]
|
255 |
+
== x_start.shape[0]
|
256 |
+
)
|
257 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
258 |
+
|
259 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
260 |
+
"""
|
261 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
262 |
+
the initial x, x_0.
|
263 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
264 |
+
as input.
|
265 |
+
:param x: the [N x C x ...] tensor at time t.
|
266 |
+
:param t: a 1-D Tensor of timesteps.
|
267 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
268 |
+
:param denoised_fn: if not None, a function which applies to the
|
269 |
+
x_start prediction before it is used to sample. Applies before
|
270 |
+
clip_denoised.
|
271 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
272 |
+
pass to the model. This can be used for conditioning.
|
273 |
+
:return: a dict with the following keys:
|
274 |
+
- 'mean': the model mean output.
|
275 |
+
- 'variance': the model variance output.
|
276 |
+
- 'log_variance': the log of 'variance'.
|
277 |
+
- 'pred_xstart': the prediction for x_0.
|
278 |
+
"""
|
279 |
+
if model_kwargs is None:
|
280 |
+
model_kwargs = {}
|
281 |
+
|
282 |
+
B, C = x.shape[:2]
|
283 |
+
assert t.shape == (B,)
|
284 |
+
model_output = model(x, t, **model_kwargs)
|
285 |
+
if isinstance(model_output, tuple):
|
286 |
+
model_output, extra = model_output
|
287 |
+
else:
|
288 |
+
extra = None
|
289 |
+
|
290 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
291 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
292 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
293 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
294 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
295 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
296 |
+
frac = (model_var_values + 1) / 2
|
297 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
298 |
+
model_variance = th.exp(model_log_variance)
|
299 |
+
elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]:
|
300 |
+
model_variance, model_log_variance = {
|
301 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
302 |
+
# to get a better decoder log likelihood.
|
303 |
+
ModelVarType.FIXED_LARGE: (
|
304 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
305 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
306 |
+
),
|
307 |
+
ModelVarType.FIXED_SMALL: (
|
308 |
+
self.posterior_variance,
|
309 |
+
self.posterior_log_variance_clipped,
|
310 |
+
),
|
311 |
+
}[self.model_var_type]
|
312 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
313 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
314 |
+
else:
|
315 |
+
model_variance = th.zeros_like(model_output)
|
316 |
+
model_log_variance = th.zeros_like(model_output)
|
317 |
+
|
318 |
+
def process_xstart(x):
|
319 |
+
if denoised_fn is not None:
|
320 |
+
x = denoised_fn(x)
|
321 |
+
return x.clamp(-1, 1) if clip_denoised else x
|
322 |
+
|
323 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
324 |
+
pred_xstart = process_xstart(model_output)
|
325 |
+
else:
|
326 |
+
pred_xstart = process_xstart(
|
327 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
328 |
+
)
|
329 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
330 |
+
|
331 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
332 |
+
return {
|
333 |
+
"mean": model_mean,
|
334 |
+
"variance": model_variance,
|
335 |
+
"log_variance": model_log_variance,
|
336 |
+
"pred_xstart": pred_xstart,
|
337 |
+
"extra": extra,
|
338 |
+
}
|
339 |
+
|
340 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
341 |
+
assert x_t.shape == eps.shape
|
342 |
+
return (
|
343 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
344 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
345 |
+
)
|
346 |
+
|
347 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
348 |
+
return (
|
349 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
350 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
351 |
+
|
352 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
353 |
+
"""
|
354 |
+
Compute the mean for the previous step, given a function cond_fn that
|
355 |
+
computes the gradient of a conditional log probability with respect to
|
356 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
357 |
+
condition on y.
|
358 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
359 |
+
"""
|
360 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
361 |
+
return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
362 |
+
|
363 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
364 |
+
"""
|
365 |
+
Compute what the p_mean_variance output would have been, should the
|
366 |
+
model's score function be conditioned by cond_fn.
|
367 |
+
See condition_mean() for details on cond_fn.
|
368 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
369 |
+
from Song et al (2020).
|
370 |
+
"""
|
371 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
372 |
+
|
373 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
374 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
375 |
+
|
376 |
+
out = p_mean_var.copy()
|
377 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
378 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
379 |
+
return out
|
380 |
+
|
381 |
+
def p_sample(
|
382 |
+
self,
|
383 |
+
model,
|
384 |
+
x,
|
385 |
+
t,
|
386 |
+
clip_denoised=True,
|
387 |
+
denoised_fn=None,
|
388 |
+
cond_fn=None,
|
389 |
+
model_kwargs=None,
|
390 |
+
):
|
391 |
+
"""
|
392 |
+
Sample x_{t-1} from the model at the given timestep.
|
393 |
+
:param model: the model to sample from.
|
394 |
+
:param x: the current tensor at x_{t-1}.
|
395 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
396 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
397 |
+
:param denoised_fn: if not None, a function which applies to the
|
398 |
+
x_start prediction before it is used to sample.
|
399 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
400 |
+
similarly to the model.
|
401 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
402 |
+
pass to the model. This can be used for conditioning.
|
403 |
+
:return: a dict containing the following keys:
|
404 |
+
- 'sample': a random sample from the model.
|
405 |
+
- 'pred_xstart': a prediction of x_0.
|
406 |
+
"""
|
407 |
+
out = self.p_mean_variance(
|
408 |
+
model,
|
409 |
+
x,
|
410 |
+
t,
|
411 |
+
clip_denoised=clip_denoised,
|
412 |
+
denoised_fn=denoised_fn,
|
413 |
+
model_kwargs=model_kwargs,
|
414 |
+
)
|
415 |
+
noise = th.randn_like(x)
|
416 |
+
nonzero_mask = (
|
417 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
418 |
+
) # no noise when t == 0
|
419 |
+
if cond_fn is not None:
|
420 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
421 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
422 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
423 |
+
|
424 |
+
def p_sample_loop(
|
425 |
+
self,
|
426 |
+
model,
|
427 |
+
shape,
|
428 |
+
noise=None,
|
429 |
+
clip_denoised=True,
|
430 |
+
denoised_fn=None,
|
431 |
+
cond_fn=None,
|
432 |
+
model_kwargs=None,
|
433 |
+
device=None,
|
434 |
+
progress=False,
|
435 |
+
):
|
436 |
+
"""
|
437 |
+
Generate samples from the model.
|
438 |
+
:param model: the model module.
|
439 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
440 |
+
:param noise: if specified, the noise from the encoder to sample.
|
441 |
+
Should be of the same shape as `shape`.
|
442 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
443 |
+
:param denoised_fn: if not None, a function which applies to the
|
444 |
+
x_start prediction before it is used to sample.
|
445 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
446 |
+
similarly to the model.
|
447 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
448 |
+
pass to the model. This can be used for conditioning.
|
449 |
+
:param device: if specified, the device to create the samples on.
|
450 |
+
If not specified, use a model parameter's device.
|
451 |
+
:param progress: if True, show a tqdm progress bar.
|
452 |
+
:return: a non-differentiable batch of samples.
|
453 |
+
"""
|
454 |
+
final = None
|
455 |
+
for sample in self.p_sample_loop_progressive(
|
456 |
+
model,
|
457 |
+
shape,
|
458 |
+
noise=noise,
|
459 |
+
clip_denoised=clip_denoised,
|
460 |
+
denoised_fn=denoised_fn,
|
461 |
+
cond_fn=cond_fn,
|
462 |
+
model_kwargs=model_kwargs,
|
463 |
+
device=device,
|
464 |
+
progress=progress,
|
465 |
+
):
|
466 |
+
final = sample
|
467 |
+
return final["sample"]
|
468 |
+
|
469 |
+
def p_sample_loop_progressive(
|
470 |
+
self,
|
471 |
+
model,
|
472 |
+
shape,
|
473 |
+
noise=None,
|
474 |
+
clip_denoised=True,
|
475 |
+
denoised_fn=None,
|
476 |
+
cond_fn=None,
|
477 |
+
model_kwargs=None,
|
478 |
+
device=None,
|
479 |
+
progress=False,
|
480 |
+
):
|
481 |
+
"""
|
482 |
+
Generate samples from the model and yield intermediate samples from
|
483 |
+
each timestep of diffusion.
|
484 |
+
Arguments are the same as p_sample_loop().
|
485 |
+
Returns a generator over dicts, where each dict is the return value of
|
486 |
+
p_sample().
|
487 |
+
"""
|
488 |
+
if device is None:
|
489 |
+
device = next(model.parameters()).device
|
490 |
+
assert isinstance(shape, (tuple, list))
|
491 |
+
img = noise if noise is not None else th.randn(*shape, device=device)
|
492 |
+
indices = list(range(self.num_timesteps))[::-1]
|
493 |
+
|
494 |
+
if progress:
|
495 |
+
# Lazy import so that we don't depend on tqdm.
|
496 |
+
from tqdm.auto import tqdm
|
497 |
+
|
498 |
+
indices = tqdm(indices)
|
499 |
+
|
500 |
+
for i in indices:
|
501 |
+
t = th.tensor([i] * shape[0], device=device)
|
502 |
+
with th.no_grad():
|
503 |
+
out = self.p_sample(
|
504 |
+
model,
|
505 |
+
img,
|
506 |
+
t,
|
507 |
+
clip_denoised=clip_denoised,
|
508 |
+
denoised_fn=denoised_fn,
|
509 |
+
cond_fn=cond_fn,
|
510 |
+
model_kwargs=model_kwargs,
|
511 |
+
)
|
512 |
+
yield out
|
513 |
+
img = out["sample"]
|
514 |
+
|
515 |
+
def ddim_sample(
|
516 |
+
self,
|
517 |
+
model,
|
518 |
+
x,
|
519 |
+
t,
|
520 |
+
clip_denoised=True,
|
521 |
+
denoised_fn=None,
|
522 |
+
cond_fn=None,
|
523 |
+
model_kwargs=None,
|
524 |
+
eta=0.0,
|
525 |
+
):
|
526 |
+
"""
|
527 |
+
Sample x_{t-1} from the model using DDIM.
|
528 |
+
Same usage as p_sample().
|
529 |
+
"""
|
530 |
+
out = self.p_mean_variance(
|
531 |
+
model,
|
532 |
+
x,
|
533 |
+
t,
|
534 |
+
clip_denoised=clip_denoised,
|
535 |
+
denoised_fn=denoised_fn,
|
536 |
+
model_kwargs=model_kwargs,
|
537 |
+
)
|
538 |
+
if cond_fn is not None:
|
539 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
540 |
+
|
541 |
+
# Usually our model outputs epsilon, but we re-derive it
|
542 |
+
# in case we used x_start or x_prev prediction.
|
543 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
544 |
+
|
545 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
546 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
547 |
+
sigma = (
|
548 |
+
eta
|
549 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
550 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
551 |
+
)
|
552 |
+
# Equation 12.
|
553 |
+
noise = th.randn_like(x)
|
554 |
+
mean_pred = (
|
555 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
556 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
557 |
+
)
|
558 |
+
nonzero_mask = (
|
559 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
560 |
+
) # no noise when t == 0
|
561 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
562 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
563 |
+
|
564 |
+
def ddim_reverse_sample(
|
565 |
+
self,
|
566 |
+
model,
|
567 |
+
x,
|
568 |
+
t,
|
569 |
+
clip_denoised=True,
|
570 |
+
denoised_fn=None,
|
571 |
+
cond_fn=None,
|
572 |
+
model_kwargs=None,
|
573 |
+
eta=0.0,
|
574 |
+
):
|
575 |
+
"""
|
576 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
577 |
+
"""
|
578 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
579 |
+
out = self.p_mean_variance(
|
580 |
+
model,
|
581 |
+
x,
|
582 |
+
t,
|
583 |
+
clip_denoised=clip_denoised,
|
584 |
+
denoised_fn=denoised_fn,
|
585 |
+
model_kwargs=model_kwargs,
|
586 |
+
)
|
587 |
+
if cond_fn is not None:
|
588 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
589 |
+
# Usually our model outputs epsilon, but we re-derive it
|
590 |
+
# in case we used x_start or x_prev prediction.
|
591 |
+
eps = (
|
592 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
593 |
+
- out["pred_xstart"]
|
594 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
595 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
596 |
+
|
597 |
+
# Equation 12. reversed
|
598 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
599 |
+
|
600 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
601 |
+
|
602 |
+
def ddim_sample_loop(
|
603 |
+
self,
|
604 |
+
model,
|
605 |
+
shape,
|
606 |
+
noise=None,
|
607 |
+
clip_denoised=True,
|
608 |
+
denoised_fn=None,
|
609 |
+
cond_fn=None,
|
610 |
+
model_kwargs=None,
|
611 |
+
device=None,
|
612 |
+
progress=False,
|
613 |
+
eta=0.0,
|
614 |
+
):
|
615 |
+
"""
|
616 |
+
Generate samples from the model using DDIM.
|
617 |
+
Same usage as p_sample_loop().
|
618 |
+
"""
|
619 |
+
final = None
|
620 |
+
for sample in self.ddim_sample_loop_progressive(
|
621 |
+
model,
|
622 |
+
shape,
|
623 |
+
noise=noise,
|
624 |
+
clip_denoised=clip_denoised,
|
625 |
+
denoised_fn=denoised_fn,
|
626 |
+
cond_fn=cond_fn,
|
627 |
+
model_kwargs=model_kwargs,
|
628 |
+
device=device,
|
629 |
+
progress=progress,
|
630 |
+
eta=eta,
|
631 |
+
):
|
632 |
+
final = sample
|
633 |
+
return final["sample"]
|
634 |
+
|
635 |
+
def ddim_sample_loop_progressive(
|
636 |
+
self,
|
637 |
+
model,
|
638 |
+
shape,
|
639 |
+
noise=None,
|
640 |
+
clip_denoised=True,
|
641 |
+
denoised_fn=None,
|
642 |
+
cond_fn=None,
|
643 |
+
model_kwargs=None,
|
644 |
+
device=None,
|
645 |
+
progress=False,
|
646 |
+
eta=0.0,
|
647 |
+
):
|
648 |
+
"""
|
649 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
650 |
+
each timestep of DDIM.
|
651 |
+
Same usage as p_sample_loop_progressive().
|
652 |
+
"""
|
653 |
+
if device is None:
|
654 |
+
device = next(model.parameters()).device
|
655 |
+
assert isinstance(shape, (tuple, list))
|
656 |
+
img = noise if noise is not None else th.randn(*shape, device=device)
|
657 |
+
indices = list(range(self.num_timesteps))[::-1]
|
658 |
+
|
659 |
+
if progress:
|
660 |
+
# Lazy import so that we don't depend on tqdm.
|
661 |
+
from tqdm.auto import tqdm
|
662 |
+
|
663 |
+
indices = tqdm(indices)
|
664 |
+
|
665 |
+
for i in indices:
|
666 |
+
t = th.tensor([i] * shape[0], device=device)
|
667 |
+
with th.no_grad():
|
668 |
+
out = self.ddim_sample(
|
669 |
+
model,
|
670 |
+
img,
|
671 |
+
t,
|
672 |
+
clip_denoised=clip_denoised,
|
673 |
+
denoised_fn=denoised_fn,
|
674 |
+
cond_fn=cond_fn,
|
675 |
+
model_kwargs=model_kwargs,
|
676 |
+
eta=eta,
|
677 |
+
)
|
678 |
+
yield out
|
679 |
+
img = out["sample"]
|
680 |
+
|
681 |
+
def _vb_terms_bpd(
|
682 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
683 |
+
):
|
684 |
+
"""
|
685 |
+
Get a term for the variational lower-bound.
|
686 |
+
The resulting units are bits (rather than nats, as one might expect).
|
687 |
+
This allows for comparison to other papers.
|
688 |
+
:return: a dict with the following keys:
|
689 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
690 |
+
- 'pred_xstart': the x_0 predictions.
|
691 |
+
"""
|
692 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
693 |
+
x_start=x_start, x_t=x_t, t=t
|
694 |
+
)
|
695 |
+
out = self.p_mean_variance(
|
696 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
697 |
+
)
|
698 |
+
kl = normal_kl(
|
699 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
700 |
+
)
|
701 |
+
kl = mean_flat(kl) / np.log(2.0)
|
702 |
+
|
703 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
704 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
705 |
+
)
|
706 |
+
assert decoder_nll.shape == x_start.shape
|
707 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
708 |
+
|
709 |
+
# At the first timestep return the decoder NLL,
|
710 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
711 |
+
output = th.where((t == 0), decoder_nll, kl)
|
712 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
713 |
+
|
714 |
+
def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
|
715 |
+
"""
|
716 |
+
Compute training losses for a single timestep.
|
717 |
+
:param model: the model to evaluate loss on.
|
718 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
719 |
+
:param t: a batch of timestep indices.
|
720 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
721 |
+
pass to the model. This can be used for conditioning.
|
722 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
723 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
724 |
+
Some mean or variance settings may also have other keys.
|
725 |
+
"""
|
726 |
+
t = timestep
|
727 |
+
if model_kwargs is None:
|
728 |
+
model_kwargs = {}
|
729 |
+
if skip_noise:
|
730 |
+
x_t = x_start
|
731 |
+
else:
|
732 |
+
if noise is None:
|
733 |
+
noise = th.randn_like(x_start)
|
734 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
735 |
+
|
736 |
+
terms = {}
|
737 |
+
|
738 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
739 |
+
terms["loss"] = self._vb_terms_bpd(
|
740 |
+
model=model,
|
741 |
+
x_start=x_start,
|
742 |
+
x_t=x_t,
|
743 |
+
t=t,
|
744 |
+
clip_denoised=False,
|
745 |
+
model_kwargs=model_kwargs,
|
746 |
+
)["output"]
|
747 |
+
if self.loss_type == LossType.RESCALED_KL:
|
748 |
+
terms["loss"] *= self.num_timesteps
|
749 |
+
elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
|
750 |
+
model_output = model(x_t, t, **model_kwargs)
|
751 |
+
if isinstance(model_output, dict) and model_output.get('x', None) is not None:
|
752 |
+
output = model_output['x']
|
753 |
+
else:
|
754 |
+
output = model_output
|
755 |
+
|
756 |
+
if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
|
757 |
+
return self._extracted_from_training_losses_diffusers(x_t, output, t)
|
758 |
+
# self.model_var_type = ModelVarType.LEARNED_RANGE:4
|
759 |
+
if self.model_var_type in [
|
760 |
+
ModelVarType.LEARNED,
|
761 |
+
ModelVarType.LEARNED_RANGE,
|
762 |
+
]:
|
763 |
+
B, C = x_t.shape[:2]
|
764 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
765 |
+
output, model_var_values = th.split(output, C, dim=1)
|
766 |
+
# Learn the variance using the variational bound, but don't let it affect our mean prediction.
|
767 |
+
frozen_out = th.cat([output.detach(), model_var_values], dim=1)
|
768 |
+
# vb variational bound
|
769 |
+
terms["vb"] = self._vb_terms_bpd(
|
770 |
+
model=lambda *args, r=frozen_out, **kwargs: r,
|
771 |
+
x_start=x_start,
|
772 |
+
x_t=x_t,
|
773 |
+
t=t,
|
774 |
+
clip_denoised=False,
|
775 |
+
)["output"]
|
776 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
777 |
+
# Divide by 1000 for equivalence with initial implementation.
|
778 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
779 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
780 |
+
|
781 |
+
target = {
|
782 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
783 |
+
x_start=x_start, x_t=x_t, t=t
|
784 |
+
)[0],
|
785 |
+
ModelMeanType.START_X: x_start,
|
786 |
+
ModelMeanType.EPSILON: noise,
|
787 |
+
}[self.model_mean_type]
|
788 |
+
assert output.shape == target.shape == x_start.shape
|
789 |
+
if self.snr:
|
790 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
791 |
+
pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
|
792 |
+
pred_startx = output
|
793 |
+
elif self.model_mean_type == ModelMeanType.EPSILON:
|
794 |
+
pred_noise = output
|
795 |
+
pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
|
796 |
+
# terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
|
797 |
+
# terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
|
798 |
+
|
799 |
+
t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
|
800 |
+
# best
|
801 |
+
target = th.where(t > 249, noise, x_start)
|
802 |
+
output = th.where(t > 249, pred_noise, pred_startx)
|
803 |
+
loss = (target - output) ** 2
|
804 |
+
if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0:
|
805 |
+
assert 'mask' in model_output
|
806 |
+
loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1)
|
807 |
+
mask = model_output['mask']
|
808 |
+
unmask = 1 - mask
|
809 |
+
terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1)
|
810 |
+
if model_kwargs['mask_loss_coef'] > 0:
|
811 |
+
terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1)
|
812 |
+
else:
|
813 |
+
terms["mse"] = mean_flat(loss)
|
814 |
+
terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
|
815 |
+
if "mae" in terms:
|
816 |
+
terms["loss"] = terms["loss"] + terms["mae"]
|
817 |
+
else:
|
818 |
+
raise NotImplementedError(self.loss_type)
|
819 |
+
|
820 |
+
return terms
|
821 |
+
|
822 |
+
def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
|
823 |
+
"""
|
824 |
+
Compute training losses for a single timestep.
|
825 |
+
:param model: the model to evaluate loss on.
|
826 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
827 |
+
:param t: a batch of timestep indices.
|
828 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
829 |
+
pass to the model. This can be used for conditioning.
|
830 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
831 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
832 |
+
Some mean or variance settings may also have other keys.
|
833 |
+
"""
|
834 |
+
t = timestep
|
835 |
+
if model_kwargs is None:
|
836 |
+
model_kwargs = {}
|
837 |
+
if skip_noise:
|
838 |
+
x_t = x_start
|
839 |
+
else:
|
840 |
+
if noise is None:
|
841 |
+
noise = th.randn_like(x_start)
|
842 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
843 |
+
|
844 |
+
terms = {}
|
845 |
+
|
846 |
+
if self.loss_type in [LossType.KL, LossType.RESCALED_KL]:
|
847 |
+
terms["loss"] = self._vb_terms_bpd(
|
848 |
+
model=model,
|
849 |
+
x_start=x_start,
|
850 |
+
x_t=x_t,
|
851 |
+
t=t,
|
852 |
+
clip_denoised=False,
|
853 |
+
model_kwargs=model_kwargs,
|
854 |
+
)["output"]
|
855 |
+
if self.loss_type == LossType.RESCALED_KL:
|
856 |
+
terms["loss"] *= self.num_timesteps
|
857 |
+
elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
|
858 |
+
output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0]
|
859 |
+
|
860 |
+
if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
|
861 |
+
return self._extracted_from_training_losses_diffusers(x_t, output, t)
|
862 |
+
|
863 |
+
if self.model_var_type in [
|
864 |
+
ModelVarType.LEARNED,
|
865 |
+
ModelVarType.LEARNED_RANGE,
|
866 |
+
]:
|
867 |
+
B, C = x_t.shape[:2]
|
868 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
869 |
+
output, model_var_values = th.split(output, C, dim=1)
|
870 |
+
# Learn the variance using the variational bound, but don't let it affect our mean prediction.
|
871 |
+
frozen_out = th.cat([output.detach(), model_var_values], dim=1)
|
872 |
+
terms["vb"] = self._vb_terms_bpd(
|
873 |
+
model=lambda *args, r=frozen_out, **kwargs: r,
|
874 |
+
x_start=x_start,
|
875 |
+
x_t=x_t,
|
876 |
+
t=t,
|
877 |
+
clip_denoised=False,
|
878 |
+
)["output"]
|
879 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
880 |
+
# Divide by 1000 for equivalence with initial implementation.
|
881 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
882 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
883 |
+
|
884 |
+
target = {
|
885 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
886 |
+
x_start=x_start, x_t=x_t, t=t
|
887 |
+
)[0],
|
888 |
+
ModelMeanType.START_X: x_start,
|
889 |
+
ModelMeanType.EPSILON: noise,
|
890 |
+
}[self.model_mean_type]
|
891 |
+
assert output.shape == target.shape == x_start.shape
|
892 |
+
if self.snr:
|
893 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
894 |
+
pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
|
895 |
+
pred_startx = output
|
896 |
+
elif self.model_mean_type == ModelMeanType.EPSILON:
|
897 |
+
pred_noise = output
|
898 |
+
pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
|
899 |
+
# terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
|
900 |
+
# terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
|
901 |
+
|
902 |
+
t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
|
903 |
+
# best
|
904 |
+
target = th.where(t > 249, noise, x_start)
|
905 |
+
output = th.where(t > 249, pred_noise, pred_startx)
|
906 |
+
loss = (target - output) ** 2
|
907 |
+
terms["mse"] = mean_flat(loss)
|
908 |
+
terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
|
909 |
+
if "mae" in terms:
|
910 |
+
terms["loss"] = terms["loss"] + terms["mae"]
|
911 |
+
else:
|
912 |
+
raise NotImplementedError(self.loss_type)
|
913 |
+
|
914 |
+
return terms
|
915 |
+
|
916 |
+
def _extracted_from_training_losses_diffusers(self, x_t, output, t):
|
917 |
+
B, C = x_t.shape[:2]
|
918 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
919 |
+
output = th.split(output, C, dim=1)[0]
|
920 |
+
return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
|
921 |
+
|
922 |
+
def _prior_bpd(self, x_start):
|
923 |
+
"""
|
924 |
+
Get the prior KL term for the variational lower-bound, measured in
|
925 |
+
bits-per-dim.
|
926 |
+
This term can't be optimized, as it only depends on the encoder.
|
927 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
928 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
929 |
+
"""
|
930 |
+
batch_size = x_start.shape[0]
|
931 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
932 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
933 |
+
kl_prior = normal_kl(
|
934 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
935 |
+
)
|
936 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
937 |
+
|
938 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
939 |
+
"""
|
940 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
941 |
+
as well as other related quantities.
|
942 |
+
:param model: the model to evaluate loss on.
|
943 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
944 |
+
:param clip_denoised: if True, clip denoised samples.
|
945 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
946 |
+
pass to the model. This can be used for conditioning.
|
947 |
+
:return: a dict containing the following keys:
|
948 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
949 |
+
- prior_bpd: the prior term in the lower-bound.
|
950 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
951 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
952 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
953 |
+
"""
|
954 |
+
device = x_start.device
|
955 |
+
batch_size = x_start.shape[0]
|
956 |
+
|
957 |
+
vb = []
|
958 |
+
xstart_mse = []
|
959 |
+
mse = []
|
960 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
961 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
962 |
+
noise = th.randn_like(x_start)
|
963 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
964 |
+
# Calculate VLB term at the current timestep
|
965 |
+
with th.no_grad():
|
966 |
+
out = self._vb_terms_bpd(
|
967 |
+
model,
|
968 |
+
x_start=x_start,
|
969 |
+
x_t=x_t,
|
970 |
+
t=t_batch,
|
971 |
+
clip_denoised=clip_denoised,
|
972 |
+
model_kwargs=model_kwargs,
|
973 |
+
)
|
974 |
+
vb.append(out["output"])
|
975 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
976 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
977 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
978 |
+
|
979 |
+
vb = th.stack(vb, dim=1)
|
980 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
981 |
+
mse = th.stack(mse, dim=1)
|
982 |
+
|
983 |
+
prior_bpd = self._prior_bpd(x_start)
|
984 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
985 |
+
return {
|
986 |
+
"total_bpd": total_bpd,
|
987 |
+
"prior_bpd": prior_bpd,
|
988 |
+
"vb": vb,
|
989 |
+
"xstart_mse": xstart_mse,
|
990 |
+
"mse": mse,
|
991 |
+
}
|
992 |
+
|
993 |
+
|
994 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
995 |
+
"""
|
996 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
997 |
+
:param arr: the 1-D numpy array.
|
998 |
+
:param timesteps: a tensor of indices into the array to extract.
|
999 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
1000 |
+
dimension equal to the length of timesteps.
|
1001 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
1002 |
+
"""
|
1003 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
1004 |
+
while len(res.shape) < len(broadcast_shape):
|
1005 |
+
res = res[..., None]
|
1006 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
DiT_VAE/diffusion/model/hed.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
2 |
+
# Please use this implementation in your products
|
3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
5 |
+
# Different from official vae and other implementations, this is an RGB-input model (rather than BGR)
|
6 |
+
# and in this way it works better for gradio's RGB protocol
|
7 |
+
import sys
|
8 |
+
from pathlib import Path
|
9 |
+
current_file_path = Path(__file__).resolve()
|
10 |
+
sys.path.insert(0, str(current_file_path.parent.parent.parent))
|
11 |
+
from torch import nn
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from torchvision import transforms as T
|
15 |
+
from tqdm import tqdm
|
16 |
+
from torch.utils.data import Dataset, DataLoader
|
17 |
+
import json
|
18 |
+
from PIL import Image
|
19 |
+
import torchvision.transforms.functional as TF
|
20 |
+
from accelerate import Accelerator
|
21 |
+
from diffusers.models import AutoencoderKL
|
22 |
+
import os
|
23 |
+
|
24 |
+
image_resize = 1024
|
25 |
+
|
26 |
+
|
27 |
+
class DoubleConvBlock(nn.Module):
|
28 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
29 |
+
super().__init__()
|
30 |
+
self.convs = torch.nn.Sequential()
|
31 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
32 |
+
for i in range(1, layer_number):
|
33 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
34 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
35 |
+
|
36 |
+
def forward(self, x, down_sampling=False):
|
37 |
+
h = x
|
38 |
+
if down_sampling:
|
39 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
40 |
+
for conv in self.convs:
|
41 |
+
h = conv(h)
|
42 |
+
h = torch.nn.functional.relu(h)
|
43 |
+
return h, self.projection(h)
|
44 |
+
|
45 |
+
|
46 |
+
class ControlNetHED_Apache2(nn.Module):
|
47 |
+
def __init__(self):
|
48 |
+
super().__init__()
|
49 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
50 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
51 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
52 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
53 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
54 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
h = x - self.norm
|
58 |
+
h, projection1 = self.block1(h)
|
59 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
60 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
61 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
62 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
63 |
+
return projection1, projection2, projection3, projection4, projection5
|
64 |
+
|
65 |
+
|
66 |
+
class InternData(Dataset):
|
67 |
+
def __init__(self):
|
68 |
+
####
|
69 |
+
with open('data/InternData/partition/data_info.json', 'r') as f:
|
70 |
+
self.j = json.load(f)
|
71 |
+
self.transform = T.Compose([
|
72 |
+
T.Lambda(lambda img: img.convert('RGB')),
|
73 |
+
T.Resize(image_resize), # Image.BICUBIC
|
74 |
+
T.CenterCrop(image_resize),
|
75 |
+
T.ToTensor(),
|
76 |
+
])
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
return len(self.j)
|
80 |
+
|
81 |
+
def getdata(self, idx):
|
82 |
+
|
83 |
+
path = self.j[idx]['path']
|
84 |
+
image = Image.open("data/InternImgs/" + path)
|
85 |
+
image = self.transform(image)
|
86 |
+
return image, path
|
87 |
+
|
88 |
+
def __getitem__(self, idx):
|
89 |
+
for i in range(20):
|
90 |
+
try:
|
91 |
+
data = self.getdata(idx)
|
92 |
+
return data
|
93 |
+
except Exception as e:
|
94 |
+
print(f"Error details: {str(e)}")
|
95 |
+
idx = np.random.randint(len(self))
|
96 |
+
raise RuntimeError('Too many bad data.')
|
97 |
+
|
98 |
+
class HEDdetector(nn.Module):
|
99 |
+
def __init__(self, feature=True, vae=None):
|
100 |
+
super().__init__()
|
101 |
+
self.model = ControlNetHED_Apache2()
|
102 |
+
self.model.load_state_dict(torch.load('output/pretrained_models/ControlNetHED.pth', map_location='cpu'))
|
103 |
+
self.model.eval()
|
104 |
+
self.model.requires_grad_(False)
|
105 |
+
if feature:
|
106 |
+
if vae is None:
|
107 |
+
self.vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema")
|
108 |
+
else:
|
109 |
+
self.vae = vae
|
110 |
+
self.vae.eval()
|
111 |
+
self.vae.requires_grad_(False)
|
112 |
+
else:
|
113 |
+
self.vae = None
|
114 |
+
|
115 |
+
def forward(self, input_image):
|
116 |
+
B, C, H, W = input_image.shape
|
117 |
+
with torch.inference_mode():
|
118 |
+
edges = self.model(input_image * 255.)
|
119 |
+
edges = torch.cat([TF.resize(e, [H, W]) for e in edges], dim=1)
|
120 |
+
edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1, keepdim=True)))
|
121 |
+
edge.clip_(0, 1)
|
122 |
+
if self.vae:
|
123 |
+
edge = TF.normalize(edge, [.5], [.5])
|
124 |
+
edge = edge.repeat(1, 3, 1, 1)
|
125 |
+
posterior = self.vae.encode(edge).latent_dist
|
126 |
+
edge = torch.cat([posterior.mean, posterior.std], dim=1).cpu().numpy()
|
127 |
+
return edge
|
128 |
+
|
129 |
+
|
130 |
+
def main():
|
131 |
+
dataset = InternData()
|
132 |
+
dataloader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=8, pin_memory=True)
|
133 |
+
hed = HEDdetector()
|
134 |
+
|
135 |
+
accelerator = Accelerator()
|
136 |
+
hed, dataloader = accelerator.prepare(hed, dataloader)
|
137 |
+
|
138 |
+
|
139 |
+
for img, path in tqdm(dataloader):
|
140 |
+
out = hed(img.cuda())
|
141 |
+
for p, o in zip(path, out):
|
142 |
+
save = f'data/InternalData/hed_feature_{image_resize}/' + p.replace('.png', '.npz')
|
143 |
+
if os.path.exists(save):
|
144 |
+
continue
|
145 |
+
os.makedirs(os.path.dirname(save), exist_ok=True)
|
146 |
+
np.savez_compressed(save, o)
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
main()
|
DiT_VAE/diffusion/model/image_embedding.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoImageProcessor, Dinov2Model
|
3 |
+
from PIL import Image
|
4 |
+
import requests
|
5 |
+
|
6 |
+
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
7 |
+
# image = Image.open(requests.get(url, stream=True).raw)
|
8 |
+
#
|
9 |
+
# processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
10 |
+
# model = AutoModel.from_pretrained('facebook/dinov2-base')
|
11 |
+
#
|
12 |
+
# inputs = processor(images=image, return_tensors="pt")
|
13 |
+
# outputs = model(**inputs)
|
14 |
+
# last_hidden_states = outputs[0]
|
15 |
+
|
DiT_VAE/diffusion/model/nets/PixArt_blocks.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from timm.models.vision_transformer import Mlp, Attention as Attention_
|
15 |
+
from einops import rearrange
|
16 |
+
import xformers.ops
|
17 |
+
|
18 |
+
|
19 |
+
def modulate(x, shift, scale):
|
20 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
21 |
+
|
22 |
+
|
23 |
+
def t2i_modulate(x, shift, scale):
|
24 |
+
return x * (1 + scale) + shift
|
25 |
+
|
26 |
+
|
27 |
+
class MultiHeadCrossAttention(nn.Module):
|
28 |
+
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., **block_kwargs):
|
29 |
+
super(MultiHeadCrossAttention, self).__init__()
|
30 |
+
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
31 |
+
|
32 |
+
self.d_model = d_model
|
33 |
+
self.num_heads = num_heads
|
34 |
+
self.head_dim = d_model // num_heads
|
35 |
+
|
36 |
+
self.q_linear = nn.Linear(d_model, d_model)
|
37 |
+
self.kv_linear = nn.Linear(d_model, d_model * 2)
|
38 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
39 |
+
self.proj = nn.Linear(d_model, d_model)
|
40 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
41 |
+
|
42 |
+
def forward(self, x, cond, mask=None):
|
43 |
+
# query: img tokens; key/value: condition; mask: if padding tokens
|
44 |
+
B, N, C = x.shape
|
45 |
+
|
46 |
+
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
47 |
+
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
48 |
+
|
49 |
+
k, v = kv.unbind(2)
|
50 |
+
attn_bias = None
|
51 |
+
if mask is not None:
|
52 |
+
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
|
53 |
+
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
54 |
+
x = x.view(B, -1, C)
|
55 |
+
x = self.proj(x)
|
56 |
+
x = self.proj_drop(x)
|
57 |
+
|
58 |
+
# q = self.q_linear(x).reshape(B, -1, self.num_heads, self.head_dim)
|
59 |
+
# kv = self.kv_linear(cond).reshape(B, -1, 2, self.num_heads, self.head_dim)
|
60 |
+
# k, v = kv.unbind(2)
|
61 |
+
# attn_bias = None
|
62 |
+
# if mask is not None:
|
63 |
+
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
64 |
+
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
|
65 |
+
# x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
66 |
+
# x = x.contiguous().reshape(B, -1, C)
|
67 |
+
# x = self.proj(x)
|
68 |
+
# x = self.proj_drop(x)
|
69 |
+
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class WindowAttention(Attention_):
|
74 |
+
"""Multi-head Attention block with relative position embeddings."""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
dim,
|
79 |
+
num_heads=8,
|
80 |
+
qkv_bias=True,
|
81 |
+
use_rel_pos=False,
|
82 |
+
rel_pos_zero_init=True,
|
83 |
+
input_size=None,
|
84 |
+
**block_kwargs,
|
85 |
+
):
|
86 |
+
"""
|
87 |
+
Args:
|
88 |
+
dim (int): Number of input channels.
|
89 |
+
num_heads (int): Number of attention heads.
|
90 |
+
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
91 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
92 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
93 |
+
input_size (int or None): Input resolution for calculating the relative positional
|
94 |
+
parameter size.
|
95 |
+
"""
|
96 |
+
super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs)
|
97 |
+
|
98 |
+
self.use_rel_pos = use_rel_pos
|
99 |
+
if self.use_rel_pos:
|
100 |
+
# initialize relative positional embeddings
|
101 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
|
102 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
|
103 |
+
|
104 |
+
if not rel_pos_zero_init:
|
105 |
+
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
|
106 |
+
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
|
107 |
+
|
108 |
+
def forward(self, x, mask=None):
|
109 |
+
B, N, C = x.shape
|
110 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
111 |
+
q, k, v = qkv.unbind(2)
|
112 |
+
if use_fp32_attention := getattr(self, 'fp32_attention', False):
|
113 |
+
q, k, v = q.float(), k.float(), v.float()
|
114 |
+
|
115 |
+
attn_bias = None
|
116 |
+
if mask is not None:
|
117 |
+
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
118 |
+
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
|
119 |
+
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
120 |
+
|
121 |
+
x = x.view(B, N, C)
|
122 |
+
x = self.proj(x)
|
123 |
+
x = self.proj_drop(x)
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
#################################################################################
|
128 |
+
# AMP attention with fp32 softmax to fix loss NaN problem during training #
|
129 |
+
#################################################################################
|
130 |
+
class Attention(Attention_):
|
131 |
+
def forward(self, x):
|
132 |
+
B, N, C = x.shape
|
133 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
134 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
135 |
+
use_fp32_attention = getattr(self, 'fp32_attention', False)
|
136 |
+
if use_fp32_attention:
|
137 |
+
q, k = q.float(), k.float()
|
138 |
+
with torch.cuda.amp.autocast(enabled=not use_fp32_attention):
|
139 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
140 |
+
attn = attn.softmax(dim=-1)
|
141 |
+
|
142 |
+
attn = self.attn_drop(attn)
|
143 |
+
|
144 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
145 |
+
x = self.proj(x)
|
146 |
+
x = self.proj_drop(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
class AttentionTest(Attention_):
|
150 |
+
def forward(self, x, mask=None):
|
151 |
+
B, N, C = x.shape
|
152 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
153 |
+
q, k, v = qkv.unbind(2)
|
154 |
+
|
155 |
+
|
156 |
+
attn_bias = None
|
157 |
+
if mask is not None:
|
158 |
+
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
159 |
+
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
|
160 |
+
x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
|
161 |
+
|
162 |
+
x = x.view(B, N, C)
|
163 |
+
x = self.proj(x)
|
164 |
+
x = self.proj_drop(x)
|
165 |
+
return x
|
166 |
+
|
167 |
+
class FinalLayer(nn.Module):
|
168 |
+
"""
|
169 |
+
The final layer of PixArt.
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
173 |
+
super().__init__()
|
174 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
175 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
176 |
+
self.adaLN_modulation = nn.Sequential(
|
177 |
+
nn.SiLU(),
|
178 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
179 |
+
)
|
180 |
+
|
181 |
+
def forward(self, x, c):
|
182 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
183 |
+
x = modulate(self.norm_final(x), shift, scale)
|
184 |
+
x = self.linear(x)
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class T2IFinalLayer(nn.Module):
|
189 |
+
"""
|
190 |
+
The final layer of PixArt.
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
194 |
+
super().__init__()
|
195 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
196 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
197 |
+
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
|
198 |
+
self.out_channels = out_channels
|
199 |
+
|
200 |
+
def forward(self, x, t):
|
201 |
+
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
202 |
+
x = t2i_modulate(self.norm_final(x), shift, scale)
|
203 |
+
x = self.linear(x)
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class MaskFinalLayer(nn.Module):
|
208 |
+
"""
|
209 |
+
The final layer of PixArt.
|
210 |
+
"""
|
211 |
+
|
212 |
+
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
|
213 |
+
super().__init__()
|
214 |
+
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
215 |
+
self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
216 |
+
self.adaLN_modulation = nn.Sequential(
|
217 |
+
nn.SiLU(),
|
218 |
+
nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
|
219 |
+
)
|
220 |
+
|
221 |
+
def forward(self, x, t):
|
222 |
+
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
223 |
+
x = modulate(self.norm_final(x), shift, scale)
|
224 |
+
x = self.linear(x)
|
225 |
+
return x
|
226 |
+
|
227 |
+
|
228 |
+
class DecoderLayer(nn.Module):
|
229 |
+
"""
|
230 |
+
The final layer of PixArt.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(self, hidden_size, decoder_hidden_size):
|
234 |
+
super().__init__()
|
235 |
+
self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
236 |
+
self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
|
237 |
+
self.adaLN_modulation = nn.Sequential(
|
238 |
+
nn.SiLU(),
|
239 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
240 |
+
)
|
241 |
+
|
242 |
+
def forward(self, x, t):
|
243 |
+
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
244 |
+
x = modulate(self.norm_decoder(x), shift, scale)
|
245 |
+
x = self.linear(x)
|
246 |
+
return x
|
247 |
+
|
248 |
+
|
249 |
+
#################################################################################
|
250 |
+
# Embedding Layers for Timesteps and Class Labels #
|
251 |
+
#################################################################################
|
252 |
+
class TimestepEmbedder(nn.Module):
|
253 |
+
"""
|
254 |
+
Embeds scalar timesteps into vector representations.
|
255 |
+
"""
|
256 |
+
|
257 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
258 |
+
super().__init__()
|
259 |
+
self.mlp = nn.Sequential(
|
260 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
261 |
+
nn.SiLU(),
|
262 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
263 |
+
)
|
264 |
+
self.frequency_embedding_size = frequency_embedding_size
|
265 |
+
|
266 |
+
@staticmethod
|
267 |
+
def timestep_embedding(t, dim, max_period=10000):
|
268 |
+
"""
|
269 |
+
Create sinusoidal timestep embeddings.
|
270 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
271 |
+
These may be fractional.
|
272 |
+
:param dim: the dimension of the output.
|
273 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
274 |
+
:return: an (N, D) Tensor of positional embeddings.
|
275 |
+
"""
|
276 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
277 |
+
half = dim // 2
|
278 |
+
freqs = torch.exp(
|
279 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
280 |
+
args = t[:, None].float() * freqs[None]
|
281 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
282 |
+
if dim % 2:
|
283 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
284 |
+
return embedding
|
285 |
+
|
286 |
+
def forward(self, t):
|
287 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(self.dtype)
|
288 |
+
return self.mlp(t_freq)
|
289 |
+
|
290 |
+
@property
|
291 |
+
def dtype(self):
|
292 |
+
# 返回模型参数的数据类型
|
293 |
+
return next(self.parameters()).dtype
|
294 |
+
|
295 |
+
|
296 |
+
class SizeEmbedder(TimestepEmbedder):
|
297 |
+
"""
|
298 |
+
Embeds scalar timesteps into vector representations.
|
299 |
+
"""
|
300 |
+
|
301 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
302 |
+
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
|
303 |
+
self.mlp = nn.Sequential(
|
304 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
305 |
+
nn.SiLU(),
|
306 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
307 |
+
)
|
308 |
+
self.frequency_embedding_size = frequency_embedding_size
|
309 |
+
self.outdim = hidden_size
|
310 |
+
|
311 |
+
def forward(self, s, bs):
|
312 |
+
if s.ndim == 1:
|
313 |
+
s = s[:, None]
|
314 |
+
assert s.ndim == 2
|
315 |
+
if s.shape[0] != bs:
|
316 |
+
s = s.repeat(bs // s.shape[0], 1)
|
317 |
+
assert s.shape[0] == bs
|
318 |
+
b, dims = s.shape[0], s.shape[1]
|
319 |
+
s = rearrange(s, "b d -> (b d)")
|
320 |
+
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
|
321 |
+
s_emb = self.mlp(s_freq)
|
322 |
+
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
323 |
+
return s_emb
|
324 |
+
|
325 |
+
@property
|
326 |
+
def dtype(self):
|
327 |
+
# 返回模型参数的数据类型
|
328 |
+
return next(self.parameters()).dtype
|
329 |
+
|
330 |
+
|
331 |
+
class LabelEmbedder(nn.Module):
|
332 |
+
"""
|
333 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
337 |
+
super().__init__()
|
338 |
+
use_cfg_embedding = dropout_prob > 0
|
339 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
340 |
+
self.num_classes = num_classes
|
341 |
+
self.dropout_prob = dropout_prob
|
342 |
+
|
343 |
+
def token_drop(self, labels, force_drop_ids=None):
|
344 |
+
"""
|
345 |
+
Drops labels to enable classifier-free guidance.
|
346 |
+
"""
|
347 |
+
if force_drop_ids is None:
|
348 |
+
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
|
349 |
+
else:
|
350 |
+
drop_ids = force_drop_ids == 1
|
351 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
352 |
+
return labels
|
353 |
+
|
354 |
+
def forward(self, labels, train, force_drop_ids=None):
|
355 |
+
use_dropout = self.dropout_prob > 0
|
356 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
357 |
+
labels = self.token_drop(labels, force_drop_ids)
|
358 |
+
return self.embedding_table(labels)
|
359 |
+
|
360 |
+
|
361 |
+
def FeedForward(dim, mult=4):
|
362 |
+
inner_dim = int(dim * mult)
|
363 |
+
return nn.Sequential(
|
364 |
+
nn.LayerNorm(dim),
|
365 |
+
nn.Linear(dim, inner_dim, bias=False),
|
366 |
+
nn.GELU(),
|
367 |
+
nn.Linear(inner_dim, dim, bias=False),
|
368 |
+
)
|
369 |
+
|
370 |
+
|
371 |
+
def reshape_tensor(x, heads):
|
372 |
+
bs, length, width = x.shape
|
373 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
374 |
+
x = x.view(bs, length, heads, -1)
|
375 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
376 |
+
x = x.transpose(1, 2)
|
377 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
378 |
+
x = x.reshape(bs, heads, length, -1)
|
379 |
+
return x
|
380 |
+
|
381 |
+
|
382 |
+
class PerceiverAttention(nn.Module):
|
383 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
384 |
+
super().__init__()
|
385 |
+
self.scale = dim_head ** -0.5
|
386 |
+
self.dim_head = dim_head
|
387 |
+
self.heads = heads
|
388 |
+
inner_dim = dim_head * heads
|
389 |
+
|
390 |
+
self.norm1 = nn.LayerNorm(dim)
|
391 |
+
self.norm2 = nn.LayerNorm(dim)
|
392 |
+
|
393 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
394 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
395 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
396 |
+
|
397 |
+
def forward(self, x, latents):
|
398 |
+
"""
|
399 |
+
Args:
|
400 |
+
x (torch.Tensor): image features
|
401 |
+
shape (b, n1, D)
|
402 |
+
latent (torch.Tensor): latent features
|
403 |
+
shape (b, n2, D)
|
404 |
+
"""
|
405 |
+
x = self.norm1(x)
|
406 |
+
latents = self.norm2(latents)
|
407 |
+
|
408 |
+
b, l, _ = latents.shape
|
409 |
+
|
410 |
+
q = self.to_q(latents)
|
411 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
412 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
413 |
+
|
414 |
+
q = reshape_tensor(q, self.heads)
|
415 |
+
k = reshape_tensor(k, self.heads)
|
416 |
+
v = reshape_tensor(v, self.heads)
|
417 |
+
|
418 |
+
# attention
|
419 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
420 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
421 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
422 |
+
out = weight @ v
|
423 |
+
|
424 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
425 |
+
|
426 |
+
return self.to_out(out)
|
427 |
+
|
428 |
+
|
429 |
+
class ImageCaptionEmbedder(nn.Module):
|
430 |
+
"""
|
431 |
+
Embeds image feature into vector representations. Also handles label dropout for classifier-free guidance.
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), depth=4,
|
435 |
+
dim_head=64, heads=12, ff_mult=4, token_num=4):
|
436 |
+
super().__init__()
|
437 |
+
self.latents = nn.Parameter(torch.randn(1, token_num, hidden_size) / hidden_size ** 0.5)
|
438 |
+
|
439 |
+
self.proj_in = nn.Linear(in_channels, hidden_size)
|
440 |
+
|
441 |
+
self.proj_out = Mlp(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size,
|
442 |
+
act_layer=act_layer, drop=0)
|
443 |
+
self.norm_out = nn.LayerNorm(hidden_size)
|
444 |
+
|
445 |
+
self.layers = nn.ModuleList([])
|
446 |
+
for _ in range(depth):
|
447 |
+
self.layers.append(
|
448 |
+
nn.ModuleList(
|
449 |
+
[
|
450 |
+
PerceiverAttention(dim=hidden_size, dim_head=dim_head, heads=heads),
|
451 |
+
FeedForward(dim=hidden_size, mult=ff_mult),
|
452 |
+
]
|
453 |
+
)
|
454 |
+
)
|
455 |
+
self.uncond_prob = uncond_prob
|
456 |
+
|
457 |
+
def forward(self, x, train, force_drop_ids=None):
|
458 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
459 |
+
x = self.proj_in(x)
|
460 |
+
|
461 |
+
for attn, ff in self.layers:
|
462 |
+
latents = attn(x, latents) + latents
|
463 |
+
latents = ff(latents) + latents
|
464 |
+
|
465 |
+
latents = self.proj_out(latents)
|
466 |
+
latents = self.norm_out(latents)
|
467 |
+
image_caption = latents.unsqueeze(1) # # (N, 1, L, D)
|
468 |
+
return image_caption
|
469 |
+
|
470 |
+
|
471 |
+
|
472 |
+
class DinoFeatureEmbedderQFormer(nn.Module):
|
473 |
+
"""
|
474 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
475 |
+
"""
|
476 |
+
|
477 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=257, depth=4,
|
478 |
+
dim_head=64, heads=12, ff_mult=4 ):
|
479 |
+
super().__init__()
|
480 |
+
self.latents = nn.Parameter(torch.randn(1, token_num, hidden_size) / hidden_size ** 0.5)
|
481 |
+
|
482 |
+
self.proj_in = nn.Linear(in_channels, hidden_size)
|
483 |
+
|
484 |
+
self.proj_out = Mlp(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size,
|
485 |
+
act_layer=act_layer, drop=0)
|
486 |
+
self.norm_out = nn.LayerNorm(hidden_size)
|
487 |
+
|
488 |
+
self.layers = nn.ModuleList([])
|
489 |
+
for _ in range(depth):
|
490 |
+
self.layers.append(
|
491 |
+
nn.ModuleList(
|
492 |
+
[
|
493 |
+
PerceiverAttention(dim=hidden_size, dim_head=dim_head, heads=heads),
|
494 |
+
FeedForward(dim=hidden_size, mult=ff_mult),
|
495 |
+
]
|
496 |
+
)
|
497 |
+
)
|
498 |
+
def forward(self, x, train, force_drop_ids=None):
|
499 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
500 |
+
x = self.proj_in(x)
|
501 |
+
|
502 |
+
for attn, ff in self.layers:
|
503 |
+
latents = attn(x, latents) + latents
|
504 |
+
latents = ff(latents) + latents
|
505 |
+
|
506 |
+
latents = self.proj_out(latents)
|
507 |
+
latents = self.norm_out(latents)
|
508 |
+
image_caption = latents.unsqueeze(1) # # (N, 1, L, D)
|
509 |
+
return image_caption
|
510 |
+
|
511 |
+
|
512 |
+
|
513 |
+
class DinoFeatureEmbedderV2(nn.Module):
|
514 |
+
"""
|
515 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
516 |
+
"""
|
517 |
+
|
518 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=257, use_drop=True, dino_norm=False):
|
519 |
+
super().__init__()
|
520 |
+
self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
|
521 |
+
act_layer=act_layer, drop=0)
|
522 |
+
self.dino_norm = dino_norm
|
523 |
+
if self.dino_norm:
|
524 |
+
self.norm_out = nn.LayerNorm(hidden_size)
|
525 |
+
|
526 |
+
def forward(self, dino_feature):
|
527 |
+
dino_feature = dino_feature.unsqueeze(1)
|
528 |
+
dino_feature = self.y_proj(dino_feature)
|
529 |
+
if self.dino_norm:
|
530 |
+
dino_feature = self.norm_out(dino_feature)
|
531 |
+
|
532 |
+
return dino_feature
|
533 |
+
|
534 |
+
|
535 |
+
class DinoFeatureEmbedder(nn.Module):
|
536 |
+
"""
|
537 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
538 |
+
"""
|
539 |
+
|
540 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=257 ):
|
541 |
+
super().__init__()
|
542 |
+
self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
|
543 |
+
act_layer=act_layer, drop=0)
|
544 |
+
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
545 |
+
self.uncond_prob = uncond_prob
|
546 |
+
|
547 |
+
def token_drop(self, dino_feature, force_drop_ids=None):
|
548 |
+
"""
|
549 |
+
Drops labels to enable classifier-free guidance.
|
550 |
+
"""
|
551 |
+
if force_drop_ids is None:
|
552 |
+
drop_ids = torch.rand(dino_feature.shape[0]).cuda() < self.uncond_prob
|
553 |
+
else:
|
554 |
+
force_drop_ids = torch.tensor(force_drop_ids).cuda()
|
555 |
+
drop_ids = force_drop_ids == 1
|
556 |
+
dino_feature = torch.where(drop_ids[:, None, None, None], self.y_embedding, dino_feature)
|
557 |
+
return dino_feature
|
558 |
+
|
559 |
+
def forward(self, dino_feature, train, force_drop_ids=None):
|
560 |
+
# print("dino_2", dino_feature)
|
561 |
+
dino_feature = dino_feature.unsqueeze(1)
|
562 |
+
|
563 |
+
if train:
|
564 |
+
assert dino_feature.shape[2:] == self.y_embedding.shape
|
565 |
+
use_dropout = self.uncond_prob > 0
|
566 |
+
|
567 |
+
if (train and use_dropout) or (force_drop_ids is not None and force_drop_ids != {} and len(force_drop_ids) != 0):
|
568 |
+
dino_feature = self.token_drop(dino_feature, force_drop_ids)
|
569 |
+
dino_feature = self.y_proj(dino_feature)
|
570 |
+
# print("dino_3", dino_feature)
|
571 |
+
|
572 |
+
return dino_feature
|
573 |
+
|
574 |
+
|
575 |
+
class FusionEmbedder(nn.Module):
|
576 |
+
"""
|
577 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
578 |
+
"""
|
579 |
+
|
580 |
+
def __init__(self, in_channels, hidden_size, act_layer=nn.GELU(approximate='tanh')):
|
581 |
+
super().__init__()
|
582 |
+
self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
|
583 |
+
act_layer=act_layer, drop=0)
|
584 |
+
|
585 |
+
def forward(self, fusion_feature):
|
586 |
+
dino_feature = self.y_proj(fusion_feature)
|
587 |
+
return dino_feature
|
588 |
+
|
589 |
+
|
590 |
+
class CaptionEmbedder(nn.Module):
|
591 |
+
"""
|
592 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
593 |
+
"""
|
594 |
+
|
595 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
|
596 |
+
super().__init__()
|
597 |
+
self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
|
598 |
+
act_layer=act_layer, drop=0)
|
599 |
+
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
600 |
+
self.uncond_prob = uncond_prob
|
601 |
+
|
602 |
+
def token_drop(self, caption, force_drop_ids=None):
|
603 |
+
"""
|
604 |
+
Drops labels to enable classifier-free guidance.
|
605 |
+
"""
|
606 |
+
if force_drop_ids is None:
|
607 |
+
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
608 |
+
else:
|
609 |
+
drop_ids = force_drop_ids == 1
|
610 |
+
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
611 |
+
return caption
|
612 |
+
|
613 |
+
def forward(self, caption, train, force_drop_ids=None):
|
614 |
+
if train:
|
615 |
+
assert caption.shape[2:] == self.y_embedding.shape
|
616 |
+
use_dropout = self.uncond_prob > 0
|
617 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
618 |
+
caption = self.token_drop(caption, force_drop_ids)
|
619 |
+
caption = self.y_proj(caption)
|
620 |
+
return caption
|
621 |
+
|
622 |
+
|
623 |
+
class CaptionEmbedderDoubleBr(nn.Module):
|
624 |
+
"""
|
625 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
626 |
+
"""
|
627 |
+
|
628 |
+
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
|
629 |
+
super().__init__()
|
630 |
+
self.proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
|
631 |
+
act_layer=act_layer, drop=0)
|
632 |
+
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
|
633 |
+
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
|
634 |
+
self.uncond_prob = uncond_prob
|
635 |
+
|
636 |
+
def token_drop(self, global_caption, caption, force_drop_ids=None):
|
637 |
+
"""
|
638 |
+
Drops labels to enable classifier-free guidance.
|
639 |
+
"""
|
640 |
+
if force_drop_ids is None:
|
641 |
+
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
|
642 |
+
else:
|
643 |
+
drop_ids = force_drop_ids == 1
|
644 |
+
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
|
645 |
+
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
646 |
+
return global_caption, caption
|
647 |
+
|
648 |
+
def forward(self, caption, train, force_drop_ids=None):
|
649 |
+
assert caption.shape[2:] == self.y_embedding.shape
|
650 |
+
global_caption = caption.mean(dim=2).squeeze()
|
651 |
+
use_dropout = self.uncond_prob > 0
|
652 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
653 |
+
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
|
654 |
+
y_embed = self.proj(global_caption)
|
655 |
+
return y_embed, caption
|
DiT_VAE/diffusion/model/nets/TriDitCLIPDINO.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import os
|
14 |
+
import numpy as np
|
15 |
+
from timm.models.layers import DropPath
|
16 |
+
from timm.models.vision_transformer import PatchEmbed, Mlp
|
17 |
+
from DiT_VAE.diffusion.model.builder import MODELS
|
18 |
+
from DiT_VAE.diffusion.model.utils import auto_grad_checkpoint, to_2tuple
|
19 |
+
from DiT_VAE.diffusion.model.nets.PixArt_blocks import t2i_modulate, WindowAttention, MultiHeadCrossAttention, \
|
20 |
+
T2IFinalLayer, TimestepEmbedder, ImageCaptionEmbedder, DinoFeatureEmbedderQFormer
|
21 |
+
from DiT_VAE.diffusion.utils.logger import get_root_logger
|
22 |
+
|
23 |
+
|
24 |
+
class PixArtBlock(nn.Module):
|
25 |
+
"""
|
26 |
+
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None,
|
30 |
+
use_rel_pos=False, **block_kwargs):
|
31 |
+
super().__init__()
|
32 |
+
self.hidden_size = hidden_size
|
33 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
34 |
+
self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
|
35 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
36 |
+
use_rel_pos=use_rel_pos, **block_kwargs)
|
37 |
+
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
|
38 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
39 |
+
# to be compatible with lower version pytorch
|
40 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
41 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
42 |
+
drop=0)
|
43 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
44 |
+
self.window_size = window_size
|
45 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
46 |
+
|
47 |
+
def forward(self, x, y, t, mask=None, img_feature=None, **kwargs):
|
48 |
+
B, N, C = x.shape
|
49 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
50 |
+
self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
51 |
+
if img_feature is None:
|
52 |
+
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
|
53 |
+
else:
|
54 |
+
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
|
55 |
+
img_feature = img_feature.squeeze(1)
|
56 |
+
N_new = N + img_feature.shape[1]
|
57 |
+
x_m = self.attn(torch.cat([x_m, img_feature], dim=1)).reshape(B, N_new, C)
|
58 |
+
x_m = x_m[:,:N, :]
|
59 |
+
x = x + self.drop_path(gate_msa * x_m)
|
60 |
+
x = x + self.cross_attn(x, y, mask)
|
61 |
+
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
62 |
+
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
#############################################################################
|
67 |
+
# Core PixArt Model #
|
68 |
+
#################################################################################
|
69 |
+
@MODELS.register_module()
|
70 |
+
class TriDitCLIPDINO(nn.Module):
|
71 |
+
"""
|
72 |
+
Diffusion model with a Transformer backbone.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, input_size, patch_size=2, in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0,
|
76 |
+
class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0,
|
77 |
+
window_block_indexes=None, use_rel_pos=False, caption_channels=1280, lewei_scale=1.0, config=None, dino_channels=768, img_feature_self_attention=False, dino_norm=False,
|
78 |
+
model_max_length=257, **kwargs):
|
79 |
+
if window_block_indexes is None:
|
80 |
+
window_block_indexes = []
|
81 |
+
super().__init__()
|
82 |
+
self.img_feature_self_attention= img_feature_self_attention
|
83 |
+
self.pred_sigma = pred_sigma
|
84 |
+
self.in_channels = in_channels
|
85 |
+
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
86 |
+
self.patch_size = patch_size
|
87 |
+
self.num_heads = num_heads
|
88 |
+
self.lewei_scale = lewei_scale,
|
89 |
+
assert isinstance(input_size, tuple)
|
90 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
91 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
92 |
+
num_patches = self.x_embedder.num_patches
|
93 |
+
self.base_size_h = input_size[0] // self.patch_size
|
94 |
+
self.base_size_w = input_size[1] // self.patch_size
|
95 |
+
self.h = self.base_size_h
|
96 |
+
self.w= self.base_size_w
|
97 |
+
# Will use fixed sin-cos embedding:
|
98 |
+
self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
99 |
+
|
100 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
101 |
+
self.t_block = nn.Sequential(
|
102 |
+
nn.SiLU(),
|
103 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
104 |
+
)
|
105 |
+
self.dino_embedder = DinoFeatureEmbedderQFormer(in_channels=dino_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=256)
|
106 |
+
self.y_embedder = ImageCaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size,
|
107 |
+
uncond_prob=class_dropout_prob, act_layer=approx_gelu,
|
108 |
+
token_num=16)
|
109 |
+
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
110 |
+
self.blocks = nn.ModuleList([
|
111 |
+
PixArtBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
112 |
+
input_size=(input_size[0] // patch_size, input_size[1] // patch_size),
|
113 |
+
window_size=window_size if i in window_block_indexes else 0,
|
114 |
+
use_rel_pos=use_rel_pos if i in window_block_indexes else False)
|
115 |
+
for i in range(depth)
|
116 |
+
])
|
117 |
+
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
|
118 |
+
|
119 |
+
self.initialize_weights()
|
120 |
+
|
121 |
+
if config:
|
122 |
+
logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
|
123 |
+
logger.warning(
|
124 |
+
f"lewei scale: {self.lewei_scale}, base size h: {self.base_size_h} base size w: {self.base_size_w}")
|
125 |
+
else:
|
126 |
+
print(
|
127 |
+
f'Warning: lewei scale: {self.lewei_scale}, base size h: {self.base_size_h} base size w: {self.base_size_w}')
|
128 |
+
|
129 |
+
def forward(self, x, timestep, y, img_feature, drop_img_mask=None, mask=None, data_info=None, **kwargs):
|
130 |
+
"""
|
131 |
+
Forward pass of PixArt.
|
132 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
133 |
+
t: (N,) tensor of diffusion timesteps
|
134 |
+
y: (N, 1, 120, C) tensor of class labels
|
135 |
+
"""
|
136 |
+
x = x.to(self.dtype)
|
137 |
+
timestep = timestep.to(self.dtype)
|
138 |
+
y = y.to(self.dtype)
|
139 |
+
img_feature = img_feature.to(self.dtype)
|
140 |
+
pos_embed = self.pos_embed.to(self.dtype)
|
141 |
+
self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
142 |
+
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
143 |
+
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
|
144 |
+
t0 = self.t_block(t)
|
145 |
+
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
146 |
+
img_embedding = self.dino_embedder(img_feature, self.training)
|
147 |
+
# y_fusion = y
|
148 |
+
if mask is not None:
|
149 |
+
if mask.shape[0] != y.shape[0]:
|
150 |
+
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
151 |
+
mask = mask.squeeze(1).squeeze(1)
|
152 |
+
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
153 |
+
y_lens = mask.sum(dim=1).tolist()
|
154 |
+
else:
|
155 |
+
y_lens = [y.shape[2]] * y.shape[0]
|
156 |
+
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
157 |
+
for block in self.blocks:
|
158 |
+
x = auto_grad_checkpoint(block, x, y, t0, y_lens, img_embedding) # (N, T, D) #support grad checkpoint
|
159 |
+
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
160 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
161 |
+
return x
|
162 |
+
|
163 |
+
def forward_with_dpmsolver(self, x, timestep, y, img_feature, mask=None, **kwargs):
|
164 |
+
"""
|
165 |
+
dpm solver donnot need variance prediction
|
166 |
+
"""
|
167 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
168 |
+
model_out = self.forward(x, timestep, y, img_feature)
|
169 |
+
return model_out.chunk(2, dim=1)[0]
|
170 |
+
|
171 |
+
def forward_with_cfg(self, x, timestep, y, img_feature, cfg_scale, mask=None, **kwargs):
|
172 |
+
"""
|
173 |
+
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
|
174 |
+
"""
|
175 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
176 |
+
half = x[: len(x) // 2]
|
177 |
+
combined = torch.cat([half, half], dim=0)
|
178 |
+
model_out = self.forward(combined, timestep, y, img_feature, kwargs)
|
179 |
+
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
|
180 |
+
eps, rest = torch.split(model_out, self.in_channels, dim=1)
|
181 |
+
# eps, rest = model_out[:, :3], model_out[:, 3:]
|
182 |
+
|
183 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
184 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
185 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
186 |
+
return torch.cat([eps, rest], dim=1)
|
187 |
+
|
188 |
+
def unpatchify(self, x):
|
189 |
+
"""
|
190 |
+
x: (N, T, patch_size**2 * C)
|
191 |
+
imgs: (N, H, W, C)
|
192 |
+
"""
|
193 |
+
c = self.out_channels
|
194 |
+
p = self.x_embedder.patch_size[0]
|
195 |
+
h = int(x.shape[1] ** 0.5 * 2)
|
196 |
+
w = int(x.shape[1] ** 0.5 / 2)
|
197 |
+
assert h * w == x.shape[1]
|
198 |
+
|
199 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
200 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
201 |
+
return x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
202 |
+
|
203 |
+
def initialize_weights(self):
|
204 |
+
# Initialize transformer layers:
|
205 |
+
def _basic_init(module):
|
206 |
+
if isinstance(module, nn.Linear):
|
207 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
208 |
+
if module.bias is not None:
|
209 |
+
nn.init.constant_(module.bias, 0)
|
210 |
+
|
211 |
+
self.apply(_basic_init)
|
212 |
+
|
213 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
214 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5),
|
215 |
+
lewei_scale=self.lewei_scale, base_size_h=self.base_size_h,
|
216 |
+
base_size_w=self.base_size_w)
|
217 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
218 |
+
|
219 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
220 |
+
w = self.x_embedder.proj.weight.data
|
221 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
222 |
+
|
223 |
+
# Initialize timestep embedding MLP:
|
224 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
225 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
226 |
+
nn.init.normal_(self.t_block[1].weight, std=0.02)
|
227 |
+
|
228 |
+
# Initialize caption embedding MLP:
|
229 |
+
nn.init.normal_(self.y_embedder.proj_out.fc1.weight, std=0.02)
|
230 |
+
nn.init.normal_(self.y_embedder.proj_out.fc2.weight, std=0.02)
|
231 |
+
nn.init.normal_(self.y_embedder.proj_in.weight, std=0.02)
|
232 |
+
|
233 |
+
|
234 |
+
# Initialize dino embedding MLP:
|
235 |
+
# nn.init.normal_(self.dino_embedder.y_proj.fc1.weight, std=0.02)
|
236 |
+
# nn.init.normal_(self.dino_embedder.y_proj.fc2.weight, std=0.02)
|
237 |
+
nn.init.normal_(self.dino_embedder.proj_out.fc1.weight, std=0.02)
|
238 |
+
nn.init.normal_(self.dino_embedder.proj_out.fc2.weight, std=0.02)
|
239 |
+
nn.init.normal_(self.dino_embedder.proj_in.weight, std=0.02)
|
240 |
+
# if not self.img_feature_self_attention:
|
241 |
+
# # Initialize fusion embedding MLP:
|
242 |
+
# nn.init.normal_(self.fusion_embedder.y_proj.fc1.weight, std=0.02)
|
243 |
+
# nn.init.normal_(self.fusion_embedder.y_proj.fc2.weight, std=0.02)
|
244 |
+
|
245 |
+
# Zero-out adaLN modulation layers in PixArt blocks:
|
246 |
+
for block in self.blocks:
|
247 |
+
nn.init.constant_(block.cross_attn.proj.weight, 0)
|
248 |
+
nn.init.constant_(block.cross_attn.proj.bias, 0)
|
249 |
+
|
250 |
+
# Zero-out output layers:
|
251 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
252 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
253 |
+
|
254 |
+
@property
|
255 |
+
def dtype(self):
|
256 |
+
return next(self.parameters()).dtype
|
257 |
+
|
258 |
+
|
259 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, lewei_scale=1.0, base_size_h=16,
|
260 |
+
base_size_w=16):
|
261 |
+
"""
|
262 |
+
grid_size: int of the grid height and width
|
263 |
+
return:
|
264 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
265 |
+
"""
|
266 |
+
if isinstance(grid_size, int):
|
267 |
+
grid_size = to_2tuple(grid_size)
|
268 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size_h) / lewei_scale
|
269 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size_w) / lewei_scale
|
270 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
271 |
+
grid = np.stack(grid, axis=0)
|
272 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
273 |
+
|
274 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
275 |
+
if cls_token and extra_tokens > 0:
|
276 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
277 |
+
return pos_embed
|
278 |
+
|
279 |
+
|
280 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
281 |
+
assert embed_dim % 2 == 0
|
282 |
+
|
283 |
+
# use half of dimensions to encode grid_h
|
284 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
285 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
286 |
+
|
287 |
+
return np.concatenate([emb_h, emb_w], axis=1)
|
288 |
+
|
289 |
+
|
290 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
291 |
+
"""
|
292 |
+
embed_dim: output dimension for each position
|
293 |
+
pos: a list of positions to be encoded: size (M,)
|
294 |
+
out: (M, D)
|
295 |
+
"""
|
296 |
+
assert embed_dim % 2 == 0
|
297 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
298 |
+
omega /= embed_dim / 2.
|
299 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
300 |
+
|
301 |
+
pos = pos.reshape(-1) # (M,)
|
302 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
303 |
+
|
304 |
+
emb_sin = np.sin(out) # (M, D/2)
|
305 |
+
emb_cos = np.cos(out) # (M, D/2)
|
306 |
+
|
307 |
+
return np.concatenate([emb_sin, emb_cos], axis=1)
|
308 |
+
|
309 |
+
|
310 |
+
#################################################################################
|
311 |
+
# PixArt Configs #
|
312 |
+
#################################################################################
|
313 |
+
@MODELS.register_module()
|
314 |
+
def TriDitCLIPDINO_XL_2(**kwargs):
|
315 |
+
return TriDitCLIPDINO(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
DiT_VAE/diffusion/model/nets/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .TriDitCLIPDINO import TriDitCLIPDINO_XL_2, TriDitCLIPDINO
|
DiT_VAE/diffusion/model/respace.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
from .gaussian_diffusion import GaussianDiffusion
|
10 |
+
|
11 |
+
|
12 |
+
def space_timesteps(num_timesteps, section_counts):
|
13 |
+
"""
|
14 |
+
Create a list of timesteps to use from an original diffusion process,
|
15 |
+
given the number of timesteps we want to take from equally-sized portions
|
16 |
+
of the original process.
|
17 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
18 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
19 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
20 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
21 |
+
from the DDIM paper is used, and only one section is allowed.
|
22 |
+
:param num_timesteps: the number of diffusion steps in the original
|
23 |
+
process to divide up.
|
24 |
+
:param section_counts: either a list of numbers, or a string containing
|
25 |
+
comma-separated numbers, indicating the step count
|
26 |
+
per section. As a special case, use "ddimN" where N
|
27 |
+
is a number of steps to use the striding from the
|
28 |
+
DDIM paper.
|
29 |
+
:return: a set of diffusion steps from the original process to use.
|
30 |
+
"""
|
31 |
+
if isinstance(section_counts, str):
|
32 |
+
if section_counts.startswith("ddim"):
|
33 |
+
desired_count = int(section_counts[len("ddim") :])
|
34 |
+
for i in range(1, num_timesteps):
|
35 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
36 |
+
return set(range(0, num_timesteps, i))
|
37 |
+
raise ValueError(
|
38 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
39 |
+
)
|
40 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
41 |
+
size_per = num_timesteps // len(section_counts)
|
42 |
+
extra = num_timesteps % len(section_counts)
|
43 |
+
start_idx = 0
|
44 |
+
all_steps = []
|
45 |
+
for i, section_count in enumerate(section_counts):
|
46 |
+
size = size_per + (1 if i < extra else 0)
|
47 |
+
if size < section_count:
|
48 |
+
raise ValueError(
|
49 |
+
f"cannot divide section of {size} steps into {section_count}"
|
50 |
+
)
|
51 |
+
frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
|
52 |
+
cur_idx = 0.0
|
53 |
+
taken_steps = []
|
54 |
+
for _ in range(section_count):
|
55 |
+
taken_steps.append(start_idx + round(cur_idx))
|
56 |
+
cur_idx += frac_stride
|
57 |
+
all_steps += taken_steps
|
58 |
+
start_idx += size
|
59 |
+
return set(all_steps)
|
60 |
+
|
61 |
+
|
62 |
+
class SpacedDiffusion(GaussianDiffusion):
|
63 |
+
"""
|
64 |
+
A diffusion process which can skip steps in a base diffusion process.
|
65 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
66 |
+
original diffusion process to retain.
|
67 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, use_timesteps, **kwargs):
|
71 |
+
self.use_timesteps = set(use_timesteps)
|
72 |
+
self.timestep_map = []
|
73 |
+
self.original_num_steps = len(kwargs["betas"])
|
74 |
+
|
75 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
76 |
+
last_alpha_cumprod = 1.0
|
77 |
+
new_betas = []
|
78 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
79 |
+
if i in self.use_timesteps:
|
80 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
81 |
+
last_alpha_cumprod = alpha_cumprod
|
82 |
+
self.timestep_map.append(i)
|
83 |
+
kwargs["betas"] = np.array(new_betas)
|
84 |
+
super().__init__(**kwargs)
|
85 |
+
|
86 |
+
def p_mean_variance(
|
87 |
+
self, model, *args, **kwargs
|
88 |
+
): # pylint: disable=signature-differs
|
89 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
90 |
+
|
91 |
+
def training_losses(
|
92 |
+
self, model, *args, **kwargs
|
93 |
+
): # pylint: disable=signature-differs
|
94 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
95 |
+
|
96 |
+
def training_losses_diffusers(
|
97 |
+
self, model, *args, **kwargs
|
98 |
+
): # pylint: disable=signature-differs
|
99 |
+
return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs)
|
100 |
+
|
101 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
102 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
103 |
+
|
104 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
105 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
106 |
+
|
107 |
+
def _wrap_model(self, model):
|
108 |
+
if isinstance(model, _WrappedModel):
|
109 |
+
return model
|
110 |
+
return _WrappedModel(
|
111 |
+
model, self.timestep_map, self.original_num_steps
|
112 |
+
)
|
113 |
+
|
114 |
+
def _scale_timesteps(self, t):
|
115 |
+
# Scaling is done by the wrapped model.
|
116 |
+
return t
|
117 |
+
|
118 |
+
|
119 |
+
class _WrappedModel:
|
120 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
121 |
+
self.model = model
|
122 |
+
self.timestep_map = timestep_map
|
123 |
+
# self.rescale_timesteps = rescale_timesteps
|
124 |
+
self.original_num_steps = original_num_steps
|
125 |
+
|
126 |
+
def __call__(self, x, timestep, **kwargs):
|
127 |
+
map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype)
|
128 |
+
new_ts = map_tensor[timestep]
|
129 |
+
# if self.rescale_timesteps:
|
130 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
131 |
+
return self.model(x, timestep=new_ts, **kwargs)
|
DiT_VAE/diffusion/model/sa_solver.py
ADDED
@@ -0,0 +1,1129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
class NoiseScheduleVP:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
schedule='discrete',
|
11 |
+
betas=None,
|
12 |
+
alphas_cumprod=None,
|
13 |
+
continuous_beta_0=0.1,
|
14 |
+
continuous_beta_1=20.,
|
15 |
+
dtype=torch.float32,
|
16 |
+
):
|
17 |
+
"""Thanks to DPM-Solver for their code base"""
|
18 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
19 |
+
***
|
20 |
+
Update: We support discrete-time diffusion vae by implementing a picewise linear interpolation for log_alpha_t.
|
21 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion vae, especially for high-resolution images.
|
22 |
+
***
|
23 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
24 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
25 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
26 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
27 |
+
sigma_t = self.marginal_std(t)
|
28 |
+
lambda_t = self.marginal_lambda(t)
|
29 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
30 |
+
t = self.inverse_lambda(lambda_t)
|
31 |
+
===============================================================
|
32 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
33 |
+
1. For discrete-time DPMs:
|
34 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
35 |
+
t_i = (i + 1) / N
|
36 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
37 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
38 |
+
Args:
|
39 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
40 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
41 |
+
Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
42 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
43 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
44 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
45 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
46 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
47 |
+
and
|
48 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
49 |
+
2. For continuous-time DPMs:
|
50 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
51 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
52 |
+
Args:
|
53 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
54 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
55 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
56 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
57 |
+
T: A `float` number. The ending time of the forward process.
|
58 |
+
===============================================================
|
59 |
+
Args:
|
60 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
61 |
+
'linear' or 'cosine' for continuous-time DPMs.
|
62 |
+
Returns:
|
63 |
+
A wrapper object of the forward SDE (VP type).
|
64 |
+
|
65 |
+
===============================================================
|
66 |
+
Example:
|
67 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
68 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
69 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
70 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
71 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
72 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
73 |
+
"""
|
74 |
+
|
75 |
+
if schedule not in ['discrete', 'linear', 'cosine']:
|
76 |
+
raise ValueError(
|
77 |
+
f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'"
|
78 |
+
)
|
79 |
+
|
80 |
+
self.schedule = schedule
|
81 |
+
if schedule == 'discrete':
|
82 |
+
if betas is not None:
|
83 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
84 |
+
else:
|
85 |
+
assert alphas_cumprod is not None
|
86 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
87 |
+
self.total_N = len(log_alphas)
|
88 |
+
self.T = 1.
|
89 |
+
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
|
90 |
+
self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype)
|
91 |
+
else:
|
92 |
+
self.total_N = 1000
|
93 |
+
self.beta_0 = continuous_beta_0
|
94 |
+
self.beta_1 = continuous_beta_1
|
95 |
+
self.cosine_s = 0.008
|
96 |
+
self.cosine_beta_max = 999.
|
97 |
+
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
|
98 |
+
1. + self.cosine_s) / math.pi - self.cosine_s
|
99 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
100 |
+
self.schedule = schedule
|
101 |
+
self.T = 0.9946 if schedule == 'cosine' else 1.
|
102 |
+
|
103 |
+
def marginal_log_mean_coeff(self, t):
|
104 |
+
"""
|
105 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
106 |
+
"""
|
107 |
+
if self.schedule == 'discrete':
|
108 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
|
109 |
+
self.log_alpha_array.to(t.device)).reshape((-1))
|
110 |
+
elif self.schedule == 'linear':
|
111 |
+
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
112 |
+
elif self.schedule == 'cosine':
|
113 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
114 |
+
return log_alpha_fn(t) - self.cosine_log_alpha_0
|
115 |
+
|
116 |
+
def marginal_alpha(self, t):
|
117 |
+
"""
|
118 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
119 |
+
"""
|
120 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
121 |
+
|
122 |
+
def marginal_std(self, t):
|
123 |
+
"""
|
124 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
125 |
+
"""
|
126 |
+
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
127 |
+
|
128 |
+
def marginal_lambda(self, t):
|
129 |
+
"""
|
130 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
131 |
+
"""
|
132 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
133 |
+
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
134 |
+
return log_mean_coeff - log_std
|
135 |
+
|
136 |
+
def inverse_lambda(self, lamb):
|
137 |
+
"""
|
138 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
139 |
+
"""
|
140 |
+
if self.schedule == 'linear':
|
141 |
+
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
142 |
+
Delta = self.beta_0 ** 2 + tmp
|
143 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
144 |
+
elif self.schedule == 'discrete':
|
145 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
146 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
147 |
+
torch.flip(self.t_array.to(lamb.device), [1]))
|
148 |
+
return t.reshape((-1,))
|
149 |
+
else:
|
150 |
+
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
151 |
+
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
|
152 |
+
1. + self.cosine_s) / math.pi - self.cosine_s
|
153 |
+
return t_fn(log_alpha)
|
154 |
+
|
155 |
+
def edm_sigma(self, t):
|
156 |
+
return self.marginal_std(t) / self.marginal_alpha(t)
|
157 |
+
|
158 |
+
def edm_inverse_sigma(self, edmsigma):
|
159 |
+
alpha = 1 / (edmsigma ** 2 + 1).sqrt()
|
160 |
+
sigma = alpha * edmsigma
|
161 |
+
lambda_t = torch.log(alpha / sigma)
|
162 |
+
return self.inverse_lambda(lambda_t)
|
163 |
+
|
164 |
+
|
165 |
+
def model_wrapper(
|
166 |
+
model,
|
167 |
+
noise_schedule,
|
168 |
+
model_type="noise",
|
169 |
+
model_kwargs={},
|
170 |
+
guidance_type="uncond",
|
171 |
+
condition=None,
|
172 |
+
unconditional_condition=None,
|
173 |
+
guidance_scale=1.,
|
174 |
+
classifier_fn=None,
|
175 |
+
classifier_kwargs={},
|
176 |
+
):
|
177 |
+
"""Thanks to DPM-Solver for their code base"""
|
178 |
+
"""Create a wrapper function for the noise prediction model.
|
179 |
+
SA-Solver needs to solve the continuous-time diffusion SDEs. For DPMs trained on discrete-time labels, we need to
|
180 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
181 |
+
We support four types of the diffusion model by setting `model_type`:
|
182 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
183 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
184 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
185 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
186 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion vae."
|
187 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
188 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
189 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
190 |
+
|
191 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
192 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
193 |
+
```
|
194 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
195 |
+
```
|
196 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
197 |
+
1. "uncond": unconditional sampling by DPMs.
|
198 |
+
The input `model` has the following format:
|
199 |
+
``
|
200 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
201 |
+
``
|
202 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
203 |
+
The input `model` has the following format:
|
204 |
+
``
|
205 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
206 |
+
``
|
207 |
+
The input `classifier_fn` has the following format:
|
208 |
+
``
|
209 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
210 |
+
``
|
211 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion vae beat GANs on image synthesis,"
|
212 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
213 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
214 |
+
The input `model` has the following format:
|
215 |
+
``
|
216 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
217 |
+
``
|
218 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
219 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
220 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
221 |
+
|
222 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
223 |
+
or continuous-time labels (i.e. epsilon to T).
|
224 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
225 |
+
``
|
226 |
+
def model_fn(x, t_continuous) -> noise:
|
227 |
+
t_input = get_model_input_time(t_continuous)
|
228 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
229 |
+
``
|
230 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for SA-Solver.
|
231 |
+
===============================================================
|
232 |
+
Args:
|
233 |
+
model: A diffusion model with the corresponding format described above.
|
234 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
235 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
236 |
+
"noise" or "x_start" or "v" or "score".
|
237 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
238 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
239 |
+
"uncond" or "classifier" or "classifier-free".
|
240 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
241 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
242 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
243 |
+
Only used for "classifier-free" guidance type.
|
244 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
245 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
246 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
247 |
+
Returns:
|
248 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
249 |
+
"""
|
250 |
+
|
251 |
+
def get_model_input_time(t_continuous):
|
252 |
+
"""
|
253 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
254 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
255 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
256 |
+
"""
|
257 |
+
if noise_schedule.schedule == 'discrete':
|
258 |
+
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
259 |
+
else:
|
260 |
+
return t_continuous
|
261 |
+
|
262 |
+
def noise_pred_fn(x, t_continuous, cond=None, cond_2=None):
|
263 |
+
t_input = get_model_input_time(t_continuous)
|
264 |
+
if cond is None:
|
265 |
+
output = model(x, t_input, **model_kwargs)
|
266 |
+
else:
|
267 |
+
output = model(x, t_input, cond, cond_2, **model_kwargs)
|
268 |
+
if model_type == "noise":
|
269 |
+
return output
|
270 |
+
elif model_type == "x_start":
|
271 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
272 |
+
return (x - alpha_t[0] * output) / sigma_t[0]
|
273 |
+
elif model_type == "v":
|
274 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
275 |
+
return alpha_t[0] * output + sigma_t[0] * x
|
276 |
+
elif model_type == "score":
|
277 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
278 |
+
return -sigma_t[0] * output
|
279 |
+
|
280 |
+
def cond_grad_fn(x, t_input):
|
281 |
+
"""
|
282 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
283 |
+
"""
|
284 |
+
with torch.enable_grad():
|
285 |
+
x_in = x.detach().requires_grad_(True)
|
286 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
287 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
288 |
+
|
289 |
+
def model_fn(x, t_continuous):
|
290 |
+
"""
|
291 |
+
The noise predicition model function that is used for DPM-Solver.
|
292 |
+
"""
|
293 |
+
if guidance_type == "uncond":
|
294 |
+
return noise_pred_fn(x, t_continuous)
|
295 |
+
elif guidance_type == "classifier":
|
296 |
+
assert classifier_fn is not None
|
297 |
+
t_input = get_model_input_time(t_continuous)
|
298 |
+
cond_grad = cond_grad_fn(x, t_input)
|
299 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
300 |
+
noise = noise_pred_fn(x, t_continuous)
|
301 |
+
return noise - guidance_scale * sigma_t * cond_grad
|
302 |
+
elif guidance_type == "classifier-free":
|
303 |
+
if guidance_scale == 1. or unconditional_condition is None:
|
304 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
305 |
+
x_in = torch.cat([x] * 2)
|
306 |
+
t_in = torch.cat([t_continuous] * 2)
|
307 |
+
# c_in = torch.cat([unconditional_condition, condition])
|
308 |
+
c_in_y = torch.cat([unconditional_condition[0], condition[0]])
|
309 |
+
c_in_dino = torch.cat([unconditional_condition[1], condition[1]])
|
310 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in_y, cond_2=c_in_dino).chunk(2)
|
311 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
312 |
+
|
313 |
+
assert model_type in ["noise", "x_start", "v", "score"]
|
314 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
315 |
+
return model_fn
|
316 |
+
|
317 |
+
|
318 |
+
class SASolver:
|
319 |
+
def __init__(
|
320 |
+
self,
|
321 |
+
model_fn,
|
322 |
+
noise_schedule,
|
323 |
+
algorithm_type="data_prediction",
|
324 |
+
correcting_x0_fn=None,
|
325 |
+
correcting_xt_fn=None,
|
326 |
+
thresholding_max_val=1.,
|
327 |
+
dynamic_thresholding_ratio=0.995
|
328 |
+
):
|
329 |
+
"""
|
330 |
+
Construct a SA-Solver
|
331 |
+
The default value for algorithm_type is "data_prediction" and we recommend not to change it to
|
332 |
+
"noise_prediction". For details, please see Appendix A.2.4 in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
333 |
+
"""
|
334 |
+
|
335 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
336 |
+
self.noise_schedule = noise_schedule
|
337 |
+
assert algorithm_type in ["data_prediction", "noise_prediction"]
|
338 |
+
|
339 |
+
if correcting_x0_fn == "dynamic_thresholding":
|
340 |
+
self.correcting_x0_fn = self.dynamic_thresholding_fn
|
341 |
+
else:
|
342 |
+
self.correcting_x0_fn = correcting_x0_fn
|
343 |
+
|
344 |
+
self.correcting_xt_fn = correcting_xt_fn
|
345 |
+
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
346 |
+
self.thresholding_max_val = thresholding_max_val
|
347 |
+
|
348 |
+
self.predict_x0 = algorithm_type == "data_prediction"
|
349 |
+
|
350 |
+
self.sigma_min = float(self.noise_schedule.edm_sigma(torch.tensor([1e-3])))
|
351 |
+
self.sigma_max = float(self.noise_schedule.edm_sigma(torch.tensor([1])))
|
352 |
+
|
353 |
+
def dynamic_thresholding_fn(self, x0, t=None):
|
354 |
+
"""
|
355 |
+
The dynamic thresholding method.
|
356 |
+
"""
|
357 |
+
dims = x0.dim()
|
358 |
+
p = self.dynamic_thresholding_ratio
|
359 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
360 |
+
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
361 |
+
x0 = torch.clamp(x0, -s, s) / s
|
362 |
+
return x0
|
363 |
+
|
364 |
+
def noise_prediction_fn(self, x, t):
|
365 |
+
"""
|
366 |
+
Return the noise prediction model.
|
367 |
+
"""
|
368 |
+
return self.model(x, t)
|
369 |
+
|
370 |
+
def data_prediction_fn(self, x, t):
|
371 |
+
"""
|
372 |
+
Return the data prediction model (with corrector).
|
373 |
+
"""
|
374 |
+
noise = self.noise_prediction_fn(x, t)
|
375 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
376 |
+
x0 = (x - sigma_t * noise) / alpha_t
|
377 |
+
if self.correcting_x0_fn is not None:
|
378 |
+
x0 = self.correcting_x0_fn(x0)
|
379 |
+
return x0
|
380 |
+
|
381 |
+
def model_fn(self, x, t):
|
382 |
+
"""
|
383 |
+
Convert the model to the noise prediction model or the data prediction model.
|
384 |
+
"""
|
385 |
+
|
386 |
+
if self.predict_x0:
|
387 |
+
return self.data_prediction_fn(x, t)
|
388 |
+
else:
|
389 |
+
return self.noise_prediction_fn(x, t)
|
390 |
+
|
391 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, order, device):
|
392 |
+
"""Compute the intermediate time steps for sampling.
|
393 |
+
"""
|
394 |
+
if skip_type == 'logSNR':
|
395 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
396 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
397 |
+
logSNR_steps = lambda_T + torch.linspace(torch.tensor(0.).cpu().item(),
|
398 |
+
(lambda_0 - lambda_T).cpu().item() ** (1. / order), N + 1).pow(
|
399 |
+
order).to(device)
|
400 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
401 |
+
elif skip_type == 'time':
|
402 |
+
t = torch.linspace(t_T ** (1. / order), t_0 ** (1. / order), N + 1).pow(order).to(device)
|
403 |
+
return t
|
404 |
+
elif skip_type == 'karras':
|
405 |
+
sigma_min = max(0.002, self.sigma_min)
|
406 |
+
sigma_max = min(80, self.sigma_max)
|
407 |
+
sigma_steps = torch.linspace(sigma_max ** (1. / 7), sigma_min ** (1. / 7), N + 1).pow(7).to(device)
|
408 |
+
return self.noise_schedule.edm_inverse_sigma(sigma_steps)
|
409 |
+
else:
|
410 |
+
raise ValueError(
|
411 |
+
f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time' or 'karras'"
|
412 |
+
)
|
413 |
+
|
414 |
+
def denoise_to_zero_fn(self, x, s):
|
415 |
+
"""
|
416 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
417 |
+
"""
|
418 |
+
return self.data_prediction_fn(x, s)
|
419 |
+
|
420 |
+
def get_coefficients_exponential_negative(self, order, interval_start, interval_end):
|
421 |
+
"""
|
422 |
+
Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end
|
423 |
+
For calculating the coefficient of gradient terms after the lagrange interpolation,
|
424 |
+
see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
425 |
+
For noise_prediction formula.
|
426 |
+
"""
|
427 |
+
assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
|
428 |
+
|
429 |
+
if order == 0:
|
430 |
+
return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1)
|
431 |
+
elif order == 1:
|
432 |
+
return torch.exp(-interval_end) * (
|
433 |
+
(interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1))
|
434 |
+
elif order == 2:
|
435 |
+
return torch.exp(-interval_end) * (
|
436 |
+
(interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - (
|
437 |
+
interval_end ** 2 + 2 * interval_end + 2))
|
438 |
+
elif order == 3:
|
439 |
+
return torch.exp(-interval_end) * (
|
440 |
+
(interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp(
|
441 |
+
interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6))
|
442 |
+
|
443 |
+
def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau):
|
444 |
+
"""
|
445 |
+
Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
|
446 |
+
For calculating the coefficient of gradient terms after the lagrange interpolation,
|
447 |
+
see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
448 |
+
For data_prediction formula.
|
449 |
+
"""
|
450 |
+
assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
|
451 |
+
|
452 |
+
# after change of variable(cov)
|
453 |
+
interval_end_cov = (1 + tau ** 2) * interval_end
|
454 |
+
interval_start_cov = (1 + tau ** 2) * interval_start
|
455 |
+
|
456 |
+
if order == 0:
|
457 |
+
return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / (
|
458 |
+
(1 + tau ** 2))
|
459 |
+
elif order == 1:
|
460 |
+
return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(
|
461 |
+
-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2)
|
462 |
+
elif order == 2:
|
463 |
+
return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - (
|
464 |
+
interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp(
|
465 |
+
-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3)
|
466 |
+
elif order == 3:
|
467 |
+
return torch.exp(interval_end_cov) * (
|
468 |
+
(interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - (
|
469 |
+
interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp(
|
470 |
+
-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4)
|
471 |
+
|
472 |
+
def lagrange_polynomial_coefficient(self, order, lambda_list):
|
473 |
+
"""
|
474 |
+
Calculate the coefficient of lagrange polynomial
|
475 |
+
For lagrange interpolation
|
476 |
+
"""
|
477 |
+
assert order in [0, 1, 2, 3]
|
478 |
+
assert order == len(lambda_list) - 1
|
479 |
+
if order == 0:
|
480 |
+
return [[1]]
|
481 |
+
elif order == 1:
|
482 |
+
return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
|
483 |
+
[1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
|
484 |
+
elif order == 2:
|
485 |
+
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
|
486 |
+
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2])
|
487 |
+
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1])
|
488 |
+
return [[1 / denominator1,
|
489 |
+
(-lambda_list[1] - lambda_list[2]) / denominator1,
|
490 |
+
lambda_list[1] * lambda_list[2] / denominator1],
|
491 |
+
|
492 |
+
[1 / denominator2,
|
493 |
+
(-lambda_list[0] - lambda_list[2]) / denominator2,
|
494 |
+
lambda_list[0] * lambda_list[2] / denominator2],
|
495 |
+
|
496 |
+
[1 / denominator3,
|
497 |
+
(-lambda_list[0] - lambda_list[1]) / denominator3,
|
498 |
+
lambda_list[0] * lambda_list[1] / denominator3]
|
499 |
+
]
|
500 |
+
elif order == 3:
|
501 |
+
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (
|
502 |
+
lambda_list[0] - lambda_list[3])
|
503 |
+
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (
|
504 |
+
lambda_list[1] - lambda_list[3])
|
505 |
+
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (
|
506 |
+
lambda_list[2] - lambda_list[3])
|
507 |
+
denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (
|
508 |
+
lambda_list[3] - lambda_list[2])
|
509 |
+
return [[1 / denominator1,
|
510 |
+
(-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1,
|
511 |
+
(lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[
|
512 |
+
3]) / denominator1,
|
513 |
+
(-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1],
|
514 |
+
|
515 |
+
[1 / denominator2,
|
516 |
+
(-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2,
|
517 |
+
(lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[
|
518 |
+
3]) / denominator2,
|
519 |
+
(-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],
|
520 |
+
|
521 |
+
[1 / denominator3,
|
522 |
+
(-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
|
523 |
+
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[
|
524 |
+
3]) / denominator3,
|
525 |
+
(-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],
|
526 |
+
|
527 |
+
[1 / denominator4,
|
528 |
+
(-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
|
529 |
+
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[
|
530 |
+
2]) / denominator4,
|
531 |
+
(-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
|
532 |
+
|
533 |
+
]
|
534 |
+
|
535 |
+
def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau):
|
536 |
+
"""
|
537 |
+
Calculate the coefficient of gradients.
|
538 |
+
"""
|
539 |
+
assert order in [1, 2, 3, 4]
|
540 |
+
assert order == len(lambda_list), 'the length of lambda list must be equal to the order'
|
541 |
+
coefficients = []
|
542 |
+
lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list)
|
543 |
+
for i in range(order):
|
544 |
+
coefficient = sum(
|
545 |
+
lagrange_coefficient[i][j]
|
546 |
+
* self.get_coefficients_exponential_positive(
|
547 |
+
order - 1 - j, interval_start, interval_end, tau
|
548 |
+
)
|
549 |
+
if self.predict_x0
|
550 |
+
else lagrange_coefficient[i][j]
|
551 |
+
* self.get_coefficients_exponential_negative(
|
552 |
+
order - 1 - j, interval_start, interval_end
|
553 |
+
)
|
554 |
+
for j in range(order)
|
555 |
+
)
|
556 |
+
coefficients.append(coefficient)
|
557 |
+
assert len(coefficients) == order, 'the length of coefficients does not match the order'
|
558 |
+
return coefficients
|
559 |
+
|
560 |
+
def adams_bashforth_update(self, order, x, tau, model_prev_list, t_prev_list, noise, t):
|
561 |
+
"""
|
562 |
+
SA-Predictor, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
563 |
+
"""
|
564 |
+
assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
|
565 |
+
|
566 |
+
# get noise schedule
|
567 |
+
ns = self.noise_schedule
|
568 |
+
alpha_t = ns.marginal_alpha(t)
|
569 |
+
sigma_t = ns.marginal_std(t)
|
570 |
+
lambda_t = ns.marginal_lambda(t)
|
571 |
+
alpha_prev = ns.marginal_alpha(t_prev_list[-1])
|
572 |
+
sigma_prev = ns.marginal_std(t_prev_list[-1])
|
573 |
+
gradient_part = torch.zeros_like(x)
|
574 |
+
h = lambda_t - ns.marginal_lambda(t_prev_list[-1])
|
575 |
+
lambda_list = [ns.marginal_lambda(t_prev_list[-(i + 1)]) for i in range(order)]
|
576 |
+
gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t,
|
577 |
+
lambda_list, tau)
|
578 |
+
|
579 |
+
for i in range(order):
|
580 |
+
if self.predict_x0:
|
581 |
+
gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
|
582 |
+
i] * model_prev_list[-(i + 1)]
|
583 |
+
else:
|
584 |
+
gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
585 |
+
|
586 |
+
if self.predict_x0:
|
587 |
+
noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
|
588 |
+
else:
|
589 |
+
noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
|
590 |
+
|
591 |
+
if self.predict_x0:
|
592 |
+
x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part
|
593 |
+
else:
|
594 |
+
x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part
|
595 |
+
|
596 |
+
return x_t
|
597 |
+
|
598 |
+
def adams_moulton_update(self, order, x, tau, model_prev_list, t_prev_list, noise, t):
|
599 |
+
"""
|
600 |
+
SA-Corrector, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
601 |
+
"""
|
602 |
+
|
603 |
+
assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
|
604 |
+
|
605 |
+
# get noise schedule
|
606 |
+
ns = self.noise_schedule
|
607 |
+
alpha_t = ns.marginal_alpha(t)
|
608 |
+
sigma_t = ns.marginal_std(t)
|
609 |
+
lambda_t = ns.marginal_lambda(t)
|
610 |
+
alpha_prev = ns.marginal_alpha(t_prev_list[-1])
|
611 |
+
sigma_prev = ns.marginal_std(t_prev_list[-1])
|
612 |
+
gradient_part = torch.zeros_like(x)
|
613 |
+
h = lambda_t - ns.marginal_lambda(t_prev_list[-1])
|
614 |
+
t_list = t_prev_list + [t]
|
615 |
+
lambda_list = [ns.marginal_lambda(t_list[-(i + 1)]) for i in range(order)]
|
616 |
+
gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t,
|
617 |
+
lambda_list, tau)
|
618 |
+
|
619 |
+
for i in range(order):
|
620 |
+
if self.predict_x0:
|
621 |
+
gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
|
622 |
+
i] * model_prev_list[-(i + 1)]
|
623 |
+
else:
|
624 |
+
gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
625 |
+
|
626 |
+
if self.predict_x0:
|
627 |
+
noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
|
628 |
+
else:
|
629 |
+
noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
|
630 |
+
|
631 |
+
if self.predict_x0:
|
632 |
+
x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part
|
633 |
+
else:
|
634 |
+
x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part
|
635 |
+
|
636 |
+
return x_t
|
637 |
+
|
638 |
+
def adams_bashforth_update_few_steps(self, order, x, tau, model_prev_list, t_prev_list, noise, t):
|
639 |
+
"""
|
640 |
+
SA-Predictor, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
641 |
+
"""
|
642 |
+
|
643 |
+
assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
|
644 |
+
|
645 |
+
# get noise schedule
|
646 |
+
ns = self.noise_schedule
|
647 |
+
alpha_t = ns.marginal_alpha(t)
|
648 |
+
sigma_t = ns.marginal_std(t)
|
649 |
+
lambda_t = ns.marginal_lambda(t)
|
650 |
+
alpha_prev = ns.marginal_alpha(t_prev_list[-1])
|
651 |
+
sigma_prev = ns.marginal_std(t_prev_list[-1])
|
652 |
+
gradient_part = torch.zeros_like(x)
|
653 |
+
h = lambda_t - ns.marginal_lambda(t_prev_list[-1])
|
654 |
+
lambda_list = [ns.marginal_lambda(t_prev_list[-(i + 1)]) for i in range(order)]
|
655 |
+
gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t,
|
656 |
+
lambda_list, tau)
|
657 |
+
|
658 |
+
if self.predict_x0:
|
659 |
+
if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling.
|
660 |
+
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
|
661 |
+
# ODE case
|
662 |
+
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
|
663 |
+
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
|
664 |
+
gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
|
665 |
+
h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
|
666 |
+
(1 + tau ** 2) ** 2)) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(
|
667 |
+
t_prev_list[-2]))
|
668 |
+
gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
|
669 |
+
h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
|
670 |
+
(1 + tau ** 2) ** 2)) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(
|
671 |
+
t_prev_list[-2]))
|
672 |
+
|
673 |
+
for i in range(order):
|
674 |
+
if self.predict_x0:
|
675 |
+
gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
|
676 |
+
i] * model_prev_list[-(i + 1)]
|
677 |
+
else:
|
678 |
+
gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
679 |
+
|
680 |
+
if self.predict_x0:
|
681 |
+
noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
|
682 |
+
else:
|
683 |
+
noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
|
684 |
+
|
685 |
+
if self.predict_x0:
|
686 |
+
x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part
|
687 |
+
else:
|
688 |
+
x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part
|
689 |
+
|
690 |
+
return x_t
|
691 |
+
|
692 |
+
def adams_moulton_update_few_steps(self, order, x, tau, model_prev_list, t_prev_list, noise, t):
|
693 |
+
"""
|
694 |
+
SA-Corrector, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
695 |
+
"""
|
696 |
+
|
697 |
+
assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
|
698 |
+
|
699 |
+
# get noise schedule
|
700 |
+
ns = self.noise_schedule
|
701 |
+
alpha_t = ns.marginal_alpha(t)
|
702 |
+
sigma_t = ns.marginal_std(t)
|
703 |
+
lambda_t = ns.marginal_lambda(t)
|
704 |
+
alpha_prev = ns.marginal_alpha(t_prev_list[-1])
|
705 |
+
sigma_prev = ns.marginal_std(t_prev_list[-1])
|
706 |
+
gradient_part = torch.zeros_like(x)
|
707 |
+
h = lambda_t - ns.marginal_lambda(t_prev_list[-1])
|
708 |
+
t_list = t_prev_list + [t]
|
709 |
+
lambda_list = [ns.marginal_lambda(t_list[-(i + 1)]) for i in range(order)]
|
710 |
+
gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t,
|
711 |
+
lambda_list, tau)
|
712 |
+
|
713 |
+
if self.predict_x0:
|
714 |
+
if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling.
|
715 |
+
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
|
716 |
+
# ODE case
|
717 |
+
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
|
718 |
+
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
|
719 |
+
gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
|
720 |
+
h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
|
721 |
+
(1 + tau ** 2) ** 2 * h))
|
722 |
+
gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
|
723 |
+
h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
|
724 |
+
(1 + tau ** 2) ** 2 * h))
|
725 |
+
|
726 |
+
for i in range(order):
|
727 |
+
if self.predict_x0:
|
728 |
+
gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
|
729 |
+
i] * model_prev_list[-(i + 1)]
|
730 |
+
else:
|
731 |
+
gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
732 |
+
|
733 |
+
if self.predict_x0:
|
734 |
+
noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
|
735 |
+
else:
|
736 |
+
noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
|
737 |
+
|
738 |
+
if self.predict_x0:
|
739 |
+
x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part
|
740 |
+
else:
|
741 |
+
x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part
|
742 |
+
|
743 |
+
return x_t
|
744 |
+
|
745 |
+
def sample_few_steps(self, x, tau, steps=5, t_start=None, t_end=None, skip_type='time', skip_order=1,
|
746 |
+
predictor_order=3, corrector_order=4, pc_mode='PEC', return_intermediate=False
|
747 |
+
):
|
748 |
+
"""
|
749 |
+
For the PC-mode, please refer to the wiki page
|
750 |
+
https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode
|
751 |
+
'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations
|
752 |
+
We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs.
|
753 |
+
"""
|
754 |
+
|
755 |
+
skip_first_step = False
|
756 |
+
skip_final_step = True
|
757 |
+
lower_order_final = True
|
758 |
+
denoise_to_zero = False
|
759 |
+
|
760 |
+
assert pc_mode in ['PEC', 'PECE'], 'Predictor-corrector mode only supports PEC and PECE'
|
761 |
+
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
762 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
763 |
+
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
764 |
+
|
765 |
+
device = x.device
|
766 |
+
intermediates = []
|
767 |
+
with torch.no_grad():
|
768 |
+
assert steps >= max(predictor_order, corrector_order - 1)
|
769 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, order=skip_order,
|
770 |
+
device=device)
|
771 |
+
assert timesteps.shape[0] - 1 == steps
|
772 |
+
# Init the initial values.
|
773 |
+
step = 0
|
774 |
+
t = timesteps[step]
|
775 |
+
noise = torch.randn_like(x)
|
776 |
+
t_prev_list = [t]
|
777 |
+
# do not evaluate if skip_first_step
|
778 |
+
if skip_first_step:
|
779 |
+
if self.predict_x0:
|
780 |
+
alpha_t = self.noise_schedule.marginal_alpha(t)
|
781 |
+
sigma_t = self.noise_schedule.marginal_std(t)
|
782 |
+
model_prev_list = [(1 - sigma_t) / alpha_t * x]
|
783 |
+
else:
|
784 |
+
model_prev_list = [x]
|
785 |
+
else:
|
786 |
+
model_prev_list = [self.model_fn(x, t)]
|
787 |
+
|
788 |
+
if self.correcting_xt_fn is not None:
|
789 |
+
x = self.correcting_xt_fn(x, t, step)
|
790 |
+
if return_intermediate:
|
791 |
+
intermediates.append(x)
|
792 |
+
|
793 |
+
# determine the first several values
|
794 |
+
for step in tqdm(range(1, max(predictor_order, corrector_order - 1))):
|
795 |
+
|
796 |
+
t = timesteps[step]
|
797 |
+
predictor_order_used = min(predictor_order, step)
|
798 |
+
corrector_order_used = min(corrector_order, step + 1)
|
799 |
+
noise = torch.randn_like(x)
|
800 |
+
# predictor step
|
801 |
+
x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau(t),
|
802 |
+
model_prev_list=model_prev_list, t_prev_list=t_prev_list,
|
803 |
+
noise=noise, t=t)
|
804 |
+
# evaluation step
|
805 |
+
model_x = self.model_fn(x_p, t)
|
806 |
+
|
807 |
+
# update model_list
|
808 |
+
model_prev_list.append(model_x)
|
809 |
+
# corrector step
|
810 |
+
if corrector_order > 0:
|
811 |
+
x = self.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau(t),
|
812 |
+
model_prev_list=model_prev_list, t_prev_list=t_prev_list,
|
813 |
+
noise=noise, t=t)
|
814 |
+
else:
|
815 |
+
x = x_p
|
816 |
+
|
817 |
+
# evaluation step if correction and mode = pece
|
818 |
+
if corrector_order > 0 and pc_mode == 'PECE':
|
819 |
+
model_x = self.model_fn(x, t)
|
820 |
+
del model_prev_list[-1]
|
821 |
+
model_prev_list.append(model_x)
|
822 |
+
|
823 |
+
if self.correcting_xt_fn is not None:
|
824 |
+
x = self.correcting_xt_fn(x, t, step)
|
825 |
+
if return_intermediate:
|
826 |
+
intermediates.append(x)
|
827 |
+
|
828 |
+
t_prev_list.append(t)
|
829 |
+
|
830 |
+
for step in tqdm(range(max(predictor_order, corrector_order - 1), steps + 1)):
|
831 |
+
if lower_order_final:
|
832 |
+
predictor_order_used = min(predictor_order, steps - step + 1)
|
833 |
+
corrector_order_used = min(corrector_order, steps - step + 2)
|
834 |
+
|
835 |
+
else:
|
836 |
+
predictor_order_used = predictor_order
|
837 |
+
corrector_order_used = corrector_order
|
838 |
+
t = timesteps[step]
|
839 |
+
noise = torch.randn_like(x)
|
840 |
+
|
841 |
+
# predictor step
|
842 |
+
if skip_final_step and step == steps and not denoise_to_zero:
|
843 |
+
x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=0,
|
844 |
+
model_prev_list=model_prev_list,
|
845 |
+
t_prev_list=t_prev_list, noise=noise, t=t)
|
846 |
+
else:
|
847 |
+
x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau(t),
|
848 |
+
model_prev_list=model_prev_list,
|
849 |
+
t_prev_list=t_prev_list, noise=noise, t=t)
|
850 |
+
|
851 |
+
# evaluation step
|
852 |
+
# do not evaluate if skip_final_step and step = steps
|
853 |
+
if not skip_final_step or step < steps:
|
854 |
+
model_x = self.model_fn(x_p, t)
|
855 |
+
|
856 |
+
# update model_list
|
857 |
+
# do not update if skip_final_step and step = steps
|
858 |
+
if not skip_final_step or step < steps:
|
859 |
+
model_prev_list.append(model_x)
|
860 |
+
|
861 |
+
# corrector step
|
862 |
+
# do not correct if skip_final_step and step = steps
|
863 |
+
if corrector_order > 0 and (not skip_final_step or step < steps):
|
864 |
+
x = self.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau(t),
|
865 |
+
model_prev_list=model_prev_list,
|
866 |
+
t_prev_list=t_prev_list, noise=noise, t=t)
|
867 |
+
else:
|
868 |
+
x = x_p
|
869 |
+
|
870 |
+
# evaluation step if mode = pece and step != steps
|
871 |
+
if corrector_order > 0 and (pc_mode == 'PECE' and step < steps):
|
872 |
+
model_x = self.model_fn(x, t)
|
873 |
+
del model_prev_list[-1]
|
874 |
+
model_prev_list.append(model_x)
|
875 |
+
|
876 |
+
if self.correcting_xt_fn is not None:
|
877 |
+
x = self.correcting_xt_fn(x, t, step)
|
878 |
+
if return_intermediate:
|
879 |
+
intermediates.append(x)
|
880 |
+
|
881 |
+
t_prev_list.append(t)
|
882 |
+
del model_prev_list[0]
|
883 |
+
|
884 |
+
if denoise_to_zero:
|
885 |
+
t = torch.ones((1,)).to(device) * t_0
|
886 |
+
x = self.denoise_to_zero_fn(x, t)
|
887 |
+
if self.correcting_xt_fn is not None:
|
888 |
+
x = self.correcting_xt_fn(x, t, step + 1)
|
889 |
+
if return_intermediate:
|
890 |
+
intermediates.append(x)
|
891 |
+
return (x, intermediates) if return_intermediate else x
|
892 |
+
|
893 |
+
def sample_more_steps(self, x, tau, steps=20, t_start=None, t_end=None, skip_type='time', skip_order=1,
|
894 |
+
predictor_order=3, corrector_order=4, pc_mode='PEC', return_intermediate=False
|
895 |
+
):
|
896 |
+
"""
|
897 |
+
For the PC-mode, please refer to the wiki page
|
898 |
+
https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode
|
899 |
+
'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations
|
900 |
+
We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs.
|
901 |
+
"""
|
902 |
+
|
903 |
+
skip_first_step = False
|
904 |
+
skip_final_step = False
|
905 |
+
lower_order_final = True
|
906 |
+
denoise_to_zero = True
|
907 |
+
|
908 |
+
assert pc_mode in ['PEC', 'PECE'], 'Predictor-corrector mode only supports PEC and PECE'
|
909 |
+
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
910 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
911 |
+
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
912 |
+
|
913 |
+
device = x.device
|
914 |
+
intermediates = []
|
915 |
+
with torch.no_grad():
|
916 |
+
assert steps >= max(predictor_order, corrector_order - 1)
|
917 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, order=skip_order,
|
918 |
+
device=device)
|
919 |
+
assert timesteps.shape[0] - 1 == steps
|
920 |
+
# Init the initial values.
|
921 |
+
step = 0
|
922 |
+
t = timesteps[step]
|
923 |
+
noise = torch.randn_like(x)
|
924 |
+
t_prev_list = [t]
|
925 |
+
# do not evaluate if skip_first_step
|
926 |
+
if skip_first_step:
|
927 |
+
if self.predict_x0:
|
928 |
+
alpha_t = self.noise_schedule.marginal_alpha(t)
|
929 |
+
sigma_t = self.noise_schedule.marginal_std(t)
|
930 |
+
model_prev_list = [(1 - sigma_t) / alpha_t * x]
|
931 |
+
else:
|
932 |
+
model_prev_list = [x]
|
933 |
+
else:
|
934 |
+
model_prev_list = [self.model_fn(x, t)]
|
935 |
+
|
936 |
+
if self.correcting_xt_fn is not None:
|
937 |
+
x = self.correcting_xt_fn(x, t, step)
|
938 |
+
if return_intermediate:
|
939 |
+
intermediates.append(x)
|
940 |
+
|
941 |
+
# determine the first several values
|
942 |
+
for step in tqdm(range(1, max(predictor_order, corrector_order - 1))):
|
943 |
+
|
944 |
+
t = timesteps[step]
|
945 |
+
predictor_order_used = min(predictor_order, step)
|
946 |
+
corrector_order_used = min(corrector_order, step + 1)
|
947 |
+
noise = torch.randn_like(x)
|
948 |
+
# predictor step
|
949 |
+
x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=tau(t),
|
950 |
+
model_prev_list=model_prev_list, t_prev_list=t_prev_list, noise=noise,
|
951 |
+
t=t)
|
952 |
+
# evaluation step
|
953 |
+
model_x = self.model_fn(x_p, t)
|
954 |
+
|
955 |
+
# update model_list
|
956 |
+
model_prev_list.append(model_x)
|
957 |
+
# corrector step
|
958 |
+
if corrector_order > 0:
|
959 |
+
x = self.adams_moulton_update(order=corrector_order_used, x=x, tau=tau(t),
|
960 |
+
model_prev_list=model_prev_list, t_prev_list=t_prev_list, noise=noise,
|
961 |
+
t=t)
|
962 |
+
else:
|
963 |
+
x = x_p
|
964 |
+
|
965 |
+
# evaluation step if mode = pece
|
966 |
+
if corrector_order > 0 and pc_mode == 'PECE':
|
967 |
+
model_x = self.model_fn(x, t)
|
968 |
+
del model_prev_list[-1]
|
969 |
+
model_prev_list.append(model_x)
|
970 |
+
if self.correcting_xt_fn is not None:
|
971 |
+
x = self.correcting_xt_fn(x, t, step)
|
972 |
+
if return_intermediate:
|
973 |
+
intermediates.append(x)
|
974 |
+
|
975 |
+
t_prev_list.append(t)
|
976 |
+
|
977 |
+
for step in tqdm(range(max(predictor_order, corrector_order - 1), steps + 1)):
|
978 |
+
if lower_order_final:
|
979 |
+
predictor_order_used = min(predictor_order, steps - step + 1)
|
980 |
+
corrector_order_used = min(corrector_order, steps - step + 2)
|
981 |
+
|
982 |
+
else:
|
983 |
+
predictor_order_used = predictor_order
|
984 |
+
corrector_order_used = corrector_order
|
985 |
+
t = timesteps[step]
|
986 |
+
noise = torch.randn_like(x)
|
987 |
+
|
988 |
+
# predictor step
|
989 |
+
if skip_final_step and step == steps and not denoise_to_zero:
|
990 |
+
x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=0,
|
991 |
+
model_prev_list=model_prev_list, t_prev_list=t_prev_list,
|
992 |
+
noise=noise, t=t)
|
993 |
+
else:
|
994 |
+
x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=tau(t),
|
995 |
+
model_prev_list=model_prev_list, t_prev_list=t_prev_list,
|
996 |
+
noise=noise, t=t)
|
997 |
+
|
998 |
+
# evaluation step
|
999 |
+
# do not evaluate if skip_final_step and step = steps
|
1000 |
+
if not skip_final_step or step < steps:
|
1001 |
+
model_x = self.model_fn(x_p, t)
|
1002 |
+
|
1003 |
+
# update model_list
|
1004 |
+
# do not update if skip_final_step and step = steps
|
1005 |
+
if not skip_final_step or step < steps:
|
1006 |
+
model_prev_list.append(model_x)
|
1007 |
+
|
1008 |
+
# corrector step
|
1009 |
+
# do not correct if skip_final_step and step = steps
|
1010 |
+
if corrector_order > 0:
|
1011 |
+
if not skip_final_step or step < steps:
|
1012 |
+
x = self.adams_moulton_update(order=corrector_order_used, x=x, tau=tau(t),
|
1013 |
+
model_prev_list=model_prev_list, t_prev_list=t_prev_list,
|
1014 |
+
noise=noise, t=t)
|
1015 |
+
else:
|
1016 |
+
x = x_p
|
1017 |
+
else:
|
1018 |
+
x = x_p
|
1019 |
+
|
1020 |
+
# evaluation step if mode = pece and step != steps
|
1021 |
+
if corrector_order > 0 and (pc_mode == 'PECE' and step < steps):
|
1022 |
+
model_x = self.model_fn(x, t)
|
1023 |
+
del model_prev_list[-1]
|
1024 |
+
model_prev_list.append(model_x)
|
1025 |
+
|
1026 |
+
if self.correcting_xt_fn is not None:
|
1027 |
+
x = self.correcting_xt_fn(x, t, step)
|
1028 |
+
if return_intermediate:
|
1029 |
+
intermediates.append(x)
|
1030 |
+
|
1031 |
+
t_prev_list.append(t)
|
1032 |
+
del model_prev_list[0]
|
1033 |
+
|
1034 |
+
if denoise_to_zero:
|
1035 |
+
t = torch.ones((1,)).to(device) * t_0
|
1036 |
+
x = self.denoise_to_zero_fn(x, t)
|
1037 |
+
if self.correcting_xt_fn is not None:
|
1038 |
+
x = self.correcting_xt_fn(x, t, step + 1)
|
1039 |
+
if return_intermediate:
|
1040 |
+
intermediates.append(x)
|
1041 |
+
if return_intermediate:
|
1042 |
+
return x, intermediates
|
1043 |
+
else:
|
1044 |
+
return x
|
1045 |
+
|
1046 |
+
def sample(self, mode, x, tau, steps, t_start=None, t_end=None, skip_type='time', skip_order=1, predictor_order=3,
|
1047 |
+
corrector_order=4, pc_mode='PEC', return_intermediate=False
|
1048 |
+
):
|
1049 |
+
"""
|
1050 |
+
For the PC-mode, please refer to the wiki page
|
1051 |
+
https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode
|
1052 |
+
'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations
|
1053 |
+
We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs.
|
1054 |
+
|
1055 |
+
'few_steps' mode is recommended. The differences between 'few_steps' and 'more_steps' are as below:
|
1056 |
+
1) 'few_steps' do not correct at final step and do not denoise to zero, while 'more_steps' do these two.
|
1057 |
+
Thus the NFEs for 'few_steps' = steps, NFEs for 'more_steps' = steps + 2
|
1058 |
+
For most of the experiments and tasks, we find these two operations do not have much help to sample quality.
|
1059 |
+
2) 'few_steps' use a rescaling trick as in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
|
1060 |
+
We find it will slightly improve the sample quality especially in few steps.
|
1061 |
+
"""
|
1062 |
+
assert mode in ['few_steps', 'more_steps'], "mode must be either 'few_steps' or 'more_steps'"
|
1063 |
+
if mode == 'few_steps':
|
1064 |
+
return self.sample_few_steps(x=x, tau=tau, steps=steps, t_start=t_start, t_end=t_end, skip_type=skip_type,
|
1065 |
+
skip_order=skip_order, predictor_order=predictor_order,
|
1066 |
+
corrector_order=corrector_order, pc_mode=pc_mode,
|
1067 |
+
return_intermediate=return_intermediate)
|
1068 |
+
else:
|
1069 |
+
return self.sample_more_steps(x=x, tau=tau, steps=steps, t_start=t_start, t_end=t_end, skip_type=skip_type,
|
1070 |
+
skip_order=skip_order, predictor_order=predictor_order,
|
1071 |
+
corrector_order=corrector_order, pc_mode=pc_mode,
|
1072 |
+
return_intermediate=return_intermediate)
|
1073 |
+
|
1074 |
+
|
1075 |
+
#############################################################
|
1076 |
+
# other utility functions
|
1077 |
+
#############################################################
|
1078 |
+
|
1079 |
+
def interpolate_fn(x, xp, yp):
|
1080 |
+
"""
|
1081 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
1082 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
1083 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
1084 |
+
Args:
|
1085 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
1086 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
1087 |
+
yp: PyTorch tensor with shape [C, K].
|
1088 |
+
Returns:
|
1089 |
+
The function values f(x), with shape [N, C].
|
1090 |
+
"""
|
1091 |
+
N, K = x.shape[0], xp.shape[1]
|
1092 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
1093 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
1094 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
1095 |
+
cand_start_idx = x_idx - 1
|
1096 |
+
start_idx = torch.where(
|
1097 |
+
torch.eq(x_idx, 0),
|
1098 |
+
torch.tensor(1, device=x.device),
|
1099 |
+
torch.where(
|
1100 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1101 |
+
),
|
1102 |
+
)
|
1103 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
1104 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
1105 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
1106 |
+
start_idx2 = torch.where(
|
1107 |
+
torch.eq(x_idx, 0),
|
1108 |
+
torch.tensor(0, device=x.device),
|
1109 |
+
torch.where(
|
1110 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1111 |
+
),
|
1112 |
+
)
|
1113 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
1114 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
1115 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
1116 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
1117 |
+
return cand
|
1118 |
+
|
1119 |
+
|
1120 |
+
def expand_dims(v, dims):
|
1121 |
+
"""
|
1122 |
+
Expand the tensor `v` to the dim `dims`.
|
1123 |
+
Args:
|
1124 |
+
`v`: a PyTorch tensor with shape [N].
|
1125 |
+
`dim`: a `int`.
|
1126 |
+
Returns:
|
1127 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
1128 |
+
"""
|
1129 |
+
return v[(...,) + (None,) * (dims - 1)]
|
DiT_VAE/diffusion/model/timestep_sampler.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch as th
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def create_named_schedule_sampler(name, diffusion):
|
14 |
+
"""
|
15 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
16 |
+
:param name: the name of the sampler.
|
17 |
+
:param diffusion: the diffusion object to sample for.
|
18 |
+
"""
|
19 |
+
if name == "uniform":
|
20 |
+
return UniformSampler(diffusion)
|
21 |
+
elif name == "loss-second-moment":
|
22 |
+
return LossSecondMomentResampler(diffusion)
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
25 |
+
|
26 |
+
|
27 |
+
class ScheduleSampler(ABC):
|
28 |
+
"""
|
29 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
30 |
+
variance of the objective.
|
31 |
+
By default, samplers perform unbiased importance sampling, in which the
|
32 |
+
objective's mean is unchanged.
|
33 |
+
However, subclasses may override sample() to change how the resampled
|
34 |
+
terms are reweighted, allowing for actual changes in the objective.
|
35 |
+
"""
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def weights(self):
|
39 |
+
"""
|
40 |
+
Get a numpy array of weights, one per diffusion step.
|
41 |
+
The weights needn't be normalized, but must be positive.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def sample(self, batch_size, device):
|
45 |
+
"""
|
46 |
+
Importance-sample timesteps for a batch.
|
47 |
+
:param batch_size: the number of timesteps.
|
48 |
+
:param device: the torch device to save to.
|
49 |
+
:return: a tuple (timesteps, weights):
|
50 |
+
- timesteps: a tensor of timestep indices.
|
51 |
+
- weights: a tensor of weights to scale the resulting losses.
|
52 |
+
"""
|
53 |
+
w = self.weights()
|
54 |
+
p = w / np.sum(w)
|
55 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
56 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
57 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
58 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
59 |
+
return indices, weights
|
60 |
+
|
61 |
+
|
62 |
+
class UniformSampler(ScheduleSampler):
|
63 |
+
def __init__(self, diffusion):
|
64 |
+
self.diffusion = diffusion
|
65 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
66 |
+
|
67 |
+
def weights(self):
|
68 |
+
return self._weights
|
69 |
+
|
70 |
+
|
71 |
+
class LossAwareSampler(ScheduleSampler):
|
72 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
73 |
+
"""
|
74 |
+
Update the reweighting using losses from a model.
|
75 |
+
Call this method from each rank with a batch of timesteps and the
|
76 |
+
corresponding losses for each of those timesteps.
|
77 |
+
This method will perform synchronization to make sure all of the ranks
|
78 |
+
maintain the exact same reweighting.
|
79 |
+
:param local_ts: an integer Tensor of timesteps.
|
80 |
+
:param local_losses: a 1D Tensor of losses.
|
81 |
+
"""
|
82 |
+
batch_sizes = [
|
83 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
84 |
+
for _ in range(dist.get_world_size())
|
85 |
+
]
|
86 |
+
dist.all_gather(
|
87 |
+
batch_sizes,
|
88 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
89 |
+
)
|
90 |
+
|
91 |
+
# Pad all_gather batches to be the maximum batch size.
|
92 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
93 |
+
max_bs = max(batch_sizes)
|
94 |
+
|
95 |
+
timestep_batches = [th.zeros(max_bs, device=local_ts.device) for _ in batch_sizes]
|
96 |
+
loss_batches = [th.zeros(max_bs, device=local_losses.device) for _ in batch_sizes]
|
97 |
+
dist.all_gather(timestep_batches, local_ts)
|
98 |
+
dist.all_gather(loss_batches, local_losses)
|
99 |
+
timesteps = [
|
100 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
101 |
+
]
|
102 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
103 |
+
self.update_with_all_losses(timesteps, losses)
|
104 |
+
|
105 |
+
@abstractmethod
|
106 |
+
def update_with_all_losses(self, ts, losses):
|
107 |
+
"""
|
108 |
+
Update the reweighting using losses from a model.
|
109 |
+
Sub-classes should override this method to update the reweighting
|
110 |
+
using losses from the model.
|
111 |
+
This method directly updates the reweighting without synchronizing
|
112 |
+
between workers. It is called by update_with_local_losses from all
|
113 |
+
ranks with identical arguments. Thus, it should have deterministic
|
114 |
+
behavior to maintain state across workers.
|
115 |
+
:param ts: a list of int timesteps.
|
116 |
+
:param losses: a list of float losses, one per timestep.
|
117 |
+
"""
|
118 |
+
|
119 |
+
|
120 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
121 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
122 |
+
self.diffusion = diffusion
|
123 |
+
self.history_per_term = history_per_term
|
124 |
+
self.uniform_prob = uniform_prob
|
125 |
+
self._loss_history = np.zeros(
|
126 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
127 |
+
)
|
128 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
129 |
+
|
130 |
+
def weights(self):
|
131 |
+
if not self._warmed_up():
|
132 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
133 |
+
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
134 |
+
weights /= np.sum(weights)
|
135 |
+
weights *= 1 - self.uniform_prob
|
136 |
+
weights += self.uniform_prob / len(weights)
|
137 |
+
return weights
|
138 |
+
|
139 |
+
def update_with_all_losses(self, ts, losses):
|
140 |
+
for t, loss in zip(ts, losses):
|
141 |
+
if self._loss_counts[t] == self.history_per_term:
|
142 |
+
# Shift out the oldest loss term.
|
143 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
144 |
+
self._loss_history[t, -1] = loss
|
145 |
+
else:
|
146 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
147 |
+
self._loss_counts[t] += 1
|
148 |
+
|
149 |
+
def _warmed_up(self):
|
150 |
+
return (self._loss_counts == self.history_per_term).all()
|
DiT_VAE/diffusion/model/utils.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
import re
|
9 |
+
import math
|
10 |
+
from collections.abc import Iterable
|
11 |
+
from itertools import repeat
|
12 |
+
from torchvision import transforms as T
|
13 |
+
import random
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
|
17 |
+
def _ntuple(n):
|
18 |
+
def parse(x):
|
19 |
+
if isinstance(x, Iterable) and not isinstance(x, str):
|
20 |
+
return x
|
21 |
+
return tuple(repeat(x, n))
|
22 |
+
return parse
|
23 |
+
|
24 |
+
|
25 |
+
to_1tuple = _ntuple(1)
|
26 |
+
to_2tuple = _ntuple(2)
|
27 |
+
|
28 |
+
def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
|
29 |
+
assert isinstance(model, nn.Module)
|
30 |
+
|
31 |
+
def set_attr(module):
|
32 |
+
module.grad_checkpointing = True
|
33 |
+
module.fp32_attention = use_fp32_attention
|
34 |
+
module.grad_checkpointing_step = gc_step
|
35 |
+
model.apply(set_attr)
|
36 |
+
|
37 |
+
|
38 |
+
def auto_grad_checkpoint(module, *args, **kwargs):
|
39 |
+
if getattr(module, 'grad_checkpointing', False):
|
40 |
+
if not isinstance(module, Iterable):
|
41 |
+
return checkpoint(module, *args, **kwargs)
|
42 |
+
gc_step = module[0].grad_checkpointing_step
|
43 |
+
return checkpoint_sequential(module, gc_step, *args, **kwargs)
|
44 |
+
return module(*args, **kwargs)
|
45 |
+
|
46 |
+
|
47 |
+
def checkpoint_sequential(functions, step, input, *args, **kwargs):
|
48 |
+
|
49 |
+
# Hack for keyword-only parameter in a python 2.7-compliant way
|
50 |
+
preserve = kwargs.pop('preserve_rng_state', True)
|
51 |
+
if kwargs:
|
52 |
+
raise ValueError("Unexpected keyword arguments: " + ",".join(kwargs))
|
53 |
+
|
54 |
+
def run_function(start, end, functions):
|
55 |
+
def forward(input):
|
56 |
+
for j in range(start, end + 1):
|
57 |
+
input = functions[j](input, *args)
|
58 |
+
return input
|
59 |
+
return forward
|
60 |
+
|
61 |
+
if isinstance(functions, torch.nn.Sequential):
|
62 |
+
functions = list(functions.children())
|
63 |
+
|
64 |
+
# the last chunk has to be non-volatile
|
65 |
+
end = -1
|
66 |
+
segment = len(functions) // step
|
67 |
+
for start in range(0, step * (segment - 1), step):
|
68 |
+
end = start + step - 1
|
69 |
+
input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
|
70 |
+
return run_function(end + 1, len(functions) - 1, functions)(input)
|
71 |
+
|
72 |
+
|
73 |
+
def window_partition(x, window_size):
|
74 |
+
"""
|
75 |
+
Partition into non-overlapping windows with padding if needed.
|
76 |
+
Args:
|
77 |
+
x (tensor): input tokens with [B, H, W, C].
|
78 |
+
window_size (int): window size.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
82 |
+
(Hp, Wp): padded height and width before partition
|
83 |
+
"""
|
84 |
+
B, H, W, C = x.shape
|
85 |
+
|
86 |
+
pad_h = (window_size - H % window_size) % window_size
|
87 |
+
pad_w = (window_size - W % window_size) % window_size
|
88 |
+
if pad_h > 0 or pad_w > 0:
|
89 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
90 |
+
Hp, Wp = H + pad_h, W + pad_w
|
91 |
+
|
92 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
93 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
94 |
+
return windows, (Hp, Wp)
|
95 |
+
|
96 |
+
|
97 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
98 |
+
"""
|
99 |
+
Window unpartition into original sequences and removing padding.
|
100 |
+
Args:
|
101 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
102 |
+
window_size (int): window size.
|
103 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
104 |
+
hw (Tuple): original height and width (H, W) before padding.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
x: unpartitioned sequences with [B, H, W, C].
|
108 |
+
"""
|
109 |
+
Hp, Wp = pad_hw
|
110 |
+
H, W = hw
|
111 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
112 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
113 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
114 |
+
|
115 |
+
if Hp > H or Wp > W:
|
116 |
+
x = x[:, :H, :W, :].contiguous()
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
def get_rel_pos(q_size, k_size, rel_pos):
|
121 |
+
"""
|
122 |
+
Get relative positional embeddings according to the relative positions of
|
123 |
+
query and key sizes.
|
124 |
+
Args:
|
125 |
+
q_size (int): size of query q.
|
126 |
+
k_size (int): size of key k.
|
127 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Extracted positional embeddings according to relative positions.
|
131 |
+
"""
|
132 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
133 |
+
# Interpolate rel pos if needed.
|
134 |
+
if rel_pos.shape[0] != max_rel_dist:
|
135 |
+
# Interpolate rel pos.
|
136 |
+
rel_pos_resized = F.interpolate(
|
137 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
138 |
+
size=max_rel_dist,
|
139 |
+
mode="linear",
|
140 |
+
)
|
141 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
142 |
+
else:
|
143 |
+
rel_pos_resized = rel_pos
|
144 |
+
|
145 |
+
# Scale the coords with short length if shapes for q and k are different.
|
146 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
147 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
148 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
149 |
+
|
150 |
+
return rel_pos_resized[relative_coords.long()]
|
151 |
+
|
152 |
+
|
153 |
+
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
|
154 |
+
"""
|
155 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
156 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
157 |
+
Args:
|
158 |
+
attn (Tensor): attention map.
|
159 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
160 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
161 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
162 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
163 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
167 |
+
"""
|
168 |
+
q_h, q_w = q_size
|
169 |
+
k_h, k_w = k_size
|
170 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
171 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
172 |
+
|
173 |
+
B, _, dim = q.shape
|
174 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
175 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
176 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
177 |
+
|
178 |
+
attn = (
|
179 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
180 |
+
).view(B, q_h * q_w, k_h * k_w)
|
181 |
+
|
182 |
+
return attn
|
183 |
+
|
184 |
+
def mean_flat(tensor):
|
185 |
+
return tensor.mean(dim=list(range(1, tensor.ndim)))
|
186 |
+
|
187 |
+
|
188 |
+
#################################################################################
|
189 |
+
# Token Masking and Unmasking #
|
190 |
+
#################################################################################
|
191 |
+
def get_mask(batch, length, mask_ratio, device, mask_type=None, data_info=None, extra_len=0):
|
192 |
+
"""
|
193 |
+
Get the binary mask for the input sequence.
|
194 |
+
Args:
|
195 |
+
- batch: batch size
|
196 |
+
- length: sequence length
|
197 |
+
- mask_ratio: ratio of tokens to mask
|
198 |
+
- data_info: dictionary with info for reconstruction
|
199 |
+
return:
|
200 |
+
mask_dict with following keys:
|
201 |
+
- mask: binary mask, 0 is keep, 1 is remove
|
202 |
+
- ids_keep: indices of tokens to keep
|
203 |
+
- ids_restore: indices to restore the original order
|
204 |
+
"""
|
205 |
+
assert mask_type in ['random', 'fft', 'laplacian', 'group']
|
206 |
+
mask = torch.ones([batch, length], device=device)
|
207 |
+
len_keep = int(length * (1 - mask_ratio)) - extra_len
|
208 |
+
|
209 |
+
if mask_type in ['random', 'group']:
|
210 |
+
noise = torch.rand(batch, length, device=device) # noise in [0, 1]
|
211 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
212 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
213 |
+
# keep the first subset
|
214 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
215 |
+
ids_removed = ids_shuffle[:, len_keep:]
|
216 |
+
|
217 |
+
elif mask_type in ['fft', 'laplacian']:
|
218 |
+
if 'strength' in data_info:
|
219 |
+
strength = data_info['strength']
|
220 |
+
|
221 |
+
else:
|
222 |
+
N = data_info['N'][0]
|
223 |
+
img = data_info['ori_img']
|
224 |
+
# 获取原图的尺寸信息
|
225 |
+
_, C, H, W = img.shape
|
226 |
+
if mask_type == 'fft':
|
227 |
+
# 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N)
|
228 |
+
reshaped_image = img.reshape((batch, -1, H // N, N, W // N, N))
|
229 |
+
fft_image = torch.fft.fftn(reshaped_image, dim=(3, 5))
|
230 |
+
# 取绝对值并求和获取频率强度
|
231 |
+
strength = torch.sum(torch.abs(fft_image), dim=(1, 3, 5)).reshape((batch, -1,))
|
232 |
+
elif type == 'laplacian':
|
233 |
+
laplacian_kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32).reshape(1, 1, 3, 3)
|
234 |
+
laplacian_kernel = laplacian_kernel.repeat(C, 1, 1, 1)
|
235 |
+
# 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N)
|
236 |
+
reshaped_image = img.reshape(-1, C, H // N, N, W // N, N).permute(0, 2, 4, 1, 3, 5).reshape(-1, C, N, N)
|
237 |
+
laplacian_response = F.conv2d(reshaped_image, laplacian_kernel, padding=1, groups=C)
|
238 |
+
strength = laplacian_response.sum(dim=[1, 2, 3]).reshape((batch, -1,))
|
239 |
+
|
240 |
+
# 对频率强度进行归一化,然后使用torch.multinomial进行采样
|
241 |
+
probabilities = strength / (strength.max(dim=1)[0][:, None]+1e-5)
|
242 |
+
ids_shuffle = torch.multinomial(probabilities.clip(1e-5, 1), length, replacement=False)
|
243 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
244 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
245 |
+
ids_removed = ids_shuffle[:, len_keep:]
|
246 |
+
|
247 |
+
mask[:, :len_keep] = 0
|
248 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
249 |
+
|
250 |
+
return {'mask': mask,
|
251 |
+
'ids_keep': ids_keep,
|
252 |
+
'ids_restore': ids_restore,
|
253 |
+
'ids_removed': ids_removed}
|
254 |
+
|
255 |
+
|
256 |
+
def mask_out_token(x, ids_keep, ids_removed=None):
|
257 |
+
"""
|
258 |
+
Mask out the tokens specified by ids_keep.
|
259 |
+
Args:
|
260 |
+
- x: input sequence, [N, L, D]
|
261 |
+
- ids_keep: indices of tokens to keep
|
262 |
+
return:
|
263 |
+
- x_masked: masked sequence
|
264 |
+
"""
|
265 |
+
N, L, D = x.shape # batch, length, dim
|
266 |
+
x_remain = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
267 |
+
if ids_removed is not None:
|
268 |
+
x_masked = torch.gather(x, dim=1, index=ids_removed.unsqueeze(-1).repeat(1, 1, D))
|
269 |
+
return x_remain, x_masked
|
270 |
+
else:
|
271 |
+
return x_remain
|
272 |
+
|
273 |
+
|
274 |
+
def mask_tokens(x, mask_ratio):
|
275 |
+
"""
|
276 |
+
Perform per-sample random masking by per-sample shuffling.
|
277 |
+
Per-sample shuffling is done by argsort random noise.
|
278 |
+
x: [N, L, D], sequence
|
279 |
+
"""
|
280 |
+
N, L, D = x.shape # batch, length, dim
|
281 |
+
len_keep = int(L * (1 - mask_ratio))
|
282 |
+
|
283 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
284 |
+
|
285 |
+
# sort noise for each sample
|
286 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
287 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
288 |
+
|
289 |
+
# keep the first subset
|
290 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
291 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
292 |
+
|
293 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
294 |
+
mask = torch.ones([N, L], device=x.device)
|
295 |
+
mask[:, :len_keep] = 0
|
296 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
297 |
+
|
298 |
+
return x_masked, mask, ids_restore
|
299 |
+
|
300 |
+
|
301 |
+
def unmask_tokens(x, ids_restore, mask_token):
|
302 |
+
# x: [N, T, D] if extras == 0 (i.e., no cls token) else x: [N, T+1, D]
|
303 |
+
mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
|
304 |
+
x = torch.cat([x, mask_tokens], dim=1)
|
305 |
+
x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
306 |
+
return x
|
307 |
+
|
308 |
+
|
309 |
+
# Parse 'None' to None and others to float value
|
310 |
+
def parse_float_none(s):
|
311 |
+
assert isinstance(s, str)
|
312 |
+
return None if s == 'None' else float(s)
|
313 |
+
|
314 |
+
|
315 |
+
#----------------------------------------------------------------------------
|
316 |
+
# Parse a comma separated list of numbers or ranges and return a list of ints.
|
317 |
+
# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
|
318 |
+
|
319 |
+
def parse_int_list(s):
|
320 |
+
if isinstance(s, list): return s
|
321 |
+
ranges = []
|
322 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
323 |
+
for p in s.split(','):
|
324 |
+
if m := range_re.match(p):
|
325 |
+
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
|
326 |
+
else:
|
327 |
+
ranges.append(int(p))
|
328 |
+
return ranges
|
329 |
+
|
330 |
+
|
331 |
+
def init_processes(fn, args):
|
332 |
+
""" Initialize the distributed environment. """
|
333 |
+
os.environ['MASTER_ADDR'] = args.master_address
|
334 |
+
os.environ['MASTER_PORT'] = str(random.randint(2000, 6000))
|
335 |
+
print(f'MASTER_ADDR = {os.environ["MASTER_ADDR"]}')
|
336 |
+
print(f'MASTER_PORT = {os.environ["MASTER_PORT"]}')
|
337 |
+
torch.cuda.set_device(args.local_rank)
|
338 |
+
dist.init_process_group(backend='nccl', init_method='env://', rank=args.global_rank, world_size=args.global_size)
|
339 |
+
fn(args)
|
340 |
+
if args.global_size > 1:
|
341 |
+
cleanup()
|
342 |
+
|
343 |
+
|
344 |
+
def mprint(*args, **kwargs):
|
345 |
+
"""
|
346 |
+
Print only from rank 0.
|
347 |
+
"""
|
348 |
+
if dist.get_rank() == 0:
|
349 |
+
print(*args, **kwargs)
|
350 |
+
|
351 |
+
|
352 |
+
def cleanup():
|
353 |
+
"""
|
354 |
+
End DDP training.
|
355 |
+
"""
|
356 |
+
dist.barrier()
|
357 |
+
mprint("Done!")
|
358 |
+
dist.barrier()
|
359 |
+
dist.destroy_process_group()
|
360 |
+
|
361 |
+
|
362 |
+
#----------------------------------------------------------------------------
|
363 |
+
# logging info.
|
364 |
+
class Logger(object):
|
365 |
+
"""
|
366 |
+
Redirect stderr to stdout, optionally print stdout to a file,
|
367 |
+
and optionally force flushing on both stdout and the file.
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(self, file_name=None, file_mode="w", should_flush=True):
|
371 |
+
self.file = None
|
372 |
+
|
373 |
+
if file_name is not None:
|
374 |
+
self.file = open(file_name, file_mode)
|
375 |
+
|
376 |
+
self.should_flush = should_flush
|
377 |
+
self.stdout = sys.stdout
|
378 |
+
self.stderr = sys.stderr
|
379 |
+
|
380 |
+
sys.stdout = self
|
381 |
+
sys.stderr = self
|
382 |
+
|
383 |
+
def __enter__(self):
|
384 |
+
return self
|
385 |
+
|
386 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
387 |
+
self.close()
|
388 |
+
|
389 |
+
def write(self, text):
|
390 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
391 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
392 |
+
return
|
393 |
+
|
394 |
+
if self.file is not None:
|
395 |
+
self.file.write(text)
|
396 |
+
|
397 |
+
self.stdout.write(text)
|
398 |
+
|
399 |
+
if self.should_flush:
|
400 |
+
self.flush()
|
401 |
+
|
402 |
+
def flush(self):
|
403 |
+
"""Flush written text to both stdout and a file, if open."""
|
404 |
+
if self.file is not None:
|
405 |
+
self.file.flush()
|
406 |
+
|
407 |
+
self.stdout.flush()
|
408 |
+
|
409 |
+
def close(self):
|
410 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
411 |
+
self.flush()
|
412 |
+
|
413 |
+
# if using multiple loggers, prevent closing in wrong order
|
414 |
+
if sys.stdout is self:
|
415 |
+
sys.stdout = self.stdout
|
416 |
+
if sys.stderr is self:
|
417 |
+
sys.stderr = self.stderr
|
418 |
+
|
419 |
+
if self.file is not None:
|
420 |
+
self.file.close()
|
421 |
+
|
422 |
+
|
423 |
+
class StackedRandomGenerator:
|
424 |
+
def __init__(self, device, seeds):
|
425 |
+
super().__init__()
|
426 |
+
self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
|
427 |
+
|
428 |
+
def randn(self, size, **kwargs):
|
429 |
+
assert size[0] == len(self.generators)
|
430 |
+
return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
|
431 |
+
|
432 |
+
def randn_like(self, input):
|
433 |
+
return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
|
434 |
+
|
435 |
+
def randint(self, *args, size, **kwargs):
|
436 |
+
assert size[0] == len(self.generators)
|
437 |
+
return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
|
438 |
+
|
439 |
+
|
440 |
+
def prepare_prompt_ar(prompt, ratios, device='cpu', show=True):
|
441 |
+
# get aspect_ratio or ar
|
442 |
+
aspect_ratios = re.findall(r"--aspect_ratio\s+(\d+:\d+)", prompt)
|
443 |
+
ars = re.findall(r"--ar\s+(\d+:\d+)", prompt)
|
444 |
+
custom_hw = re.findall(r"--hw\s+(\d+:\d+)", prompt)
|
445 |
+
if show:
|
446 |
+
print("aspect_ratios:", aspect_ratios, "ars:", ars, "hws:", custom_hw)
|
447 |
+
prompt_clean = prompt.split("--aspect_ratio")[0].split("--ar")[0].split("--hw")[0]
|
448 |
+
if len(aspect_ratios) + len(ars) + len(custom_hw) == 0 and show:
|
449 |
+
print( "Wrong prompt format. Set to default ar: 1. change your prompt into format '--ar h:w or --hw h:w' for correct generating")
|
450 |
+
if len(aspect_ratios) != 0:
|
451 |
+
ar = float(aspect_ratios[0].split(':')[0]) / float(aspect_ratios[0].split(':')[1])
|
452 |
+
elif len(ars) != 0:
|
453 |
+
ar = float(ars[0].split(':')[0]) / float(ars[0].split(':')[1])
|
454 |
+
else:
|
455 |
+
ar = 1.
|
456 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
457 |
+
if len(custom_hw) != 0:
|
458 |
+
custom_hw = [float(custom_hw[0].split(':')[0]), float(custom_hw[0].split(':')[1])]
|
459 |
+
else:
|
460 |
+
custom_hw = ratios[closest_ratio]
|
461 |
+
default_hw = ratios[closest_ratio]
|
462 |
+
prompt_show = f'prompt: {prompt_clean.strip()}\nSize: --ar {closest_ratio}, --bin hw {ratios[closest_ratio]}, --custom hw {custom_hw}'
|
463 |
+
return prompt_clean, prompt_show, torch.tensor(default_hw, device=device)[None], torch.tensor([float(closest_ratio)], device=device)[None], torch.tensor(custom_hw, device=device)[None]
|
464 |
+
|
465 |
+
|
466 |
+
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int):
|
467 |
+
orig_hw = torch.tensor([samples.shape[2], samples.shape[3]], dtype=torch.int)
|
468 |
+
custom_hw = torch.tensor([int(new_height), int(new_width)], dtype=torch.int)
|
469 |
+
|
470 |
+
if (orig_hw != custom_hw).all():
|
471 |
+
ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1])
|
472 |
+
resized_width = int(orig_hw[1] * ratio)
|
473 |
+
resized_height = int(orig_hw[0] * ratio)
|
474 |
+
|
475 |
+
transform = T.Compose([
|
476 |
+
T.Resize((resized_height, resized_width)),
|
477 |
+
T.CenterCrop(custom_hw.tolist())
|
478 |
+
])
|
479 |
+
return transform(samples)
|
480 |
+
else:
|
481 |
+
return samples
|
482 |
+
|
483 |
+
|
484 |
+
def resize_and_crop_img(img: Image, new_width, new_height):
|
485 |
+
orig_width, orig_height = img.size
|
486 |
+
|
487 |
+
ratio = max(new_width/orig_width, new_height/orig_height)
|
488 |
+
resized_width = int(orig_width * ratio)
|
489 |
+
resized_height = int(orig_height * ratio)
|
490 |
+
|
491 |
+
img = img.resize((resized_width, resized_height), Image.LANCZOS)
|
492 |
+
|
493 |
+
left = (resized_width - new_width)/2
|
494 |
+
top = (resized_height - new_height)/2
|
495 |
+
right = (resized_width + new_width)/2
|
496 |
+
bottom = (resized_height + new_height)/2
|
497 |
+
|
498 |
+
img = img.crop((left, top, right, bottom))
|
499 |
+
|
500 |
+
return img
|
501 |
+
|
502 |
+
|
503 |
+
|
504 |
+
def mask_feature(emb, mask):
|
505 |
+
if emb.shape[0] == 1:
|
506 |
+
keep_index = mask.sum().item()
|
507 |
+
return emb[:, :, :keep_index, :], keep_index
|
508 |
+
else:
|
509 |
+
masked_feature = emb * mask[:, None, :, None]
|
510 |
+
return masked_feature, emb.shape[2]
|
DiT_VAE/diffusion/sa_sampler.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SAMPLING ONLY."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from DiT_VAE.diffusion.model.sa_solver import NoiseScheduleVP, model_wrapper, SASolver
|
7 |
+
from .model import gaussian_diffusion as gd
|
8 |
+
|
9 |
+
|
10 |
+
class SASolverSampler(object):
|
11 |
+
def __init__(self, model,
|
12 |
+
noise_schedule="linear",
|
13 |
+
diffusion_steps=1000,
|
14 |
+
device='cpu',
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.model = model
|
18 |
+
self.device = device
|
19 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(device)
|
20 |
+
betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
|
21 |
+
alphas = 1.0 - betas
|
22 |
+
self.register_buffer('alphas_cumprod', to_torch(np.cumprod(alphas, axis=0)))
|
23 |
+
|
24 |
+
def register_buffer(self, name, attr):
|
25 |
+
if type(attr) == torch.Tensor and attr.device != torch.device("cuda"):
|
26 |
+
attr = attr.to(torch.device("cuda"))
|
27 |
+
setattr(self, name, attr)
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def sample(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, model_kwargs=None, **kwargs):
|
31 |
+
if model_kwargs is None:
|
32 |
+
model_kwargs = {}
|
33 |
+
if conditioning is not None:
|
34 |
+
if isinstance(conditioning, dict):
|
35 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
36 |
+
if cbs != batch_size:
|
37 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
38 |
+
elif conditioning.shape[0] != batch_size:
|
39 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
40 |
+
|
41 |
+
# sampling
|
42 |
+
C, H, W = shape
|
43 |
+
size = (batch_size, C, H, W)
|
44 |
+
|
45 |
+
device = self.device
|
46 |
+
img = torch.randn(size, device=device) if x_T is None else x_T
|
47 |
+
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
48 |
+
|
49 |
+
model_fn = model_wrapper(
|
50 |
+
self.model,
|
51 |
+
ns,
|
52 |
+
model_type="noise",
|
53 |
+
guidance_type="classifier-free",
|
54 |
+
condition=conditioning,
|
55 |
+
unconditional_condition=unconditional_conditioning,
|
56 |
+
guidance_scale=unconditional_guidance_scale,
|
57 |
+
model_kwargs=model_kwargs,
|
58 |
+
)
|
59 |
+
|
60 |
+
sasolver = SASolver(model_fn, ns, algorithm_type="data_prediction")
|
61 |
+
|
62 |
+
tau_t = lambda t: eta if 0.2 <= t <= 0.8 else 0
|
63 |
+
|
64 |
+
x = sasolver.sample(mode='few_steps', x=img, tau=tau_t, steps=S, skip_type='time', skip_order=1, predictor_order=2, corrector_order=2, pc_mode='PEC', return_intermediate=False)
|
65 |
+
|
66 |
+
return x.to(device), None
|
DiT_VAE/diffusion/sa_solver_diffusers.py
ADDED
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
2 |
+
# you may not use this file except in compliance with the License.
|
3 |
+
# You may obtain a copy of the License at
|
4 |
+
#
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
#
|
7 |
+
# Unless required by applicable law or agreed to in writing, software
|
8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10 |
+
# See the License for the specific language governing permissions and
|
11 |
+
# limitations under the License.
|
12 |
+
|
13 |
+
# DISCLAIMER: check https://arxiv.org/abs/2309.05019
|
14 |
+
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
|
15 |
+
|
16 |
+
import math
|
17 |
+
from typing import List, Optional, Tuple, Union, Callable
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from diffusers.utils.torch_utils import randn_tensor
|
24 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
25 |
+
|
26 |
+
|
27 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
28 |
+
def betas_for_alpha_bar(
|
29 |
+
num_diffusion_timesteps,
|
30 |
+
max_beta=0.999,
|
31 |
+
alpha_transform_type="cosine",
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
35 |
+
(1-beta) over time from t = [0,1].
|
36 |
+
|
37 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
38 |
+
to that part of the diffusion process.
|
39 |
+
|
40 |
+
|
41 |
+
Args:
|
42 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
43 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
44 |
+
prevent singularities.
|
45 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
46 |
+
Choose from `cosine` or `exp`
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
50 |
+
"""
|
51 |
+
if alpha_transform_type == "cosine":
|
52 |
+
|
53 |
+
def alpha_bar_fn(t):
|
54 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
55 |
+
|
56 |
+
elif alpha_transform_type == "exp":
|
57 |
+
|
58 |
+
def alpha_bar_fn(t):
|
59 |
+
return math.exp(t * -12.0)
|
60 |
+
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
63 |
+
|
64 |
+
betas = []
|
65 |
+
for i in range(num_diffusion_timesteps):
|
66 |
+
t1 = i / num_diffusion_timesteps
|
67 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
68 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
69 |
+
return torch.tensor(betas, dtype=torch.float32)
|
70 |
+
|
71 |
+
|
72 |
+
class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
73 |
+
"""
|
74 |
+
`SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs.
|
75 |
+
|
76 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
77 |
+
methods the library implements for all schedulers such as loading and saving.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
num_train_timesteps (`int`, defaults to 1000):
|
81 |
+
The number of diffusion steps to train the model.
|
82 |
+
beta_start (`float`, defaults to 0.0001):
|
83 |
+
The starting `beta` value of inference.
|
84 |
+
beta_end (`float`, defaults to 0.02):
|
85 |
+
The final `beta` value.
|
86 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
87 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
88 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
89 |
+
trained_betas (`np.ndarray`, *optional*):
|
90 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
91 |
+
predictor_order (`int`, defaults to 2):
|
92 |
+
The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided
|
93 |
+
sampling, and `predictor_order=3` for unconditional sampling.
|
94 |
+
corrector_order (`int`, defaults to 2):
|
95 |
+
The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided
|
96 |
+
sampling, and `corrector_order=3` for unconditional sampling.
|
97 |
+
predictor_corrector_mode (`str`, defaults to `PEC`):
|
98 |
+
The predictor-corrector mode can be `PEC` or 'PECE'. It is recommended to use `PEC` mode for fast
|
99 |
+
sampling, and `PECE` for high-quality sampling (PECE needs around twice model evaluations as PEC).
|
100 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
101 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
102 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
103 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
104 |
+
thresholding (`bool`, defaults to `False`):
|
105 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion vae such
|
106 |
+
as Stable Diffusion.
|
107 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
108 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
109 |
+
sample_max_value (`float`, defaults to 1.0):
|
110 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
111 |
+
`algorithm_type="dpmsolver++"`.
|
112 |
+
algorithm_type (`str`, defaults to `data_prediction`):
|
113 |
+
Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction`
|
114 |
+
with `solver_order=2` for guided sampling like in Stable Diffusion.
|
115 |
+
lower_order_final (`bool`, defaults to `True`):
|
116 |
+
Whether to use lower-order solvers in the final steps. Default = True.
|
117 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
118 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
119 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
120 |
+
lambda_min_clipped (`float`, defaults to `-inf`):
|
121 |
+
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
122 |
+
cosine (`squaredcos_cap_v2`) noise schedule.
|
123 |
+
variance_type (`str`, *optional*):
|
124 |
+
Set to "learned" or "learned_range" for diffusion vae that predict variance. If set, the model's output
|
125 |
+
contains the predicted Gaussian variance.
|
126 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
127 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
128 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
129 |
+
steps_offset (`int`, defaults to 0):
|
130 |
+
An offset added to the inference steps. You can use a combination of `offset=1` and
|
131 |
+
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
132 |
+
Diffusion.
|
133 |
+
"""
|
134 |
+
|
135 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
136 |
+
order = 1
|
137 |
+
|
138 |
+
@register_to_config
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
num_train_timesteps: int = 1000,
|
142 |
+
beta_start: float = 0.0001,
|
143 |
+
beta_end: float = 0.02,
|
144 |
+
beta_schedule: str = "linear",
|
145 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
146 |
+
predictor_order: int = 2,
|
147 |
+
corrector_order: int = 2,
|
148 |
+
predictor_corrector_mode: str = 'PEC',
|
149 |
+
prediction_type: str = "epsilon",
|
150 |
+
tau_func: Callable = lambda t: 1 if t >= 200 and t <= 800 else 0,
|
151 |
+
thresholding: bool = False,
|
152 |
+
dynamic_thresholding_ratio: float = 0.995,
|
153 |
+
sample_max_value: float = 1.0,
|
154 |
+
algorithm_type: str = "data_prediction",
|
155 |
+
lower_order_final: bool = True,
|
156 |
+
use_karras_sigmas: Optional[bool] = False,
|
157 |
+
lambda_min_clipped: float = -float("inf"),
|
158 |
+
variance_type: Optional[str] = None,
|
159 |
+
timestep_spacing: str = "linspace",
|
160 |
+
steps_offset: int = 0,
|
161 |
+
):
|
162 |
+
if trained_betas is not None:
|
163 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
164 |
+
elif beta_schedule == "linear":
|
165 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
166 |
+
elif beta_schedule == "scaled_linear":
|
167 |
+
# this schedule is very specific to the latent diffusion model.
|
168 |
+
self.betas = (
|
169 |
+
torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
170 |
+
)
|
171 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
172 |
+
# Glide cosine schedule
|
173 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
174 |
+
else:
|
175 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
176 |
+
|
177 |
+
self.alphas = 1.0 - self.betas
|
178 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
179 |
+
# Currently we only support VP-type noise schedule
|
180 |
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
181 |
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
182 |
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
183 |
+
|
184 |
+
# standard deviation of the initial noise distribution
|
185 |
+
self.init_noise_sigma = 1.0
|
186 |
+
|
187 |
+
if algorithm_type not in ["data_prediction", "noise_prediction"]:
|
188 |
+
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
189 |
+
|
190 |
+
# setable values
|
191 |
+
self.num_inference_steps = None
|
192 |
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
193 |
+
self.timesteps = torch.from_numpy(timesteps)
|
194 |
+
self.timestep_list = [None] * max(predictor_order, corrector_order - 1)
|
195 |
+
self.model_outputs = [None] * max(predictor_order, corrector_order - 1)
|
196 |
+
|
197 |
+
self.tau_func = tau_func
|
198 |
+
self.predict_x0 = algorithm_type == "data_prediction"
|
199 |
+
self.lower_order_nums = 0
|
200 |
+
self.last_sample = None
|
201 |
+
|
202 |
+
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
203 |
+
"""
|
204 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
205 |
+
|
206 |
+
Args:
|
207 |
+
num_inference_steps (`int`):
|
208 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
209 |
+
device (`str` or `torch.device`, *optional*):
|
210 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
211 |
+
"""
|
212 |
+
# Clipping the minimum of all lambda(t) for numerical stability.
|
213 |
+
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
214 |
+
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
|
215 |
+
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
|
216 |
+
|
217 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
218 |
+
if self.config.timestep_spacing == "linspace":
|
219 |
+
timesteps = (
|
220 |
+
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
|
221 |
+
)
|
222 |
+
|
223 |
+
elif self.config.timestep_spacing == "leading":
|
224 |
+
step_ratio = last_timestep // (num_inference_steps + 1)
|
225 |
+
# creates integer timesteps by multiplying by ratio
|
226 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
227 |
+
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
|
228 |
+
timesteps += self.config.steps_offset
|
229 |
+
elif self.config.timestep_spacing == "trailing":
|
230 |
+
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
231 |
+
# creates integer timesteps by multiplying by ratio
|
232 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
233 |
+
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
|
234 |
+
timesteps -= 1
|
235 |
+
else:
|
236 |
+
raise ValueError(
|
237 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
238 |
+
)
|
239 |
+
|
240 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
241 |
+
if self.config.use_karras_sigmas:
|
242 |
+
log_sigmas = np.log(sigmas)
|
243 |
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
244 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
245 |
+
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
246 |
+
|
247 |
+
self.sigmas = torch.from_numpy(sigmas)
|
248 |
+
|
249 |
+
# when num_inference_steps == num_train_timesteps, we can end up with
|
250 |
+
# duplicates in timesteps.
|
251 |
+
_, unique_indices = np.unique(timesteps, return_index=True)
|
252 |
+
timesteps = timesteps[np.sort(unique_indices)]
|
253 |
+
|
254 |
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
255 |
+
|
256 |
+
self.num_inference_steps = len(timesteps)
|
257 |
+
|
258 |
+
self.model_outputs = [
|
259 |
+
None,
|
260 |
+
] * max(self.config.predictor_order, self.config.corrector_order - 1)
|
261 |
+
self.lower_order_nums = 0
|
262 |
+
self.last_sample = None
|
263 |
+
|
264 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
265 |
+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
266 |
+
"""
|
267 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
268 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
269 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
270 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
271 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
272 |
+
|
273 |
+
https://arxiv.org/abs/2205.11487
|
274 |
+
"""
|
275 |
+
dtype = sample.dtype
|
276 |
+
batch_size, channels, height, width = sample.shape
|
277 |
+
|
278 |
+
if dtype not in (torch.float32, torch.float64):
|
279 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
280 |
+
|
281 |
+
# Flatten sample for doing quantile calculation along each image
|
282 |
+
sample = sample.reshape(batch_size, channels * height * width)
|
283 |
+
|
284 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
285 |
+
|
286 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
287 |
+
s = torch.clamp(
|
288 |
+
s, min=1, max=self.config.sample_max_value
|
289 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
290 |
+
|
291 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
292 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
293 |
+
|
294 |
+
sample = sample.reshape(batch_size, channels, height, width)
|
295 |
+
sample = sample.to(dtype)
|
296 |
+
|
297 |
+
return sample
|
298 |
+
|
299 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
300 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
301 |
+
# get log sigma
|
302 |
+
log_sigma = np.log(sigma)
|
303 |
+
|
304 |
+
# get distribution
|
305 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
306 |
+
|
307 |
+
# get sigmas range
|
308 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
309 |
+
high_idx = low_idx + 1
|
310 |
+
|
311 |
+
low = log_sigmas[low_idx]
|
312 |
+
high = log_sigmas[high_idx]
|
313 |
+
|
314 |
+
# interpolate sigmas
|
315 |
+
w = (low - log_sigma) / (low - high)
|
316 |
+
w = np.clip(w, 0, 1)
|
317 |
+
|
318 |
+
# transform interpolation to time range
|
319 |
+
t = (1 - w) * low_idx + w * high_idx
|
320 |
+
t = t.reshape(sigma.shape)
|
321 |
+
return t
|
322 |
+
|
323 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
324 |
+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
325 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
326 |
+
|
327 |
+
sigma_min: float = in_sigmas[-1].item()
|
328 |
+
sigma_max: float = in_sigmas[0].item()
|
329 |
+
|
330 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
331 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
332 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
333 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
334 |
+
return (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
335 |
+
|
336 |
+
def convert_model_output(
|
337 |
+
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
338 |
+
) -> torch.FloatTensor:
|
339 |
+
"""
|
340 |
+
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
341 |
+
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
342 |
+
integral of the data prediction model.
|
343 |
+
|
344 |
+
<Tip>
|
345 |
+
|
346 |
+
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
|
347 |
+
prediction and data prediction vae.
|
348 |
+
|
349 |
+
</Tip>
|
350 |
+
|
351 |
+
Args:
|
352 |
+
model_output (`torch.FloatTensor`):
|
353 |
+
The direct output from the learned diffusion model.
|
354 |
+
timestep (`int`):
|
355 |
+
The current discrete timestep in the diffusion chain.
|
356 |
+
sample (`torch.FloatTensor`):
|
357 |
+
A current instance of a sample created by the diffusion process.
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
`torch.FloatTensor`:
|
361 |
+
The converted model output.
|
362 |
+
"""
|
363 |
+
|
364 |
+
# SA-Solver_data_prediction needs to solve an integral of the data prediction model.
|
365 |
+
if self.config.algorithm_type in ["data_prediction"]:
|
366 |
+
if self.config.prediction_type == "epsilon":
|
367 |
+
# SA-Solver only needs the "mean" output.
|
368 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
369 |
+
model_output = model_output[:, :3]
|
370 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
371 |
+
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
372 |
+
elif self.config.prediction_type == "sample":
|
373 |
+
x0_pred = model_output
|
374 |
+
elif self.config.prediction_type == "v_prediction":
|
375 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
376 |
+
x0_pred = alpha_t * sample - sigma_t * model_output
|
377 |
+
else:
|
378 |
+
raise ValueError(
|
379 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
380 |
+
" `v_prediction` for the SASolverScheduler."
|
381 |
+
)
|
382 |
+
|
383 |
+
if self.config.thresholding:
|
384 |
+
x0_pred = self._threshold_sample(x0_pred)
|
385 |
+
|
386 |
+
return x0_pred
|
387 |
+
|
388 |
+
# SA-Solver_noise_prediction needs to solve an integral of the noise prediction model.
|
389 |
+
elif self.config.algorithm_type in ["noise_prediction"]:
|
390 |
+
if self.config.prediction_type == "epsilon":
|
391 |
+
# SA-Solver only needs the "mean" output.
|
392 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
393 |
+
epsilon = model_output[:, :3]
|
394 |
+
else:
|
395 |
+
epsilon = model_output
|
396 |
+
elif self.config.prediction_type == "sample":
|
397 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
398 |
+
epsilon = (sample - alpha_t * model_output) / sigma_t
|
399 |
+
elif self.config.prediction_type == "v_prediction":
|
400 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
401 |
+
epsilon = alpha_t * model_output + sigma_t * sample
|
402 |
+
else:
|
403 |
+
raise ValueError(
|
404 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
405 |
+
" `v_prediction` for the SASolverScheduler."
|
406 |
+
)
|
407 |
+
|
408 |
+
if self.config.thresholding:
|
409 |
+
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
410 |
+
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
411 |
+
x0_pred = self._threshold_sample(x0_pred)
|
412 |
+
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
413 |
+
|
414 |
+
return epsilon
|
415 |
+
|
416 |
+
def get_coefficients_exponential_negative(self, order, interval_start, interval_end):
|
417 |
+
"""
|
418 |
+
Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end
|
419 |
+
"""
|
420 |
+
assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
|
421 |
+
|
422 |
+
if order == 0:
|
423 |
+
return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1)
|
424 |
+
elif order == 1:
|
425 |
+
return torch.exp(-interval_end) * (
|
426 |
+
(interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1))
|
427 |
+
elif order == 2:
|
428 |
+
return torch.exp(-interval_end) * (
|
429 |
+
(interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - (
|
430 |
+
interval_end ** 2 + 2 * interval_end + 2))
|
431 |
+
elif order == 3:
|
432 |
+
return torch.exp(-interval_end) * (
|
433 |
+
(interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp(
|
434 |
+
interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6))
|
435 |
+
|
436 |
+
def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau):
|
437 |
+
"""
|
438 |
+
Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
|
439 |
+
"""
|
440 |
+
assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
|
441 |
+
|
442 |
+
# after change of variable(cov)
|
443 |
+
interval_end_cov = (1 + tau ** 2) * interval_end
|
444 |
+
interval_start_cov = (1 + tau ** 2) * interval_start
|
445 |
+
|
446 |
+
if order == 0:
|
447 |
+
return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / (
|
448 |
+
(1 + tau ** 2))
|
449 |
+
elif order == 1:
|
450 |
+
return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(
|
451 |
+
-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2)
|
452 |
+
elif order == 2:
|
453 |
+
return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - (
|
454 |
+
interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp(
|
455 |
+
-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3)
|
456 |
+
elif order == 3:
|
457 |
+
return torch.exp(interval_end_cov) * (
|
458 |
+
(interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - (
|
459 |
+
interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp(
|
460 |
+
-(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4)
|
461 |
+
|
462 |
+
def lagrange_polynomial_coefficient(self, order, lambda_list):
|
463 |
+
"""
|
464 |
+
Calculate the coefficient of lagrange polynomial
|
465 |
+
"""
|
466 |
+
|
467 |
+
assert order in [0, 1, 2, 3]
|
468 |
+
assert order == len(lambda_list) - 1
|
469 |
+
if order == 0:
|
470 |
+
return [[1]]
|
471 |
+
elif order == 1:
|
472 |
+
return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
|
473 |
+
[1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
|
474 |
+
elif order == 2:
|
475 |
+
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
|
476 |
+
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2])
|
477 |
+
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1])
|
478 |
+
return [[1 / denominator1,
|
479 |
+
(-lambda_list[1] - lambda_list[2]) / denominator1,
|
480 |
+
lambda_list[1] * lambda_list[2] / denominator1],
|
481 |
+
|
482 |
+
[1 / denominator2,
|
483 |
+
(-lambda_list[0] - lambda_list[2]) / denominator2,
|
484 |
+
lambda_list[0] * lambda_list[2] / denominator2],
|
485 |
+
|
486 |
+
[1 / denominator3,
|
487 |
+
(-lambda_list[0] - lambda_list[1]) / denominator3,
|
488 |
+
lambda_list[0] * lambda_list[1] / denominator3]
|
489 |
+
]
|
490 |
+
elif order == 3:
|
491 |
+
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (
|
492 |
+
lambda_list[0] - lambda_list[3])
|
493 |
+
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (
|
494 |
+
lambda_list[1] - lambda_list[3])
|
495 |
+
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (
|
496 |
+
lambda_list[2] - lambda_list[3])
|
497 |
+
denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (
|
498 |
+
lambda_list[3] - lambda_list[2])
|
499 |
+
return [[1 / denominator1,
|
500 |
+
(-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1,
|
501 |
+
(lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[
|
502 |
+
3]) / denominator1,
|
503 |
+
(-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1],
|
504 |
+
|
505 |
+
[1 / denominator2,
|
506 |
+
(-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2,
|
507 |
+
(lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[
|
508 |
+
3]) / denominator2,
|
509 |
+
(-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],
|
510 |
+
|
511 |
+
[1 / denominator3,
|
512 |
+
(-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
|
513 |
+
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[
|
514 |
+
3]) / denominator3,
|
515 |
+
(-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],
|
516 |
+
|
517 |
+
[1 / denominator4,
|
518 |
+
(-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
|
519 |
+
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[
|
520 |
+
2]) / denominator4,
|
521 |
+
(-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
|
522 |
+
|
523 |
+
]
|
524 |
+
|
525 |
+
def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau):
|
526 |
+
assert order in [1, 2, 3, 4]
|
527 |
+
assert order == len(lambda_list), 'the length of lambda list must be equal to the order'
|
528 |
+
coefficients = []
|
529 |
+
lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list)
|
530 |
+
for i in range(order):
|
531 |
+
coefficient = sum(
|
532 |
+
lagrange_coefficient[i][j]
|
533 |
+
* self.get_coefficients_exponential_positive(
|
534 |
+
order - 1 - j, interval_start, interval_end, tau
|
535 |
+
)
|
536 |
+
if self.predict_x0
|
537 |
+
else lagrange_coefficient[i][j]
|
538 |
+
* self.get_coefficients_exponential_negative(
|
539 |
+
order - 1 - j, interval_start, interval_end
|
540 |
+
)
|
541 |
+
for j in range(order)
|
542 |
+
)
|
543 |
+
coefficients.append(coefficient)
|
544 |
+
assert len(coefficients) == order, 'the length of coefficients does not match the order'
|
545 |
+
return coefficients
|
546 |
+
|
547 |
+
def stochastic_adams_bashforth_update(
|
548 |
+
self,
|
549 |
+
model_output: torch.FloatTensor,
|
550 |
+
prev_timestep: int,
|
551 |
+
sample: torch.FloatTensor,
|
552 |
+
noise: torch.FloatTensor,
|
553 |
+
order: int,
|
554 |
+
tau: torch.FloatTensor,
|
555 |
+
) -> torch.FloatTensor:
|
556 |
+
"""
|
557 |
+
One step for the SA-Predictor.
|
558 |
+
|
559 |
+
Args:
|
560 |
+
model_output (`torch.FloatTensor`):
|
561 |
+
The direct output from the learned diffusion model at the current timestep.
|
562 |
+
prev_timestep (`int`):
|
563 |
+
The previous discrete timestep in the diffusion chain.
|
564 |
+
sample (`torch.FloatTensor`):
|
565 |
+
A current instance of a sample created by the diffusion process.
|
566 |
+
order (`int`):
|
567 |
+
The order of SA-Predictor at this timestep.
|
568 |
+
|
569 |
+
Returns:
|
570 |
+
`torch.FloatTensor`:
|
571 |
+
The sample tensor at the previous timestep.
|
572 |
+
"""
|
573 |
+
|
574 |
+
assert noise is not None
|
575 |
+
timestep_list = self.timestep_list
|
576 |
+
model_output_list = self.model_outputs
|
577 |
+
s0, t = self.timestep_list[-1], prev_timestep
|
578 |
+
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
579 |
+
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
580 |
+
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
581 |
+
gradient_part = torch.zeros_like(sample)
|
582 |
+
h = lambda_t - lambda_s0
|
583 |
+
lambda_list = [self.lambda_t[timestep_list[-(i + 1)]] for i in range(order)]
|
584 |
+
gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau)
|
585 |
+
|
586 |
+
x = sample
|
587 |
+
|
588 |
+
if self.predict_x0 and order == 2:
|
589 |
+
gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
|
590 |
+
h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
|
591 |
+
(1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[
|
592 |
+
timestep_list[-2]])
|
593 |
+
gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
|
594 |
+
h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
|
595 |
+
(1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[
|
596 |
+
timestep_list[-2]])
|
597 |
+
|
598 |
+
for i in range(order):
|
599 |
+
if self.predict_x0:
|
600 |
+
|
601 |
+
gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
|
602 |
+
i] * model_output_list[-(i + 1)]
|
603 |
+
else:
|
604 |
+
gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_output_list[-(i + 1)]
|
605 |
+
|
606 |
+
if self.predict_x0:
|
607 |
+
noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
|
608 |
+
else:
|
609 |
+
noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
|
610 |
+
|
611 |
+
if self.predict_x0:
|
612 |
+
x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part
|
613 |
+
else:
|
614 |
+
x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part
|
615 |
+
|
616 |
+
x_t = x_t.to(x.dtype)
|
617 |
+
return x_t
|
618 |
+
|
619 |
+
def stochastic_adams_moulton_update(
|
620 |
+
self,
|
621 |
+
this_model_output: torch.FloatTensor,
|
622 |
+
this_timestep: int,
|
623 |
+
last_sample: torch.FloatTensor,
|
624 |
+
last_noise: torch.FloatTensor,
|
625 |
+
this_sample: torch.FloatTensor,
|
626 |
+
order: int,
|
627 |
+
tau: torch.FloatTensor,
|
628 |
+
) -> torch.FloatTensor:
|
629 |
+
"""
|
630 |
+
One step for the SA-Corrector.
|
631 |
+
|
632 |
+
Args:
|
633 |
+
this_model_output (`torch.FloatTensor`):
|
634 |
+
The model outputs at `x_t`.
|
635 |
+
this_timestep (`int`):
|
636 |
+
The current timestep `t`.
|
637 |
+
last_sample (`torch.FloatTensor`):
|
638 |
+
The generated sample before the last predictor `x_{t-1}`.
|
639 |
+
this_sample (`torch.FloatTensor`):
|
640 |
+
The generated sample after the last predictor `x_{t}`.
|
641 |
+
order (`int`):
|
642 |
+
The order of SA-Corrector at this step.
|
643 |
+
|
644 |
+
Returns:
|
645 |
+
`torch.FloatTensor`:
|
646 |
+
The corrected sample tensor at the current timestep.
|
647 |
+
"""
|
648 |
+
|
649 |
+
assert last_noise is not None
|
650 |
+
timestep_list = self.timestep_list
|
651 |
+
model_output_list = self.model_outputs
|
652 |
+
s0, t = self.timestep_list[-1], this_timestep
|
653 |
+
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
654 |
+
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
655 |
+
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
656 |
+
gradient_part = torch.zeros_like(this_sample)
|
657 |
+
h = lambda_t - lambda_s0
|
658 |
+
t_list = timestep_list + [this_timestep]
|
659 |
+
lambda_list = [self.lambda_t[t_list[-(i + 1)]] for i in range(order)]
|
660 |
+
model_prev_list = model_output_list + [this_model_output]
|
661 |
+
|
662 |
+
gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau)
|
663 |
+
|
664 |
+
x = last_sample
|
665 |
+
|
666 |
+
if self.predict_x0 and order == 2:
|
667 |
+
gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
|
668 |
+
h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
|
669 |
+
(1 + tau ** 2) ** 2 * h))
|
670 |
+
gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
|
671 |
+
h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
|
672 |
+
(1 + tau ** 2) ** 2 * h))
|
673 |
+
|
674 |
+
for i in range(order):
|
675 |
+
if self.predict_x0:
|
676 |
+
gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
|
677 |
+
i] * model_prev_list[-(i + 1)]
|
678 |
+
else:
|
679 |
+
gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
|
680 |
+
|
681 |
+
if self.predict_x0:
|
682 |
+
noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * last_noise
|
683 |
+
else:
|
684 |
+
noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise
|
685 |
+
|
686 |
+
if self.predict_x0:
|
687 |
+
x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part
|
688 |
+
else:
|
689 |
+
x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part
|
690 |
+
|
691 |
+
x_t = x_t.to(x.dtype)
|
692 |
+
return x_t
|
693 |
+
|
694 |
+
def step(
|
695 |
+
self,
|
696 |
+
model_output: torch.FloatTensor,
|
697 |
+
timestep: int,
|
698 |
+
sample: torch.FloatTensor,
|
699 |
+
generator=None,
|
700 |
+
return_dict: bool = True,
|
701 |
+
) -> Union[SchedulerOutput, Tuple]:
|
702 |
+
"""
|
703 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
704 |
+
the SA-Solver.
|
705 |
+
|
706 |
+
Args:
|
707 |
+
model_output (`torch.FloatTensor`):
|
708 |
+
The direct output from learned diffusion model.
|
709 |
+
timestep (`int`):
|
710 |
+
The current discrete timestep in the diffusion chain.
|
711 |
+
sample (`torch.FloatTensor`):
|
712 |
+
A current instance of a sample created by the diffusion process.
|
713 |
+
generator (`torch.Generator`, *optional*):
|
714 |
+
A random number generator.
|
715 |
+
return_dict (`bool`):
|
716 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
717 |
+
|
718 |
+
Returns:
|
719 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
720 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
721 |
+
tuple is returned where the first element is the sample tensor.
|
722 |
+
|
723 |
+
"""
|
724 |
+
if self.num_inference_steps is None:
|
725 |
+
raise ValueError(
|
726 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
727 |
+
)
|
728 |
+
|
729 |
+
if isinstance(timestep, torch.Tensor):
|
730 |
+
timestep = timestep.to(self.timesteps.device)
|
731 |
+
step_index = (self.timesteps == timestep).nonzero()
|
732 |
+
if len(step_index) == 0:
|
733 |
+
step_index = len(self.timesteps) - 1
|
734 |
+
else:
|
735 |
+
step_index = step_index.item()
|
736 |
+
|
737 |
+
use_corrector = (
|
738 |
+
step_index > 0 and self.last_sample is not None
|
739 |
+
)
|
740 |
+
|
741 |
+
model_output_convert = self.convert_model_output(model_output, timestep, sample)
|
742 |
+
|
743 |
+
if use_corrector:
|
744 |
+
current_tau = self.tau_func(self.timestep_list[-1])
|
745 |
+
sample = self.stochastic_adams_moulton_update(
|
746 |
+
this_model_output=model_output_convert,
|
747 |
+
this_timestep=timestep,
|
748 |
+
last_sample=self.last_sample,
|
749 |
+
last_noise=self.last_noise,
|
750 |
+
this_sample=sample,
|
751 |
+
order=self.this_corrector_order,
|
752 |
+
tau=current_tau,
|
753 |
+
)
|
754 |
+
|
755 |
+
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
756 |
+
|
757 |
+
for i in range(max(self.config.predictor_order, self.config.corrector_order - 1) - 1):
|
758 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
759 |
+
self.timestep_list[i] = self.timestep_list[i + 1]
|
760 |
+
|
761 |
+
self.model_outputs[-1] = model_output_convert
|
762 |
+
self.timestep_list[-1] = timestep
|
763 |
+
|
764 |
+
noise = randn_tensor(
|
765 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
766 |
+
)
|
767 |
+
|
768 |
+
if self.config.lower_order_final:
|
769 |
+
this_predictor_order = min(self.config.predictor_order, len(self.timesteps) - step_index)
|
770 |
+
this_corrector_order = min(self.config.corrector_order, len(self.timesteps) - step_index + 1)
|
771 |
+
else:
|
772 |
+
this_predictor_order = self.config.predictor_order
|
773 |
+
this_corrector_order = self.config.corrector_order
|
774 |
+
|
775 |
+
self.this_predictor_order = min(this_predictor_order, self.lower_order_nums + 1) # warmup for multistep
|
776 |
+
self.this_corrector_order = min(this_corrector_order, self.lower_order_nums + 2) # warmup for multistep
|
777 |
+
assert self.this_predictor_order > 0
|
778 |
+
assert self.this_corrector_order > 0
|
779 |
+
|
780 |
+
self.last_sample = sample
|
781 |
+
self.last_noise = noise
|
782 |
+
|
783 |
+
current_tau = self.tau_func(self.timestep_list[-1])
|
784 |
+
prev_sample = self.stochastic_adams_bashforth_update(
|
785 |
+
model_output=model_output_convert,
|
786 |
+
prev_timestep=prev_timestep,
|
787 |
+
sample=sample,
|
788 |
+
noise=noise,
|
789 |
+
order=self.this_predictor_order,
|
790 |
+
tau=current_tau,
|
791 |
+
)
|
792 |
+
|
793 |
+
if self.lower_order_nums < max(self.config.predictor_order, self.config.corrector_order - 1):
|
794 |
+
self.lower_order_nums += 1
|
795 |
+
|
796 |
+
if not return_dict:
|
797 |
+
return (prev_sample,)
|
798 |
+
|
799 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
800 |
+
|
801 |
+
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
802 |
+
"""
|
803 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
804 |
+
current timestep.
|
805 |
+
|
806 |
+
Args:
|
807 |
+
sample (`torch.FloatTensor`):
|
808 |
+
The input sample.
|
809 |
+
|
810 |
+
Returns:
|
811 |
+
`torch.FloatTensor`:
|
812 |
+
A scaled input sample.
|
813 |
+
"""
|
814 |
+
return sample
|
815 |
+
|
816 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
817 |
+
def add_noise(
|
818 |
+
self,
|
819 |
+
original_samples: torch.FloatTensor,
|
820 |
+
noise: torch.FloatTensor,
|
821 |
+
timesteps: torch.IntTensor,
|
822 |
+
) -> torch.FloatTensor:
|
823 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
824 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
825 |
+
timesteps = timesteps.to(original_samples.device)
|
826 |
+
|
827 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
828 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
829 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
830 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
831 |
+
|
832 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
833 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
834 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
835 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
836 |
+
|
837 |
+
return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
838 |
+
|
839 |
+
def __len__(self):
|
840 |
+
return self.config.num_train_timesteps
|
DiT_VAE/diffusion/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# from .logger import get_root_logger
|
DiT_VAE/diffusion/utils/checkpoint.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from DiT_VAE.diffusion.utils.logger import get_root_logger
|
6 |
+
|
7 |
+
|
8 |
+
def save_checkpoint(work_dir,
|
9 |
+
epoch,
|
10 |
+
model,
|
11 |
+
model_ema=None,
|
12 |
+
optimizer=None,
|
13 |
+
lr_scheduler=None,
|
14 |
+
keep_last=False,
|
15 |
+
step=None,
|
16 |
+
):
|
17 |
+
os.makedirs(work_dir, exist_ok=True)
|
18 |
+
state_dict = dict(state_dict=model.state_dict())
|
19 |
+
if model_ema is not None:
|
20 |
+
state_dict['state_dict_ema'] = model_ema.state_dict()
|
21 |
+
if optimizer is not None:
|
22 |
+
state_dict['optimizer'] = optimizer.state_dict()
|
23 |
+
if lr_scheduler is not None:
|
24 |
+
state_dict['scheduler'] = lr_scheduler.state_dict()
|
25 |
+
if epoch is not None:
|
26 |
+
state_dict['epoch'] = epoch
|
27 |
+
file_path = os.path.join(work_dir, f"epoch_{epoch}.pth")
|
28 |
+
if step is not None:
|
29 |
+
file_path = file_path.split('.pth')[0] + f"_step_{step}.pth"
|
30 |
+
logger = get_root_logger()
|
31 |
+
torch.save(state_dict, file_path)
|
32 |
+
logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.')
|
33 |
+
if keep_last:
|
34 |
+
for i in range(epoch):
|
35 |
+
previous_ckgt = file_path.format(i)
|
36 |
+
if os.path.exists(previous_ckgt):
|
37 |
+
os.remove(previous_ckgt)
|
38 |
+
|
39 |
+
|
40 |
+
def load_checkpoint(checkpoint,
|
41 |
+
model,
|
42 |
+
model_ema=None,
|
43 |
+
optimizer=None,
|
44 |
+
lr_scheduler=None,
|
45 |
+
load_ema=False,
|
46 |
+
resume_optimizer=True,
|
47 |
+
resume_lr_scheduler=True
|
48 |
+
):
|
49 |
+
assert isinstance(checkpoint, str)
|
50 |
+
ckpt_file = checkpoint
|
51 |
+
checkpoint = torch.load(ckpt_file, map_location="cpu")
|
52 |
+
|
53 |
+
state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']
|
54 |
+
for key in state_dict_keys:
|
55 |
+
if key in checkpoint['state_dict']:
|
56 |
+
del checkpoint['state_dict'][key]
|
57 |
+
if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']:
|
58 |
+
del checkpoint['state_dict_ema'][key]
|
59 |
+
break
|
60 |
+
|
61 |
+
if load_ema:
|
62 |
+
state_dict = checkpoint['state_dict_ema']
|
63 |
+
else:
|
64 |
+
state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint
|
65 |
+
# model.load_state_dict(state_dict)
|
66 |
+
missing, unexpect = model.load_state_dict(state_dict, strict=False)
|
67 |
+
if model_ema is not None:
|
68 |
+
model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
|
69 |
+
if optimizer is not None and resume_optimizer:
|
70 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
71 |
+
if lr_scheduler is not None and resume_lr_scheduler:
|
72 |
+
lr_scheduler.load_state_dict(checkpoint['scheduler'])
|
73 |
+
logger = get_root_logger()
|
74 |
+
if optimizer is not None:
|
75 |
+
epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0])
|
76 |
+
logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, '
|
77 |
+
f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.')
|
78 |
+
return epoch, missing, unexpect
|
79 |
+
logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.')
|
80 |
+
return missing, unexpect
|
DiT_VAE/diffusion/utils/data_sampler.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os
|
3 |
+
from typing import Sequence
|
4 |
+
from torch.utils.data import BatchSampler, Sampler, Dataset
|
5 |
+
from random import shuffle, choice
|
6 |
+
from copy import deepcopy
|
7 |
+
from DiT_VAE.diffusion.utils.logger import get_root_logger
|
8 |
+
|
9 |
+
|
10 |
+
class AspectRatioBatchSampler(BatchSampler):
|
11 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
sampler (Sampler): Base sampler.
|
15 |
+
dataset (Dataset): Dataset providing data information.
|
16 |
+
batch_size (int): Size of mini-batch.
|
17 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
18 |
+
its size would be less than ``batch_size``.
|
19 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
sampler: Sampler,
|
24 |
+
dataset: Dataset,
|
25 |
+
batch_size: int,
|
26 |
+
aspect_ratios: dict,
|
27 |
+
drop_last: bool = False,
|
28 |
+
config=None,
|
29 |
+
valid_num=0, # take as valid aspect-ratio when sample number >= valid_num
|
30 |
+
**kwargs) -> None:
|
31 |
+
if not isinstance(sampler, Sampler):
|
32 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
33 |
+
f'but got {sampler}')
|
34 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
35 |
+
raise ValueError('batch_size should be a positive integer value, '
|
36 |
+
f'but got batch_size={batch_size}')
|
37 |
+
self.sampler = sampler
|
38 |
+
self.dataset = dataset
|
39 |
+
self.batch_size = batch_size
|
40 |
+
self.aspect_ratios = aspect_ratios
|
41 |
+
self.drop_last = drop_last
|
42 |
+
self.ratio_nums_gt = kwargs.get('ratio_nums', None)
|
43 |
+
self.config = config
|
44 |
+
assert self.ratio_nums_gt
|
45 |
+
# buckets for each aspect ratio
|
46 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
47 |
+
self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num]
|
48 |
+
logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
|
49 |
+
logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
|
50 |
+
|
51 |
+
def __iter__(self) -> Sequence[int]:
|
52 |
+
for idx in self.sampler:
|
53 |
+
data_info = self.dataset.get_data_info(idx)
|
54 |
+
height, width = data_info['height'], data_info['width']
|
55 |
+
ratio = height / width
|
56 |
+
# find the closest aspect ratio
|
57 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
58 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
59 |
+
continue
|
60 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
61 |
+
bucket.append(idx)
|
62 |
+
# yield a batch of indices in the same aspect ratio group
|
63 |
+
if len(bucket) == self.batch_size:
|
64 |
+
yield bucket[:]
|
65 |
+
del bucket[:]
|
66 |
+
|
67 |
+
# yield the rest data and reset the buckets
|
68 |
+
for bucket in self._aspect_ratio_buckets.values():
|
69 |
+
while len(bucket) > 0:
|
70 |
+
if len(bucket) <= self.batch_size:
|
71 |
+
if not self.drop_last:
|
72 |
+
yield bucket[:]
|
73 |
+
bucket = []
|
74 |
+
else:
|
75 |
+
yield bucket[:self.batch_size]
|
76 |
+
bucket = bucket[self.batch_size:]
|
77 |
+
|
78 |
+
|
79 |
+
class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler):
|
80 |
+
def __init__(self, *args, **kwargs):
|
81 |
+
super().__init__(*args, **kwargs)
|
82 |
+
# Assign samples to each bucket
|
83 |
+
self.ratio_nums_gt = kwargs.get('ratio_nums', None)
|
84 |
+
assert self.ratio_nums_gt
|
85 |
+
self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()}
|
86 |
+
self.original_buckets = {}
|
87 |
+
self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000]
|
88 |
+
self.all_available_keys = deepcopy(self.current_available_bucket_keys)
|
89 |
+
self.exhausted_bucket_keys = []
|
90 |
+
self.total_batches = len(self.sampler) // self.batch_size
|
91 |
+
self._aspect_ratio_count = {}
|
92 |
+
for k in self.all_available_keys:
|
93 |
+
self._aspect_ratio_count[float(k)] = 0
|
94 |
+
self.original_buckets[float(k)] = []
|
95 |
+
logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log'))
|
96 |
+
logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
|
97 |
+
|
98 |
+
def __iter__(self) -> Sequence[int]:
|
99 |
+
i = 0
|
100 |
+
for idx in self.sampler:
|
101 |
+
data_info = self.dataset.get_data_info(idx)
|
102 |
+
height, width = data_info['height'], data_info['width']
|
103 |
+
ratio = height / width
|
104 |
+
closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)))
|
105 |
+
if closest_ratio not in self.all_available_keys:
|
106 |
+
continue
|
107 |
+
if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]:
|
108 |
+
self._aspect_ratio_count[closest_ratio] += 1
|
109 |
+
self._aspect_ratio_buckets[closest_ratio].append(idx)
|
110 |
+
self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket
|
111 |
+
if not self.current_available_bucket_keys:
|
112 |
+
self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, []
|
113 |
+
|
114 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
115 |
+
continue
|
116 |
+
key = closest_ratio
|
117 |
+
bucket = self._aspect_ratio_buckets[key]
|
118 |
+
if len(bucket) == self.batch_size:
|
119 |
+
yield bucket[:self.batch_size]
|
120 |
+
del bucket[:self.batch_size]
|
121 |
+
i += 1
|
122 |
+
self.exhausted_bucket_keys.append(key)
|
123 |
+
self.current_available_bucket_keys.remove(key)
|
124 |
+
|
125 |
+
for _ in range(self.total_batches - i):
|
126 |
+
key = choice(self.all_available_keys)
|
127 |
+
bucket = self._aspect_ratio_buckets[key]
|
128 |
+
if len(bucket) >= self.batch_size:
|
129 |
+
yield bucket[:self.batch_size]
|
130 |
+
del bucket[:self.batch_size]
|
131 |
+
|
132 |
+
# If a bucket is exhausted
|
133 |
+
if not bucket:
|
134 |
+
self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
|
135 |
+
shuffle(self._aspect_ratio_buckets[key])
|
136 |
+
else:
|
137 |
+
self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
|
138 |
+
shuffle(self._aspect_ratio_buckets[key])
|
DiT_VAE/diffusion/utils/dist_utils.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains primitives for multi-gpu communication.
|
3 |
+
This is useful when doing distributed training.
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import pickle
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
import gc
|
10 |
+
import mmcv
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
from mmcv.runner import get_dist_info
|
14 |
+
|
15 |
+
|
16 |
+
def is_distributed():
|
17 |
+
return get_world_size() > 1
|
18 |
+
|
19 |
+
|
20 |
+
def get_world_size():
|
21 |
+
if not dist.is_available():
|
22 |
+
return 1
|
23 |
+
return dist.get_world_size() if dist.is_initialized() else 1
|
24 |
+
|
25 |
+
|
26 |
+
def get_rank():
|
27 |
+
if not dist.is_available():
|
28 |
+
return 0
|
29 |
+
return dist.get_rank() if dist.is_initialized() else 0
|
30 |
+
|
31 |
+
|
32 |
+
def get_local_rank():
|
33 |
+
if not dist.is_available():
|
34 |
+
return 0
|
35 |
+
return int(os.getenv('LOCAL_RANK', 0)) if dist.is_initialized() else 0
|
36 |
+
|
37 |
+
|
38 |
+
def is_master():
|
39 |
+
return get_rank() == 0
|
40 |
+
|
41 |
+
|
42 |
+
def is_local_master():
|
43 |
+
return get_local_rank() == 0
|
44 |
+
|
45 |
+
|
46 |
+
def get_local_proc_group(group_size=8):
|
47 |
+
world_size = get_world_size()
|
48 |
+
if world_size <= group_size or group_size == 1:
|
49 |
+
return None
|
50 |
+
assert world_size % group_size == 0, f'world size ({world_size}) should be evenly divided by group size ({group_size}).'
|
51 |
+
process_groups = getattr(get_local_proc_group, 'process_groups', {})
|
52 |
+
if group_size not in process_groups:
|
53 |
+
num_groups = dist.get_world_size() // group_size
|
54 |
+
groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)]
|
55 |
+
process_groups.update({group_size: [torch.distributed.new_group(group) for group in groups]})
|
56 |
+
get_local_proc_group.process_groups = process_groups
|
57 |
+
|
58 |
+
group_idx = get_rank() // group_size
|
59 |
+
return get_local_proc_group.process_groups.get(group_size)[group_idx]
|
60 |
+
|
61 |
+
|
62 |
+
def synchronize():
|
63 |
+
"""
|
64 |
+
Helper function to synchronize (barrier) among all processes when
|
65 |
+
using distributed training
|
66 |
+
"""
|
67 |
+
if not dist.is_available():
|
68 |
+
return
|
69 |
+
if not dist.is_initialized():
|
70 |
+
return
|
71 |
+
world_size = dist.get_world_size()
|
72 |
+
if world_size == 1:
|
73 |
+
return
|
74 |
+
dist.barrier()
|
75 |
+
|
76 |
+
|
77 |
+
def all_gather(data):
|
78 |
+
"""
|
79 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
80 |
+
Args:
|
81 |
+
data: any picklable object
|
82 |
+
Returns:
|
83 |
+
list[data]: list of data gathered from each rank
|
84 |
+
"""
|
85 |
+
to_device = torch.device("cuda")
|
86 |
+
# to_device = torch.device("cpu")
|
87 |
+
|
88 |
+
world_size = get_world_size()
|
89 |
+
if world_size == 1:
|
90 |
+
return [data]
|
91 |
+
|
92 |
+
# serialized to a Tensor
|
93 |
+
buffer = pickle.dumps(data)
|
94 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
95 |
+
tensor = torch.ByteTensor(storage).to(to_device)
|
96 |
+
|
97 |
+
# obtain Tensor size of each rank
|
98 |
+
local_size = torch.LongTensor([tensor.numel()]).to(to_device)
|
99 |
+
size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)]
|
100 |
+
dist.all_gather(size_list, local_size)
|
101 |
+
size_list = [int(size.item()) for size in size_list]
|
102 |
+
max_size = max(size_list)
|
103 |
+
|
104 |
+
tensor_list = [
|
105 |
+
torch.ByteTensor(size=(max_size,)).to(to_device) for _ in size_list
|
106 |
+
]
|
107 |
+
if local_size != max_size:
|
108 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device)
|
109 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
110 |
+
dist.all_gather(tensor_list, tensor)
|
111 |
+
|
112 |
+
data_list = []
|
113 |
+
for size, tensor in zip(size_list, tensor_list):
|
114 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
115 |
+
data_list.append(pickle.loads(buffer))
|
116 |
+
|
117 |
+
return data_list
|
118 |
+
|
119 |
+
|
120 |
+
def reduce_dict(input_dict, average=True):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
input_dict (dict): all the values will be reduced
|
124 |
+
average (bool): whether to do average or sum
|
125 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
126 |
+
0 has the averaged results. Returns a dict with the same fields as
|
127 |
+
input_dict, after reduction.
|
128 |
+
"""
|
129 |
+
world_size = get_world_size()
|
130 |
+
if world_size < 2:
|
131 |
+
return input_dict
|
132 |
+
with torch.no_grad():
|
133 |
+
reduced_dict = _extracted_from_reduce_dict_14(input_dict, average, world_size)
|
134 |
+
return reduced_dict
|
135 |
+
|
136 |
+
|
137 |
+
# TODO Rename this here and in `reduce_dict`
|
138 |
+
def _extracted_from_reduce_dict_14(input_dict, average, world_size):
|
139 |
+
names = []
|
140 |
+
values = []
|
141 |
+
# sort the keys so that they are consistent across processes
|
142 |
+
for k in sorted(input_dict.keys()):
|
143 |
+
names.append(k)
|
144 |
+
values.append(input_dict[k])
|
145 |
+
values = torch.stack(values, dim=0)
|
146 |
+
dist.reduce(values, dst=0)
|
147 |
+
if dist.get_rank() == 0 and average:
|
148 |
+
# only main process gets accumulated, so only divide by
|
149 |
+
# world_size in this case
|
150 |
+
values /= world_size
|
151 |
+
return dict(zip(names, values))
|
152 |
+
|
153 |
+
|
154 |
+
def broadcast(data, **kwargs):
|
155 |
+
if get_world_size() == 1:
|
156 |
+
return data
|
157 |
+
data = [data]
|
158 |
+
dist.broadcast_object_list(data, **kwargs)
|
159 |
+
return data[0]
|
160 |
+
|
161 |
+
|
162 |
+
def all_gather_cpu(result_part, tmpdir=None, collect_by_master=True):
|
163 |
+
rank, world_size = get_dist_info()
|
164 |
+
if tmpdir is None:
|
165 |
+
tmpdir = './tmp'
|
166 |
+
if rank == 0:
|
167 |
+
mmcv.mkdir_or_exist(tmpdir)
|
168 |
+
synchronize()
|
169 |
+
# dump the part result to the dir
|
170 |
+
mmcv.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl'))
|
171 |
+
synchronize()
|
172 |
+
if collect_by_master and rank != 0:
|
173 |
+
return None
|
174 |
+
# load results of all parts from tmp dir
|
175 |
+
results = []
|
176 |
+
for i in range(world_size):
|
177 |
+
part_file = os.path.join(tmpdir, f'part_{i}.pkl')
|
178 |
+
results.append(mmcv.load(part_file))
|
179 |
+
if not collect_by_master:
|
180 |
+
synchronize()
|
181 |
+
# remove tmp dir
|
182 |
+
if rank == 0:
|
183 |
+
shutil.rmtree(tmpdir)
|
184 |
+
return results
|
185 |
+
|
186 |
+
def all_gather_tensor(tensor, group_size=None, group=None):
|
187 |
+
if group_size is None:
|
188 |
+
group_size = get_world_size()
|
189 |
+
if group_size == 1:
|
190 |
+
output = [tensor]
|
191 |
+
else:
|
192 |
+
output = [torch.zeros_like(tensor) for _ in range(group_size)]
|
193 |
+
dist.all_gather(output, tensor, group=group)
|
194 |
+
return output
|
195 |
+
|
196 |
+
|
197 |
+
def gather_difflen_tensor(feat, num_samples_list, concat=True, group=None, group_size=None):
|
198 |
+
world_size = get_world_size()
|
199 |
+
if world_size == 1:
|
200 |
+
return feat if concat else [feat]
|
201 |
+
num_samples, *feat_dim = feat.size()
|
202 |
+
# padding to max number of samples
|
203 |
+
feat_padding = feat.new_zeros((max(num_samples_list), *feat_dim))
|
204 |
+
feat_padding[:num_samples] = feat
|
205 |
+
# gather
|
206 |
+
feat_gather = all_gather_tensor(feat_padding, group=group, group_size=group_size)
|
207 |
+
for r, num in enumerate(num_samples_list):
|
208 |
+
feat_gather[r] = feat_gather[r][:num]
|
209 |
+
if concat:
|
210 |
+
feat_gather = torch.cat(feat_gather)
|
211 |
+
return feat_gather
|
212 |
+
|
213 |
+
|
214 |
+
class GatherLayer(torch.autograd.Function):
|
215 |
+
'''Gather tensors from all process, supporting backward propagation.
|
216 |
+
'''
|
217 |
+
|
218 |
+
@staticmethod
|
219 |
+
def forward(ctx, input):
|
220 |
+
ctx.save_for_backward(input)
|
221 |
+
num_samples = torch.tensor(input.size(0), dtype=torch.long, device=input.device)
|
222 |
+
ctx.num_samples_list = all_gather_tensor(num_samples)
|
223 |
+
output = gather_difflen_tensor(input, ctx.num_samples_list, concat=False)
|
224 |
+
return tuple(output)
|
225 |
+
|
226 |
+
@staticmethod
|
227 |
+
def backward(ctx, *grads): # tuple(output)'s grad
|
228 |
+
input, = ctx.saved_tensors
|
229 |
+
num_samples_list = ctx.num_samples_list
|
230 |
+
rank = get_rank()
|
231 |
+
start, end = sum(num_samples_list[:rank]), sum(num_samples_list[:rank + 1])
|
232 |
+
grads = torch.cat(grads)
|
233 |
+
if is_distributed():
|
234 |
+
dist.all_reduce(grads)
|
235 |
+
grad_out = torch.zeros_like(input)
|
236 |
+
grad_out[:] = grads[start:end]
|
237 |
+
return grad_out, None, None
|
238 |
+
|
239 |
+
|
240 |
+
class GatherLayerWithGroup(torch.autograd.Function):
|
241 |
+
'''Gather tensors from all process, supporting backward propagation.
|
242 |
+
'''
|
243 |
+
|
244 |
+
@staticmethod
|
245 |
+
def forward(ctx, input, group, group_size):
|
246 |
+
ctx.save_for_backward(input)
|
247 |
+
ctx.group_size = group_size
|
248 |
+
output = all_gather_tensor(input, group=group, group_size=group_size)
|
249 |
+
return tuple(output)
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def backward(ctx, *grads): # tuple(output)'s grad
|
253 |
+
input, = ctx.saved_tensors
|
254 |
+
grads = torch.stack(grads)
|
255 |
+
if is_distributed():
|
256 |
+
dist.all_reduce(grads)
|
257 |
+
grad_out = torch.zeros_like(input)
|
258 |
+
grad_out[:] = grads[get_rank() % ctx.group_size]
|
259 |
+
return grad_out, None, None
|
260 |
+
|
261 |
+
|
262 |
+
def gather_layer_with_group(data, group=None, group_size=None):
|
263 |
+
if group_size is None:
|
264 |
+
group_size = get_world_size()
|
265 |
+
return GatherLayer.apply(data, group, group_size)
|
266 |
+
|
267 |
+
from typing import Union
|
268 |
+
import math
|
269 |
+
# from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_, _calc_grad_norm
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def clip_grad_norm_(
|
273 |
+
self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
|
274 |
+
) -> None:
|
275 |
+
self._lazy_init()
|
276 |
+
self._wait_for_previous_optim_step()
|
277 |
+
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
|
278 |
+
self._assert_state(TrainingState_.IDLE)
|
279 |
+
|
280 |
+
max_norm = float(max_norm)
|
281 |
+
norm_type = float(norm_type)
|
282 |
+
# Computes the max norm for this shard's gradients and sync's across workers
|
283 |
+
local_norm = _calc_grad_norm(self.params_with_grad, norm_type).cuda() # type: ignore[arg-type]
|
284 |
+
if norm_type == math.inf:
|
285 |
+
total_norm = local_norm
|
286 |
+
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
|
287 |
+
else:
|
288 |
+
total_norm = local_norm ** norm_type
|
289 |
+
dist.all_reduce(total_norm, group=self.process_group)
|
290 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
291 |
+
|
292 |
+
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
|
293 |
+
if clip_coef < 1:
|
294 |
+
# multiply by clip_coef, aka, (max_norm/total_norm).
|
295 |
+
for p in self.params_with_grad:
|
296 |
+
assert p.grad is not None
|
297 |
+
p.grad.detach().mul_(clip_coef.to(p.grad.device))
|
298 |
+
return total_norm
|
299 |
+
|
300 |
+
|
301 |
+
def flush():
|
302 |
+
gc.collect()
|
303 |
+
torch.cuda.empty_cache()
|
DiT_VAE/diffusion/utils/logger.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import torch.distributed as dist
|
4 |
+
from datetime import datetime
|
5 |
+
from .dist_utils import is_local_master
|
6 |
+
from mmcv.utils.logging import logger_initialized
|
7 |
+
|
8 |
+
|
9 |
+
def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'):
|
10 |
+
"""Get root logger.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
log_file (str, optional): File path of log. Defaults to None.
|
14 |
+
log_level (int, optional): The level of logger.
|
15 |
+
Defaults to logging.INFO.
|
16 |
+
name (str): logger name
|
17 |
+
Returns:
|
18 |
+
:obj:`logging.Logger`: The obtained logger
|
19 |
+
"""
|
20 |
+
if log_file is None:
|
21 |
+
log_file = '/dev/null'
|
22 |
+
return get_logger(name=name, log_file=log_file, log_level=log_level)
|
23 |
+
|
24 |
+
|
25 |
+
def get_logger(name, log_file=None, log_level=logging.INFO):
|
26 |
+
"""Initialize and get a logger by name.
|
27 |
+
|
28 |
+
If the logger has not been initialized, this method will initialize the
|
29 |
+
logger by adding one or two handlers, otherwise the initialized logger will
|
30 |
+
be directly returned. During initialization, a StreamHandler will always be
|
31 |
+
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
32 |
+
will also be added.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
name (str): Logger name.
|
36 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
37 |
+
will be added to the logger.
|
38 |
+
log_level (int): The logger level. Note that only the process of
|
39 |
+
rank 0 is affected, and other processes will set the level to
|
40 |
+
"Error" thus be silent most of the time.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
logging.Logger: The expected logger.
|
44 |
+
"""
|
45 |
+
logger = logging.getLogger(name)
|
46 |
+
logger.propagate = False # disable root logger to avoid duplicate logging
|
47 |
+
|
48 |
+
if name in logger_initialized:
|
49 |
+
return logger
|
50 |
+
# handle hierarchical names
|
51 |
+
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
52 |
+
# initialization since it is a child of "a".
|
53 |
+
for logger_name in logger_initialized:
|
54 |
+
if name.startswith(logger_name):
|
55 |
+
return logger
|
56 |
+
|
57 |
+
stream_handler = logging.StreamHandler()
|
58 |
+
handlers = [stream_handler]
|
59 |
+
|
60 |
+
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
|
61 |
+
# only rank 0 will add a FileHandler
|
62 |
+
if rank == 0 and log_file is not None:
|
63 |
+
file_handler = logging.FileHandler(log_file, 'w')
|
64 |
+
handlers.append(file_handler)
|
65 |
+
|
66 |
+
formatter = logging.Formatter(
|
67 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
68 |
+
for handler in handlers:
|
69 |
+
handler.setFormatter(formatter)
|
70 |
+
handler.setLevel(log_level)
|
71 |
+
logger.addHandler(handler)
|
72 |
+
|
73 |
+
# only rank0 for each node will print logs
|
74 |
+
log_level = log_level if is_local_master() else logging.ERROR
|
75 |
+
logger.setLevel(log_level)
|
76 |
+
|
77 |
+
logger_initialized[name] = True
|
78 |
+
|
79 |
+
return logger
|
80 |
+
|
81 |
+
def rename_file_with_creation_time(file_path):
|
82 |
+
# 获取文件的创建时间
|
83 |
+
creation_time = os.path.getctime(file_path)
|
84 |
+
creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S')
|
85 |
+
|
86 |
+
# 构建新的文件名
|
87 |
+
dir_name, file_name = os.path.split(file_path)
|
88 |
+
name, ext = os.path.splitext(file_name)
|
89 |
+
new_file_name = f"{name}_{creation_time_str}{ext}"
|
90 |
+
new_file_path = os.path.join(dir_name, new_file_name)
|
91 |
+
|
92 |
+
# 重命名文件
|
93 |
+
os.rename(file_path, new_file_path)
|
94 |
+
print(f"File renamed to: {new_file_path}")
|
DiT_VAE/diffusion/utils/lr_scheduler.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
|
2 |
+
from torch.optim import Optimizer
|
3 |
+
from torch.optim.lr_scheduler import LambdaLR
|
4 |
+
import math
|
5 |
+
|
6 |
+
from DiT_VAE.diffusion.utils.logger import get_root_logger
|
7 |
+
|
8 |
+
|
9 |
+
def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio):
|
10 |
+
if not config.get('lr_schedule_args', None):
|
11 |
+
config.lr_schedule_args = {}
|
12 |
+
if config.get('lr_warmup_steps', None):
|
13 |
+
config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version
|
14 |
+
|
15 |
+
logger = get_root_logger()
|
16 |
+
logger.info(
|
17 |
+
f'Lr schedule: {config.lr_schedule}, ' + ",".join(
|
18 |
+
[f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.')
|
19 |
+
if config.lr_schedule == 'cosine':
|
20 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
21 |
+
optimizer=optimizer,
|
22 |
+
**config.lr_schedule_args,
|
23 |
+
num_training_steps=(len(train_dataloader) * config.num_epochs),
|
24 |
+
)
|
25 |
+
elif config.lr_schedule == 'constant':
|
26 |
+
lr_scheduler = get_constant_schedule_with_warmup(
|
27 |
+
optimizer=optimizer,
|
28 |
+
**config.lr_schedule_args,
|
29 |
+
)
|
30 |
+
elif config.lr_schedule == 'cosine_decay_to_constant':
|
31 |
+
assert lr_scale_ratio >= 1
|
32 |
+
lr_scheduler = get_cosine_decay_to_constant_with_warmup(
|
33 |
+
optimizer=optimizer,
|
34 |
+
**config.lr_schedule_args,
|
35 |
+
final_lr=1 / lr_scale_ratio,
|
36 |
+
num_training_steps=(len(train_dataloader) * config.num_epochs),
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.')
|
40 |
+
return lr_scheduler
|
41 |
+
|
42 |
+
|
43 |
+
def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer,
|
44 |
+
num_warmup_steps: int,
|
45 |
+
num_training_steps: int,
|
46 |
+
final_lr: float = 0.0,
|
47 |
+
num_decay: float = 0.667,
|
48 |
+
num_cycles: float = 0.5,
|
49 |
+
last_epoch: int = -1
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Create a schedule with a cosine annealing lr followed by a constant lr.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
56 |
+
The optimizer for which to schedule the learning rate.
|
57 |
+
num_warmup_steps (`int`):
|
58 |
+
The number of steps for the warmup phase.
|
59 |
+
num_training_steps (`int`):
|
60 |
+
The number of total training steps.
|
61 |
+
final_lr (`int`):
|
62 |
+
The final constant lr after cosine decay.
|
63 |
+
num_decay (`int`):
|
64 |
+
The
|
65 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
66 |
+
The index of the last epoch when resuming training.
|
67 |
+
|
68 |
+
Return:
|
69 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
70 |
+
"""
|
71 |
+
|
72 |
+
def lr_lambda(current_step):
|
73 |
+
if current_step < num_warmup_steps:
|
74 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
75 |
+
|
76 |
+
num_decay_steps = int(num_training_steps * num_decay)
|
77 |
+
if current_step > num_decay_steps:
|
78 |
+
return final_lr
|
79 |
+
|
80 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps))
|
81 |
+
return (
|
82 |
+
max(
|
83 |
+
0.0,
|
84 |
+
0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)),
|
85 |
+
)
|
86 |
+
* (1 - final_lr)
|
87 |
+
) + final_lr
|
88 |
+
|
89 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
DiT_VAE/diffusion/utils/misc.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from mmcv import Config
|
11 |
+
from mmcv.runner import get_dist_info
|
12 |
+
|
13 |
+
from .logger import get_root_logger
|
14 |
+
|
15 |
+
os.environ["MOX_SILENT_MODE"] = "1" # mute moxing log
|
16 |
+
|
17 |
+
|
18 |
+
def read_config(file):
|
19 |
+
# solve config loading conflict when multi-processes
|
20 |
+
import time
|
21 |
+
while True:
|
22 |
+
config = Config.fromfile(file)
|
23 |
+
if len(config) == 0:
|
24 |
+
time.sleep(0.1)
|
25 |
+
continue
|
26 |
+
break
|
27 |
+
return config
|
28 |
+
|
29 |
+
|
30 |
+
def init_random_seed(seed=None, device='cuda'):
|
31 |
+
"""Initialize random seed.
|
32 |
+
|
33 |
+
If the seed is not set, the seed will be automatically randomized,
|
34 |
+
and then broadcast to all processes to prevent some potential bugs.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
seed (int, Optional): The seed. Default to None.
|
38 |
+
device (str): The device where the seed will be put on.
|
39 |
+
Default to 'cuda'.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
int: Seed to be used.
|
43 |
+
"""
|
44 |
+
if seed is not None:
|
45 |
+
return seed
|
46 |
+
|
47 |
+
# Make sure all ranks share the same random seed to prevent
|
48 |
+
# some potential bugs. Please refer to
|
49 |
+
# https://github.com/open-mmlab/mmdetection/issues/6339
|
50 |
+
rank, world_size = get_dist_info()
|
51 |
+
seed = np.random.randint(2 ** 31)
|
52 |
+
if world_size == 1:
|
53 |
+
return seed
|
54 |
+
|
55 |
+
if rank == 0:
|
56 |
+
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
57 |
+
else:
|
58 |
+
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
59 |
+
dist.broadcast(random_num, src=0)
|
60 |
+
return random_num.item()
|
61 |
+
|
62 |
+
|
63 |
+
def set_random_seed(seed, deterministic=False):
|
64 |
+
"""Set random seed.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
seed (int): Seed to be used.
|
68 |
+
deterministic (bool): Whether to set the deterministic option for
|
69 |
+
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
70 |
+
to True and `torch.backends.cudnn.benchmark` to False.
|
71 |
+
Default: False.
|
72 |
+
"""
|
73 |
+
random.seed(seed)
|
74 |
+
np.random.seed(seed)
|
75 |
+
torch.manual_seed(seed)
|
76 |
+
torch.cuda.manual_seed_all(seed)
|
77 |
+
if deterministic:
|
78 |
+
torch.backends.cudnn.deterministic = True
|
79 |
+
torch.backends.cudnn.benchmark = False
|
80 |
+
|
81 |
+
class SimpleTimer:
|
82 |
+
def __init__(self, num_tasks, log_interval=1, desc="Process"):
|
83 |
+
self.num_tasks = num_tasks
|
84 |
+
self.desc = desc
|
85 |
+
self.count = 0
|
86 |
+
self.log_interval = log_interval
|
87 |
+
self.start_time = time.time()
|
88 |
+
self.logger = get_root_logger()
|
89 |
+
|
90 |
+
def log(self):
|
91 |
+
self.count += 1
|
92 |
+
if (self.count % self.log_interval) == 0 or self.count == self.num_tasks:
|
93 |
+
time_elapsed = time.time() - self.start_time
|
94 |
+
avg_time = time_elapsed / self.count
|
95 |
+
eta_sec = avg_time * (self.num_tasks - self.count)
|
96 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
97 |
+
elapsed_str = str(datetime.timedelta(seconds=int(time_elapsed)))
|
98 |
+
log_info = f"{self.desc} [{self.count}/{self.num_tasks}], elapsed_time:{elapsed_str}," \
|
99 |
+
f" avg_time: {avg_time}, eta: {eta_str}."
|
100 |
+
self.logger.info(log_info)
|
101 |
+
|
102 |
+
|
103 |
+
class DebugUnderflowOverflow:
|
104 |
+
"""
|
105 |
+
This debug class helps detect and understand where the model starts getting very large or very small, and more
|
106 |
+
importantly `nan` or `inf` weight and activation elements.
|
107 |
+
There are 2 working modes:
|
108 |
+
1. Underflow/overflow detection (default)
|
109 |
+
2. Specific batch absolute min/max tracing without detection
|
110 |
+
Mode 1: Underflow/overflow detection
|
111 |
+
To activate the underflow/overflow detection, initialize the object with the model :
|
112 |
+
```python
|
113 |
+
debug_overflow = DebugUnderflowOverflow(model)
|
114 |
+
```
|
115 |
+
then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or
|
116 |
+
output elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this
|
117 |
+
event, each frame reporting
|
118 |
+
1. the fully qualified module name plus the class name whose `forward` was run
|
119 |
+
2. the absolute min and max value of all elements for each module weights, and the inputs and output
|
120 |
+
For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16 mixed precision :
|
121 |
+
```
|
122 |
+
Detected inf/nan during batch_number=0
|
123 |
+
Last 21 forward frames:
|
124 |
+
abs min abs max metadata
|
125 |
+
[...]
|
126 |
+
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
|
127 |
+
2.17e-07 4.50e+00 weight
|
128 |
+
1.79e-06 4.65e+00 input[0]
|
129 |
+
2.68e-06 3.70e+01 output
|
130 |
+
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
|
131 |
+
8.08e-07 2.66e+01 weight
|
132 |
+
1.79e-06 4.65e+00 input[0]
|
133 |
+
1.27e-04 2.37e+02 output
|
134 |
+
encoder.block.2.layer.1.DenseReluDense.wo Linear
|
135 |
+
1.01e-06 6.44e+00 weight
|
136 |
+
0.00e+00 9.74e+03 input[0]
|
137 |
+
3.18e-04 6.27e+04 output
|
138 |
+
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
|
139 |
+
1.79e-06 4.65e+00 input[0]
|
140 |
+
3.18e-04 6.27e+04 output
|
141 |
+
encoder.block.2.layer.1.dropout Dropout
|
142 |
+
3.18e-04 6.27e+04 input[0]
|
143 |
+
0.00e+00 inf output
|
144 |
+
```
|
145 |
+
You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value
|
146 |
+
was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
|
147 |
+
renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
|
148 |
+
64K, and we get an overlow.
|
149 |
+
As you can see it's the previous frames that we need to look into when the numbers start going into very large for
|
150 |
+
fp16 numbers.
|
151 |
+
The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
|
152 |
+
By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
|
153 |
+
```python
|
154 |
+
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
|
155 |
+
```
|
156 |
+
To validate that you have set up this debugging feature correctly, and you intend to use it in a training that may
|
157 |
+
take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in the next
|
158 |
+
section.
|
159 |
+
Mode 2. Specific batch absolute min/max tracing without detection
|
160 |
+
The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
|
161 |
+
Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
|
162 |
+
given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
|
163 |
+
```python
|
164 |
+
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3])
|
165 |
+
```
|
166 |
+
And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
|
167 |
+
This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
|
168 |
+
fast-forward right to that area.
|
169 |
+
Early stopping:
|
170 |
+
You can also specify the batch number after which to stop the training, with :
|
171 |
+
```python
|
172 |
+
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3)
|
173 |
+
```
|
174 |
+
This feature is mainly useful in the tracing mode, but you can use it for any mode.
|
175 |
+
**Performance**:
|
176 |
+
As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the
|
177 |
+
training down. Therefore remember to turn it off once the debugging needs have been met.
|
178 |
+
Args:
|
179 |
+
model (`nn.Module`):
|
180 |
+
The model to debug.
|
181 |
+
max_frames_to_save (`int`, *optional*, defaults to 21):
|
182 |
+
How many frames back to record
|
183 |
+
trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
|
184 |
+
Which batch numbers to trace (turns detection off)
|
185 |
+
abort_after_batch_num (`int``, *optional*):
|
186 |
+
Whether to abort after a certain batch number has finished
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=None, abort_after_batch_num=None):
|
190 |
+
if trace_batch_nums is None:
|
191 |
+
trace_batch_nums = []
|
192 |
+
self.model = model
|
193 |
+
self.trace_batch_nums = trace_batch_nums
|
194 |
+
self.abort_after_batch_num = abort_after_batch_num
|
195 |
+
|
196 |
+
# keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
|
197 |
+
self.frames = collections.deque([], max_frames_to_save)
|
198 |
+
self.frame = []
|
199 |
+
self.batch_number = 0
|
200 |
+
self.total_calls = 0
|
201 |
+
self.detected_overflow = False
|
202 |
+
self.prefix = " "
|
203 |
+
|
204 |
+
self.analyse_model()
|
205 |
+
|
206 |
+
self.register_forward_hook()
|
207 |
+
|
208 |
+
def save_frame(self, frame=None):
|
209 |
+
if frame is not None:
|
210 |
+
self.expand_frame(frame)
|
211 |
+
self.frames.append("\n".join(self.frame))
|
212 |
+
self.frame = [] # start a new frame
|
213 |
+
|
214 |
+
def expand_frame(self, line):
|
215 |
+
self.frame.append(line)
|
216 |
+
|
217 |
+
def trace_frames(self):
|
218 |
+
print("\n".join(self.frames))
|
219 |
+
self.frames = []
|
220 |
+
|
221 |
+
def reset_saved_frames(self):
|
222 |
+
self.frames = []
|
223 |
+
|
224 |
+
def dump_saved_frames(self):
|
225 |
+
print(f"\nDetected inf/nan during batch_number={self.batch_number} "
|
226 |
+
f"Last {len(self.frames)} forward frames:"
|
227 |
+
f"{'abs min':8} {'abs max':8} metadata"
|
228 |
+
f"'\n'.join(self.frames)"
|
229 |
+
f"\n\n")
|
230 |
+
self.frames = []
|
231 |
+
|
232 |
+
def analyse_model(self):
|
233 |
+
# extract the fully qualified module names, to be able to report at run time. e.g.:
|
234 |
+
# encoder.block.2.layer.0.SelfAttention.o
|
235 |
+
#
|
236 |
+
# for shared weights only the first shared module name will be registered
|
237 |
+
self.module_names = {m: name for name, m in self.model.named_modules()}
|
238 |
+
# self.longest_module_name = max(len(v) for v in self.module_names.values())
|
239 |
+
|
240 |
+
def analyse_variable(self, var, ctx):
|
241 |
+
if torch.is_tensor(var):
|
242 |
+
self.expand_frame(self.get_abs_min_max(var, ctx))
|
243 |
+
if self.detect_overflow(var, ctx):
|
244 |
+
self.detected_overflow = True
|
245 |
+
elif var is None:
|
246 |
+
self.expand_frame(f"{'None':>17} {ctx}")
|
247 |
+
else:
|
248 |
+
self.expand_frame(f"{'not a tensor':>17} {ctx}")
|
249 |
+
|
250 |
+
def batch_start_frame(self):
|
251 |
+
self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
|
252 |
+
self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
|
253 |
+
|
254 |
+
def batch_end_frame(self):
|
255 |
+
self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number - 1} ***\n\n")
|
256 |
+
|
257 |
+
def create_frame(self, module, input, output):
|
258 |
+
self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
|
259 |
+
|
260 |
+
# params
|
261 |
+
for name, p in module.named_parameters(recurse=False):
|
262 |
+
self.analyse_variable(p, name)
|
263 |
+
|
264 |
+
# inputs
|
265 |
+
if isinstance(input, tuple):
|
266 |
+
for i, x in enumerate(input):
|
267 |
+
self.analyse_variable(x, f"input[{i}]")
|
268 |
+
else:
|
269 |
+
self.analyse_variable(input, "input")
|
270 |
+
|
271 |
+
# outputs
|
272 |
+
if isinstance(output, tuple):
|
273 |
+
for i, x in enumerate(output):
|
274 |
+
# possibly a tuple of tuples
|
275 |
+
if isinstance(x, tuple):
|
276 |
+
for j, y in enumerate(x):
|
277 |
+
self.analyse_variable(y, f"output[{i}][{j}]")
|
278 |
+
else:
|
279 |
+
self.analyse_variable(x, f"output[{i}]")
|
280 |
+
else:
|
281 |
+
self.analyse_variable(output, "output")
|
282 |
+
|
283 |
+
self.save_frame()
|
284 |
+
|
285 |
+
def register_forward_hook(self):
|
286 |
+
self.model.apply(self._register_forward_hook)
|
287 |
+
|
288 |
+
def _register_forward_hook(self, module):
|
289 |
+
module.register_forward_hook(self.forward_hook)
|
290 |
+
|
291 |
+
def forward_hook(self, module, input, output):
|
292 |
+
# - input is a tuple of packed inputs (could be non-Tensors)
|
293 |
+
# - output could be a Tensor or a tuple of Tensors and non-Tensors
|
294 |
+
|
295 |
+
last_frame_of_batch = False
|
296 |
+
|
297 |
+
trace_mode = self.batch_number in self.trace_batch_nums
|
298 |
+
if trace_mode:
|
299 |
+
self.reset_saved_frames()
|
300 |
+
|
301 |
+
if self.total_calls == 0:
|
302 |
+
self.batch_start_frame()
|
303 |
+
self.total_calls += 1
|
304 |
+
|
305 |
+
# count batch numbers - the very first forward hook of the batch will be called when the
|
306 |
+
# batch completes - i.e. it gets called very last - we know this batch has finished
|
307 |
+
if module == self.model:
|
308 |
+
self.batch_number += 1
|
309 |
+
last_frame_of_batch = True
|
310 |
+
|
311 |
+
self.create_frame(module, input, output)
|
312 |
+
|
313 |
+
# if last_frame_of_batch:
|
314 |
+
# self.batch_end_frame()
|
315 |
+
|
316 |
+
if trace_mode:
|
317 |
+
self.trace_frames()
|
318 |
+
|
319 |
+
if last_frame_of_batch:
|
320 |
+
self.batch_start_frame()
|
321 |
+
|
322 |
+
if self.detected_overflow and not trace_mode:
|
323 |
+
self.dump_saved_frames()
|
324 |
+
|
325 |
+
# now we can abort, as it's pointless to continue running
|
326 |
+
raise ValueError(
|
327 |
+
"DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
|
328 |
+
"Please scroll up above this traceback to see the activation values prior to this event."
|
329 |
+
)
|
330 |
+
|
331 |
+
# abort after certain batch if requested to do so
|
332 |
+
if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
|
333 |
+
raise ValueError(
|
334 |
+
f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to `abort_after_batch_num={self.abort_after_batch_num}` arg"
|
335 |
+
)
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def get_abs_min_max(var, ctx):
|
339 |
+
abs_var = var.abs()
|
340 |
+
return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
|
341 |
+
|
342 |
+
@staticmethod
|
343 |
+
def detect_overflow(var, ctx):
|
344 |
+
"""
|
345 |
+
Report whether the tensor contains any `nan` or `inf` entries.
|
346 |
+
This is useful for detecting overflows/underflows and best to call right after the function that did some math that
|
347 |
+
modified the tensor in question.
|
348 |
+
This function contains a few other helper features that you can enable and tweak directly if you want to track
|
349 |
+
various other things.
|
350 |
+
Args:
|
351 |
+
var: the tensor variable to check
|
352 |
+
ctx: the message to print as a context
|
353 |
+
Return:
|
354 |
+
`True` if `inf` or `nan` was detected, `False` otherwise
|
355 |
+
"""
|
356 |
+
detected = False
|
357 |
+
if torch.isnan(var).any().item():
|
358 |
+
detected = True
|
359 |
+
print(f"{ctx} has nans")
|
360 |
+
if torch.isinf(var).any().item():
|
361 |
+
detected = True
|
362 |
+
print(f"{ctx} has infs")
|
363 |
+
if var.dtype == torch.float32 and torch.ge(var.abs(), 65535).any().item():
|
364 |
+
detected = True
|
365 |
+
print(f"{ctx} has overflow values {var.abs().max().item()}.")
|
366 |
+
return detected
|
DiT_VAE/diffusion/utils/optimizer.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
from mmcv import Config
|
4 |
+
from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \
|
5 |
+
OPTIMIZERS
|
6 |
+
from mmcv.utils import _BatchNorm, _InstanceNorm
|
7 |
+
from torch.nn import GroupNorm, LayerNorm
|
8 |
+
|
9 |
+
from .logger import get_root_logger
|
10 |
+
|
11 |
+
from typing import Tuple, Optional, Callable
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch.optim.optimizer import Optimizer
|
15 |
+
|
16 |
+
|
17 |
+
def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256):
|
18 |
+
assert rule in ['linear', 'sqrt']
|
19 |
+
logger = get_root_logger()
|
20 |
+
# scale by world size
|
21 |
+
if rule == 'sqrt':
|
22 |
+
scale_ratio = math.sqrt(effective_bs / base_batch_size)
|
23 |
+
elif rule == 'linear':
|
24 |
+
scale_ratio = effective_bs / base_batch_size
|
25 |
+
optimizer_cfg['lr'] *= scale_ratio
|
26 |
+
logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.7f} (using {rule} scaling rule).')
|
27 |
+
return scale_ratio
|
28 |
+
|
29 |
+
|
30 |
+
@OPTIMIZER_BUILDERS.register_module()
|
31 |
+
class MyOptimizerConstructor(DefaultOptimizerConstructor):
|
32 |
+
|
33 |
+
def add_params(self, params, module, prefix='', is_dcn_module=None):
|
34 |
+
"""Add all parameters of module to the params list.
|
35 |
+
|
36 |
+
The parameters of the given module will be added to the list of param
|
37 |
+
groups, with specific rules defined by paramwise_cfg.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
params (list[dict]): A list of param groups, it will be modified
|
41 |
+
in place.
|
42 |
+
module (nn.Module): The module to be added.
|
43 |
+
prefix (str): The prefix of the module
|
44 |
+
|
45 |
+
"""
|
46 |
+
# get param-wise options
|
47 |
+
custom_keys = self.paramwise_cfg.get('custom_keys', {})
|
48 |
+
# first sort with alphabet order and then sort with reversed len of str
|
49 |
+
# sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
|
50 |
+
|
51 |
+
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
|
52 |
+
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
|
53 |
+
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
|
54 |
+
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
|
55 |
+
|
56 |
+
# special rules for norm layers and depth-wise conv layers
|
57 |
+
is_norm = isinstance(module,
|
58 |
+
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
|
59 |
+
|
60 |
+
for name, param in module.named_parameters(recurse=False):
|
61 |
+
base_lr = self.base_lr
|
62 |
+
if name == 'bias' and not is_norm and not is_dcn_module:
|
63 |
+
base_lr *= bias_lr_mult
|
64 |
+
|
65 |
+
# apply weight decay policies
|
66 |
+
base_wd = self.base_wd
|
67 |
+
# norm decay
|
68 |
+
if is_norm:
|
69 |
+
if self.base_wd is not None:
|
70 |
+
base_wd *= norm_decay_mult
|
71 |
+
elif name == 'bias' and not is_dcn_module:
|
72 |
+
if self.base_wd is not None:
|
73 |
+
# TODO: current bias_decay_mult will have affect on DCN
|
74 |
+
base_wd *= bias_decay_mult
|
75 |
+
|
76 |
+
param_group = {'params': [param]}
|
77 |
+
if not param.requires_grad:
|
78 |
+
param_group['requires_grad'] = False
|
79 |
+
params.append(param_group)
|
80 |
+
continue
|
81 |
+
if bypass_duplicate and self._is_in(param_group, params):
|
82 |
+
logger = get_root_logger()
|
83 |
+
logger.warn(f'{prefix} is duplicate. It is skipped since '
|
84 |
+
f'bypass_duplicate={bypass_duplicate}')
|
85 |
+
continue
|
86 |
+
# if the parameter match one of the custom keys, ignore other rules
|
87 |
+
is_custom = False
|
88 |
+
for key in custom_keys:
|
89 |
+
scope, key_name = key if isinstance(key, tuple) else (None, key)
|
90 |
+
if scope is not None and scope not in f'{prefix}':
|
91 |
+
continue
|
92 |
+
if key_name in f'{prefix}.{name}':
|
93 |
+
is_custom = True
|
94 |
+
if 'lr_mult' in custom_keys[key]:
|
95 |
+
# if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}':
|
96 |
+
# param_group['lr'] = self.base_lr
|
97 |
+
# else:
|
98 |
+
param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult']
|
99 |
+
elif 'lr' not in param_group:
|
100 |
+
param_group['lr'] = base_lr
|
101 |
+
if self.base_wd is not None:
|
102 |
+
if 'decay_mult' in custom_keys[key]:
|
103 |
+
param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult']
|
104 |
+
elif 'weight_decay' not in param_group:
|
105 |
+
param_group['weight_decay'] = base_wd
|
106 |
+
|
107 |
+
if not is_custom:
|
108 |
+
# bias_lr_mult affects all bias parameters
|
109 |
+
# except for norm.bias dcn.conv_offset.bias
|
110 |
+
if base_lr != self.base_lr:
|
111 |
+
param_group['lr'] = base_lr
|
112 |
+
if base_wd != self.base_wd:
|
113 |
+
param_group['weight_decay'] = base_wd
|
114 |
+
params.append(param_group)
|
115 |
+
|
116 |
+
for child_name, child_mod in module.named_children():
|
117 |
+
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
|
118 |
+
self.add_params(
|
119 |
+
params,
|
120 |
+
child_mod,
|
121 |
+
prefix=child_prefix,
|
122 |
+
is_dcn_module=is_dcn_module)
|
123 |
+
|
124 |
+
|
125 |
+
def build_optimizer(model, optimizer_cfg):
|
126 |
+
# default parameter-wise config
|
127 |
+
logger = get_root_logger()
|
128 |
+
|
129 |
+
if hasattr(model, 'module'):
|
130 |
+
model = model.module
|
131 |
+
# set optimizer constructor
|
132 |
+
optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor')
|
133 |
+
# parameter-wise setting: cancel weight decay for some specific modules
|
134 |
+
custom_keys = dict()
|
135 |
+
for name, module in model.named_modules():
|
136 |
+
if hasattr(module, 'zero_weight_decay'):
|
137 |
+
custom_keys |= {
|
138 |
+
(name, key): dict(decay_mult=0)
|
139 |
+
for key in module.zero_weight_decay
|
140 |
+
}
|
141 |
+
|
142 |
+
paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys)))
|
143 |
+
if given_cfg := optimizer_cfg.get('paramwise_cfg'):
|
144 |
+
paramwise_cfg.merge_from_dict(dict(cfg=given_cfg))
|
145 |
+
optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg
|
146 |
+
# build optimizer
|
147 |
+
optimizer = mm_build_optimizer(model, optimizer_cfg)
|
148 |
+
|
149 |
+
weight_decay_groups = dict()
|
150 |
+
lr_groups = dict()
|
151 |
+
for group in optimizer.param_groups:
|
152 |
+
if not group.get('requires_grad', True): continue
|
153 |
+
lr_groups.setdefault(group['lr'], []).append(group)
|
154 |
+
weight_decay_groups.setdefault(group['weight_decay'], []).append(group)
|
155 |
+
|
156 |
+
learnable_count, fix_count = 0, 0
|
157 |
+
for p in model.parameters():
|
158 |
+
if p.requires_grad:
|
159 |
+
learnable_count += 1
|
160 |
+
else:
|
161 |
+
fix_count += 1
|
162 |
+
fix_info = f"{learnable_count} are learnable, {fix_count} are fix"
|
163 |
+
lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()])
|
164 |
+
wd_info = "Weight decay group: " + ", ".join(
|
165 |
+
[f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()])
|
166 |
+
opt_info = f"Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}."
|
167 |
+
logger.info(opt_info)
|
168 |
+
|
169 |
+
return optimizer
|
170 |
+
|
171 |
+
|
172 |
+
@OPTIMIZERS.register_module()
|
173 |
+
class Lion(Optimizer):
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
params,
|
177 |
+
lr: float = 1e-4,
|
178 |
+
betas: Tuple[float, float] = (0.9, 0.99),
|
179 |
+
weight_decay: float = 0.0,
|
180 |
+
):
|
181 |
+
assert lr > 0.
|
182 |
+
assert all(0. <= beta <= 1. for beta in betas)
|
183 |
+
|
184 |
+
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
|
185 |
+
|
186 |
+
super().__init__(params, defaults)
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
|
190 |
+
# stepweight decay
|
191 |
+
p.data.mul_(1 - lr * wd)
|
192 |
+
|
193 |
+
# weight update
|
194 |
+
update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_()
|
195 |
+
p.add_(update, alpha=-lr)
|
196 |
+
|
197 |
+
# decay the momentum running average coefficient
|
198 |
+
exp_avg.lerp_(grad, 1 - beta2)
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def exists(val):
|
202 |
+
return val is not None
|
203 |
+
|
204 |
+
@torch.no_grad()
|
205 |
+
def step(
|
206 |
+
self,
|
207 |
+
closure: Optional[Callable] = None
|
208 |
+
):
|
209 |
+
|
210 |
+
loss = None
|
211 |
+
if self.exists(closure):
|
212 |
+
with torch.enable_grad():
|
213 |
+
loss = closure()
|
214 |
+
|
215 |
+
for group in self.param_groups:
|
216 |
+
for p in filter(lambda p: self.exists(p.grad), group['params']):
|
217 |
+
|
218 |
+
grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \
|
219 |
+
self.state[p]
|
220 |
+
|
221 |
+
# init state - exponential moving average of gradient values
|
222 |
+
if len(state) == 0:
|
223 |
+
state['exp_avg'] = torch.zeros_like(p)
|
224 |
+
|
225 |
+
exp_avg = state['exp_avg']
|
226 |
+
|
227 |
+
self.update_fn(
|
228 |
+
p,
|
229 |
+
grad,
|
230 |
+
exp_avg,
|
231 |
+
lr,
|
232 |
+
wd,
|
233 |
+
beta1,
|
234 |
+
beta2
|
235 |
+
)
|
236 |
+
|
237 |
+
return loss
|
DiT_VAE/train_diffusion.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: Implement model training and evaluation.
|
2 |
+
# This script will load data, train a deep learning model, and evaluate its performance.
|
3 |
+
# Future improvements may include hyperparameter tuning and multi-GPU training.
|
4 |
+
|
5 |
+
|
DiT_VAE/train_vae.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
|
5 |
+
import sys
|
6 |
+
|
7 |
+
current_path = os.path.abspath(__file__)
|
8 |
+
father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".")
|
9 |
+
sys.path.append((os.path.join(father_path, 'Next3d')))
|
10 |
+
|
11 |
+
from typing import Dict, Optional, Tuple
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
import torch
|
14 |
+
import logging
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.utils.checkpoint
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
import inspect
|
19 |
+
from accelerate import Accelerator
|
20 |
+
from accelerate.logging import get_logger
|
21 |
+
from accelerate.utils import set_seed
|
22 |
+
import dnnlib
|
23 |
+
from diffusers.optimization import get_scheduler
|
24 |
+
from tqdm.auto import tqdm
|
25 |
+
from vae.triplane_vae import AutoencoderKL, AutoencoderKLRollOut
|
26 |
+
from vae.data.dataset_online_vae import TriplaneDataset
|
27 |
+
from einops import rearrange
|
28 |
+
from vae.utils.common_utils import instantiate_from_config
|
29 |
+
from Next3d.training_avatar_texture.triplane_generation import TriPlaneGenerator
|
30 |
+
import Next3d.legacy as legacy
|
31 |
+
|
32 |
+
from torch_utils import misc
|
33 |
+
import datetime
|
34 |
+
|
35 |
+
logger = get_logger(__name__, log_level="INFO")
|
36 |
+
|
37 |
+
|
38 |
+
def collate_fn(data):
|
39 |
+
model_names = [example["data_model_name"] for example in data]
|
40 |
+
zs = torch.cat([example["data_z"] for example in data], dim=0)
|
41 |
+
verts = torch.cat([example["data_vert"] for example in data], dim=0)
|
42 |
+
|
43 |
+
return {
|
44 |
+
'model_names': model_names,
|
45 |
+
'zs': zs,
|
46 |
+
'verts': verts
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
def rollout_fn(triplane):
|
51 |
+
triplane = rearrange(triplane, "b c f h w -> b f c h w")
|
52 |
+
b, f, c, h, w = triplane.shape
|
53 |
+
triplane = triplane.permute(0, 2, 3, 1, 4).reshape(-1, c, h, f * w)
|
54 |
+
return triplane
|
55 |
+
|
56 |
+
|
57 |
+
def unrollout_fn(triplane):
|
58 |
+
res = triplane.shape[-2]
|
59 |
+
ch = triplane.shape[1]
|
60 |
+
triplane = triplane.reshape(-1, ch // 3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, 3, ch, res, res)
|
61 |
+
triplane = rearrange(triplane, "b f c h w -> b c f h w")
|
62 |
+
return triplane
|
63 |
+
|
64 |
+
|
65 |
+
def triplane_generate(G_model, z, conditioning_params, std, mean, truncation_psi=0.7, truncation_cutoff=14):
|
66 |
+
w = G_model.mapping(z, conditioning_params, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
|
67 |
+
triplane = G_model.synthesis(w, noise_mode='const')
|
68 |
+
triplane = (triplane - mean) / std
|
69 |
+
return triplane
|
70 |
+
|
71 |
+
|
72 |
+
def gan_model(gan_models, device, gan_model_base_dir):
|
73 |
+
gan_model_dict = gan_models
|
74 |
+
gan_model_load = {}
|
75 |
+
for model_name in gan_model_dict.keys():
|
76 |
+
model_pkl = os.path.join(gan_model_base_dir, model_name + '.pkl')
|
77 |
+
with dnnlib.util.open_url(model_pkl) as f:
|
78 |
+
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
79 |
+
G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device)
|
80 |
+
misc.copy_params_and_buffers(G, G_new, require_all=True)
|
81 |
+
G_new.neural_rendering_resolution = G.neural_rendering_resolution
|
82 |
+
G_new.rendering_kwargs = G.rendering_kwargs
|
83 |
+
gan_model_load[model_name] = G_new
|
84 |
+
return gan_model_load
|
85 |
+
|
86 |
+
|
87 |
+
def main(vae_config: str,
|
88 |
+
gan_model_config: str,
|
89 |
+
output_dir: str,
|
90 |
+
std_dir: str,
|
91 |
+
mean_dir: str,
|
92 |
+
conditioning_params_dir: str,
|
93 |
+
gan_model_base_dir: str,
|
94 |
+
train_data: Dict,
|
95 |
+
train_batch_size: int = 2,
|
96 |
+
max_train_steps: int = 500,
|
97 |
+
learning_rate: float = 3e-5,
|
98 |
+
scale_lr: bool = False,
|
99 |
+
lr_scheduler: str = "constant",
|
100 |
+
lr_warmup_steps: int = 0,
|
101 |
+
adam_beta1: float = 0.5,
|
102 |
+
adam_beta2: float = 0.9,
|
103 |
+
adam_weight_decay: float = 1e-2,
|
104 |
+
adam_epsilon: float = 1e-08,
|
105 |
+
max_grad_norm: float = 1.0,
|
106 |
+
gradient_accumulation_steps: int = 1,
|
107 |
+
gradient_checkpointing: bool = True,
|
108 |
+
checkpointing_steps: int = 500,
|
109 |
+
pretrained_model_path_zero123: str = None,
|
110 |
+
resume_from_checkpoint: Optional[str] = None,
|
111 |
+
mixed_precision: Optional[str] = "fp16",
|
112 |
+
use_8bit_adam: bool = False,
|
113 |
+
rollout: bool = False,
|
114 |
+
enable_xformers_memory_efficient_attention: bool = True,
|
115 |
+
seed: Optional[int] = None, ):
|
116 |
+
*_, config = inspect.getargvalues(inspect.currentframe())
|
117 |
+
base_dir = output_dir
|
118 |
+
|
119 |
+
accelerator = Accelerator(
|
120 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
121 |
+
mixed_precision=mixed_precision,
|
122 |
+
)
|
123 |
+
logging.basicConfig(
|
124 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
125 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
126 |
+
level=logging.INFO,
|
127 |
+
)
|
128 |
+
logger.info(accelerator.state, main_process_only=False)
|
129 |
+
# If passed along, set the training seed now.
|
130 |
+
if seed is not None:
|
131 |
+
set_seed(seed)
|
132 |
+
if accelerator.is_main_process:
|
133 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
134 |
+
output_dir = os.path.join(output_dir, now)
|
135 |
+
os.makedirs(output_dir, exist_ok=True)
|
136 |
+
os.makedirs(f"{output_dir}/samples", exist_ok=True)
|
137 |
+
os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
|
138 |
+
OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
|
139 |
+
|
140 |
+
config_vae = OmegaConf.load(vae_config)
|
141 |
+
|
142 |
+
if rollout:
|
143 |
+
vae = AutoencoderKLRollOut(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8)
|
144 |
+
|
145 |
+
else:
|
146 |
+
vae = AutoencoderKL(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8)
|
147 |
+
print(f"VAE total params = {len(list(vae.named_parameters()))} ")
|
148 |
+
if 'perceptual_weight' in config_vae['lossconfig']['params'].keys():
|
149 |
+
config_vae['lossconfig']['params']['device'] = str(accelerator.device)
|
150 |
+
loss_fn = instantiate_from_config(config_vae['lossconfig'])
|
151 |
+
conditioning_params = torch.load(conditioning_params_dir).to(str(accelerator.device))
|
152 |
+
data_std = torch.load(std_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1)
|
153 |
+
|
154 |
+
data_mean = torch.load(mean_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1)
|
155 |
+
|
156 |
+
# define the gan model
|
157 |
+
print("########## gan model load ##########")
|
158 |
+
config_gan_model = OmegaConf.load(gan_model_config)
|
159 |
+
gan_model_all = gan_model(config_gan_model['gan_models'], str(accelerator.device), gan_model_base_dir)
|
160 |
+
print("########## gan model loaded ##########")
|
161 |
+
if scale_lr:
|
162 |
+
learning_rate = (
|
163 |
+
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
|
164 |
+
)
|
165 |
+
|
166 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
167 |
+
if use_8bit_adam:
|
168 |
+
try:
|
169 |
+
import bitsandbytes as bnb
|
170 |
+
except ImportError:
|
171 |
+
raise ImportError(
|
172 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
173 |
+
)
|
174 |
+
|
175 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
176 |
+
else:
|
177 |
+
optimizer_cls = torch.optim.AdamW
|
178 |
+
|
179 |
+
optimizer = optimizer_cls(
|
180 |
+
vae.parameters(),
|
181 |
+
lr=learning_rate,
|
182 |
+
betas=(adam_beta1, adam_beta2),
|
183 |
+
weight_decay=adam_weight_decay,
|
184 |
+
eps=adam_epsilon,
|
185 |
+
)
|
186 |
+
|
187 |
+
train_dataset = TriplaneDataset(**train_data)
|
188 |
+
|
189 |
+
# Preprocessing the dataset
|
190 |
+
|
191 |
+
# DataLoaders creation:
|
192 |
+
train_dataloader = torch.utils.data.DataLoader(
|
193 |
+
train_dataset, batch_size=train_batch_size, collate_fn=collate_fn, shuffle=True, num_workers=2
|
194 |
+
)
|
195 |
+
|
196 |
+
lr_scheduler = get_scheduler(
|
197 |
+
lr_scheduler,
|
198 |
+
optimizer=optimizer,
|
199 |
+
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
200 |
+
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
201 |
+
)
|
202 |
+
|
203 |
+
vae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
204 |
+
vae, optimizer, train_dataloader, lr_scheduler
|
205 |
+
)
|
206 |
+
|
207 |
+
weight_dtype = torch.float32
|
208 |
+
|
209 |
+
# Move text_encode and vae to gpu and cast to weight_dtype
|
210 |
+
|
211 |
+
if accelerator.mixed_precision == "fp16":
|
212 |
+
weight_dtype = torch.float16
|
213 |
+
elif accelerator.mixed_precision == "bf16":
|
214 |
+
weight_dtype = torch.bfloat16
|
215 |
+
|
216 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
217 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
218 |
+
# Afterwards we recalculate our number of training epochs
|
219 |
+
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
220 |
+
|
221 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
222 |
+
# The trackers initializes automatically on the main process.
|
223 |
+
if accelerator.is_main_process:
|
224 |
+
accelerator.init_trackers("trainvae", config=vars(args))
|
225 |
+
|
226 |
+
# Train!
|
227 |
+
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
228 |
+
|
229 |
+
logger.info("***** Running training *****")
|
230 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
231 |
+
logger.info(f" Num Epochs = {num_train_epochs}")
|
232 |
+
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
233 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
234 |
+
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
235 |
+
logger.info(f" Total optimization steps = {max_train_steps}")
|
236 |
+
global_step = 0
|
237 |
+
first_epoch = 0
|
238 |
+
|
239 |
+
# Potentially load in the weights and states from a previous save
|
240 |
+
if resume_from_checkpoint:
|
241 |
+
if resume_from_checkpoint != "latest":
|
242 |
+
path = os.path.basename(resume_from_checkpoint)
|
243 |
+
else:
|
244 |
+
# Get the most recent checkpoint
|
245 |
+
dirs = os.listdir(output_dir)
|
246 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
247 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
248 |
+
path = dirs[-1]
|
249 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
250 |
+
if resume_from_checkpoint != "latest":
|
251 |
+
accelerator.load_state(resume_from_checkpoint)
|
252 |
+
else:
|
253 |
+
accelerator.load_state(os.path.join(output_dir, path))
|
254 |
+
|
255 |
+
global_step = int(path.split("-")[1])
|
256 |
+
|
257 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
258 |
+
resume_step = global_step % num_update_steps_per_epoch
|
259 |
+
else:
|
260 |
+
all_final_training_dirs = []
|
261 |
+
dirs = os.listdir(base_dir)
|
262 |
+
if len(dirs) != 0:
|
263 |
+
dirs = [d for d in dirs if d.startswith("2024")] # specific years
|
264 |
+
if len(dirs) != 0:
|
265 |
+
base_resume_paths = [os.path.join(base_dir, d) for d in dirs]
|
266 |
+
for base_resume_path in base_resume_paths:
|
267 |
+
checkpoint_file_names = os.listdir(base_resume_path)
|
268 |
+
checkpoint_file_names = [d for d in checkpoint_file_names if d.startswith("checkpoint")]
|
269 |
+
if len(checkpoint_file_names) != 0:
|
270 |
+
for checkpoint_file_name in checkpoint_file_names:
|
271 |
+
final_training_dir = os.path.join(base_resume_path, checkpoint_file_name)
|
272 |
+
all_final_training_dirs.append(final_training_dir)
|
273 |
+
if len(all_final_training_dirs) != 0:
|
274 |
+
sorted_all_final_training_dirs = sorted(all_final_training_dirs, key=lambda x: int(x.split("-")[1]))
|
275 |
+
latest_dir = sorted_all_final_training_dirs[-1]
|
276 |
+
path = os.path.basename( latest_dir)
|
277 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
278 |
+
accelerator.load_state(latest_dir)
|
279 |
+
global_step = int(path.split("-")[1])
|
280 |
+
|
281 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
282 |
+
resume_step = global_step % num_update_steps_per_epoch
|
283 |
+
else:
|
284 |
+
accelerator.print(f"Training from start")
|
285 |
+
else:
|
286 |
+
accelerator.print(f"Training from start")
|
287 |
+
else:
|
288 |
+
accelerator.print(f"Training from start")
|
289 |
+
|
290 |
+
# Only show the progress bar once on each machine.
|
291 |
+
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
292 |
+
progress_bar.set_description("Steps")
|
293 |
+
|
294 |
+
for epoch in range(first_epoch, num_train_epochs):
|
295 |
+
vae.train()
|
296 |
+
train_loss = 0.0
|
297 |
+
for step, batch in enumerate(train_dataloader):
|
298 |
+
# if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
299 |
+
# print(epoch)
|
300 |
+
# print(first_epoch)
|
301 |
+
# print(step)
|
302 |
+
# if step % gradient_accumulation_steps == 0:
|
303 |
+
# progress_bar.update(1)
|
304 |
+
# continue
|
305 |
+
with accelerator.accumulate(vae):
|
306 |
+
# Convert images to latent space
|
307 |
+
z_values = batch["zs"].to(weight_dtype)
|
308 |
+
model_names = batch["model_names"]
|
309 |
+
|
310 |
+
triplane_values = []
|
311 |
+
with torch.no_grad():
|
312 |
+
for z_id in range(z_values.shape[0]):
|
313 |
+
z_value = z_values[z_id].unsqueeze(0)
|
314 |
+
model_name = model_names[z_id]
|
315 |
+
triplane_value = triplane_generate(gan_model_all[model_name], z_value,
|
316 |
+
conditioning_params, data_std, data_mean)
|
317 |
+
triplane_values.append(triplane_value)
|
318 |
+
triplane_values = torch.cat(triplane_values, dim=0)
|
319 |
+
vert_values = batch["verts"].to(weight_dtype)
|
320 |
+
triplane_values = rearrange(triplane_values, "b f c h w -> b c f h w")
|
321 |
+
if rollout:
|
322 |
+
triplane_values_roll = rollout_fn(triplane_values.clone())
|
323 |
+
reconstructions, posterior = vae(triplane_values_roll)
|
324 |
+
reconstructions_unroll = unrollout_fn(reconstructions)
|
325 |
+
loss, log_dict_ae = loss_fn(triplane_values, reconstructions_unroll, posterior, vert_values,
|
326 |
+
split="train")
|
327 |
+
else:
|
328 |
+
reconstructions, posterior = vae(triplane_values)
|
329 |
+
loss, log_dict_ae = loss_fn(triplane_values, reconstructions, posterior, vert_values,
|
330 |
+
split="train")
|
331 |
+
accelerator.backward(loss)
|
332 |
+
if accelerator.sync_gradients:
|
333 |
+
accelerator.clip_grad_norm_(vae.parameters(), max_grad_norm)
|
334 |
+
optimizer.step()
|
335 |
+
lr_scheduler.step()
|
336 |
+
optimizer.zero_grad()
|
337 |
+
|
338 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
339 |
+
if accelerator.sync_gradients:
|
340 |
+
progress_bar.update(1)
|
341 |
+
global_step += 1
|
342 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
343 |
+
train_loss = 0.0
|
344 |
+
|
345 |
+
if global_step % checkpointing_steps == 0:
|
346 |
+
if accelerator.is_main_process:
|
347 |
+
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
|
348 |
+
accelerator.save_state(save_path)
|
349 |
+
logger.info(f"Saved state to {save_path}")
|
350 |
+
|
351 |
+
# logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
352 |
+
|
353 |
+
logs = log_dict_ae
|
354 |
+
progress_bar.set_postfix(**logs)
|
355 |
+
accelerator.log(logs, step=global_step)
|
356 |
+
|
357 |
+
if global_step >= max_train_steps:
|
358 |
+
break
|
359 |
+
|
360 |
+
accelerator.wait_for_everyone()
|
361 |
+
|
362 |
+
accelerator.end_training()
|
363 |
+
|
364 |
+
|
365 |
+
if __name__ == "__main__":
|
366 |
+
parser = argparse.ArgumentParser()
|
367 |
+
parser.add_argument("--config", type=str, default="./configs/triplane_vae.yaml")
|
368 |
+
args = parser.parse_args()
|
369 |
+
main(**OmegaConf.load(args.config))
|
DiT_VAE/util.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
|
13 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
|
14 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
15 |
+
outputs = []
|
16 |
+
for x in videos:
|
17 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
18 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
19 |
+
if rescale:
|
20 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
21 |
+
x = (x * 255).numpy().astype(np.uint8)
|
22 |
+
outputs.append(x)
|
23 |
+
|
24 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
25 |
+
imageio.mimsave(path, outputs, fps=fps)
|
26 |
+
|
27 |
+
|
28 |
+
# DDIM Inversion
|
29 |
+
@torch.no_grad()
|
30 |
+
def init_prompt(prompt, pipeline):
|
31 |
+
uncond_input = pipeline.tokenizer(
|
32 |
+
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
33 |
+
return_tensors="pt"
|
34 |
+
)
|
35 |
+
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
36 |
+
text_input = pipeline.tokenizer(
|
37 |
+
[prompt],
|
38 |
+
padding="max_length",
|
39 |
+
max_length=pipeline.tokenizer.model_max_length,
|
40 |
+
truncation=True,
|
41 |
+
return_tensors="pt",
|
42 |
+
)
|
43 |
+
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
44 |
+
context = torch.cat([uncond_embeddings, text_embeddings])
|
45 |
+
|
46 |
+
return context
|
47 |
+
|
48 |
+
|
49 |
+
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
50 |
+
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
51 |
+
timestep, next_timestep = min(
|
52 |
+
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
53 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
54 |
+
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
55 |
+
beta_prod_t = 1 - alpha_prod_t
|
56 |
+
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
57 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
58 |
+
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
59 |
+
return next_sample
|
60 |
+
|
61 |
+
|
62 |
+
def get_noise_pred_single(latents, t, context, unet):
|
63 |
+
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
64 |
+
return noise_pred
|
65 |
+
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
69 |
+
context = init_prompt(prompt, pipeline)
|
70 |
+
uncond_embeddings, cond_embeddings = context.chunk(2)
|
71 |
+
all_latent = [latent]
|
72 |
+
latent = latent.clone().detach()
|
73 |
+
for i in tqdm(range(num_inv_steps)):
|
74 |
+
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
75 |
+
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
76 |
+
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
77 |
+
all_latent.append(latent)
|
78 |
+
return all_latent
|
79 |
+
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
83 |
+
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
84 |
+
return ddim_latents
|
85 |
+
|
86 |
+
def rendering():
|
87 |
+
pass
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
def force_zero_snr(betas):
|
92 |
+
alphas = 1 - betas
|
93 |
+
alphas_bar = torch.cumprod(alphas, dim=0)
|
94 |
+
alphas_bar_sqrt = alphas_bar ** (1/2)
|
95 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
96 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - 1e-6
|
97 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
98 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
99 |
+
alphas_bar = alphas_bar_sqrt ** 2
|
100 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
101 |
+
alphas = torch.cat([alphas_bar[0:1], alphas], 0)
|
102 |
+
betas = 1 - alphas
|
103 |
+
return betas
|
104 |
+
|
105 |
+
def make_beta_schedule(schedule="scaled_linear", n_timestep=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3, shift_scale=None):
|
106 |
+
if schedule == "scaled_linear":
|
107 |
+
betas = (
|
108 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
109 |
+
)
|
110 |
+
elif schedule == 'linear':
|
111 |
+
betas = (
|
112 |
+
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
113 |
+
)
|
114 |
+
elif schedule == "cosine":
|
115 |
+
timesteps = (
|
116 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
117 |
+
)
|
118 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
119 |
+
alphas = torch.cos(alphas).pow(2)
|
120 |
+
alphas = alphas / alphas[0]
|
121 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
122 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
123 |
+
|
124 |
+
elif schedule == "sqrt_linear":
|
125 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
126 |
+
elif schedule == "sqrt":
|
127 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
128 |
+
elif schedule == 'linear_force_zero_snr':
|
129 |
+
betas = (
|
130 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
131 |
+
)
|
132 |
+
betas = force_zero_snr(betas)
|
133 |
+
elif schedule == 'linear_100':
|
134 |
+
betas = (
|
135 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
136 |
+
)
|
137 |
+
betas = betas[:100]
|
138 |
+
else:
|
139 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
140 |
+
|
141 |
+
if shift_scale is not None:
|
142 |
+
print("shift_scale")
|
143 |
+
betas = shift_schedule(betas, shift_scale)
|
144 |
+
|
145 |
+
return betas.numpy()
|
146 |
+
|
147 |
+
def shift_schedule(base_betas, shift_scale):
|
148 |
+
alphas = 1 - base_betas
|
149 |
+
alphas_bar = torch.cumprod(alphas, dim=0)
|
150 |
+
snr = alphas_bar / (1 - alphas_bar) # snr(1-ab)=ab; snr-snr*ab=ab; snr=(1+snr)ab; ab=snr/(1+snr)
|
151 |
+
shifted_snr = snr * ((1 / shift_scale) ** 2)
|
152 |
+
shifted_alphas_bar = shifted_snr / (1 + shifted_snr)
|
153 |
+
shifted_alphas = shifted_alphas_bar[1:] / shifted_alphas_bar[:-1]
|
154 |
+
shifted_alphas = torch.cat([shifted_alphas_bar[0:1], shifted_alphas], 0)
|
155 |
+
shifted_betas = 1 - shifted_alphas
|
156 |
+
return shifted_betas
|
157 |
+
|
158 |
+
|
159 |
+
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
|
160 |
+
n_dims = len(x.shape)
|
161 |
+
if src_dim < 0:
|
162 |
+
src_dim = n_dims + src_dim
|
163 |
+
if dest_dim < 0:
|
164 |
+
dest_dim = n_dims + dest_dim
|
165 |
+
|
166 |
+
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
|
167 |
+
|
168 |
+
dims = list(range(n_dims))
|
169 |
+
del dims[src_dim]
|
170 |
+
|
171 |
+
permutation = []
|
172 |
+
ctr = 0
|
173 |
+
for i in range(n_dims):
|
174 |
+
if i == dest_dim:
|
175 |
+
permutation.append(src_dim)
|
176 |
+
else:
|
177 |
+
permutation.append(dims[ctr])
|
178 |
+
ctr += 1
|
179 |
+
x = x.permute(permutation)
|
180 |
+
if make_contiguous:
|
181 |
+
x = x.contiguous()
|
182 |
+
return x
|
183 |
+
|
184 |
+
|
185 |
+
# reshapes tensor start from dim i (inclusive)
|
186 |
+
# to dim j (exclusive) to the desired shape
|
187 |
+
# e.g. if x.shape = (b, thw, c) then
|
188 |
+
# view_range(x, 1, 2, (t, h, w)) returns
|
189 |
+
# x of shape (b, t, h, w, c)
|
190 |
+
def view_range(x, i, j, shape):
|
191 |
+
shape = tuple(shape)
|
192 |
+
|
193 |
+
n_dims = len(x.shape)
|
194 |
+
if i < 0:
|
195 |
+
i = n_dims + i
|
196 |
+
|
197 |
+
if j is None:
|
198 |
+
j = n_dims
|
199 |
+
elif j < 0:
|
200 |
+
j = n_dims + j
|
201 |
+
|
202 |
+
assert 0 <= i < j <= n_dims
|
203 |
+
|
204 |
+
x_shape = x.shape
|
205 |
+
target_shape = x_shape[:i] + shape + x_shape[j:]
|
206 |
+
return x.view(target_shape)
|
207 |
+
|
208 |
+
|
209 |
+
def tensor_slice(x, begin, size):
|
210 |
+
assert all([b >= 0 for b in begin])
|
211 |
+
size = [l - b if s == -1 else s
|
212 |
+
for s, b, l in zip(size, begin, x.shape)]
|
213 |
+
assert all([s >= 0 for s in size])
|
214 |
+
|
215 |
+
slices = [slice(b, b + s) for b, s in zip(begin, size)]
|
216 |
+
return x[slices]
|
217 |
+
|
DiT_VAE/vae/__init__.py
ADDED
File without changes
|
DiT_VAE/vae/aemodules3d.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TATS
|
2 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from .attention_vae import MultiHeadAttention
|
11 |
+
|
12 |
+
|
13 |
+
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
|
14 |
+
n_dims = len(x.shape)
|
15 |
+
if src_dim < 0:
|
16 |
+
src_dim = n_dims + src_dim
|
17 |
+
if dest_dim < 0:
|
18 |
+
dest_dim = n_dims + dest_dim
|
19 |
+
|
20 |
+
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
|
21 |
+
|
22 |
+
dims = list(range(n_dims))
|
23 |
+
del dims[src_dim]
|
24 |
+
|
25 |
+
permutation = []
|
26 |
+
ctr = 0
|
27 |
+
for i in range(n_dims):
|
28 |
+
if i == dest_dim:
|
29 |
+
permutation.append(src_dim)
|
30 |
+
else:
|
31 |
+
permutation.append(dims[ctr])
|
32 |
+
ctr += 1
|
33 |
+
x = x.permute(permutation)
|
34 |
+
if make_contiguous:
|
35 |
+
x = x.contiguous()
|
36 |
+
return x
|
37 |
+
|
38 |
+
def silu(x):
|
39 |
+
return x * torch.sigmoid(x)
|
40 |
+
|
41 |
+
|
42 |
+
class SiLU(nn.Module):
|
43 |
+
def __init__(self):
|
44 |
+
super(SiLU, self).__init__()
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
return silu(x)
|
48 |
+
|
49 |
+
|
50 |
+
def hinge_d_loss(logits_real, logits_fake):
|
51 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
52 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
53 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
54 |
+
return d_loss
|
55 |
+
|
56 |
+
|
57 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
58 |
+
d_loss = 0.5 * (
|
59 |
+
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
60 |
+
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
61 |
+
return d_loss
|
62 |
+
|
63 |
+
|
64 |
+
def Normalize(in_channels, norm_type='group'):
|
65 |
+
assert norm_type in ['group', 'batch']
|
66 |
+
if norm_type == 'group':
|
67 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
68 |
+
elif norm_type == 'batch':
|
69 |
+
return torch.nn.SyncBatchNorm(in_channels)
|
70 |
+
|
71 |
+
|
72 |
+
class ResBlock(nn.Module):
|
73 |
+
def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group',
|
74 |
+
padding_type='replicate'):
|
75 |
+
super().__init__()
|
76 |
+
self.in_channels = in_channels
|
77 |
+
out_channels = in_channels if out_channels is None else out_channels
|
78 |
+
self.out_channels = out_channels
|
79 |
+
self.use_conv_shortcut = conv_shortcut
|
80 |
+
|
81 |
+
self.norm1 = Normalize(in_channels, norm_type)
|
82 |
+
self.conv1 = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type)
|
83 |
+
self.dropout = torch.nn.Dropout(dropout)
|
84 |
+
self.norm2 = Normalize(in_channels, norm_type)
|
85 |
+
self.conv2 = SamePadConv3d(out_channels, out_channels, kernel_size=3, padding_type=padding_type)
|
86 |
+
if self.in_channels != self.out_channels:
|
87 |
+
self.conv_shortcut = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
h = x
|
91 |
+
h = self.norm1(h)
|
92 |
+
h = silu(h)
|
93 |
+
h = self.conv1(h)
|
94 |
+
h = self.norm2(h)
|
95 |
+
h = silu(h)
|
96 |
+
h = self.conv2(h)
|
97 |
+
|
98 |
+
if self.in_channels != self.out_channels:
|
99 |
+
x = self.conv_shortcut(x)
|
100 |
+
|
101 |
+
return x + h
|
102 |
+
|
103 |
+
|
104 |
+
# Does not support dilation
|
105 |
+
class SamePadConv3d(nn.Module):
|
106 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
|
107 |
+
super().__init__()
|
108 |
+
if isinstance(kernel_size, int):
|
109 |
+
kernel_size = (kernel_size,) * 3
|
110 |
+
if isinstance(stride, int):
|
111 |
+
stride = (stride,) * 3
|
112 |
+
|
113 |
+
# assumes that the input shape is divisible by stride
|
114 |
+
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
|
115 |
+
pad_input = []
|
116 |
+
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
|
117 |
+
pad_input.append((p // 2 + p % 2, p // 2))
|
118 |
+
pad_input = sum(pad_input, tuple())
|
119 |
+
|
120 |
+
self.pad_input = pad_input
|
121 |
+
self.padding_type = padding_type
|
122 |
+
|
123 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
|
124 |
+
stride=stride, padding=0, bias=bias)
|
125 |
+
self.weight = self.conv.weight
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
|
129 |
+
|
130 |
+
|
131 |
+
class SamePadConvTranspose3d(nn.Module):
|
132 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
|
133 |
+
super().__init__()
|
134 |
+
if isinstance(kernel_size, int):
|
135 |
+
kernel_size = (kernel_size,) * 3
|
136 |
+
if isinstance(stride, int):
|
137 |
+
stride = (stride,) * 3
|
138 |
+
|
139 |
+
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
|
140 |
+
pad_input = []
|
141 |
+
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
|
142 |
+
pad_input.append((p // 2 + p % 2, p // 2))
|
143 |
+
pad_input = sum(pad_input, tuple())
|
144 |
+
self.pad_input = pad_input
|
145 |
+
self.padding_type = padding_type
|
146 |
+
self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
|
147 |
+
stride=stride, bias=bias,
|
148 |
+
padding=tuple([k - 1 for k in kernel_size]))
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
|
152 |
+
|
153 |
+
|
154 |
+
class AxialBlock(nn.Module):
|
155 |
+
def __init__(self, n_hiddens, n_head):
|
156 |
+
super().__init__()
|
157 |
+
kwargs = dict(shape=(0,) * 3, dim_q=n_hiddens,
|
158 |
+
dim_kv=n_hiddens, n_head=n_head,
|
159 |
+
n_layer=1, causal=False, attn_type='axial')
|
160 |
+
self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2),
|
161 |
+
**kwargs)
|
162 |
+
self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3),
|
163 |
+
**kwargs)
|
164 |
+
self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4),
|
165 |
+
**kwargs)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
x = shift_dim(x, 1, -1)
|
169 |
+
x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x)
|
170 |
+
x = shift_dim(x, -1, 1)
|
171 |
+
return x
|
172 |
+
class AttentionResidualBlock(nn.Module):
|
173 |
+
def __init__(self, n_hiddens):
|
174 |
+
super().__init__()
|
175 |
+
self.block = nn.Sequential(
|
176 |
+
Normalize(n_hiddens),
|
177 |
+
SiLU(),
|
178 |
+
SamePadConv3d(n_hiddens, n_hiddens // 2, 3, bias=False),
|
179 |
+
Normalize(n_hiddens // 2),
|
180 |
+
SiLU(),
|
181 |
+
SamePadConv3d(n_hiddens // 2, n_hiddens, 1, bias=False),
|
182 |
+
Normalize(n_hiddens),
|
183 |
+
SiLU(),
|
184 |
+
AxialBlock(n_hiddens, 2)
|
185 |
+
)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
return x + self.block(x)
|
189 |
+
|
190 |
+
class Encoder(nn.Module):
|
191 |
+
def __init__(self, n_hiddens, downsample, z_channels, double_z, image_channel=3, norm_type='group',
|
192 |
+
padding_type='replicate', res_num=1):
|
193 |
+
super().__init__()
|
194 |
+
n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
|
195 |
+
self.conv_blocks = nn.ModuleList()
|
196 |
+
max_ds = n_times_downsample.max()
|
197 |
+
self.conv_first = SamePadConv3d(image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
|
198 |
+
|
199 |
+
for i in range(max_ds):
|
200 |
+
block = nn.Module()
|
201 |
+
in_channels = n_hiddens * 2 ** i
|
202 |
+
out_channels = n_hiddens * 2 ** (i + 1)
|
203 |
+
stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
|
204 |
+
stride = list(stride)
|
205 |
+
stride[0] = 1
|
206 |
+
stride = tuple(stride)
|
207 |
+
block.down = SamePadConv3d(in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
|
208 |
+
|
209 |
+
block.res = ResBlock(out_channels, out_channels, norm_type=norm_type)
|
210 |
+
self.conv_blocks.append(block)
|
211 |
+
n_times_downsample -= 1
|
212 |
+
|
213 |
+
self.final_block = nn.Sequential(
|
214 |
+
Normalize(out_channels, norm_type),
|
215 |
+
SiLU(),
|
216 |
+
SamePadConv3d(out_channels, 2 * z_channels if double_z else z_channels,
|
217 |
+
kernel_size=3,
|
218 |
+
stride=1,
|
219 |
+
padding_type=padding_type)
|
220 |
+
)
|
221 |
+
self.out_channels = out_channels
|
222 |
+
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
h = self.conv_first(x)
|
226 |
+
for block in self.conv_blocks:
|
227 |
+
h = block.down(h)
|
228 |
+
h = block.res(h)
|
229 |
+
h = self.final_block(h)
|
230 |
+
return h
|
231 |
+
|
232 |
+
|
233 |
+
class Decoder(nn.Module):
|
234 |
+
def __init__(self, n_hiddens, upsample, z_channels, image_channel, norm_type='group'):
|
235 |
+
super().__init__()
|
236 |
+
|
237 |
+
n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
|
238 |
+
max_us = n_times_upsample.max()
|
239 |
+
in_channels = z_channels
|
240 |
+
self.conv_blocks = nn.ModuleList()
|
241 |
+
for i in range(max_us):
|
242 |
+
block = nn.Module()
|
243 |
+
in_channels = in_channels if i == 0 else n_hiddens * 2 ** (max_us - i + 1)
|
244 |
+
out_channels = n_hiddens * 2 ** (max_us - i)
|
245 |
+
us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
|
246 |
+
us = list(us)
|
247 |
+
us[0] = 1
|
248 |
+
us = tuple(us)
|
249 |
+
block.up = SamePadConvTranspose3d(in_channels, out_channels, 4, stride=us)
|
250 |
+
block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
|
251 |
+
block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
|
252 |
+
self.conv_blocks.append(block)
|
253 |
+
n_times_upsample -= 1
|
254 |
+
|
255 |
+
self.conv_out = SamePadConv3d(out_channels, image_channel, kernel_size=3)
|
256 |
+
|
257 |
+
def forward(self, x):
|
258 |
+
h = x
|
259 |
+
for i, block in enumerate(self.conv_blocks):
|
260 |
+
h = block.up(h)
|
261 |
+
h = block.res1(h)
|
262 |
+
h = block.res2(h)
|
263 |
+
h = self.conv_out(h)
|
264 |
+
return h
|
265 |
+
|
266 |
+
|
267 |
+
class EncoderRe(nn.Module):
|
268 |
+
def __init__(self, n_hiddens, downsample, z_channels, double_z, image_channel=3, norm_type='group',
|
269 |
+
padding_type='replicate', n_res_layers=2):
|
270 |
+
super().__init__()
|
271 |
+
# n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
|
272 |
+
self.conv_blocks = nn.ModuleList()
|
273 |
+
# max_ds = n_times_downsample.max()
|
274 |
+
self.conv_first = SamePadConv3d(image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
|
275 |
+
|
276 |
+
for i, step in enumerate(downsample):
|
277 |
+
block = nn.Module()
|
278 |
+
in_channels = n_hiddens
|
279 |
+
out_channels = n_hiddens
|
280 |
+
stride = [1, downsample[i], downsample[i]]
|
281 |
+
stride = tuple(stride)
|
282 |
+
block.down = SamePadConv3d(in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
|
283 |
+
block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
|
284 |
+
block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
|
285 |
+
self.conv_blocks.append(block)
|
286 |
+
|
287 |
+
|
288 |
+
self.res_stack = nn.Sequential(
|
289 |
+
*[AttentionResidualBlock(out_channels)
|
290 |
+
for _ in range(n_res_layers)]
|
291 |
+
)
|
292 |
+
self.final_block = nn.Sequential(
|
293 |
+
Normalize(out_channels, norm_type),
|
294 |
+
SiLU(),
|
295 |
+
SamePadConv3d(out_channels, 2 * z_channels if double_z else z_channels,
|
296 |
+
kernel_size=3,
|
297 |
+
stride=1,
|
298 |
+
padding_type=padding_type)
|
299 |
+
)
|
300 |
+
self.out_channels = out_channels
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
h = self.conv_first(x)
|
304 |
+
for block in self.conv_blocks:
|
305 |
+
h = block.down(h)
|
306 |
+
h = block.res1(h)
|
307 |
+
h = block.res2(h)
|
308 |
+
h = self.res_stack(h)
|
309 |
+
h = self.final_block(h)
|
310 |
+
return h
|
311 |
+
|
312 |
+
|
313 |
+
class DecoderRe(nn.Module):
|
314 |
+
def __init__(self, n_hiddens, upsample, z_channels, image_channel, norm_type='group', padding_type='replicate', n_res_layers=2):
|
315 |
+
super().__init__()
|
316 |
+
self.conv_first = SamePadConv3d(z_channels, n_hiddens, kernel_size=3, padding_type=padding_type)
|
317 |
+
self.res_stack = nn.Sequential(
|
318 |
+
*[AttentionResidualBlock(n_hiddens)
|
319 |
+
for _ in range(n_res_layers)]
|
320 |
+
)
|
321 |
+
# n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
|
322 |
+
# max_us = n_times_upsample.max()
|
323 |
+
# in_channels = n_hiddens
|
324 |
+
self.conv_blocks = nn.ModuleList()
|
325 |
+
for i, step in enumerate(upsample):
|
326 |
+
block = nn.Module()
|
327 |
+
in_channels = n_hiddens
|
328 |
+
out_channels = n_hiddens
|
329 |
+
stride = [1, upsample[i], upsample[i]]
|
330 |
+
stride = tuple(stride)
|
331 |
+
|
332 |
+
block.up = SamePadConvTranspose3d(in_channels, out_channels, 4, stride=stride)
|
333 |
+
block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
|
334 |
+
block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
|
335 |
+
self.conv_blocks.append(block)
|
336 |
+
|
337 |
+
self.conv_out = SamePadConv3d(out_channels, image_channel, kernel_size=3)
|
338 |
+
|
339 |
+
def forward(self, x):
|
340 |
+
h = x
|
341 |
+
h = self.conv_first(h)
|
342 |
+
h = self.res_stack(h)
|
343 |
+
for i, block in enumerate(self.conv_blocks):
|
344 |
+
h = block.up(h)
|
345 |
+
h = block.res1(h)
|
346 |
+
h = block.res2(h)
|
347 |
+
h = self.conv_out(h)
|
348 |
+
return h
|
349 |
+
|
350 |
+
|
351 |
+
# unit test
|
352 |
+
if __name__ == '__main__':
|
353 |
+
encoder = EncoderRe(n_hiddens=320, downsample=[1, 2, 2, 2], z_channels=8, double_z=True, image_channel=96,
|
354 |
+
norm_type='group', padding_type='replicate')
|
355 |
+
encoder = encoder.cuda()
|
356 |
+
en_input = torch.rand(1, 96, 3, 256, 256).cuda()
|
357 |
+
out = encoder(en_input)
|
358 |
+
print(out.shape)
|
359 |
+
mean, logvar = torch.chunk(out, 2, dim=1)
|
360 |
+
# print(mean.shape)
|
361 |
+
decoder = DecoderRe(n_hiddens=320, upsample=[2, 2, 2, 1], z_channels=8, image_channel=96,
|
362 |
+
norm_type='group' )
|
363 |
+
|
364 |
+
decoder = decoder.cuda()
|
365 |
+
out = decoder(mean)
|
366 |
+
print(out.shape)
|
367 |
+
# logvar = nn.Parameter(torch.ones(size=()) * 0.0)
|
368 |
+
# print(logvar)
|
DiT_VAE/vae/attention_vae.py
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.utils.checkpoint import checkpoint
|
7 |
+
def tensor_slice(x, begin, size):
|
8 |
+
assert all([b >= 0 for b in begin])
|
9 |
+
size = [l - b if s == -1 else s
|
10 |
+
for s, b, l in zip(size, begin, x.shape)]
|
11 |
+
assert all([s >= 0 for s in size])
|
12 |
+
|
13 |
+
slices = [slice(b, b + s) for b, s in zip(begin, size)]
|
14 |
+
return x[slices]
|
15 |
+
|
16 |
+
|
17 |
+
# reshapes tensor start from dim i (inclusive)
|
18 |
+
# to dim j (exclusive) to the desired shape
|
19 |
+
# e.g. if x.shape = (b, thw, c) then
|
20 |
+
# view_range(x, 1, 2, (t, h, w)) returns
|
21 |
+
# x of shape (b, t, h, w, c)
|
22 |
+
def view_range(x, i, j, shape):
|
23 |
+
shape = tuple(shape)
|
24 |
+
|
25 |
+
n_dims = len(x.shape)
|
26 |
+
if i < 0:
|
27 |
+
i = n_dims + i
|
28 |
+
|
29 |
+
if j is None:
|
30 |
+
j = n_dims
|
31 |
+
elif j < 0:
|
32 |
+
j = n_dims + j
|
33 |
+
|
34 |
+
assert 0 <= i < j <= n_dims
|
35 |
+
|
36 |
+
x_shape = x.shape
|
37 |
+
target_shape = x_shape[:i] + shape + x_shape[j:]
|
38 |
+
return x.view(target_shape)
|
39 |
+
|
40 |
+
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
|
41 |
+
n_dims = len(x.shape)
|
42 |
+
if src_dim < 0:
|
43 |
+
src_dim = n_dims + src_dim
|
44 |
+
if dest_dim < 0:
|
45 |
+
dest_dim = n_dims + dest_dim
|
46 |
+
|
47 |
+
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
|
48 |
+
|
49 |
+
dims = list(range(n_dims))
|
50 |
+
del dims[src_dim]
|
51 |
+
|
52 |
+
permutation = []
|
53 |
+
ctr = 0
|
54 |
+
for i in range(n_dims):
|
55 |
+
if i == dest_dim:
|
56 |
+
permutation.append(src_dim)
|
57 |
+
else:
|
58 |
+
permutation.append(dims[ctr])
|
59 |
+
ctr += 1
|
60 |
+
x = x.permute(permutation)
|
61 |
+
if make_contiguous:
|
62 |
+
x = x.contiguous()
|
63 |
+
return x
|
64 |
+
class AttentionStack(nn.Module):
|
65 |
+
def __init__(
|
66 |
+
self, shape, embd_dim, n_head, n_layer, dropout,
|
67 |
+
attn_type, attn_dropout, class_cond_dim, frame_cond_shape,
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
self.shape = shape
|
71 |
+
self.embd_dim = embd_dim
|
72 |
+
self.use_frame_cond = frame_cond_shape is not None
|
73 |
+
|
74 |
+
self.right_shift = RightShift(embd_dim)
|
75 |
+
self.pos_embd = AddBroadcastPosEmbed(
|
76 |
+
shape=shape, embd_dim=embd_dim
|
77 |
+
)
|
78 |
+
|
79 |
+
self.attn_nets = nn.ModuleList(
|
80 |
+
[
|
81 |
+
AttentionBlock(
|
82 |
+
shape=shape,
|
83 |
+
embd_dim=embd_dim,
|
84 |
+
n_head=n_head,
|
85 |
+
n_layer=n_layer,
|
86 |
+
dropout=dropout,
|
87 |
+
attn_type=attn_type,
|
88 |
+
attn_dropout=attn_dropout,
|
89 |
+
class_cond_dim=class_cond_dim,
|
90 |
+
frame_cond_shape=frame_cond_shape
|
91 |
+
)
|
92 |
+
for i in range(n_layer)
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
def forward(self, x, cond, decode_step, decode_idx):
|
97 |
+
"""
|
98 |
+
Args
|
99 |
+
------
|
100 |
+
x: (b, d1, d2, ..., dn, embd_dim)
|
101 |
+
cond: a dictionary of conditioning tensors
|
102 |
+
|
103 |
+
(below is used only when sampling for fast decoding)
|
104 |
+
decode: the enumerated rasterscan order of the current idx being sampled
|
105 |
+
decode_step: a tuple representing the current idx being sampled
|
106 |
+
"""
|
107 |
+
x = self.right_shift(x, decode_step)
|
108 |
+
x = self.pos_embd(x, decode_step, decode_idx)
|
109 |
+
for net in self.attn_nets:
|
110 |
+
x = net(x, cond, decode_step, decode_idx)
|
111 |
+
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
class AttentionBlock(nn.Module):
|
116 |
+
def __init__(self, shape, embd_dim, n_head, n_layer, dropout,
|
117 |
+
attn_type, attn_dropout, class_cond_dim, frame_cond_shape):
|
118 |
+
super().__init__()
|
119 |
+
self.use_frame_cond = frame_cond_shape is not None
|
120 |
+
|
121 |
+
self.pre_attn_norm = LayerNorm(embd_dim, class_cond_dim)
|
122 |
+
self.post_attn_dp = nn.Dropout(dropout)
|
123 |
+
self.attn = MultiHeadAttention(shape, embd_dim, embd_dim, n_head,
|
124 |
+
n_layer, causal=True, attn_type=attn_type,
|
125 |
+
attn_kwargs=dict(attn_dropout=attn_dropout))
|
126 |
+
|
127 |
+
if frame_cond_shape is not None:
|
128 |
+
enc_len = np.prod(frame_cond_shape[:-1])
|
129 |
+
self.pre_enc_norm = LayerNorm(embd_dim, class_cond_dim)
|
130 |
+
self.post_enc_dp = nn.Dropout(dropout)
|
131 |
+
self.enc_attn = MultiHeadAttention(shape, embd_dim, frame_cond_shape[-1],
|
132 |
+
n_head, n_layer, attn_type='full',
|
133 |
+
attn_kwargs=dict(attn_dropout=0.), causal=False)
|
134 |
+
|
135 |
+
self.pre_fc_norm = LayerNorm(embd_dim, class_cond_dim)
|
136 |
+
self.post_fc_dp = nn.Dropout(dropout)
|
137 |
+
self.fc_block = nn.Sequential(
|
138 |
+
nn.Linear(in_features=embd_dim, out_features=embd_dim * 4),
|
139 |
+
GeLU2(),
|
140 |
+
nn.Linear(in_features=embd_dim * 4, out_features=embd_dim),
|
141 |
+
)
|
142 |
+
|
143 |
+
def forward(self, x, cond, decode_step, decode_idx):
|
144 |
+
h = self.pre_attn_norm(x, cond)
|
145 |
+
if self.training:
|
146 |
+
h = checkpoint(self.attn, h, h, h, decode_step, decode_idx)
|
147 |
+
else:
|
148 |
+
h = self.attn(h, h, h, decode_step, decode_idx)
|
149 |
+
h = self.post_attn_dp(h)
|
150 |
+
x = x + h
|
151 |
+
|
152 |
+
if self.use_frame_cond:
|
153 |
+
h = self.pre_enc_norm(x, cond)
|
154 |
+
if self.training:
|
155 |
+
h = checkpoint(self.enc_attn, h, cond['frame_cond'], cond['frame_cond'],
|
156 |
+
decode_step, decode_idx)
|
157 |
+
else:
|
158 |
+
h = self.enc_attn(h, cond['frame_cond'], cond['frame_cond'],
|
159 |
+
decode_step, decode_idx)
|
160 |
+
h = self.post_enc_dp(h)
|
161 |
+
x = x + h
|
162 |
+
|
163 |
+
h = self.pre_fc_norm(x, cond)
|
164 |
+
if self.training:
|
165 |
+
h = checkpoint(self.fc_block, h)
|
166 |
+
else:
|
167 |
+
h = self.fc_block(h)
|
168 |
+
h = self.post_fc_dp(h)
|
169 |
+
x = x + h
|
170 |
+
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class MultiHeadAttention(nn.Module):
|
175 |
+
def __init__(self, shape, dim_q, dim_kv, n_head, n_layer,
|
176 |
+
causal, attn_type, attn_kwargs):
|
177 |
+
super().__init__()
|
178 |
+
self.causal = causal
|
179 |
+
self.shape = shape
|
180 |
+
|
181 |
+
self.d_k = dim_q // n_head
|
182 |
+
self.d_v = dim_kv // n_head
|
183 |
+
self.n_head = n_head
|
184 |
+
|
185 |
+
self.w_qs = nn.Linear(dim_q, n_head * self.d_k, bias=False) # q
|
186 |
+
self.w_qs.weight.data.normal_(std=1.0 / np.sqrt(dim_q))
|
187 |
+
|
188 |
+
self.w_ks = nn.Linear(dim_kv, n_head * self.d_k, bias=False) # k
|
189 |
+
self.w_ks.weight.data.normal_(std=1.0 / np.sqrt(dim_kv))
|
190 |
+
|
191 |
+
self.w_vs = nn.Linear(dim_kv, n_head * self.d_v, bias=False) # v
|
192 |
+
self.w_vs.weight.data.normal_(std=1.0 / np.sqrt(dim_kv))
|
193 |
+
|
194 |
+
self.fc = nn.Linear(n_head * self.d_v, dim_q, bias=True) # c
|
195 |
+
self.fc.weight.data.normal_(std=1.0 / np.sqrt(dim_q * n_layer))
|
196 |
+
|
197 |
+
if attn_type == 'full':
|
198 |
+
self.attn = FullAttention(shape, causal, **attn_kwargs)
|
199 |
+
elif attn_type == 'axial':
|
200 |
+
assert not causal, 'causal axial attention is not supported'
|
201 |
+
self.attn = AxialAttention(len(shape), **attn_kwargs)
|
202 |
+
elif attn_type == 'sparse':
|
203 |
+
self.attn = SparseAttention(shape, n_head, causal, **attn_kwargs)
|
204 |
+
|
205 |
+
self.cache = None
|
206 |
+
|
207 |
+
def forward(self, q, k, v, decode_step=None, decode_idx=None):
|
208 |
+
""" Compute multi-head attention
|
209 |
+
Args
|
210 |
+
q, k, v: a [b, d1, ..., dn, c] tensor or
|
211 |
+
a [b, 1, ..., 1, c] tensor if decode_step is not None
|
212 |
+
|
213 |
+
Returns
|
214 |
+
The output after performing attention
|
215 |
+
"""
|
216 |
+
|
217 |
+
# compute k, q, v
|
218 |
+
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
219 |
+
q = view_range(self.w_qs(q), -1, None, (n_head, d_k))
|
220 |
+
k = view_range(self.w_ks(k), -1, None, (n_head, d_k))
|
221 |
+
v = view_range(self.w_vs(v), -1, None, (n_head, d_v))
|
222 |
+
|
223 |
+
# b x n_head x seq_len x d
|
224 |
+
# (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d)
|
225 |
+
q = shift_dim(q, -2, 1)
|
226 |
+
k = shift_dim(k, -2, 1)
|
227 |
+
v = shift_dim(v, -2, 1)
|
228 |
+
|
229 |
+
# fast decoding
|
230 |
+
if decode_step is not None:
|
231 |
+
if decode_step == 0:
|
232 |
+
if self.causal:
|
233 |
+
k_shape = (q.shape[0], n_head, *self.shape, self.d_k)
|
234 |
+
v_shape = (q.shape[0], n_head, *self.shape, self.d_v)
|
235 |
+
self.cache = dict(k=torch.zeros(k_shape, dtype=k.dtype, device=q.device),
|
236 |
+
v=torch.zeros(v_shape, dtype=v.dtype, device=q.device))
|
237 |
+
else:
|
238 |
+
# cache only once in the non-causal case
|
239 |
+
self.cache = dict(k=k.clone(), v=v.clone())
|
240 |
+
if self.causal:
|
241 |
+
idx = (slice(None, None), slice(None, None), *[slice(i, i+ 1) for i in decode_idx])
|
242 |
+
self.cache['k'][idx] = k
|
243 |
+
self.cache['v'][idx] = v
|
244 |
+
k, v = self.cache['k'], self.cache['v']
|
245 |
+
|
246 |
+
a = self.attn(q, k, v, decode_step, decode_idx)
|
247 |
+
|
248 |
+
# (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d)
|
249 |
+
a = shift_dim(a, 1, -2).flatten(start_dim=-2)
|
250 |
+
a = self.fc(a) # (b x seq_len x embd_dim)
|
251 |
+
|
252 |
+
return a
|
253 |
+
|
254 |
+
############## Attention #######################
|
255 |
+
class FullAttention(nn.Module):
|
256 |
+
def __init__(self, shape, causal, attn_dropout):
|
257 |
+
super().__init__()
|
258 |
+
self.causal = causal
|
259 |
+
self.attn_dropout = attn_dropout
|
260 |
+
|
261 |
+
seq_len = np.prod(shape)
|
262 |
+
if self.causal:
|
263 |
+
self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)))
|
264 |
+
|
265 |
+
def forward(self, q, k, v, decode_step, decode_idx):
|
266 |
+
mask = self.mask if self.causal else None
|
267 |
+
if decode_step is not None and mask is not None:
|
268 |
+
mask = mask[[decode_step]]
|
269 |
+
|
270 |
+
old_shape = q.shape[2:-1]
|
271 |
+
q = q.flatten(start_dim=2, end_dim=-2)
|
272 |
+
k = k.flatten(start_dim=2, end_dim=-2)
|
273 |
+
v = v.flatten(start_dim=2, end_dim=-2)
|
274 |
+
|
275 |
+
out = scaled_dot_product_attention(q, k, v, mask=mask,
|
276 |
+
attn_dropout=self.attn_dropout,
|
277 |
+
training=self.training)
|
278 |
+
|
279 |
+
return view_range(out, 2, 3, old_shape)
|
280 |
+
|
281 |
+
class AxialAttention(nn.Module):
|
282 |
+
def __init__(self, n_dim, axial_dim):
|
283 |
+
super().__init__()
|
284 |
+
if axial_dim < 0:
|
285 |
+
axial_dim = 2 + n_dim + 1 + axial_dim
|
286 |
+
else:
|
287 |
+
axial_dim += 2 # account for batch, head, dim
|
288 |
+
self.axial_dim = axial_dim
|
289 |
+
|
290 |
+
def forward(self, q, k, v, decode_step, decode_idx):
|
291 |
+
q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3)
|
292 |
+
k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3)
|
293 |
+
v = shift_dim(v, self.axial_dim, -2)
|
294 |
+
old_shape = list(v.shape)
|
295 |
+
v = v.flatten(end_dim=-3)
|
296 |
+
|
297 |
+
out = scaled_dot_product_attention(q, k, v, training=self.training)
|
298 |
+
out = out.view(*old_shape)
|
299 |
+
out = shift_dim(out, -2, self.axial_dim)
|
300 |
+
return out
|
301 |
+
|
302 |
+
|
303 |
+
class SparseAttention(nn.Module):
|
304 |
+
ops = dict()
|
305 |
+
attn_mask = dict()
|
306 |
+
block_layout = dict()
|
307 |
+
|
308 |
+
def __init__(self, shape, n_head, causal, num_local_blocks=4, block=32,
|
309 |
+
attn_dropout=0.): # does not use attn_dropout
|
310 |
+
super().__init__()
|
311 |
+
self.causal = causal
|
312 |
+
self.shape = shape
|
313 |
+
|
314 |
+
self.sparsity_config = StridedSparsityConfig(shape=shape, n_head=n_head,
|
315 |
+
causal=causal, block=block,
|
316 |
+
num_local_blocks=num_local_blocks)
|
317 |
+
|
318 |
+
if self.shape not in SparseAttention.block_layout:
|
319 |
+
SparseAttention.block_layout[self.shape] = self.sparsity_config.make_layout()
|
320 |
+
if causal and self.shape not in SparseAttention.attn_mask:
|
321 |
+
SparseAttention.attn_mask[self.shape] = self.sparsity_config.make_sparse_attn_mask()
|
322 |
+
|
323 |
+
def get_ops(self):
|
324 |
+
try:
|
325 |
+
from deepspeed.ops.sparse_attention import MatMul, Softmax
|
326 |
+
except:
|
327 |
+
raise Exception('Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`')
|
328 |
+
if self.shape not in SparseAttention.ops:
|
329 |
+
sparsity_layout = self.sparsity_config.make_layout()
|
330 |
+
sparse_dot_sdd_nt = MatMul(sparsity_layout,
|
331 |
+
self.sparsity_config.block,
|
332 |
+
'sdd',
|
333 |
+
trans_a=False,
|
334 |
+
trans_b=True)
|
335 |
+
|
336 |
+
sparse_dot_dsd_nn = MatMul(sparsity_layout,
|
337 |
+
self.sparsity_config.block,
|
338 |
+
'dsd',
|
339 |
+
trans_a=False,
|
340 |
+
trans_b=False)
|
341 |
+
|
342 |
+
sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)
|
343 |
+
|
344 |
+
SparseAttention.ops[self.shape] = (sparse_dot_sdd_nt,
|
345 |
+
sparse_dot_dsd_nn,
|
346 |
+
sparse_softmax)
|
347 |
+
return SparseAttention.ops[self.shape]
|
348 |
+
|
349 |
+
def forward(self, q, k, v, decode_step, decode_idx):
|
350 |
+
if self.training and self.shape not in SparseAttention.ops:
|
351 |
+
self.get_ops()
|
352 |
+
|
353 |
+
SparseAttention.block_layout[self.shape] = SparseAttention.block_layout[self.shape].to(q)
|
354 |
+
if self.causal:
|
355 |
+
SparseAttention.attn_mask[self.shape] = SparseAttention.attn_mask[self.shape].to(q).type_as(q)
|
356 |
+
attn_mask = SparseAttention.attn_mask[self.shape] if self.causal else None
|
357 |
+
|
358 |
+
old_shape = q.shape[2:-1]
|
359 |
+
q = q.flatten(start_dim=2, end_dim=-2)
|
360 |
+
k = k.flatten(start_dim=2, end_dim=-2)
|
361 |
+
v = v.flatten(start_dim=2, end_dim=-2)
|
362 |
+
|
363 |
+
if decode_step is not None:
|
364 |
+
mask = self.sparsity_config.get_non_block_layout_row(SparseAttention.block_layout[self.shape], decode_step)
|
365 |
+
out = scaled_dot_product_attention(q, k, v, mask=mask, training=self.training)
|
366 |
+
else:
|
367 |
+
if q.shape != k.shape or k.shape != v.shape:
|
368 |
+
raise Exception('SparseAttention only support self-attention')
|
369 |
+
sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops()
|
370 |
+
scaling = float(q.shape[-1]) ** -0.5
|
371 |
+
|
372 |
+
attn_output_weights = sparse_dot_sdd_nt(q, k)
|
373 |
+
if attn_mask is not None:
|
374 |
+
attn_output_weights = attn_output_weights.masked_fill(attn_mask == 0,
|
375 |
+
float('-inf'))
|
376 |
+
attn_output_weights = sparse_softmax(
|
377 |
+
attn_output_weights,
|
378 |
+
scale=scaling
|
379 |
+
)
|
380 |
+
|
381 |
+
out = sparse_dot_dsd_nn(attn_output_weights, v)
|
382 |
+
|
383 |
+
return view_range(out, 2, 3, old_shape)
|
384 |
+
|
385 |
+
|
386 |
+
class StridedSparsityConfig(object):
|
387 |
+
"""
|
388 |
+
Strided Sparse configuration specified in https://arxiv.org/abs/1904.10509 that
|
389 |
+
generalizes to arbitrary dimensions
|
390 |
+
"""
|
391 |
+
def __init__(self, shape, n_head, causal, block, num_local_blocks):
|
392 |
+
self.n_head = n_head
|
393 |
+
self.shape = shape
|
394 |
+
self.causal = causal
|
395 |
+
self.block = block
|
396 |
+
self.num_local_blocks = num_local_blocks
|
397 |
+
|
398 |
+
assert self.num_local_blocks >= 1, 'Must have at least 1 local block'
|
399 |
+
assert self.seq_len % self.block == 0, 'seq len must be divisible by block size'
|
400 |
+
|
401 |
+
self._block_shape = self._compute_block_shape()
|
402 |
+
self._block_shape_cum = self._block_shape_cum_sizes()
|
403 |
+
|
404 |
+
@property
|
405 |
+
def seq_len(self):
|
406 |
+
return np.prod(self.shape)
|
407 |
+
|
408 |
+
@property
|
409 |
+
def num_blocks(self):
|
410 |
+
return self.seq_len // self.block
|
411 |
+
|
412 |
+
def set_local_layout(self, layout):
|
413 |
+
num_blocks = self.num_blocks
|
414 |
+
for row in range(0, num_blocks):
|
415 |
+
end = min(row + self.num_local_blocks, num_blocks)
|
416 |
+
for col in range(
|
417 |
+
max(0, row - self.num_local_blocks),
|
418 |
+
(row + 1 if self.causal else end)):
|
419 |
+
layout[:, row, col] = 1
|
420 |
+
return layout
|
421 |
+
|
422 |
+
def set_global_layout(self, layout):
|
423 |
+
num_blocks = self.num_blocks
|
424 |
+
n_dim = len(self._block_shape)
|
425 |
+
for row in range(num_blocks):
|
426 |
+
assert self._to_flattened_idx(self._to_unflattened_idx(row)) == row
|
427 |
+
cur_idx = self._to_unflattened_idx(row)
|
428 |
+
# no strided attention over last dim
|
429 |
+
for d in range(n_dim - 1):
|
430 |
+
end = self._block_shape[d]
|
431 |
+
for i in range(0, (cur_idx[d] + 1 if self.causal else end)):
|
432 |
+
new_idx = list(cur_idx)
|
433 |
+
new_idx[d] = i
|
434 |
+
new_idx = tuple(new_idx)
|
435 |
+
|
436 |
+
col = self._to_flattened_idx(new_idx)
|
437 |
+
layout[:, row, col] = 1
|
438 |
+
|
439 |
+
return layout
|
440 |
+
|
441 |
+
def make_layout(self):
|
442 |
+
layout = torch.zeros((self.n_head, self.num_blocks, self.num_blocks), dtype=torch.int64)
|
443 |
+
layout = self.set_local_layout(layout)
|
444 |
+
layout = self.set_global_layout(layout)
|
445 |
+
return layout
|
446 |
+
|
447 |
+
def make_sparse_attn_mask(self):
|
448 |
+
block_layout = self.make_layout()
|
449 |
+
assert block_layout.shape[1] == block_layout.shape[2] == self.num_blocks
|
450 |
+
|
451 |
+
num_dense_blocks = block_layout.sum().item()
|
452 |
+
attn_mask = torch.ones(num_dense_blocks, self.block, self.block)
|
453 |
+
counter = 0
|
454 |
+
for h in range(self.n_head):
|
455 |
+
for i in range(self.num_blocks):
|
456 |
+
for j in range(self.num_blocks):
|
457 |
+
elem = block_layout[h, i, j].item()
|
458 |
+
if elem == 1:
|
459 |
+
assert i >= j
|
460 |
+
if i == j: # need to mask within block on diagonals
|
461 |
+
attn_mask[counter] = torch.tril(attn_mask[counter])
|
462 |
+
counter += 1
|
463 |
+
assert counter == num_dense_blocks
|
464 |
+
|
465 |
+
return attn_mask.unsqueeze(0)
|
466 |
+
|
467 |
+
def get_non_block_layout_row(self, block_layout, row):
|
468 |
+
block_row = row // self.block
|
469 |
+
block_row = block_layout[:, [block_row]] # n_head x 1 x n_blocks
|
470 |
+
block_row = block_row.repeat_interleave(self.block, dim=-1)
|
471 |
+
block_row[:, :, row + 1:] = 0.
|
472 |
+
return block_row
|
473 |
+
|
474 |
+
############# Helper functions ##########################
|
475 |
+
|
476 |
+
def _compute_block_shape(self):
|
477 |
+
n_dim = len(self.shape)
|
478 |
+
cum_prod = 1
|
479 |
+
for i in range(n_dim - 1, -1, -1):
|
480 |
+
cum_prod *= self.shape[i]
|
481 |
+
if cum_prod > self.block:
|
482 |
+
break
|
483 |
+
assert cum_prod % self.block == 0
|
484 |
+
new_shape = (*self.shape[:i], cum_prod // self.block)
|
485 |
+
|
486 |
+
assert np.prod(new_shape) == np.prod(self.shape) // self.block
|
487 |
+
|
488 |
+
return new_shape
|
489 |
+
|
490 |
+
def _block_shape_cum_sizes(self):
|
491 |
+
bs = np.flip(np.array(self._block_shape))
|
492 |
+
return tuple(np.flip(np.cumprod(bs)[:-1])) + (1,)
|
493 |
+
|
494 |
+
def _to_flattened_idx(self, idx):
|
495 |
+
assert len(idx) == len(self._block_shape), f"{len(idx)} != {len(self._block_shape)}"
|
496 |
+
flat_idx = 0
|
497 |
+
for i in range(len(self._block_shape)):
|
498 |
+
flat_idx += idx[i] * self._block_shape_cum[i]
|
499 |
+
return flat_idx
|
500 |
+
|
501 |
+
def _to_unflattened_idx(self, flat_idx):
|
502 |
+
assert flat_idx < np.prod(self._block_shape)
|
503 |
+
idx = []
|
504 |
+
for i in range(len(self._block_shape)):
|
505 |
+
idx.append(flat_idx // self._block_shape_cum[i])
|
506 |
+
flat_idx %= self._block_shape_cum[i]
|
507 |
+
return tuple(idx)
|
508 |
+
|
509 |
+
|
510 |
+
################ Spatiotemporal broadcasted positional embeddings ###############
|
511 |
+
class AddBroadcastPosEmbed(nn.Module):
|
512 |
+
def __init__(self, shape, embd_dim, dim=-1):
|
513 |
+
super().__init__()
|
514 |
+
assert dim in [-1, 1] # only first or last dim supported
|
515 |
+
self.shape = shape
|
516 |
+
self.n_dim = n_dim = len(shape)
|
517 |
+
self.embd_dim = embd_dim
|
518 |
+
self.dim = dim
|
519 |
+
|
520 |
+
assert embd_dim % n_dim == 0, f"{embd_dim} % {n_dim} != 0"
|
521 |
+
self.emb = nn.ParameterDict({
|
522 |
+
f'd_{i}': nn.Parameter(torch.randn(shape[i], embd_dim // n_dim) * 0.01
|
523 |
+
if dim == -1 else
|
524 |
+
torch.randn(embd_dim // n_dim, shape[i]) * 0.01)
|
525 |
+
for i in range(n_dim)
|
526 |
+
})
|
527 |
+
|
528 |
+
def forward(self, x, decode_step=None, decode_idx=None):
|
529 |
+
embs = []
|
530 |
+
for i in range(self.n_dim):
|
531 |
+
e = self.emb[f'd_{i}']
|
532 |
+
if self.dim == -1:
|
533 |
+
# (1, 1, ..., 1, self.shape[i], 1, ..., -1)
|
534 |
+
e = e.view(1, *((1,) * i), self.shape[i], *((1,) * (self.n_dim - i - 1)), -1)
|
535 |
+
e = e.expand(1, *self.shape, -1)
|
536 |
+
else:
|
537 |
+
e = e.view(1, -1, *((1,) * i), self.shape[i], *((1,) * (self.n_dim - i - 1)))
|
538 |
+
e = e.expand(1, -1, *self.shape)
|
539 |
+
embs.append(e)
|
540 |
+
|
541 |
+
embs = torch.cat(embs, dim=self.dim)
|
542 |
+
if decode_step is not None:
|
543 |
+
embs = tensor_slice(embs, [0, *decode_idx, 0],
|
544 |
+
[x.shape[0], *(1,) * self.n_dim, x.shape[-1]])
|
545 |
+
|
546 |
+
return x + embs
|
547 |
+
|
548 |
+
################# Helper Functions ###################################
|
549 |
+
def scaled_dot_product_attention(q, k, v, mask=None, attn_dropout=0., training=True):
|
550 |
+
# Performs scaled dot-product attention over the second to last dimension dn
|
551 |
+
|
552 |
+
# (b, n_head, d1, ..., dn, d)
|
553 |
+
attn = torch.matmul(q, k.transpose(-1, -2))
|
554 |
+
attn = attn / np.sqrt(q.shape[-1])
|
555 |
+
if mask is not None:
|
556 |
+
attn = attn.masked_fill(mask == 0, float('-inf'))
|
557 |
+
attn_float = F.softmax(attn, dim=-1)
|
558 |
+
attn = attn_float.type_as(attn) # b x n_head x d1 x ... x dn x d
|
559 |
+
attn = F.dropout(attn, p=attn_dropout, training=training)
|
560 |
+
|
561 |
+
a = torch.matmul(attn, v) # b x n_head x d1 x ... x dn x d
|
562 |
+
|
563 |
+
return a
|
564 |
+
|
565 |
+
|
566 |
+
class RightShift(nn.Module):
|
567 |
+
def __init__(self, embd_dim):
|
568 |
+
super().__init__()
|
569 |
+
self.embd_dim = embd_dim
|
570 |
+
self.sos = nn.Parameter(torch.FloatTensor(embd_dim).normal_(std=0.02), requires_grad=True)
|
571 |
+
|
572 |
+
def forward(self, x, decode_step):
|
573 |
+
if decode_step is not None and decode_step > 0:
|
574 |
+
return x
|
575 |
+
|
576 |
+
x_shape = list(x.shape)
|
577 |
+
x = x.flatten(start_dim=1, end_dim=-2) # (b, seq_len, embd_dim)
|
578 |
+
sos = torch.ones(x_shape[0], 1, self.embd_dim, dtype=torch.float32).to(self.sos) * self.sos
|
579 |
+
sos = sos.type_as(x)
|
580 |
+
x = torch.cat([sos, x[:, :-1, :]], axis=1)
|
581 |
+
x = x.view(*x_shape)
|
582 |
+
|
583 |
+
return x
|
584 |
+
|
585 |
+
|
586 |
+
class GeLU2(nn.Module):
|
587 |
+
def forward(self, x):
|
588 |
+
return (1.702 * x).sigmoid() * x
|
589 |
+
|
590 |
+
|
591 |
+
class LayerNorm(nn.Module):
|
592 |
+
def __init__(self, embd_dim, class_cond_dim):
|
593 |
+
super().__init__()
|
594 |
+
self.conditional = class_cond_dim is not None
|
595 |
+
|
596 |
+
if self.conditional:
|
597 |
+
self.w = nn.Linear(class_cond_dim, embd_dim, bias=False)
|
598 |
+
nn.init.constant_(self.w.weight.data, 1. / np.sqrt(class_cond_dim))
|
599 |
+
self.wb = nn.Linear(class_cond_dim, embd_dim, bias=False)
|
600 |
+
else:
|
601 |
+
self.g = nn.Parameter(torch.ones(embd_dim, dtype=torch.float32), requires_grad=True)
|
602 |
+
self.b = nn.Parameter(torch.zeros(embd_dim, dtype=torch.float32), requires_grad=True)
|
603 |
+
|
604 |
+
def forward(self, x, cond):
|
605 |
+
if self.conditional: # (b, cond_dim)
|
606 |
+
g = 1 + self.w(cond['class_cond']).view(x.shape[0], *(1,)*(len(x.shape)-2), x.shape[-1]) # (b, ..., embd_dim)
|
607 |
+
b = self.wb(cond['class_cond']).view(x.shape[0], *(1,)*(len(x.shape)-2), x.shape[-1])
|
608 |
+
else:
|
609 |
+
g = self.g # (embd_dim,)
|
610 |
+
b = self.b
|
611 |
+
|
612 |
+
x_float = x.float()
|
613 |
+
|
614 |
+
mu = x_float.mean(dim=-1, keepdims=True)
|
615 |
+
s = (x_float - mu).square().mean(dim=-1, keepdims=True)
|
616 |
+
x_float = (x_float - mu) * (1e-5 + s.rsqrt()) # (b, ..., embd_dim)
|
617 |
+
x_float = x_float * g + b
|
618 |
+
|
619 |
+
x = x_float.type_as(x)
|
620 |
+
return x
|
DiT_VAE/vae/data/__init__.py
ADDED
File without changes
|
DiT_VAE/vae/data/dataset_online_vae.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import json
|
5 |
+
import zipfile
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
# from transformers import CLIPImageProcessor
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
import io
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
import numpy as np
|
13 |
+
# from torchvision import transforms
|
14 |
+
# from einops import rearrange
|
15 |
+
# import random
|
16 |
+
# import os
|
17 |
+
# from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDIMScheduler
|
18 |
+
# import time
|
19 |
+
# import io
|
20 |
+
# import array
|
21 |
+
# import numpy as np
|
22 |
+
#
|
23 |
+
# from training.triplane import TriPlaneGenerator
|
24 |
+
|
25 |
+
|
26 |
+
def to_rgb_image(maybe_rgba: Image.Image):
|
27 |
+
if maybe_rgba.mode == 'RGB':
|
28 |
+
return maybe_rgba
|
29 |
+
elif maybe_rgba.mode == 'RGBA':
|
30 |
+
rgba = maybe_rgba
|
31 |
+
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
|
32 |
+
img = Image.fromarray(img, 'RGB')
|
33 |
+
img.paste(rgba, mask=rgba.getchannel('A'))
|
34 |
+
return img
|
35 |
+
else:
|
36 |
+
raise ValueError("Unsupported image type.", maybe_rgba.mode)
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
# image(contain style),z,pose,text
|
41 |
+
class TriplaneDataset(Dataset):
|
42 |
+
# image, triplane, ref_feature
|
43 |
+
def __init__(self, json_file, data_base_dir, model_names):
|
44 |
+
super().__init__()
|
45 |
+
self.dict_data_image = json.load(open(json_file)) # {'image_name': pose}
|
46 |
+
self.data_base_dir = data_base_dir
|
47 |
+
self.data_list = list(self.dict_data_image.keys())
|
48 |
+
self.zip_file_dict = {}
|
49 |
+
config_gan_model = OmegaConf.load(model_names)
|
50 |
+
all_models = config_gan_model['gan_models'].keys()
|
51 |
+
for model_name in all_models:
|
52 |
+
zipfile_path = os.path.join(self.data_base_dir, model_name+'.zip')
|
53 |
+
zipfile_load = zipfile.ZipFile(zipfile_path)
|
54 |
+
self.zip_file_dict[model_name] = zipfile_load
|
55 |
+
|
56 |
+
def getdata(self, idx):
|
57 |
+
# need z and expression and model name
|
58 |
+
# image:"seed0035.png"
|
59 |
+
# data_each_dict = {
|
60 |
+
# 'vert_dir': vert_dir,
|
61 |
+
# 'z_dir': z_dir,
|
62 |
+
# 'pose_dir': pose_dir,
|
63 |
+
# 'img_dir': img_dir,
|
64 |
+
# 'model_name': model_name
|
65 |
+
# }
|
66 |
+
data_name = self.data_list[idx]
|
67 |
+
data_model_name = self.dict_data_image[data_name]['model_name']
|
68 |
+
zipfile_loaded = self.zip_file_dict[data_model_name]
|
69 |
+
# zipfile_path = os.path.join(self.data_base_dir, data_model_name)
|
70 |
+
# zipfile_loaded = zipfile.ZipFile(zipfile_path)
|
71 |
+
with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f:
|
72 |
+
buffer = io.BytesIO(f.read())
|
73 |
+
data_z = torch.load(buffer)
|
74 |
+
buffer.close()
|
75 |
+
f.close()
|
76 |
+
with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as ff:
|
77 |
+
buffer_v = io.BytesIO(ff.read())
|
78 |
+
data_vert = torch.load(buffer_v)
|
79 |
+
buffer_v.close()
|
80 |
+
ff.close()
|
81 |
+
# raw_image = to_rgb_image(Image.open(f))
|
82 |
+
#
|
83 |
+
# data_model_name = self.dict_data_image[data_name]['model_name']
|
84 |
+
# data_z_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['z_dir'])
|
85 |
+
# data_vert_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['vert_dir'])
|
86 |
+
# data_z = torch.load(data_z_dir)
|
87 |
+
# data_vert = torch.load(data_vert_dir)
|
88 |
+
|
89 |
+
return {
|
90 |
+
"data_z": data_z,
|
91 |
+
"data_vert": data_vert,
|
92 |
+
"data_model_name": data_model_name
|
93 |
+
}
|
94 |
+
|
95 |
+
def __getitem__(self, idx):
|
96 |
+
for _ in range(20):
|
97 |
+
try:
|
98 |
+
return self.getdata(idx)
|
99 |
+
except Exception as e:
|
100 |
+
print(f"Error details: {str(e)}")
|
101 |
+
idx = np.random.randint(len(self))
|
102 |
+
raise RuntimeError('Too many bad data.')
|
103 |
+
|
104 |
+
|
105 |
+
def __len__(self):
|
106 |
+
return len(self.data_list)
|
107 |
+
|
108 |
+
# for zip files
|
DiT_VAE/vae/distributions.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class AbstractDistribution:
|
6 |
+
def sample(self):
|
7 |
+
raise NotImplementedError()
|
8 |
+
|
9 |
+
def mode(self):
|
10 |
+
raise NotImplementedError()
|
11 |
+
|
12 |
+
|
13 |
+
class DiracDistribution(AbstractDistribution):
|
14 |
+
def __init__(self, value):
|
15 |
+
self.value = value
|
16 |
+
|
17 |
+
def sample(self):
|
18 |
+
return self.value
|
19 |
+
|
20 |
+
def mode(self):
|
21 |
+
return self.value
|
22 |
+
|
23 |
+
|
24 |
+
class DiagonalGaussianDistribution(object):
|
25 |
+
def __init__(self, parameters, deterministic=False):
|
26 |
+
self.parameters = parameters
|
27 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
28 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
29 |
+
self.deterministic = deterministic
|
30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
31 |
+
self.var = torch.exp(self.logvar)
|
32 |
+
if self.deterministic:
|
33 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
34 |
+
|
35 |
+
def sample(self, noise=None):
|
36 |
+
if noise is None:
|
37 |
+
noise = torch.randn(self.mean.shape)
|
38 |
+
x = self.mean + self.std * noise.to(device=self.parameters.device, dtype=self.parameters.dtype)
|
39 |
+
return x
|
40 |
+
|
41 |
+
def kl(self, other=None):
|
42 |
+
if self.deterministic:
|
43 |
+
return torch.Tensor([0.])
|
44 |
+
else:
|
45 |
+
if other is None:
|
46 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
47 |
+
+ self.var - 1.0 - self.logvar,
|
48 |
+
dim=[1, 2, 3])
|
49 |
+
else:
|
50 |
+
return 0.5 * torch.sum(
|
51 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
52 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
53 |
+
dim=[1, 2, 3])
|
54 |
+
|
55 |
+
def nll(self, sample, dims=[1,2,3]):
|
56 |
+
if self.deterministic:
|
57 |
+
return torch.Tensor([0.])
|
58 |
+
logtwopi = np.log(2.0 * np.pi)
|
59 |
+
return 0.5 * torch.sum(
|
60 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
61 |
+
dim=dims)
|
62 |
+
|
63 |
+
def mode(self):
|
64 |
+
return self.mean
|
65 |
+
|
66 |
+
|
67 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
68 |
+
"""
|
69 |
+
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
70 |
+
Compute the KL divergence between two gaussians.
|
71 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
72 |
+
scalars, among other use cases.
|
73 |
+
"""
|
74 |
+
tensor = None
|
75 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
76 |
+
if isinstance(obj, torch.Tensor):
|
77 |
+
tensor = obj
|
78 |
+
break
|
79 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
80 |
+
|
81 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
82 |
+
# Tensors, but it does not work for torch.exp().
|
83 |
+
logvar1, logvar2 = [
|
84 |
+
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
85 |
+
for x in (logvar1, logvar2)
|
86 |
+
]
|
87 |
+
|
88 |
+
return 0.5 * (
|
89 |
+
-1.0
|
90 |
+
+ logvar2
|
91 |
+
- logvar1
|
92 |
+
+ torch.exp(logvar1 - logvar2)
|
93 |
+
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
94 |
+
)
|
DiT_VAE/vae/losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .contperceptual import LPIPSithTVLoss
|