刘虹雨 commited on
Commit
8ed2f16
·
1 Parent(s): 8f481d2
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +8 -32
  3. DiT_VAE/.DS_Store +0 -0
  4. DiT_VAE/__init__.py +0 -0
  5. DiT_VAE/diffusion/__init__.py +8 -0
  6. DiT_VAE/diffusion/configs/PixArt_xl2_4D_Triplane.py +64 -0
  7. DiT_VAE/diffusion/configs/PixArt_xl2_img256_4D_Triplane.py +41 -0
  8. DiT_VAE/diffusion/configs/__init__.py +0 -0
  9. DiT_VAE/diffusion/configs/vae_model.yaml +24 -0
  10. DiT_VAE/diffusion/data/__init__.py +2 -0
  11. DiT_VAE/diffusion/data/builder.py +67 -0
  12. DiT_VAE/diffusion/data/transforms.py +29 -0
  13. DiT_VAE/diffusion/dpm_solver.py +28 -0
  14. DiT_VAE/diffusion/iddpm.py +51 -0
  15. DiT_VAE/diffusion/lcm_scheduler.py +455 -0
  16. DiT_VAE/diffusion/model/__init__.py +2 -0
  17. DiT_VAE/diffusion/model/builder.py +14 -0
  18. DiT_VAE/diffusion/model/diffusion_utils.py +92 -0
  19. DiT_VAE/diffusion/model/dpm_solver.py +1337 -0
  20. DiT_VAE/diffusion/model/edm_sample.py +168 -0
  21. DiT_VAE/diffusion/model/gaussian_diffusion.py +1006 -0
  22. DiT_VAE/diffusion/model/hed.py +150 -0
  23. DiT_VAE/diffusion/model/image_embedding.py +15 -0
  24. DiT_VAE/diffusion/model/nets/PixArt_blocks.py +655 -0
  25. DiT_VAE/diffusion/model/nets/TriDitCLIPDINO.py +315 -0
  26. DiT_VAE/diffusion/model/nets/__init__.py +1 -0
  27. DiT_VAE/diffusion/model/respace.py +131 -0
  28. DiT_VAE/diffusion/model/sa_solver.py +1129 -0
  29. DiT_VAE/diffusion/model/timestep_sampler.py +150 -0
  30. DiT_VAE/diffusion/model/utils.py +510 -0
  31. DiT_VAE/diffusion/sa_sampler.py +66 -0
  32. DiT_VAE/diffusion/sa_solver_diffusers.py +840 -0
  33. DiT_VAE/diffusion/utils/__init__.py +1 -0
  34. DiT_VAE/diffusion/utils/checkpoint.py +80 -0
  35. DiT_VAE/diffusion/utils/data_sampler.py +138 -0
  36. DiT_VAE/diffusion/utils/dist_utils.py +303 -0
  37. DiT_VAE/diffusion/utils/logger.py +94 -0
  38. DiT_VAE/diffusion/utils/lr_scheduler.py +89 -0
  39. DiT_VAE/diffusion/utils/misc.py +366 -0
  40. DiT_VAE/diffusion/utils/optimizer.py +237 -0
  41. DiT_VAE/train_diffusion.py +5 -0
  42. DiT_VAE/train_vae.py +369 -0
  43. DiT_VAE/util.py +217 -0
  44. DiT_VAE/vae/__init__.py +0 -0
  45. DiT_VAE/vae/aemodules3d.py +368 -0
  46. DiT_VAE/vae/attention_vae.py +620 -0
  47. DiT_VAE/vae/data/__init__.py +0 -0
  48. DiT_VAE/vae/data/dataset_online_vae.py +108 -0
  49. DiT_VAE/vae/distributions.py +94 -0
  50. DiT_VAE/vae/losses/__init__.py +1 -0
.DS_Store ADDED
Binary file (10.2 kB). View file
 
.gitattributes CHANGED
@@ -1,36 +1,12 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
 
 
 
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
1
+ *.gif filter=lfs diff=lfs merge=lfs -text
2
+ data_process/lib/FaceVerse/v3/faceverse_v3_1.npy filter=lfs diff=lfs merge=lfs -text
3
+ data_process/lib/faceverse_process/BgMatting_models/rvm_resnet50_fp32.torchscript filter=lfs diff=lfs merge=lfs -text
4
+ data_process/lib/faceverse_process/metamodel/v3/faceverse_v3_1.npy filter=lfs diff=lfs merge=lfs -text
5
+ *.pt filter=lfs diff=lfs merge=lfs -text
6
  *.bin filter=lfs diff=lfs merge=lfs -text
 
7
  *.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  *.safetensors filter=lfs diff=lfs merge=lfs -text
9
+ *.torchscript filter=lfs diff=lfs merge=lfs -text
10
+ *.npy filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
11
  *.png filter=lfs diff=lfs merge=lfs -text
12
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
DiT_VAE/.DS_Store ADDED
Binary file (6.15 kB). View file
 
DiT_VAE/__init__.py ADDED
File without changes
DiT_VAE/diffusion/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from .iddpm import IDDPM
7
+ from .dpm_solver import DPMS
8
+ from .sa_sampler import SASolverSampler
DiT_VAE/diffusion/configs/PixArt_xl2_4D_Triplane.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root = '/data/data'
2
+ data = dict(type='TriplaneData', data_base_dir='triplane', data_json_file='/nas8/liuhongyu/HeadGallery_Data/data.json', model_names='configs/gan_model.yaml' )
3
+ image_size = 256 # the generated image resolution
4
+ train_batch_size = 32
5
+ eval_batch_size = 16
6
+ use_fsdp=False # if use FSDP mode
7
+ valid_num=0 # take as valid aspect-ratio when sample number >= valid_num
8
+ triplane_size = (256*4, 256)
9
+ # model setting
10
+ image_encoder_path = "/home/liuhongyu/code/IP-Adapter/models/image_encoder"
11
+ vae_triplane_config_path = "vae_model.yaml"
12
+ std_dir = '/nas8/liuhongyu/HeadGallery_Data/final_std.pt'
13
+ mean_dir = '/nas8/liuhongyu/HeadGallery_Data/final_mean.pt'
14
+ conditioning_params_dir = '/nas8/liuhongyu/HeadGallery_Data/conditioning_params.pkl'
15
+ gan_model_base_dir = '/nas8/liuhongyu/HeadGallery_Data/gan_models'
16
+ model = 'PixArt_XL_2'
17
+ aspect_ratio_type = None # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
18
+ multi_scale = False # if use multiscale dataset model training
19
+ lewei_scale = 1.0 # lewei_scale for positional embedding interpolation
20
+ # training setting
21
+ num_workers=4
22
+ train_sampling_steps = 1000
23
+ eval_sampling_steps = 250
24
+ model_max_length = 8
25
+ lora_rank = 4
26
+
27
+ num_epochs = 80
28
+ gradient_accumulation_steps = 1
29
+ grad_checkpointing = False
30
+ gradient_clip = 1.0
31
+ gc_step = 1
32
+ auto_lr = dict(rule='sqrt')
33
+
34
+ # we use different weight decay with the official implementation since it results better result
35
+ optimizer = dict(type='AdamW', lr=1e-4, weight_decay=3e-2, eps=1e-10)
36
+ lr_schedule = 'constant'
37
+ lr_schedule_args = dict(num_warmup_steps=500)
38
+
39
+ save_image_epochs = 1
40
+ save_model_epochs = 1
41
+ save_model_steps=1000000
42
+
43
+ sample_posterior = True
44
+ mixed_precision = 'fp16'
45
+ scale_factor = 0.3994218
46
+ ema_rate = 0.9999
47
+ tensorboard_mox_interval = 50
48
+ log_interval = 50
49
+ cfg_scale = 4
50
+ mask_type='null'
51
+ num_group_tokens=0
52
+ mask_loss_coef=0.
53
+ load_mask_index=False # load prepared mask_type index
54
+ # load model settings
55
+ vae_pretrained = "/cache/pretrained_models/sd-vae-ft-ema"
56
+ load_from = None
57
+ resume_from = dict(checkpoint=None, load_ema=False, resume_optimizer=True, resume_lr_scheduler=True)
58
+ snr_loss=False
59
+
60
+ # work dir settings
61
+ work_dir = '/cache/exps/'
62
+ s3_work_dir = None
63
+
64
+ seed = 43
DiT_VAE/diffusion/configs/PixArt_xl2_img256_4D_Triplane.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['./PixArt_xl2_4D_Triplane.py']
2
+ data_root = 'data'
3
+
4
+ data = dict(type='TriplaneData', data_base_dir='/nas8/liuhongyu/HeadGallery_Data',
5
+ data_json_file='/nas8/liuhongyu/HeadGallery_Data/data_combine.json', model_names='configs/gan_model.yaml',
6
+ dino_path='/nas8/liuhongyu/model/dinov2-base')
7
+ image_size = 256
8
+
9
+ # model setting
10
+ gan_model_config = "./configs/gan_model.yaml"
11
+ image_encoder_path = "/home/liuhongyu/code/IP-Adapter/models/image_encoder"
12
+ vae_triplane_config_path = "./vae_model.yaml"
13
+ std_dir = '/nas8/liuhongyu/HeadGallery_Data/final_std.pt'
14
+ mean_dir = '/nas8/liuhongyu/HeadGallery_Data/final_mean.pt'
15
+ conditioning_params_dir = '/nas8/liuhongyu/HeadGallery_Data/conditioning_params.pkl'
16
+ gan_model_base_dir = '/nas8/liuhongyu/HeadGallery_Data/gan_models'
17
+ dino_pretrained = '/nas8/liuhongyu/HeadGallery_Data/dinov2-base'
18
+ window_block_indexes = []
19
+ window_size = 0
20
+ use_rel_pos = False
21
+ model = 'PixArt_XL_2'
22
+ fp32_attention = True
23
+ dino_norm = False
24
+ img_feature_self_attention = False
25
+ load_from = None
26
+ vae_pretrained = "/nas8/liuhongyu/all_training_results/VAE/checkpoint-140000"
27
+ # training setting
28
+ eval_sampling_steps = 200
29
+ save_model_steps = 10000
30
+ num_workers = 2
31
+ train_batch_size = 8 # 32 # max 96 for PixArt-L/4 when grad_checkpoint
32
+ num_epochs = 200 # 3
33
+ gradient_accumulation_steps = 1
34
+ grad_checkpointing = True
35
+ gradient_clip = 0.01
36
+ optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
37
+ lr_schedule_args = dict(num_warmup_steps=1000)
38
+
39
+ log_interval = 20
40
+ save_model_epochs = 5
41
+ work_dir = 'output/debug'
DiT_VAE/diffusion/configs/__init__.py ADDED
File without changes
DiT_VAE/diffusion/configs/vae_model.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ embed_dim: 8
2
+ ddconfig:
3
+ double_z: True
4
+ z_channels: 8
5
+ encoder:
6
+ target: DiT_VAE.vae.aemodules3d.Encoder
7
+ params:
8
+ n_hiddens: 128
9
+ downsample: [4, 8, 8]
10
+ image_channel: 32
11
+ norm_type: group
12
+ padding_type: replicate
13
+ double_z: True
14
+ z_channels: 8
15
+
16
+ decoder:
17
+ target: DiT_VAE.vae.aemodules3d.Decoder
18
+ params:
19
+ n_hiddens: 128
20
+ upsample: [4, 8, 8]
21
+ z_channels: 8
22
+ image_channel: 32
23
+ norm_type: group
24
+
DiT_VAE/diffusion/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .datasets import *
2
+ from .transforms import get_transform
DiT_VAE/diffusion/data/builder.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ from mmcv import Registry, build_from_cfg
5
+ from torch.utils.data import DataLoader
6
+
7
+ from DiT_VAE.diffusion.data.transforms import get_transform
8
+ from DiT_VAE.diffusion.utils.logger import get_root_logger
9
+
10
+ DATASETS = Registry('datasets')
11
+
12
+ DATA_ROOT = '/cache/data'
13
+
14
+
15
+ def set_data_root(data_root):
16
+ global DATA_ROOT
17
+ DATA_ROOT = data_root
18
+
19
+
20
+ def get_data_path(data_dir):
21
+ if os.path.isabs(data_dir):
22
+ return data_dir
23
+ global DATA_ROOT
24
+ return os.path.join(DATA_ROOT, data_dir)
25
+
26
+
27
+ def build_dataset_triplane(cfg, resolution=256, **kwargs):
28
+ logger = get_root_logger()
29
+
30
+ dataset_type = cfg.get('type')
31
+ logger.info(f"Constructing dataset {dataset_type}...")
32
+ t = time.time()
33
+
34
+ dataset = build_from_cfg(cfg, DATASETS, default_args=dict( image_size=resolution, **kwargs))
35
+ logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length: {len(dataset)} ")
36
+ return dataset
37
+ def build_dataset(cfg, resolution=224, **kwargs):
38
+ logger = get_root_logger()
39
+
40
+ dataset_type = cfg.get('type')
41
+ logger.info(f"Constructing dataset {dataset_type}...")
42
+ t = time.time()
43
+ transform = cfg.pop('transform', 'default_train')
44
+ transform = get_transform(transform, resolution)
45
+ dataset = build_from_cfg(cfg, DATASETS, default_args=dict(transform=transform, resolution=resolution, **kwargs))
46
+ logger.info(f"Dataset {dataset_type} constructed. time: {(time.time() - t):.2f} s, length (use/ori): {len(dataset)}/{dataset.ori_imgs_nums}")
47
+ return dataset
48
+
49
+
50
+ def build_dataloader(dataset, batch_size=256, num_workers=2, shuffle=True, **kwargs):
51
+ return (
52
+ DataLoader(
53
+ dataset,
54
+ batch_sampler=kwargs['batch_sampler'],
55
+ num_workers=num_workers,
56
+ pin_memory=True,
57
+ )
58
+ if 'batch_sampler' in kwargs
59
+ else DataLoader(
60
+ dataset,
61
+ batch_size=batch_size,
62
+ shuffle=shuffle,
63
+ num_workers=num_workers,
64
+ pin_memory=True,
65
+ **kwargs
66
+ )
67
+ )
DiT_VAE/diffusion/data/transforms.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as T
2
+
3
+ TRANSFORMS = {}
4
+
5
+
6
+ def register_transform(transform):
7
+ name = transform.__name__
8
+ if name in TRANSFORMS:
9
+ raise RuntimeError(f'Transform {name} has already registered.')
10
+ TRANSFORMS.update({name: transform})
11
+
12
+
13
+ def get_transform(type, resolution):
14
+ transform = TRANSFORMS[type](resolution)
15
+ transform = T.Compose(transform)
16
+ transform.image_size = resolution
17
+ return transform
18
+
19
+
20
+ @register_transform
21
+ def default_train(n_px):
22
+ return [
23
+ T.Lambda(lambda img: img.convert('RGB')),
24
+ T.Resize(n_px), # Image.BICUBIC
25
+ T.CenterCrop(n_px),
26
+ # T.RandomHorizontalFlip(),
27
+ T.ToTensor(),
28
+ T.Normalize([0.5], [0.5]),
29
+ ]
DiT_VAE/diffusion/dpm_solver.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .model import gaussian_diffusion as gd
3
+ from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP
4
+
5
+
6
+ def DPMS(model, condition, uncondition, cfg_scale, model_type='noise', noise_schedule="linear", guidance_type='classifier-free', model_kwargs=None, diffusion_steps=1000):
7
+ if model_kwargs is None:
8
+ model_kwargs = {}
9
+ betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
10
+
11
+ ## 1. Define the noise schedule.
12
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
13
+
14
+ ## 2. Convert your discrete-time `model` to the continuous-time
15
+ ## noise prediction model. Here is an example for a diffusion model
16
+ ## `model` with the noise prediction type ("noise") .
17
+ model_fn = model_wrapper(
18
+ model,
19
+ noise_schedule,
20
+ model_type=model_type,
21
+ model_kwargs=model_kwargs,
22
+ guidance_type=guidance_type,
23
+ condition=condition,
24
+ unconditional_condition=uncondition,
25
+ guidance_scale=cfg_scale,
26
+ )
27
+ ## 3. Define dpm-solver and sample by multistep DPM-Solver.
28
+ return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
DiT_VAE/diffusion/iddpm.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+ from .model.respace import SpacedDiffusion, space_timesteps
6
+ from .model import gaussian_diffusion as gd
7
+
8
+
9
+ def IDDPM(
10
+ timestep_respacing,
11
+ noise_schedule="linear",
12
+ use_kl=False,
13
+ sigma_small=False,
14
+ predict_xstart=False,
15
+ learn_sigma=True,
16
+ pred_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000,
19
+ snr=False,
20
+ return_startx=False,
21
+ ):
22
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
23
+ if use_kl:
24
+ loss_type = gd.LossType.RESCALED_KL
25
+ elif rescale_learned_sigmas:
26
+ loss_type = gd.LossType.RESCALED_MSE
27
+ else:
28
+ loss_type = gd.LossType.MSE
29
+ if timestep_respacing is None or timestep_respacing == "":
30
+ timestep_respacing = [diffusion_steps]
31
+ return SpacedDiffusion(
32
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
33
+ betas=betas,
34
+ model_mean_type=(
35
+ gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON
36
+ ),
37
+ model_var_type=(
38
+ (gd.ModelVarType.LEARNED_RANGE if learn_sigma else (
39
+ gd.ModelVarType.FIXED_LARGE
40
+ if not sigma_small
41
+ else gd.ModelVarType.FIXED_SMALL
42
+ )
43
+ )
44
+ if pred_sigma
45
+ else None
46
+ ),
47
+ loss_type=loss_type,
48
+ snr=snr,
49
+ return_startx=return_startx,
50
+ # rescale_timesteps=rescale_timesteps,
51
+ )
DiT_VAE/diffusion/lcm_scheduler.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers import ConfigMixin, SchedulerMixin
26
+ from diffusers.configuration_utils import register_to_config
27
+ from diffusers.utils import BaseOutput
28
+
29
+
30
+ @dataclass
31
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
+ class LCMSchedulerOutput(BaseOutput):
33
+ """
34
+ Output class for the scheduler's `step` function output.
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
+ `pred_original_sample` can be used to preview progress or for guidance.
42
+ """
43
+
44
+ prev_sample: torch.FloatTensor
45
+ denoised: Optional[torch.FloatTensor] = None
46
+
47
+
48
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
49
+ def betas_for_alpha_bar(
50
+ num_diffusion_timesteps,
51
+ max_beta=0.999,
52
+ alpha_transform_type="cosine",
53
+ ):
54
+ """
55
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
56
+ (1-beta) over time from t = [0,1].
57
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
58
+ to that part of the diffusion process.
59
+ Args:
60
+ num_diffusion_timesteps (`int`): the number of betas to produce.
61
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
62
+ prevent singularities.
63
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
64
+ Choose from `cosine` or `exp`
65
+ Returns:
66
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
67
+ """
68
+ if alpha_transform_type == "cosine":
69
+
70
+ def alpha_bar_fn(t):
71
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
72
+
73
+ elif alpha_transform_type == "exp":
74
+
75
+ def alpha_bar_fn(t):
76
+ return math.exp(t * -12.0)
77
+
78
+ else:
79
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
80
+
81
+ betas = []
82
+ for i in range(num_diffusion_timesteps):
83
+ t1 = i / num_diffusion_timesteps
84
+ t2 = (i + 1) / num_diffusion_timesteps
85
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
86
+ return torch.tensor(betas, dtype=torch.float32)
87
+
88
+
89
+ def rescale_zero_terminal_snr(betas):
90
+ """
91
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
92
+ Args:
93
+ betas (`torch.FloatTensor`):
94
+ the betas that the scheduler is being initialized with.
95
+ Returns:
96
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
97
+ """
98
+ # Convert betas to alphas_bar_sqrt
99
+ alphas = 1.0 - betas
100
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
101
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
102
+
103
+ # Store old values.
104
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
105
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
106
+
107
+ # Shift so the last timestep is zero.
108
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
109
+
110
+ # Scale so the first timestep is back to the old value.
111
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
112
+
113
+ # Convert alphas_bar_sqrt to betas
114
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
115
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
116
+ alphas = torch.cat([alphas_bar[:1], alphas])
117
+ betas = 1 - alphas
118
+
119
+ return betas
120
+
121
+
122
+ class LCMScheduler(SchedulerMixin, ConfigMixin):
123
+ """
124
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic vae (DDPMs) with
125
+ non-Markovian guidance.
126
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
127
+ methods the library implements for all schedulers such as loading and saving.
128
+ Args:
129
+ num_train_timesteps (`int`, defaults to 1000):
130
+ The number of diffusion steps to train the model.
131
+ beta_start (`float`, defaults to 0.0001):
132
+ The starting `beta` value of inference.
133
+ beta_end (`float`, defaults to 0.02):
134
+ The final `beta` value.
135
+ beta_schedule (`str`, defaults to `"linear"`):
136
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
137
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
138
+ trained_betas (`np.ndarray`, *optional*):
139
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
140
+ clip_sample (`bool`, defaults to `True`):
141
+ Clip the predicted sample for numerical stability.
142
+ clip_sample_range (`float`, defaults to 1.0):
143
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
144
+ set_alpha_to_one (`bool`, defaults to `True`):
145
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
146
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
147
+ otherwise it uses the alpha value at step 0.
148
+ steps_offset (`int`, defaults to 0):
149
+ An offset added to the inference steps. You can use a combination of `offset=1` and
150
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
151
+ Diffusion.
152
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
153
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
154
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
155
+ Video](https://imagen.research.google/video/paper.pdf) paper).
156
+ thresholding (`bool`, defaults to `False`):
157
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion vae such
158
+ as Stable Diffusion.
159
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
160
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
161
+ sample_max_value (`float`, defaults to 1.0):
162
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
163
+ timestep_spacing (`str`, defaults to `"leading"`):
164
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
165
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
166
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
167
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
168
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
169
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
170
+ """
171
+
172
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
173
+ order = 1
174
+
175
+ @register_to_config
176
+ def __init__(
177
+ self,
178
+ num_train_timesteps: int = 1000,
179
+ beta_start: float = 0.0001,
180
+ beta_end: float = 0.02,
181
+ beta_schedule: str = "linear",
182
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
183
+ clip_sample: bool = True,
184
+ set_alpha_to_one: bool = True,
185
+ steps_offset: int = 0,
186
+ prediction_type: str = "epsilon",
187
+ thresholding: bool = False,
188
+ dynamic_thresholding_ratio: float = 0.995,
189
+ clip_sample_range: float = 1.0,
190
+ sample_max_value: float = 1.0,
191
+ timestep_spacing: str = "leading",
192
+ rescale_betas_zero_snr: bool = False,
193
+ ):
194
+ if trained_betas is not None:
195
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
196
+ elif beta_schedule == "linear":
197
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
198
+ elif beta_schedule == "scaled_linear":
199
+ # this schedule is very specific to the latent diffusion model.
200
+ self.betas = (
201
+ torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
202
+ )
203
+ elif beta_schedule == "squaredcos_cap_v2":
204
+ # Glide cosine schedule
205
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
206
+ else:
207
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
208
+
209
+ # Rescale for zero SNR
210
+ if rescale_betas_zero_snr:
211
+ self.betas = rescale_zero_terminal_snr(self.betas)
212
+
213
+ self.alphas = 1.0 - self.betas
214
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
215
+
216
+ # At every step in ddim, we are looking into the previous alphas_cumprod
217
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
218
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
219
+ # whether we use the final alpha of the "non-previous" one.
220
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
221
+
222
+ # standard deviation of the initial noise distribution
223
+ self.init_noise_sigma = 1.0
224
+
225
+ # setable values
226
+ self.num_inference_steps = None
227
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
228
+
229
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
230
+ """
231
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
232
+ current timestep.
233
+ Args:
234
+ sample (`torch.FloatTensor`):
235
+ The input sample.
236
+ timestep (`int`, *optional*):
237
+ The current timestep in the diffusion chain.
238
+ Returns:
239
+ `torch.FloatTensor`:
240
+ A scaled input sample.
241
+ """
242
+ return sample
243
+
244
+ def _get_variance(self, timestep, prev_timestep):
245
+ alpha_prod_t = self.alphas_cumprod[timestep]
246
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
247
+ beta_prod_t = 1 - alpha_prod_t
248
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
249
+
250
+ return (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
251
+
252
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
253
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
254
+ """
255
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
256
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
257
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
258
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
259
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
260
+ https://arxiv.org/abs/2205.11487
261
+ """
262
+ dtype = sample.dtype
263
+ batch_size, channels, height, width = sample.shape
264
+
265
+ if dtype not in (torch.float32, torch.float64):
266
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
267
+
268
+ # Flatten sample for doing quantile calculation along each image
269
+ sample = sample.reshape(batch_size, channels * height * width)
270
+
271
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
272
+
273
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
274
+ s = torch.clamp(
275
+ s, min=1, max=self.config.sample_max_value
276
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
277
+
278
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
279
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
280
+
281
+ sample = sample.reshape(batch_size, channels, height, width)
282
+ sample = sample.to(dtype)
283
+
284
+ return sample
285
+
286
+ def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
287
+ """
288
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
289
+ Args:
290
+ num_inference_steps (`int`):
291
+ The number of diffusion steps used when generating samples with a pre-trained model.
292
+ """
293
+
294
+ if num_inference_steps > self.config.num_train_timesteps:
295
+ raise ValueError(
296
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
297
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
298
+ f" maximal {self.config.num_train_timesteps} timesteps."
299
+ )
300
+
301
+ self.num_inference_steps = num_inference_steps
302
+
303
+ # LCM Timesteps Setting: # Linear Spacing
304
+ c = self.config.num_train_timesteps // lcm_origin_steps
305
+ lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
306
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
307
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
308
+
309
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
310
+
311
+ def get_scalings_for_boundary_condition_discrete(self, t):
312
+ self.sigma_data = 0.5 # Default: 0.5
313
+
314
+ # By dividing 0.1: This is almost a delta function at t=0.
315
+ c_skip = self.sigma_data ** 2 / ((t / 0.1) ** 2 + self.sigma_data ** 2)
316
+ c_out = ((t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data ** 2) ** 0.5)
317
+ return c_skip, c_out
318
+
319
+ def step(
320
+ self,
321
+ model_output: torch.FloatTensor,
322
+ timeindex: int,
323
+ timestep: int,
324
+ sample: torch.FloatTensor,
325
+ eta: float = 0.0,
326
+ use_clipped_model_output: bool = False,
327
+ generator=None,
328
+ variance_noise: Optional[torch.FloatTensor] = None,
329
+ return_dict: bool = True,
330
+ ) -> Union[LCMSchedulerOutput, Tuple]:
331
+ """
332
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
333
+ process from the learned model outputs (most often the predicted noise).
334
+ Args:
335
+ model_output (`torch.FloatTensor`):
336
+ The direct output from learned diffusion model.
337
+ timestep (`float`):
338
+ The current discrete timestep in the diffusion chain.
339
+ sample (`torch.FloatTensor`):
340
+ A current instance of a sample created by the diffusion process.
341
+ eta (`float`):
342
+ The weight of noise for added noise in diffusion step.
343
+ use_clipped_model_output (`bool`, defaults to `False`):
344
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
345
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
346
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
347
+ `use_clipped_model_output` has no effect.
348
+ generator (`torch.Generator`, *optional*):
349
+ A random number generator.
350
+ variance_noise (`torch.FloatTensor`):
351
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
352
+ itself. Useful for methods such as [`CycleDiffusion`].
353
+ return_dict (`bool`, *optional*, defaults to `True`):
354
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
355
+ Returns:
356
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
357
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
358
+ tuple is returned where the first element is the sample tensor.
359
+ """
360
+ if self.num_inference_steps is None:
361
+ raise ValueError(
362
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
363
+ )
364
+
365
+ # 1. get previous step value
366
+ prev_timeindex = timeindex + 1
367
+ if prev_timeindex < len(self.timesteps):
368
+ prev_timestep = self.timesteps[prev_timeindex]
369
+ else:
370
+ prev_timestep = timestep
371
+
372
+ # 2. compute alphas, betas
373
+ alpha_prod_t = self.alphas_cumprod[timestep]
374
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
375
+
376
+ beta_prod_t = 1 - alpha_prod_t
377
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
378
+
379
+ # 3. Get scalings for boundary conditions
380
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
381
+
382
+ # 4. Different Parameterization:
383
+ parameterization = self.config.prediction_type
384
+
385
+ if parameterization == "epsilon": # noise-prediction
386
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
387
+
388
+ elif parameterization == "sample": # x-prediction
389
+ pred_x0 = model_output
390
+
391
+ elif parameterization == "v_prediction": # v-prediction
392
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
393
+
394
+ # 4. Denoise model output using boundary conditions
395
+ denoised = c_out * pred_x0 + c_skip * sample
396
+
397
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
398
+ # Noise is not used for one-step sampling.
399
+ if len(self.timesteps) > 1:
400
+ noise = torch.randn(model_output.shape).to(model_output.device)
401
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
402
+ else:
403
+ prev_sample = denoised
404
+
405
+ if not return_dict:
406
+ return (prev_sample, denoised)
407
+
408
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
409
+
410
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
411
+ def add_noise(
412
+ self,
413
+ original_samples: torch.FloatTensor,
414
+ noise: torch.FloatTensor,
415
+ timesteps: torch.IntTensor,
416
+ ) -> torch.FloatTensor:
417
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
418
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
419
+ timesteps = timesteps.to(original_samples.device)
420
+
421
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
422
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
423
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
424
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
425
+
426
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
427
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
428
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
429
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
430
+
431
+ return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
432
+
433
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
434
+ def get_velocity(
435
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
436
+ ) -> torch.FloatTensor:
437
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
438
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
439
+ timesteps = timesteps.to(sample.device)
440
+
441
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
442
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
443
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
444
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
445
+
446
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
447
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
448
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
449
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
450
+
451
+ return sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
452
+
453
+ def __len__(self):
454
+ return self.config.num_train_timesteps
455
+
DiT_VAE/diffusion/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .nets import *
2
+ # import utils
DiT_VAE/diffusion/model/builder.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcv import Registry
2
+
3
+ from DiT_VAE.diffusion.model.utils import set_grad_checkpoint
4
+
5
+ MODELS = Registry('vae')
6
+
7
+
8
+ def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs):
9
+ if isinstance(cfg, str):
10
+ cfg = dict(type=cfg)
11
+ model = MODELS.build(cfg, default_args=kwargs)
12
+ if use_grad_checkpoint:
13
+ set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step)
14
+ return model
DiT_VAE/diffusion/model/diffusion_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = next(
17
+ (
18
+ obj
19
+ for obj in (mean1, logvar1, mean2, logvar2)
20
+ if isinstance(obj, th.Tensor)
21
+ ),
22
+ None,
23
+ )
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ return 0.5 * (
34
+ -1.0
35
+ + logvar2
36
+ - logvar1
37
+ + th.exp(logvar1 - logvar2)
38
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39
+ )
40
+
41
+
42
+ def approx_standard_normal_cdf(x):
43
+ """
44
+ A fast approximation of the cumulative distribution function of the
45
+ standard normal.
46
+ """
47
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48
+
49
+
50
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
51
+ """
52
+ Compute the log-likelihood of a continuous Gaussian distribution.
53
+ :param x: the targets
54
+ :param means: the Gaussian mean Tensor.
55
+ :param log_scales: the Gaussian log stddev Tensor.
56
+ :return: a tensor like x of log probabilities (in nats).
57
+ """
58
+ centered_x = x - means
59
+ inv_stdv = th.exp(-log_scales)
60
+ normalized_x = centered_x * inv_stdv
61
+ return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
62
+ normalized_x
63
+ )
64
+
65
+
66
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
67
+ """
68
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
69
+ given image.
70
+ :param x: the target images. It is assumed that this was uint8 values,
71
+ rescaled to the range [-1, 1].
72
+ :param means: the Gaussian mean Tensor.
73
+ :param log_scales: the Gaussian log stddev Tensor.
74
+ :return: a tensor like x of log probabilities (in nats).
75
+ """
76
+ assert x.shape == means.shape == log_scales.shape
77
+ centered_x = x - means
78
+ inv_stdv = th.exp(-log_scales)
79
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
80
+ cdf_plus = approx_standard_normal_cdf(plus_in)
81
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
82
+ cdf_min = approx_standard_normal_cdf(min_in)
83
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
84
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
85
+ cdf_delta = cdf_plus - cdf_min
86
+ log_probs = th.where(
87
+ x < -0.999,
88
+ log_cdf_plus,
89
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
90
+ )
91
+ assert log_probs.shape == x.shape
92
+ return log_probs
DiT_VAE/diffusion/model/dpm_solver.py ADDED
@@ -0,0 +1,1337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+
4
+
5
+ class NoiseScheduleVP:
6
+ def __init__(
7
+ self,
8
+ schedule='discrete',
9
+ betas=None,
10
+ alphas_cumprod=None,
11
+ continuous_beta_0=0.1,
12
+ continuous_beta_1=20.,
13
+ dtype=torch.float32,
14
+ ):
15
+ """Create a wrapper class for the forward SDE (VP type).
16
+
17
+ ***
18
+ Update: We support discrete-time diffusion vae by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion vae, especially for high-resolution images.
20
+ ***
21
+
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+
26
+ log_alpha_t = self.marginal_log_mean_coeff(t)
27
+ sigma_t = self.marginal_std(t)
28
+ lambda_t = self.marginal_lambda(t)
29
+
30
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
31
+
32
+ t = self.inverse_lambda(lambda_t)
33
+
34
+ ===============================================================
35
+
36
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
37
+
38
+ 1. For discrete-time DPMs:
39
+
40
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
41
+ t_i = (i + 1) / N
42
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
43
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
44
+
45
+ Args:
46
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
47
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
48
+
49
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
50
+
51
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
52
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
53
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
54
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
55
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
56
+ and
57
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
58
+
59
+
60
+ 2. For continuous-time DPMs:
61
+
62
+ We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
63
+ schedule are the default settings in Yang Song's ScoreSDE:
64
+
65
+ Args:
66
+ beta_min: A `float` number. The smallest beta for the linear schedule.
67
+ beta_max: A `float` number. The largest beta for the linear schedule.
68
+ T: A `float` number. The ending time of the forward process.
69
+
70
+ ===============================================================
71
+
72
+ Args:
73
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
74
+ 'linear' for continuous-time DPMs.
75
+ Returns:
76
+ A wrapper object of the forward SDE (VP type).
77
+
78
+ ===============================================================
79
+
80
+ Example:
81
+
82
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
83
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
84
+
85
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
86
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
87
+
88
+ # For continuous-time DPMs (VPSDE), linear schedule:
89
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
90
+
91
+ """
92
+
93
+ if schedule not in ['discrete', 'linear']:
94
+ raise ValueError(
95
+ f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear'"
96
+ )
97
+
98
+ self.schedule = schedule
99
+ if schedule == 'discrete':
100
+ if betas is not None:
101
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
102
+ else:
103
+ assert alphas_cumprod is not None
104
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
105
+ self.T = 1.
106
+ self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
107
+ self.total_N = self.log_alpha_array.shape[1]
108
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
109
+ else:
110
+ self.T = 1.
111
+ self.total_N = 1000
112
+ self.beta_0 = continuous_beta_0
113
+ self.beta_1 = continuous_beta_1
114
+
115
+ def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
116
+ """
117
+ For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
118
+ We clip the log-SNR near t=T within -5.1 to ensure the stability.
119
+ Such a trick is very useful for diffusion vae with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
120
+ """
121
+ log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
122
+ lambs = log_alphas - log_sigmas
123
+ idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
124
+ if idx > 0:
125
+ log_alphas = log_alphas[:-idx]
126
+ return log_alphas
127
+
128
+ def marginal_log_mean_coeff(self, t):
129
+ """
130
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
131
+ """
132
+ if self.schedule == 'discrete':
133
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
134
+ self.log_alpha_array.to(t.device)).reshape((-1))
135
+ elif self.schedule == 'linear':
136
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
137
+
138
+ def marginal_alpha(self, t):
139
+ """
140
+ Compute alpha_t of a given continuous-time label t in [0, T].
141
+ """
142
+ return torch.exp(self.marginal_log_mean_coeff(t))
143
+
144
+ def marginal_std(self, t):
145
+ """
146
+ Compute sigma_t of a given continuous-time label t in [0, T].
147
+ """
148
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
149
+
150
+ def marginal_lambda(self, t):
151
+ """
152
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
153
+ """
154
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
155
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
156
+ return log_mean_coeff - log_std
157
+
158
+ def inverse_lambda(self, lamb):
159
+ """
160
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
161
+ """
162
+ if self.schedule == 'linear':
163
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
164
+ Delta = self.beta_0 ** 2 + tmp
165
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
166
+ elif self.schedule == 'discrete':
167
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
168
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
169
+ torch.flip(self.t_array.to(lamb.device), [1]))
170
+ return t.reshape((-1,))
171
+
172
+
173
+ def model_wrapper(
174
+ model,
175
+ noise_schedule,
176
+ model_type="noise",
177
+ model_kwargs={},
178
+ guidance_type="uncond",
179
+ condition=None,
180
+ unconditional_condition=None,
181
+ guidance_scale=1.,
182
+ classifier_fn=None,
183
+ classifier_kwargs={},
184
+ ):
185
+ """Create a wrapper function for the noise prediction model.
186
+
187
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
188
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
189
+
190
+ We support four types of the diffusion model by setting `model_type`:
191
+
192
+ 1. "noise": noise prediction model. (Trained by predicting noise).
193
+
194
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
195
+
196
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
197
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
198
+
199
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion vae."
200
+ arXiv preprint arXiv:2202.00512 (2022).
201
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
202
+ arXiv preprint arXiv:2210.02303 (2022).
203
+
204
+ 4. "score": marginal score function. (Trained by denoising score matching).
205
+ Note that the score function and the noise prediction model follows a simple relationship:
206
+ ```
207
+ noise(x_t, t) = -sigma_t * score(x_t, t)
208
+ ```
209
+
210
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
211
+ 1. "uncond": unconditional sampling by DPMs.
212
+ The input `model` has the following format:
213
+ ``
214
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
215
+ ``
216
+
217
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
218
+ The input `model` has the following format:
219
+ ``
220
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
221
+ ``
222
+
223
+ The input `classifier_fn` has the following format:
224
+ ``
225
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
226
+ ``
227
+
228
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion vae beat GANs on image synthesis,"
229
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
230
+
231
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
232
+ The input `model` has the following format:
233
+ ``
234
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
235
+ ``
236
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
237
+
238
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
239
+ arXiv preprint arXiv:2207.12598 (2022).
240
+
241
+
242
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
243
+ or continuous-time labels (i.e. epsilon to T).
244
+
245
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
246
+ ``
247
+ def model_fn(x, t_continuous) -> noise:
248
+ t_input = get_model_input_time(t_continuous)
249
+ return noise_pred(model, x, t_input, **model_kwargs)
250
+ ``
251
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
252
+
253
+ ===============================================================
254
+
255
+ Args:
256
+ model: A diffusion model with the corresponding format described above.
257
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
258
+ model_type: A `str`. The parameterization type of the diffusion model.
259
+ "noise" or "x_start" or "v" or "score".
260
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
261
+ guidance_type: A `str`. The type of the guidance for sampling.
262
+ "uncond" or "classifier" or "classifier-free".
263
+ condition: A pytorch tensor. The condition for the guided sampling.
264
+ Only used for "classifier" or "classifier-free" guidance type.
265
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
266
+ Only used for "classifier-free" guidance type.
267
+ guidance_scale: A `float`. The scale for the guided sampling.
268
+ classifier_fn: A classifier function. Only used for the classifier guidance.
269
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
270
+ Returns:
271
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
272
+ """
273
+
274
+ def get_model_input_time(t_continuous):
275
+ """
276
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
277
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
278
+ For continuous-time DPMs, we just use `t_continuous`.
279
+ """
280
+ if noise_schedule.schedule == 'discrete':
281
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
282
+ else:
283
+ return t_continuous
284
+
285
+ def noise_pred_fn(x, t_continuous, cond=None, cond_2=None):
286
+ t_input = get_model_input_time(t_continuous)
287
+ if cond is None:
288
+ output = model(x, t_input, **model_kwargs)
289
+ else:
290
+ output = model(x, t_input, y=cond, img_feature=cond_2, **model_kwargs)
291
+ if model_type == "noise":
292
+ return output
293
+ elif model_type == "x_start":
294
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
295
+ return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
296
+ elif model_type == "v":
297
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
298
+ return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
299
+ elif model_type == "score":
300
+ sigma_t = noise_schedule.marginal_std(t_continuous)
301
+ return -expand_dims(sigma_t, x.dim()) * output
302
+
303
+ def cond_grad_fn(x, t_input):
304
+ """
305
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
306
+ """
307
+ with torch.enable_grad():
308
+ x_in = x.detach().requires_grad_(True)
309
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
310
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
311
+
312
+ def model_fn(x, t_continuous):
313
+ """
314
+ The noise predicition model function that is used for DPM-Solver.
315
+ """
316
+ if guidance_type == "uncond":
317
+ return noise_pred_fn(x, t_continuous)
318
+ elif guidance_type == "classifier":
319
+ assert classifier_fn is not None
320
+ t_input = get_model_input_time(t_continuous)
321
+ cond_grad = cond_grad_fn(x, t_input)
322
+ sigma_t = noise_schedule.marginal_std(t_continuous)
323
+ noise = noise_pred_fn(x, t_continuous)
324
+ return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
325
+ elif guidance_type == "classifier-free":
326
+ if guidance_scale == 1. or unconditional_condition is None:
327
+ return noise_pred_fn(x, t_continuous, cond=condition)
328
+ x_in = torch.cat([x] * 2)
329
+ t_in = torch.cat([t_continuous] * 2)
330
+ # c_in = torch.cat([unconditional_condition, condition])
331
+ c_in_y = torch.cat([unconditional_condition[0], condition[0]])
332
+ c_in_dino = torch.cat([unconditional_condition[1], condition[1]])
333
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in_y, cond_2=c_in_dino).chunk(2)
334
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
335
+
336
+ assert model_type in ["noise", "x_start", "v", "score"]
337
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
338
+ return model_fn
339
+
340
+
341
+ class DPM_Solver:
342
+ def __init__(
343
+ self,
344
+ model_fn,
345
+ noise_schedule,
346
+ algorithm_type="dpmsolver++",
347
+ correcting_x0_fn=None,
348
+ correcting_xt_fn=None,
349
+ thresholding_max_val=1.,
350
+ dynamic_thresholding_ratio=0.995,
351
+ ):
352
+ """Construct a DPM-Solver.
353
+
354
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
355
+
356
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion vae, you
357
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
358
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
359
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
360
+ DPMs (such as stable-diffusion).
361
+
362
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
363
+ both x0 and xt.
364
+
365
+ Args:
366
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
367
+ ``
368
+ def model_fn(x, t_continuous):
369
+ return noise
370
+ ``
371
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
372
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
373
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
374
+ correcting_x0_fn: A `str` or a function with the following format:
375
+ ```
376
+ def correcting_x0_fn(x0, t):
377
+ x0_new = ...
378
+ return x0_new
379
+ ```
380
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
381
+ ```
382
+ x0_pred = data_pred_model(xt, t)
383
+ if correcting_x0_fn is not None:
384
+ x0_pred = correcting_x0_fn(x0_pred, t)
385
+ xt_1 = update(x0_pred, xt, t)
386
+ ```
387
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
388
+ correcting_xt_fn: A function with the following format:
389
+ ```
390
+ def correcting_xt_fn(xt, t, step):
391
+ x_new = ...
392
+ return x_new
393
+ ```
394
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
395
+ ```
396
+ xt = ...
397
+ xt = correcting_xt_fn(xt, t, step)
398
+ ```
399
+ thresholding_max_val: A `float`. The max value for thresholding.
400
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
401
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
402
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
403
+
404
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
405
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion vae
406
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
407
+ """
408
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
409
+ self.noise_schedule = noise_schedule
410
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
411
+ self.algorithm_type = algorithm_type
412
+ if correcting_x0_fn == "dynamic_thresholding":
413
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
414
+ else:
415
+ self.correcting_x0_fn = correcting_x0_fn
416
+ self.correcting_xt_fn = correcting_xt_fn
417
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
418
+ self.thresholding_max_val = thresholding_max_val
419
+
420
+ def dynamic_thresholding_fn(self, x0, t):
421
+ """
422
+ The dynamic thresholding method.
423
+ """
424
+ dims = x0.dim()
425
+ p = self.dynamic_thresholding_ratio
426
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
427
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
428
+ x0 = torch.clamp(x0, -s, s) / s
429
+ return x0
430
+
431
+ def noise_prediction_fn(self, x, t):
432
+ """
433
+ Return the noise prediction model.
434
+ """
435
+ return self.model(x, t)
436
+
437
+ def data_prediction_fn(self, x, t):
438
+ """
439
+ Return the data prediction model (with corrector).
440
+ """
441
+ noise = self.noise_prediction_fn(x, t)
442
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
443
+ x0 = (x - sigma_t * noise) / alpha_t
444
+ if self.correcting_x0_fn is not None:
445
+ x0 = self.correcting_x0_fn(x0, t)
446
+ return x0
447
+
448
+ def model_fn(self, x, t):
449
+ """
450
+ Convert the model to the noise prediction model or the data prediction model.
451
+ """
452
+ if self.algorithm_type == "dpmsolver++":
453
+ return self.data_prediction_fn(x, t)
454
+ else:
455
+ return self.noise_prediction_fn(x, t)
456
+
457
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
458
+ """Compute the intermediate time steps for sampling.
459
+
460
+ Args:
461
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
462
+ - 'logSNR': uniform logSNR for the time steps.
463
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
464
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
465
+ t_T: A `float`. The starting time of the sampling (default is T).
466
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
467
+ N: A `int`. The total number of the spacing of the time steps.
468
+ device: A torch device.
469
+ Returns:
470
+ A pytorch tensor of the time steps, with the shape (N + 1,).
471
+ """
472
+ if skip_type == 'logSNR':
473
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
474
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
475
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
476
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
477
+ elif skip_type == 'time_uniform':
478
+ return torch.linspace(t_T, t_0, N + 1).to(device)
479
+ elif skip_type == 'time_quadratic':
480
+ t_order = 2
481
+ return (
482
+ torch.linspace(
483
+ t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1
484
+ )
485
+ .pow(t_order)
486
+ .to(device)
487
+ )
488
+ else:
489
+ raise ValueError(
490
+ f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
491
+ )
492
+
493
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
494
+ """
495
+ Get the order of each step for sampling by the singlestep DPM-Solver.
496
+
497
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
498
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
499
+ - If order == 1:
500
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
501
+ - If order == 2:
502
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
503
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
504
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
505
+ - If order == 3:
506
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
507
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
508
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
509
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
510
+
511
+ ============================================
512
+ Args:
513
+ order: A `int`. The max order for the solver (2 or 3).
514
+ steps: A `int`. The total number of function evaluations (NFE).
515
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
516
+ - 'logSNR': uniform logSNR for the time steps.
517
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
518
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
519
+ t_T: A `float`. The starting time of the sampling (default is T).
520
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
521
+ device: A torch device.
522
+ Returns:
523
+ orders: A list of the solver order of each step.
524
+ """
525
+ if order == 3:
526
+ K = steps // 3 + 1
527
+ if steps % 3 == 0:
528
+ orders = [3, ] * (K - 2) + [2, 1]
529
+ elif steps % 3 == 1:
530
+ orders = [3, ] * (K - 1) + [1]
531
+ else:
532
+ orders = [3, ] * (K - 1) + [2]
533
+ elif order == 2:
534
+ if steps % 2 == 0:
535
+ K = steps // 2
536
+ orders = [2, ] * K
537
+ else:
538
+ K = steps // 2 + 1
539
+ orders = [2, ] * (K - 1) + [1]
540
+ elif order == 1:
541
+ K = 1
542
+ orders = [1, ] * steps
543
+ else:
544
+ raise ValueError("'order' must be '1' or '2' or '3'.")
545
+ if skip_type == 'logSNR':
546
+ # To reproduce the results in DPM-Solver paper
547
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
548
+ else:
549
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
550
+ torch.cumsum(torch.tensor([0, ] + orders), 0).to(device)]
551
+ return timesteps_outer, orders
552
+
553
+ def denoise_to_zero_fn(self, x, s):
554
+ """
555
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
556
+ """
557
+ return self.data_prediction_fn(x, s)
558
+
559
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
560
+ """
561
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
562
+
563
+ Args:
564
+ x: A pytorch tensor. The initial value at time `s`.
565
+ s: A pytorch tensor. The starting time, with the shape (1,).
566
+ t: A pytorch tensor. The ending time, with the shape (1,).
567
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
568
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
569
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
570
+ Returns:
571
+ x_t: A pytorch tensor. The approximated solution at time `t`.
572
+ """
573
+ ns = self.noise_schedule
574
+ dims = x.dim()
575
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
576
+ h = lambda_t - lambda_s
577
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
578
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
579
+ alpha_t = torch.exp(log_alpha_t)
580
+
581
+ if self.algorithm_type == "dpmsolver++":
582
+ phi_1 = torch.expm1(-h)
583
+ if model_s is None:
584
+ model_s = self.model_fn(x, s)
585
+ x_t = (
586
+ sigma_t / sigma_s * x
587
+ - alpha_t * phi_1 * model_s
588
+ )
589
+ else:
590
+ phi_1 = torch.expm1(h)
591
+ if model_s is None:
592
+ model_s = self.model_fn(x, s)
593
+ x_t = (
594
+ torch.exp(log_alpha_t - log_alpha_s) * x
595
+ - (sigma_t * phi_1) * model_s
596
+ )
597
+ return (x_t, {'model_s': model_s}) if return_intermediate else x_t
598
+
599
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
600
+ solver_type='dpmsolver'):
601
+ """
602
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
603
+
604
+ Args:
605
+ x: A pytorch tensor. The initial value at time `s`.
606
+ s: A pytorch tensor. The starting time, with the shape (1,).
607
+ t: A pytorch tensor. The ending time, with the shape (1,).
608
+ r1: A `float`. The hyperparameter of the second-order solver.
609
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
610
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
611
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
612
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
613
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
614
+ Returns:
615
+ x_t: A pytorch tensor. The approximated solution at time `t`.
616
+ """
617
+ if solver_type not in ['dpmsolver', 'taylor']:
618
+ raise ValueError(
619
+ f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}"
620
+ )
621
+ if r1 is None:
622
+ r1 = 0.5
623
+ ns = self.noise_schedule
624
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
625
+ h = lambda_t - lambda_s
626
+ lambda_s1 = lambda_s + r1 * h
627
+ s1 = ns.inverse_lambda(lambda_s1)
628
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
629
+ s1), ns.marginal_log_mean_coeff(t)
630
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
631
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
632
+
633
+ if self.algorithm_type == "dpmsolver++":
634
+ phi_11 = torch.expm1(-r1 * h)
635
+ phi_1 = torch.expm1(-h)
636
+
637
+ if model_s is None:
638
+ model_s = self.model_fn(x, s)
639
+ x_s1 = (
640
+ (sigma_s1 / sigma_s) * x
641
+ - (alpha_s1 * phi_11) * model_s
642
+ )
643
+ model_s1 = self.model_fn(x_s1, s1)
644
+ if solver_type == 'dpmsolver':
645
+ x_t = (
646
+ (sigma_t / sigma_s) * x
647
+ - (alpha_t * phi_1) * model_s
648
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
649
+ )
650
+ elif solver_type == 'taylor':
651
+ x_t = (
652
+ (sigma_t / sigma_s) * x
653
+ - (alpha_t * phi_1) * model_s
654
+ + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
655
+ )
656
+ else:
657
+ phi_11 = torch.expm1(r1 * h)
658
+ phi_1 = torch.expm1(h)
659
+
660
+ if model_s is None:
661
+ model_s = self.model_fn(x, s)
662
+ x_s1 = (
663
+ torch.exp(log_alpha_s1 - log_alpha_s) * x
664
+ - (sigma_s1 * phi_11) * model_s
665
+ )
666
+ model_s1 = self.model_fn(x_s1, s1)
667
+ if solver_type == 'dpmsolver':
668
+ x_t = (
669
+ torch.exp(log_alpha_t - log_alpha_s) * x
670
+ - (sigma_t * phi_1) * model_s
671
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
672
+ )
673
+ elif solver_type == 'taylor':
674
+ x_t = (
675
+ torch.exp(log_alpha_t - log_alpha_s) * x
676
+ - (sigma_t * phi_1) * model_s
677
+ - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
678
+ )
679
+ if return_intermediate:
680
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
681
+ else:
682
+ return x_t
683
+
684
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
685
+ return_intermediate=False, solver_type='dpmsolver'):
686
+ """
687
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
688
+
689
+ Args:
690
+ x: A pytorch tensor. The initial value at time `s`.
691
+ s: A pytorch tensor. The starting time, with the shape (1,).
692
+ t: A pytorch tensor. The ending time, with the shape (1,).
693
+ r1: A `float`. The hyperparameter of the third-order solver.
694
+ r2: A `float`. The hyperparameter of the third-order solver.
695
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
696
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
697
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
698
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
699
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
700
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
701
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
702
+ Returns:
703
+ x_t: A pytorch tensor. The approximated solution at time `t`.
704
+ """
705
+ if solver_type not in ['dpmsolver', 'taylor']:
706
+ raise ValueError(
707
+ f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}"
708
+ )
709
+ if r1 is None:
710
+ r1 = 1. / 3.
711
+ if r2 is None:
712
+ r2 = 2. / 3.
713
+ ns = self.noise_schedule
714
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
715
+ h = lambda_t - lambda_s
716
+ lambda_s1 = lambda_s + r1 * h
717
+ lambda_s2 = lambda_s + r2 * h
718
+ s1 = ns.inverse_lambda(lambda_s1)
719
+ s2 = ns.inverse_lambda(lambda_s2)
720
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
721
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
722
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
723
+ s2), ns.marginal_std(t)
724
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
725
+
726
+ if self.algorithm_type == "dpmsolver++":
727
+ phi_11 = torch.expm1(-r1 * h)
728
+ phi_12 = torch.expm1(-r2 * h)
729
+ phi_1 = torch.expm1(-h)
730
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
731
+ phi_2 = phi_1 / h + 1.
732
+ phi_3 = phi_2 / h - 0.5
733
+
734
+ if model_s is None:
735
+ model_s = self.model_fn(x, s)
736
+ if model_s1 is None:
737
+ x_s1 = (
738
+ (sigma_s1 / sigma_s) * x
739
+ - (alpha_s1 * phi_11) * model_s
740
+ )
741
+ model_s1 = self.model_fn(x_s1, s1)
742
+ x_s2 = (
743
+ (sigma_s2 / sigma_s) * x
744
+ - (alpha_s2 * phi_12) * model_s
745
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
746
+ )
747
+ model_s2 = self.model_fn(x_s2, s2)
748
+ if solver_type == 'dpmsolver':
749
+ x_t = (
750
+ (sigma_t / sigma_s) * x
751
+ - (alpha_t * phi_1) * model_s
752
+ + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
753
+ )
754
+ elif solver_type == 'taylor':
755
+ D1_0 = (1. / r1) * (model_s1 - model_s)
756
+ D1_1 = (1. / r2) * (model_s2 - model_s)
757
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
758
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
759
+ x_t = (
760
+ (sigma_t / sigma_s) * x
761
+ - (alpha_t * phi_1) * model_s
762
+ + (alpha_t * phi_2) * D1
763
+ - (alpha_t * phi_3) * D2
764
+ )
765
+ else:
766
+ phi_11 = torch.expm1(r1 * h)
767
+ phi_12 = torch.expm1(r2 * h)
768
+ phi_1 = torch.expm1(h)
769
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
770
+ phi_2 = phi_1 / h - 1.
771
+ phi_3 = phi_2 / h - 0.5
772
+
773
+ if model_s is None:
774
+ model_s = self.model_fn(x, s)
775
+ if model_s1 is None:
776
+ x_s1 = (
777
+ (torch.exp(log_alpha_s1 - log_alpha_s)) * x
778
+ - (sigma_s1 * phi_11) * model_s
779
+ )
780
+ model_s1 = self.model_fn(x_s1, s1)
781
+ x_s2 = (
782
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
783
+ - (sigma_s2 * phi_12) * model_s
784
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
785
+ )
786
+ model_s2 = self.model_fn(x_s2, s2)
787
+ if solver_type == 'dpmsolver':
788
+ x_t = (
789
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
790
+ - (sigma_t * phi_1) * model_s
791
+ - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
792
+ )
793
+ elif solver_type == 'taylor':
794
+ D1_0 = (1. / r1) * (model_s1 - model_s)
795
+ D1_1 = (1. / r2) * (model_s2 - model_s)
796
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
797
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
798
+ x_t = (
799
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
800
+ - (sigma_t * phi_1) * model_s
801
+ - (sigma_t * phi_2) * D1
802
+ - (sigma_t * phi_3) * D2
803
+ )
804
+
805
+ if return_intermediate:
806
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
807
+ else:
808
+ return x_t
809
+
810
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
811
+ """
812
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
813
+
814
+ Args:
815
+ x: A pytorch tensor. The initial value at time `s`.
816
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
817
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
818
+ t: A pytorch tensor. The ending time, with the shape (1,).
819
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
820
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
821
+ Returns:
822
+ x_t: A pytorch tensor. The approximated solution at time `t`.
823
+ """
824
+ if solver_type not in ['dpmsolver', 'taylor']:
825
+ raise ValueError(
826
+ f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}"
827
+ )
828
+ ns = self.noise_schedule
829
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
830
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
831
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
832
+ t_prev_0), ns.marginal_lambda(t)
833
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
834
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
835
+ alpha_t = torch.exp(log_alpha_t)
836
+
837
+ h_0 = lambda_prev_0 - lambda_prev_1
838
+ h = lambda_t - lambda_prev_0
839
+ r0 = h_0 / h
840
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
841
+ if self.algorithm_type == "dpmsolver++":
842
+ phi_1 = torch.expm1(-h)
843
+ if solver_type == 'dpmsolver':
844
+ x_t = (
845
+ (sigma_t / sigma_prev_0) * x
846
+ - (alpha_t * phi_1) * model_prev_0
847
+ - 0.5 * (alpha_t * phi_1) * D1_0
848
+ )
849
+ elif solver_type == 'taylor':
850
+ x_t = (
851
+ (sigma_t / sigma_prev_0) * x
852
+ - (alpha_t * phi_1) * model_prev_0
853
+ + (alpha_t * (phi_1 / h + 1.)) * D1_0
854
+ )
855
+ else:
856
+ phi_1 = torch.expm1(h)
857
+ if solver_type == 'dpmsolver':
858
+ x_t = (
859
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
860
+ - (sigma_t * phi_1) * model_prev_0
861
+ - 0.5 * (sigma_t * phi_1) * D1_0
862
+ )
863
+ elif solver_type == 'taylor':
864
+ x_t = (
865
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
866
+ - (sigma_t * phi_1) * model_prev_0
867
+ - (sigma_t * (phi_1 / h - 1.)) * D1_0
868
+ )
869
+ return x_t
870
+
871
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
872
+ """
873
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
874
+
875
+ Args:
876
+ x: A pytorch tensor. The initial value at time `s`.
877
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
878
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
879
+ t: A pytorch tensor. The ending time, with the shape (1,).
880
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
881
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
882
+ Returns:
883
+ x_t: A pytorch tensor. The approximated solution at time `t`.
884
+ """
885
+ ns = self.noise_schedule
886
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
887
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
888
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
889
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
890
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
891
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
892
+ alpha_t = torch.exp(log_alpha_t)
893
+
894
+ h_1 = lambda_prev_1 - lambda_prev_2
895
+ h_0 = lambda_prev_0 - lambda_prev_1
896
+ h = lambda_t - lambda_prev_0
897
+ r0, r1 = h_0 / h, h_1 / h
898
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
899
+ D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
900
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
901
+ D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
902
+ if self.algorithm_type == "dpmsolver++":
903
+ phi_1 = torch.expm1(-h)
904
+ phi_2 = phi_1 / h + 1.
905
+ phi_3 = phi_2 / h - 0.5
906
+ return (
907
+ (sigma_t / sigma_prev_0) * x
908
+ - (alpha_t * phi_1) * model_prev_0
909
+ + (alpha_t * phi_2) * D1
910
+ - (alpha_t * phi_3) * D2
911
+ )
912
+ else:
913
+ phi_1 = torch.expm1(h)
914
+ phi_2 = phi_1 / h - 1.
915
+ phi_3 = phi_2 / h - 0.5
916
+ return (
917
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
918
+ - (sigma_t * phi_1) * model_prev_0
919
+ - (sigma_t * phi_2) * D1
920
+ - (sigma_t * phi_3) * D2
921
+ )
922
+
923
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None,
924
+ r2=None):
925
+ """
926
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
927
+
928
+ Args:
929
+ x: A pytorch tensor. The initial value at time `s`.
930
+ s: A pytorch tensor. The starting time, with the shape (1,).
931
+ t: A pytorch tensor. The ending time, with the shape (1,).
932
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
933
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
934
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
935
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
936
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
937
+ r2: A `float`. The hyperparameter of the third-order solver.
938
+ Returns:
939
+ x_t: A pytorch tensor. The approximated solution at time `t`.
940
+ """
941
+ if order == 1:
942
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
943
+ elif order == 2:
944
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
945
+ solver_type=solver_type, r1=r1)
946
+ elif order == 3:
947
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
948
+ solver_type=solver_type, r1=r1, r2=r2)
949
+ else:
950
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
951
+
952
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
953
+ """
954
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
955
+
956
+ Args:
957
+ x: A pytorch tensor. The initial value at time `s`.
958
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
959
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
960
+ t: A pytorch tensor. The ending time, with the shape (1,).
961
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
962
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
963
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
964
+ Returns:
965
+ x_t: A pytorch tensor. The approximated solution at time `t`.
966
+ """
967
+ if order == 1:
968
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
969
+ elif order == 2:
970
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
971
+ elif order == 3:
972
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
973
+ else:
974
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
975
+
976
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
977
+ solver_type='dpmsolver'):
978
+ """
979
+ The adaptive step size solver based on singlestep DPM-Solver.
980
+
981
+ Args:
982
+ x: A pytorch tensor. The initial value at time `t_T`.
983
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
984
+ t_T: A `float`. The starting time of the sampling (default is T).
985
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
986
+ h_init: A `float`. The initial step size (for logSNR).
987
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
988
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
989
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
990
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
991
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
992
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
993
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
994
+ Returns:
995
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
996
+
997
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based vae," arXiv preprint arXiv:2105.14080, 2021.
998
+ """
999
+ ns = self.noise_schedule
1000
+ s = t_T * torch.ones((1,)).to(x)
1001
+ lambda_s = ns.marginal_lambda(s)
1002
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
1003
+ h = h_init * torch.ones_like(s).to(x)
1004
+ x_prev = x
1005
+ nfe = 0
1006
+ if order == 2:
1007
+ r1 = 0.5
1008
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
1009
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
1010
+ solver_type=solver_type,
1011
+ **kwargs)
1012
+ elif order == 3:
1013
+ r1, r2 = 1. / 3., 2. / 3.
1014
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
1015
+ return_intermediate=True,
1016
+ solver_type=solver_type)
1017
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
1018
+ solver_type=solver_type,
1019
+ **kwargs)
1020
+ else:
1021
+ raise ValueError(
1022
+ f"For adaptive step size solver, order must be 2 or 3, got {order}"
1023
+ )
1024
+ while torch.abs((s - t_0)).mean() > t_err:
1025
+ t = ns.inverse_lambda(lambda_s + h)
1026
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
1027
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1028
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1029
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1030
+ E = norm_fn((x_higher - x_lower) / delta).max()
1031
+ if torch.all(E <= 1.):
1032
+ x = x_higher
1033
+ s = t
1034
+ x_prev = x_lower
1035
+ lambda_s = ns.marginal_lambda(s)
1036
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
1037
+ nfe += order
1038
+ print('adaptive solver nfe', nfe)
1039
+ return x
1040
+
1041
+ def add_noise(self, x, t, noise=None):
1042
+ """
1043
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1044
+
1045
+ Args:
1046
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1047
+ t: A `torch.Tensor` with shape `(t_size,)`.
1048
+ Returns:
1049
+ xt with shape `(t_size, batch_size, *shape)`.
1050
+ """
1051
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1052
+ if noise is None:
1053
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1054
+ x = x.reshape((-1, *x.shape))
1055
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1056
+ return xt.squeeze(0) if t.shape[0] == 1 else xt
1057
+
1058
+ def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1059
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1060
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1061
+ ):
1062
+ """
1063
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1064
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1065
+ """
1066
+ t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
1067
+ t_T = self.noise_schedule.T if t_end is None else t_end
1068
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1069
+ return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
1070
+ method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero,
1071
+ solver_type=solver_type,
1072
+ atol=atol, rtol=rtol, return_intermediate=return_intermediate)
1073
+
1074
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1075
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1076
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1077
+ ):
1078
+ """
1079
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1080
+
1081
+ =====================================================
1082
+
1083
+ We support the following algorithms for both noise prediction model and data prediction model:
1084
+ - 'singlestep':
1085
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1086
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1087
+ The total number of function evaluations (NFE) == `steps`.
1088
+ Given a fixed NFE == `steps`, the sampling procedure is:
1089
+ - If `order` == 1:
1090
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1091
+ - If `order` == 2:
1092
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1093
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1094
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1095
+ - If `order` == 3:
1096
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1097
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1098
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1099
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1100
+ - 'multistep':
1101
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1102
+ We initialize the first `order` values by lower order multistep solvers.
1103
+ Given a fixed NFE == `steps`, the sampling procedure is:
1104
+ Denote K = steps.
1105
+ - If `order` == 1:
1106
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1107
+ - If `order` == 2:
1108
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1109
+ - If `order` == 3:
1110
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1111
+ - 'singlestep_fixed':
1112
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1113
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1114
+ - 'adaptive':
1115
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1116
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1117
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1118
+ (NFE) and the sample quality.
1119
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1120
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1121
+
1122
+ =====================================================
1123
+
1124
+ Some advices for choosing the algorithm:
1125
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1126
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1127
+ e.g., DPM-Solver:
1128
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1129
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1130
+ skip_type='time_uniform', method='singlestep')
1131
+ e.g., DPM-Solver++:
1132
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1133
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1134
+ skip_type='time_uniform', method='singlestep')
1135
+ - For **guided sampling with large guidance scale** by DPMs:
1136
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1137
+ e.g.
1138
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1139
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1140
+ skip_type='time_uniform', method='multistep')
1141
+
1142
+ We support three types of `skip_type`:
1143
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1144
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1145
+ - 'time_quadratic': quadratic time for the time steps.
1146
+
1147
+ =====================================================
1148
+ Args:
1149
+ x: A pytorch tensor. The initial value at time `t_start`
1150
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1151
+ steps: A `int`. The total number of function evaluations (NFE).
1152
+ t_start: A `float`. The starting time of the sampling.
1153
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1154
+ t_end: A `float`. The ending time of the sampling.
1155
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1156
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1157
+ For discrete-time DPMs:
1158
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1159
+ For continuous-time DPMs:
1160
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1161
+ order: A `int`. The order of DPM-Solver.
1162
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1163
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1164
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1165
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1166
+
1167
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1168
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1169
+ for diffusion vae sampling by diffusion SDEs for low-resolutional images
1170
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1171
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1172
+ it for high-resolutional images.
1173
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1174
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1175
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1176
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1177
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1178
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1179
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1180
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1181
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1182
+ Returns:
1183
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1184
+
1185
+ """
1186
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1187
+ t_T = self.noise_schedule.T if t_start is None else t_start
1188
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1189
+ if return_intermediate:
1190
+ assert method in ['multistep', 'singlestep',
1191
+ 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
1192
+ if self.correcting_xt_fn is not None:
1193
+ assert method in ['multistep', 'singlestep',
1194
+ 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
1195
+ device = x.device
1196
+ intermediates = []
1197
+ with torch.no_grad():
1198
+ if method == 'adaptive':
1199
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1200
+ solver_type=solver_type)
1201
+ elif method == 'multistep':
1202
+ assert steps >= order
1203
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1204
+ assert timesteps.shape[0] - 1 == steps
1205
+ # Init the initial values.
1206
+ step = 0
1207
+ t = timesteps[step]
1208
+ t_prev_list = [t]
1209
+ model_prev_list = [self.model_fn(x, t)]
1210
+ if self.correcting_xt_fn is not None:
1211
+ x = self.correcting_xt_fn(x, t, step)
1212
+ if return_intermediate:
1213
+ intermediates.append(x)
1214
+ # Init the first `order` values by lower order multistep DPM-Solver.
1215
+ for step in range(1, order):
1216
+ t = timesteps[step]
1217
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step,
1218
+ solver_type=solver_type)
1219
+ if self.correcting_xt_fn is not None:
1220
+ x = self.correcting_xt_fn(x, t, step)
1221
+ if return_intermediate:
1222
+ intermediates.append(x)
1223
+ t_prev_list.append(t)
1224
+ model_prev_list.append(self.model_fn(x, t))
1225
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1226
+ for step in tqdm(range(order, steps + 1)):
1227
+ t = timesteps[step]
1228
+ # We only use lower order for steps < 10
1229
+ if lower_order_final and steps < 10:
1230
+ step_order = min(order, steps + 1 - step)
1231
+ else:
1232
+ step_order = order
1233
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order,
1234
+ solver_type=solver_type)
1235
+ if self.correcting_xt_fn is not None:
1236
+ x = self.correcting_xt_fn(x, t, step)
1237
+ if return_intermediate:
1238
+ intermediates.append(x)
1239
+ for i in range(order - 1):
1240
+ t_prev_list[i] = t_prev_list[i + 1]
1241
+ model_prev_list[i] = model_prev_list[i + 1]
1242
+ t_prev_list[-1] = t
1243
+ # We do not need to evaluate the final model value.
1244
+ if step < steps:
1245
+ model_prev_list[-1] = self.model_fn(x, t)
1246
+ elif method in ['singlestep', 'singlestep_fixed']:
1247
+ if method == 'singlestep':
1248
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps,
1249
+ order=order,
1250
+ skip_type=skip_type,
1251
+ t_T=t_T, t_0=t_0,
1252
+ device=device)
1253
+ elif method == 'singlestep_fixed':
1254
+ K = steps // order
1255
+ orders = [order, ] * K
1256
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1257
+ for step, order in enumerate(orders):
1258
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1259
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order,
1260
+ device=device)
1261
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1262
+ h = lambda_inner[-1] - lambda_inner[0]
1263
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1264
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1265
+ x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1266
+ if self.correcting_xt_fn is not None:
1267
+ x = self.correcting_xt_fn(x, t, step)
1268
+ if return_intermediate:
1269
+ intermediates.append(x)
1270
+ else:
1271
+ raise ValueError(f"Got wrong method {method}")
1272
+ if denoise_to_zero:
1273
+ t = torch.ones((1,)).to(device) * t_0
1274
+ x = self.denoise_to_zero_fn(x, t)
1275
+ if self.correcting_xt_fn is not None:
1276
+ x = self.correcting_xt_fn(x, t, step + 1)
1277
+ if return_intermediate:
1278
+ intermediates.append(x)
1279
+ return (x, intermediates) if return_intermediate else x
1280
+
1281
+
1282
+ #############################################################
1283
+ # other utility functions
1284
+ #############################################################
1285
+
1286
+ def interpolate_fn(x, xp, yp):
1287
+ """
1288
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1289
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1290
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1291
+
1292
+ Args:
1293
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1294
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1295
+ yp: PyTorch tensor with shape [C, K].
1296
+ Returns:
1297
+ The function values f(x), with shape [N, C].
1298
+ """
1299
+ N, K = x.shape[0], xp.shape[1]
1300
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1301
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1302
+ x_idx = torch.argmin(x_indices, dim=2)
1303
+ cand_start_idx = x_idx - 1
1304
+ start_idx = torch.where(
1305
+ torch.eq(x_idx, 0),
1306
+ torch.tensor(1, device=x.device),
1307
+ torch.where(
1308
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1309
+ ),
1310
+ )
1311
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1312
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1313
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1314
+ start_idx2 = torch.where(
1315
+ torch.eq(x_idx, 0),
1316
+ torch.tensor(0, device=x.device),
1317
+ torch.where(
1318
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1319
+ ),
1320
+ )
1321
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1322
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1323
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1324
+ return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1325
+
1326
+
1327
+ def expand_dims(v, dims):
1328
+ """
1329
+ Expand the tensor `v` to the dim `dims`.
1330
+
1331
+ Args:
1332
+ `v`: a PyTorch tensor with shape [N].
1333
+ `dim`: a `int`.
1334
+ Returns:
1335
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1336
+ """
1337
+ return v[(...,) + (None,) * (dims - 1)]
DiT_VAE/diffusion/model/edm_sample.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+
4
+
5
+ # ----------------------------------------------------------------------------
6
+ # Proposed EDM sampler (Algorithm 2).
7
+
8
+ def edm_sampler(
9
+ net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like,
10
+ num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
11
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs
12
+ ):
13
+ # Adjust noise levels based on what's supported by the network.
14
+ sigma_min = max(sigma_min, net.sigma_min)
15
+ sigma_max = min(sigma_max, net.sigma_max)
16
+
17
+ # Time step discretization.
18
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
19
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
20
+ sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
21
+ t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
22
+
23
+ # Main sampling loop.
24
+ x_next = latents.to(torch.float64) * t_steps[0]
25
+ for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1
26
+ x_cur = x_next
27
+
28
+ # Increase noise temporarily.
29
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
30
+ t_hat = net.round_sigma(t_cur + gamma * t_cur)
31
+ x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
32
+
33
+ # Euler step.
34
+ denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
35
+ d_cur = (x_hat - denoised) / t_hat
36
+ x_next = x_hat + (t_next - t_hat) * d_cur
37
+
38
+ # Apply 2nd order correction.
39
+ if i < num_steps - 1:
40
+ denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
41
+ d_prime = (x_next - denoised) / t_next
42
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
43
+
44
+ return x_next
45
+
46
+
47
+ # ----------------------------------------------------------------------------
48
+ # Generalized ablation sampler, representing the superset of all sampling
49
+ # methods discussed in the paper.
50
+
51
+ def ablation_sampler(
52
+ net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
53
+ num_steps=18, sigma_min=None, sigma_max=None, rho=7,
54
+ solver='heun', discretization='edm', schedule='linear', scaling='none',
55
+ epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
56
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
57
+ ):
58
+ assert solver in ['euler', 'heun']
59
+ assert discretization in ['vp', 've', 'iddpm', 'edm']
60
+ assert schedule in ['vp', 've', 'linear']
61
+ assert scaling in ['vp', 'none']
62
+
63
+ # Helper functions for VP & VE noise level schedules.
64
+ vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
65
+ vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
66
+ vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
67
+ sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
68
+ ve_sigma = lambda t: t.sqrt()
69
+ ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
70
+ ve_sigma_inv = lambda sigma: sigma ** 2
71
+
72
+ # Select default noise level range based on the specified time step discretization.
73
+ if sigma_min is None:
74
+ vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
75
+ sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
76
+ if sigma_max is None:
77
+ vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
78
+ sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
79
+
80
+ # Adjust noise levels based on what's supported by the network.
81
+ sigma_min = max(sigma_min, net.sigma_min)
82
+ sigma_max = min(sigma_max, net.sigma_max)
83
+
84
+ # Compute corresponding betas for VP.
85
+ vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
86
+ vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
87
+
88
+ # Define time steps in terms of noise level.
89
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
90
+ if discretization == 'vp':
91
+ orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
92
+ sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
93
+ elif discretization == 've':
94
+ orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
95
+ sigma_steps = ve_sigma(orig_t_steps)
96
+ elif discretization == 'iddpm':
97
+ u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
98
+ alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
99
+ for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
100
+ u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
101
+ u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
102
+ sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
103
+ else:
104
+ assert discretization == 'edm'
105
+ sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
106
+ sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
107
+
108
+ # Define noise level schedule.
109
+ if schedule == 'vp':
110
+ sigma = vp_sigma(vp_beta_d, vp_beta_min)
111
+ sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
112
+ sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
113
+ elif schedule == 've':
114
+ sigma = ve_sigma
115
+ sigma_deriv = ve_sigma_deriv
116
+ sigma_inv = ve_sigma_inv
117
+ else:
118
+ assert schedule == 'linear'
119
+ sigma = lambda t: t
120
+ sigma_deriv = lambda t: 1
121
+ sigma_inv = lambda sigma: sigma
122
+
123
+ # Define scaling schedule.
124
+ if scaling == 'vp':
125
+ s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
126
+ s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
127
+ else:
128
+ assert scaling == 'none'
129
+ s = lambda t: 1
130
+ s_deriv = lambda t: 0
131
+
132
+ # Compute final time steps based on the corresponding noise levels.
133
+ t_steps = sigma_inv(net.round_sigma(sigma_steps))
134
+ t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
135
+
136
+ # Main sampling loop.
137
+ t_next = t_steps[0]
138
+ x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
139
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
140
+ x_cur = x_next
141
+
142
+ # Increase noise temporarily.
143
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
144
+ t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
145
+ x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
146
+ t_hat) * S_noise * randn_like(x_cur)
147
+
148
+ # Euler step.
149
+ h = t_next - t_hat
150
+ denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(
151
+ torch.float64)
152
+ d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
153
+ t_hat) / sigma(t_hat) * denoised
154
+ x_prime = x_hat + alpha * h * d_cur
155
+ t_prime = t_hat + alpha * h
156
+
157
+ # Apply 2nd order correction.
158
+ if solver == 'euler' or i == num_steps - 1:
159
+ x_next = x_hat + h * d_cur
160
+ else:
161
+ assert solver == 'heun'
162
+ denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(
163
+ torch.float64)
164
+ d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
165
+ t_prime) * s(t_prime) / sigma(t_prime) * denoised
166
+ x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
167
+
168
+ return x_next
DiT_VAE/diffusion/model/gaussian_diffusion.py ADDED
@@ -0,0 +1,1006 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import enum
8
+ import math
9
+
10
+ import numpy as np
11
+ import torch as th
12
+ import torch.nn.functional as F
13
+
14
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
15
+
16
+
17
+ def mean_flat(tensor):
18
+ """
19
+ Take the mean over all non-batch dimensions.
20
+ """
21
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
22
+
23
+
24
+ class ModelMeanType(enum.Enum):
25
+ """
26
+ Which type of output the model predicts.
27
+ """
28
+
29
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
30
+ START_X = enum.auto() # the model predicts x_0
31
+ EPSILON = enum.auto() # the model predicts epsilon
32
+
33
+
34
+ class ModelVarType(enum.Enum):
35
+ """
36
+ What is used as the model's output variance.
37
+ The LEARNED_RANGE option has been added to allow the model to predict
38
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
39
+ """
40
+
41
+ LEARNED = enum.auto()
42
+ FIXED_SMALL = enum.auto()
43
+ FIXED_LARGE = enum.auto()
44
+ LEARNED_RANGE = enum.auto()
45
+
46
+
47
+ class LossType(enum.Enum):
48
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
49
+ RESCALED_MSE = (
50
+ enum.auto()
51
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
52
+ KL = enum.auto() # use the variational lower-bound
53
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
54
+
55
+ def is_vb(self):
56
+ return self in [LossType.KL, LossType.RESCALED_KL]
57
+
58
+
59
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
60
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
61
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
62
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
63
+ return betas
64
+
65
+
66
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
67
+ """
68
+ This is the deprecated API for creating beta schedules.
69
+ See get_named_beta_schedule() for the new library of schedules.
70
+ """
71
+ if beta_schedule == "quad":
72
+ betas = (
73
+ np.linspace(
74
+ beta_start ** 0.5,
75
+ beta_end ** 0.5,
76
+ num_diffusion_timesteps,
77
+ dtype=np.float64,
78
+ )
79
+ ** 2
80
+ )
81
+ elif beta_schedule == "linear":
82
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
83
+ elif beta_schedule == "warmup10":
84
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
85
+ elif beta_schedule == "warmup50":
86
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
87
+ elif beta_schedule == "const":
88
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
89
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
90
+ betas = 1.0 / np.linspace(
91
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
92
+ )
93
+ else:
94
+ raise NotImplementedError(beta_schedule)
95
+ assert betas.shape == (num_diffusion_timesteps,)
96
+ return betas
97
+
98
+
99
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
100
+ """
101
+ Get a pre-defined beta schedule for the given name.
102
+ The beta schedule library consists of beta schedules which remain similar
103
+ in the limit of num_diffusion_timesteps.
104
+ Beta schedules may be added, but should not be removed or changed once
105
+ they are committed to maintain backwards compatibility.
106
+ """
107
+ if schedule_name == "linear":
108
+ # Linear schedule from Ho et al, extended to work for any number of
109
+ # diffusion steps.
110
+ scale = 1000 / num_diffusion_timesteps
111
+ return get_beta_schedule(
112
+ "linear",
113
+ beta_start=scale * 0.0001,
114
+ beta_end=scale * 0.02,
115
+ num_diffusion_timesteps=num_diffusion_timesteps,
116
+ )
117
+ elif schedule_name == "squaredcos_cap_v2":
118
+ return betas_for_alpha_bar(
119
+ num_diffusion_timesteps,
120
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
121
+ )
122
+ else:
123
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
124
+
125
+
126
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
127
+ """
128
+ Create a beta schedule that discretizes the given alpha_t_bar function,
129
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
130
+ :param num_diffusion_timesteps: the number of betas to produce.
131
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
132
+ produces the cumulative product of (1-beta) up to that
133
+ part of the diffusion process.
134
+ :param max_beta: the maximum beta to use; use values lower than 1 to
135
+ prevent singularities.
136
+ """
137
+ betas = []
138
+ for i in range(num_diffusion_timesteps):
139
+ t1 = i / num_diffusion_timesteps
140
+ t2 = (i + 1) / num_diffusion_timesteps
141
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
142
+ return np.array(betas)
143
+
144
+
145
+ class GaussianDiffusion:
146
+ """
147
+ Utilities for training and sampling diffusion vae.
148
+ Original ported from this codebase:
149
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
150
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
151
+ starting at T and going to 1.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ *,
157
+ betas,
158
+ model_mean_type,
159
+ model_var_type,
160
+ loss_type,
161
+ snr=False,
162
+ return_startx=False,
163
+ ):
164
+
165
+ self.model_mean_type = model_mean_type
166
+ self.model_var_type = model_var_type
167
+ self.loss_type = loss_type
168
+ self.snr = snr
169
+ self.return_startx = return_startx
170
+
171
+ # Use float64 for accuracy.
172
+ betas = np.array(betas, dtype=np.float64)
173
+ self.betas = betas
174
+ assert len(betas.shape) == 1, "betas must be 1-D"
175
+ assert (betas > 0).all() and (betas <= 1).all()
176
+
177
+ self.num_timesteps = int(betas.shape[0])
178
+
179
+ alphas = 1.0 - betas
180
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
181
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
182
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
183
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
184
+
185
+ # calculations for diffusion q(x_t | x_{t-1}) and others
186
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
187
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
188
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
189
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
190
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
191
+
192
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
193
+ self.posterior_variance = (
194
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
195
+ )
196
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
197
+ self.posterior_log_variance_clipped = np.log(
198
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
199
+ ) if len(self.posterior_variance) > 1 else np.array([])
200
+
201
+ self.posterior_mean_coef1 = (
202
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
203
+ )
204
+ self.posterior_mean_coef2 = (
205
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
206
+ )
207
+
208
+ def q_mean_variance(self, x_start, t):
209
+ """
210
+ Get the distribution q(x_t | x_0).
211
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
212
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
213
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
214
+ """
215
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
216
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
217
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
218
+ return mean, variance, log_variance
219
+
220
+ def q_sample(self, x_start, t, noise=None):
221
+ """
222
+ Diffuse the data for a given number of diffusion steps.
223
+ In other words, sample from q(x_t | x_0).
224
+ :param x_start: the initial data batch.
225
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
226
+ :param noise: if specified, the split-out normal noise.
227
+ :return: A noisy version of x_start.
228
+ """
229
+ if noise is None:
230
+ noise = th.randn_like(x_start)
231
+ assert noise.shape == x_start.shape
232
+ return (
233
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
234
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
235
+ )
236
+
237
+ def q_posterior_mean_variance(self, x_start, x_t, t):
238
+ """
239
+ Compute the mean and variance of the diffusion posterior:
240
+ q(x_{t-1} | x_t, x_0)
241
+ """
242
+ assert x_start.shape == x_t.shape
243
+ posterior_mean = (
244
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
245
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
246
+ )
247
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
248
+ posterior_log_variance_clipped = _extract_into_tensor(
249
+ self.posterior_log_variance_clipped, t, x_t.shape
250
+ )
251
+ assert (
252
+ posterior_mean.shape[0]
253
+ == posterior_variance.shape[0]
254
+ == posterior_log_variance_clipped.shape[0]
255
+ == x_start.shape[0]
256
+ )
257
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
258
+
259
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
260
+ """
261
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
262
+ the initial x, x_0.
263
+ :param model: the model, which takes a signal and a batch of timesteps
264
+ as input.
265
+ :param x: the [N x C x ...] tensor at time t.
266
+ :param t: a 1-D Tensor of timesteps.
267
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
268
+ :param denoised_fn: if not None, a function which applies to the
269
+ x_start prediction before it is used to sample. Applies before
270
+ clip_denoised.
271
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
272
+ pass to the model. This can be used for conditioning.
273
+ :return: a dict with the following keys:
274
+ - 'mean': the model mean output.
275
+ - 'variance': the model variance output.
276
+ - 'log_variance': the log of 'variance'.
277
+ - 'pred_xstart': the prediction for x_0.
278
+ """
279
+ if model_kwargs is None:
280
+ model_kwargs = {}
281
+
282
+ B, C = x.shape[:2]
283
+ assert t.shape == (B,)
284
+ model_output = model(x, t, **model_kwargs)
285
+ if isinstance(model_output, tuple):
286
+ model_output, extra = model_output
287
+ else:
288
+ extra = None
289
+
290
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
291
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
292
+ model_output, model_var_values = th.split(model_output, C, dim=1)
293
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
294
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
295
+ # The model_var_values is [-1, 1] for [min_var, max_var].
296
+ frac = (model_var_values + 1) / 2
297
+ model_log_variance = frac * max_log + (1 - frac) * min_log
298
+ model_variance = th.exp(model_log_variance)
299
+ elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]:
300
+ model_variance, model_log_variance = {
301
+ # for fixedlarge, we set the initial (log-)variance like so
302
+ # to get a better decoder log likelihood.
303
+ ModelVarType.FIXED_LARGE: (
304
+ np.append(self.posterior_variance[1], self.betas[1:]),
305
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
306
+ ),
307
+ ModelVarType.FIXED_SMALL: (
308
+ self.posterior_variance,
309
+ self.posterior_log_variance_clipped,
310
+ ),
311
+ }[self.model_var_type]
312
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
313
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
314
+ else:
315
+ model_variance = th.zeros_like(model_output)
316
+ model_log_variance = th.zeros_like(model_output)
317
+
318
+ def process_xstart(x):
319
+ if denoised_fn is not None:
320
+ x = denoised_fn(x)
321
+ return x.clamp(-1, 1) if clip_denoised else x
322
+
323
+ if self.model_mean_type == ModelMeanType.START_X:
324
+ pred_xstart = process_xstart(model_output)
325
+ else:
326
+ pred_xstart = process_xstart(
327
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
328
+ )
329
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
330
+
331
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
332
+ return {
333
+ "mean": model_mean,
334
+ "variance": model_variance,
335
+ "log_variance": model_log_variance,
336
+ "pred_xstart": pred_xstart,
337
+ "extra": extra,
338
+ }
339
+
340
+ def _predict_xstart_from_eps(self, x_t, t, eps):
341
+ assert x_t.shape == eps.shape
342
+ return (
343
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
344
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
345
+ )
346
+
347
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
348
+ return (
349
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
350
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
351
+
352
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
353
+ """
354
+ Compute the mean for the previous step, given a function cond_fn that
355
+ computes the gradient of a conditional log probability with respect to
356
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
357
+ condition on y.
358
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
359
+ """
360
+ gradient = cond_fn(x, t, **model_kwargs)
361
+ return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
362
+
363
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
364
+ """
365
+ Compute what the p_mean_variance output would have been, should the
366
+ model's score function be conditioned by cond_fn.
367
+ See condition_mean() for details on cond_fn.
368
+ Unlike condition_mean(), this instead uses the conditioning strategy
369
+ from Song et al (2020).
370
+ """
371
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
372
+
373
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
374
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
375
+
376
+ out = p_mean_var.copy()
377
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
378
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
379
+ return out
380
+
381
+ def p_sample(
382
+ self,
383
+ model,
384
+ x,
385
+ t,
386
+ clip_denoised=True,
387
+ denoised_fn=None,
388
+ cond_fn=None,
389
+ model_kwargs=None,
390
+ ):
391
+ """
392
+ Sample x_{t-1} from the model at the given timestep.
393
+ :param model: the model to sample from.
394
+ :param x: the current tensor at x_{t-1}.
395
+ :param t: the value of t, starting at 0 for the first diffusion step.
396
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
397
+ :param denoised_fn: if not None, a function which applies to the
398
+ x_start prediction before it is used to sample.
399
+ :param cond_fn: if not None, this is a gradient function that acts
400
+ similarly to the model.
401
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
402
+ pass to the model. This can be used for conditioning.
403
+ :return: a dict containing the following keys:
404
+ - 'sample': a random sample from the model.
405
+ - 'pred_xstart': a prediction of x_0.
406
+ """
407
+ out = self.p_mean_variance(
408
+ model,
409
+ x,
410
+ t,
411
+ clip_denoised=clip_denoised,
412
+ denoised_fn=denoised_fn,
413
+ model_kwargs=model_kwargs,
414
+ )
415
+ noise = th.randn_like(x)
416
+ nonzero_mask = (
417
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
418
+ ) # no noise when t == 0
419
+ if cond_fn is not None:
420
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
421
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
422
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
423
+
424
+ def p_sample_loop(
425
+ self,
426
+ model,
427
+ shape,
428
+ noise=None,
429
+ clip_denoised=True,
430
+ denoised_fn=None,
431
+ cond_fn=None,
432
+ model_kwargs=None,
433
+ device=None,
434
+ progress=False,
435
+ ):
436
+ """
437
+ Generate samples from the model.
438
+ :param model: the model module.
439
+ :param shape: the shape of the samples, (N, C, H, W).
440
+ :param noise: if specified, the noise from the encoder to sample.
441
+ Should be of the same shape as `shape`.
442
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
443
+ :param denoised_fn: if not None, a function which applies to the
444
+ x_start prediction before it is used to sample.
445
+ :param cond_fn: if not None, this is a gradient function that acts
446
+ similarly to the model.
447
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
448
+ pass to the model. This can be used for conditioning.
449
+ :param device: if specified, the device to create the samples on.
450
+ If not specified, use a model parameter's device.
451
+ :param progress: if True, show a tqdm progress bar.
452
+ :return: a non-differentiable batch of samples.
453
+ """
454
+ final = None
455
+ for sample in self.p_sample_loop_progressive(
456
+ model,
457
+ shape,
458
+ noise=noise,
459
+ clip_denoised=clip_denoised,
460
+ denoised_fn=denoised_fn,
461
+ cond_fn=cond_fn,
462
+ model_kwargs=model_kwargs,
463
+ device=device,
464
+ progress=progress,
465
+ ):
466
+ final = sample
467
+ return final["sample"]
468
+
469
+ def p_sample_loop_progressive(
470
+ self,
471
+ model,
472
+ shape,
473
+ noise=None,
474
+ clip_denoised=True,
475
+ denoised_fn=None,
476
+ cond_fn=None,
477
+ model_kwargs=None,
478
+ device=None,
479
+ progress=False,
480
+ ):
481
+ """
482
+ Generate samples from the model and yield intermediate samples from
483
+ each timestep of diffusion.
484
+ Arguments are the same as p_sample_loop().
485
+ Returns a generator over dicts, where each dict is the return value of
486
+ p_sample().
487
+ """
488
+ if device is None:
489
+ device = next(model.parameters()).device
490
+ assert isinstance(shape, (tuple, list))
491
+ img = noise if noise is not None else th.randn(*shape, device=device)
492
+ indices = list(range(self.num_timesteps))[::-1]
493
+
494
+ if progress:
495
+ # Lazy import so that we don't depend on tqdm.
496
+ from tqdm.auto import tqdm
497
+
498
+ indices = tqdm(indices)
499
+
500
+ for i in indices:
501
+ t = th.tensor([i] * shape[0], device=device)
502
+ with th.no_grad():
503
+ out = self.p_sample(
504
+ model,
505
+ img,
506
+ t,
507
+ clip_denoised=clip_denoised,
508
+ denoised_fn=denoised_fn,
509
+ cond_fn=cond_fn,
510
+ model_kwargs=model_kwargs,
511
+ )
512
+ yield out
513
+ img = out["sample"]
514
+
515
+ def ddim_sample(
516
+ self,
517
+ model,
518
+ x,
519
+ t,
520
+ clip_denoised=True,
521
+ denoised_fn=None,
522
+ cond_fn=None,
523
+ model_kwargs=None,
524
+ eta=0.0,
525
+ ):
526
+ """
527
+ Sample x_{t-1} from the model using DDIM.
528
+ Same usage as p_sample().
529
+ """
530
+ out = self.p_mean_variance(
531
+ model,
532
+ x,
533
+ t,
534
+ clip_denoised=clip_denoised,
535
+ denoised_fn=denoised_fn,
536
+ model_kwargs=model_kwargs,
537
+ )
538
+ if cond_fn is not None:
539
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
540
+
541
+ # Usually our model outputs epsilon, but we re-derive it
542
+ # in case we used x_start or x_prev prediction.
543
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
544
+
545
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
546
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
547
+ sigma = (
548
+ eta
549
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
550
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
551
+ )
552
+ # Equation 12.
553
+ noise = th.randn_like(x)
554
+ mean_pred = (
555
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
556
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
557
+ )
558
+ nonzero_mask = (
559
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
560
+ ) # no noise when t == 0
561
+ sample = mean_pred + nonzero_mask * sigma * noise
562
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
563
+
564
+ def ddim_reverse_sample(
565
+ self,
566
+ model,
567
+ x,
568
+ t,
569
+ clip_denoised=True,
570
+ denoised_fn=None,
571
+ cond_fn=None,
572
+ model_kwargs=None,
573
+ eta=0.0,
574
+ ):
575
+ """
576
+ Sample x_{t+1} from the model using DDIM reverse ODE.
577
+ """
578
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
579
+ out = self.p_mean_variance(
580
+ model,
581
+ x,
582
+ t,
583
+ clip_denoised=clip_denoised,
584
+ denoised_fn=denoised_fn,
585
+ model_kwargs=model_kwargs,
586
+ )
587
+ if cond_fn is not None:
588
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
589
+ # Usually our model outputs epsilon, but we re-derive it
590
+ # in case we used x_start or x_prev prediction.
591
+ eps = (
592
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
593
+ - out["pred_xstart"]
594
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
595
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
596
+
597
+ # Equation 12. reversed
598
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
599
+
600
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
601
+
602
+ def ddim_sample_loop(
603
+ self,
604
+ model,
605
+ shape,
606
+ noise=None,
607
+ clip_denoised=True,
608
+ denoised_fn=None,
609
+ cond_fn=None,
610
+ model_kwargs=None,
611
+ device=None,
612
+ progress=False,
613
+ eta=0.0,
614
+ ):
615
+ """
616
+ Generate samples from the model using DDIM.
617
+ Same usage as p_sample_loop().
618
+ """
619
+ final = None
620
+ for sample in self.ddim_sample_loop_progressive(
621
+ model,
622
+ shape,
623
+ noise=noise,
624
+ clip_denoised=clip_denoised,
625
+ denoised_fn=denoised_fn,
626
+ cond_fn=cond_fn,
627
+ model_kwargs=model_kwargs,
628
+ device=device,
629
+ progress=progress,
630
+ eta=eta,
631
+ ):
632
+ final = sample
633
+ return final["sample"]
634
+
635
+ def ddim_sample_loop_progressive(
636
+ self,
637
+ model,
638
+ shape,
639
+ noise=None,
640
+ clip_denoised=True,
641
+ denoised_fn=None,
642
+ cond_fn=None,
643
+ model_kwargs=None,
644
+ device=None,
645
+ progress=False,
646
+ eta=0.0,
647
+ ):
648
+ """
649
+ Use DDIM to sample from the model and yield intermediate samples from
650
+ each timestep of DDIM.
651
+ Same usage as p_sample_loop_progressive().
652
+ """
653
+ if device is None:
654
+ device = next(model.parameters()).device
655
+ assert isinstance(shape, (tuple, list))
656
+ img = noise if noise is not None else th.randn(*shape, device=device)
657
+ indices = list(range(self.num_timesteps))[::-1]
658
+
659
+ if progress:
660
+ # Lazy import so that we don't depend on tqdm.
661
+ from tqdm.auto import tqdm
662
+
663
+ indices = tqdm(indices)
664
+
665
+ for i in indices:
666
+ t = th.tensor([i] * shape[0], device=device)
667
+ with th.no_grad():
668
+ out = self.ddim_sample(
669
+ model,
670
+ img,
671
+ t,
672
+ clip_denoised=clip_denoised,
673
+ denoised_fn=denoised_fn,
674
+ cond_fn=cond_fn,
675
+ model_kwargs=model_kwargs,
676
+ eta=eta,
677
+ )
678
+ yield out
679
+ img = out["sample"]
680
+
681
+ def _vb_terms_bpd(
682
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
683
+ ):
684
+ """
685
+ Get a term for the variational lower-bound.
686
+ The resulting units are bits (rather than nats, as one might expect).
687
+ This allows for comparison to other papers.
688
+ :return: a dict with the following keys:
689
+ - 'output': a shape [N] tensor of NLLs or KLs.
690
+ - 'pred_xstart': the x_0 predictions.
691
+ """
692
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
693
+ x_start=x_start, x_t=x_t, t=t
694
+ )
695
+ out = self.p_mean_variance(
696
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
697
+ )
698
+ kl = normal_kl(
699
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
700
+ )
701
+ kl = mean_flat(kl) / np.log(2.0)
702
+
703
+ decoder_nll = -discretized_gaussian_log_likelihood(
704
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
705
+ )
706
+ assert decoder_nll.shape == x_start.shape
707
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
708
+
709
+ # At the first timestep return the decoder NLL,
710
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
711
+ output = th.where((t == 0), decoder_nll, kl)
712
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
713
+
714
+ def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
715
+ """
716
+ Compute training losses for a single timestep.
717
+ :param model: the model to evaluate loss on.
718
+ :param x_start: the [N x C x ...] tensor of inputs.
719
+ :param t: a batch of timestep indices.
720
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
721
+ pass to the model. This can be used for conditioning.
722
+ :param noise: if specified, the specific Gaussian noise to try to remove.
723
+ :return: a dict with the key "loss" containing a tensor of shape [N].
724
+ Some mean or variance settings may also have other keys.
725
+ """
726
+ t = timestep
727
+ if model_kwargs is None:
728
+ model_kwargs = {}
729
+ if skip_noise:
730
+ x_t = x_start
731
+ else:
732
+ if noise is None:
733
+ noise = th.randn_like(x_start)
734
+ x_t = self.q_sample(x_start, t, noise=noise)
735
+
736
+ terms = {}
737
+
738
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
739
+ terms["loss"] = self._vb_terms_bpd(
740
+ model=model,
741
+ x_start=x_start,
742
+ x_t=x_t,
743
+ t=t,
744
+ clip_denoised=False,
745
+ model_kwargs=model_kwargs,
746
+ )["output"]
747
+ if self.loss_type == LossType.RESCALED_KL:
748
+ terms["loss"] *= self.num_timesteps
749
+ elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
750
+ model_output = model(x_t, t, **model_kwargs)
751
+ if isinstance(model_output, dict) and model_output.get('x', None) is not None:
752
+ output = model_output['x']
753
+ else:
754
+ output = model_output
755
+
756
+ if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
757
+ return self._extracted_from_training_losses_diffusers(x_t, output, t)
758
+ # self.model_var_type = ModelVarType.LEARNED_RANGE:4
759
+ if self.model_var_type in [
760
+ ModelVarType.LEARNED,
761
+ ModelVarType.LEARNED_RANGE,
762
+ ]:
763
+ B, C = x_t.shape[:2]
764
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
765
+ output, model_var_values = th.split(output, C, dim=1)
766
+ # Learn the variance using the variational bound, but don't let it affect our mean prediction.
767
+ frozen_out = th.cat([output.detach(), model_var_values], dim=1)
768
+ # vb variational bound
769
+ terms["vb"] = self._vb_terms_bpd(
770
+ model=lambda *args, r=frozen_out, **kwargs: r,
771
+ x_start=x_start,
772
+ x_t=x_t,
773
+ t=t,
774
+ clip_denoised=False,
775
+ )["output"]
776
+ if self.loss_type == LossType.RESCALED_MSE:
777
+ # Divide by 1000 for equivalence with initial implementation.
778
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
779
+ terms["vb"] *= self.num_timesteps / 1000.0
780
+
781
+ target = {
782
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
783
+ x_start=x_start, x_t=x_t, t=t
784
+ )[0],
785
+ ModelMeanType.START_X: x_start,
786
+ ModelMeanType.EPSILON: noise,
787
+ }[self.model_mean_type]
788
+ assert output.shape == target.shape == x_start.shape
789
+ if self.snr:
790
+ if self.model_mean_type == ModelMeanType.START_X:
791
+ pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
792
+ pred_startx = output
793
+ elif self.model_mean_type == ModelMeanType.EPSILON:
794
+ pred_noise = output
795
+ pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
796
+ # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
797
+ # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
798
+
799
+ t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
800
+ # best
801
+ target = th.where(t > 249, noise, x_start)
802
+ output = th.where(t > 249, pred_noise, pred_startx)
803
+ loss = (target - output) ** 2
804
+ if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0:
805
+ assert 'mask' in model_output
806
+ loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1)
807
+ mask = model_output['mask']
808
+ unmask = 1 - mask
809
+ terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1)
810
+ if model_kwargs['mask_loss_coef'] > 0:
811
+ terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1)
812
+ else:
813
+ terms["mse"] = mean_flat(loss)
814
+ terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
815
+ if "mae" in terms:
816
+ terms["loss"] = terms["loss"] + terms["mae"]
817
+ else:
818
+ raise NotImplementedError(self.loss_type)
819
+
820
+ return terms
821
+
822
+ def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
823
+ """
824
+ Compute training losses for a single timestep.
825
+ :param model: the model to evaluate loss on.
826
+ :param x_start: the [N x C x ...] tensor of inputs.
827
+ :param t: a batch of timestep indices.
828
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
829
+ pass to the model. This can be used for conditioning.
830
+ :param noise: if specified, the specific Gaussian noise to try to remove.
831
+ :return: a dict with the key "loss" containing a tensor of shape [N].
832
+ Some mean or variance settings may also have other keys.
833
+ """
834
+ t = timestep
835
+ if model_kwargs is None:
836
+ model_kwargs = {}
837
+ if skip_noise:
838
+ x_t = x_start
839
+ else:
840
+ if noise is None:
841
+ noise = th.randn_like(x_start)
842
+ x_t = self.q_sample(x_start, t, noise=noise)
843
+
844
+ terms = {}
845
+
846
+ if self.loss_type in [LossType.KL, LossType.RESCALED_KL]:
847
+ terms["loss"] = self._vb_terms_bpd(
848
+ model=model,
849
+ x_start=x_start,
850
+ x_t=x_t,
851
+ t=t,
852
+ clip_denoised=False,
853
+ model_kwargs=model_kwargs,
854
+ )["output"]
855
+ if self.loss_type == LossType.RESCALED_KL:
856
+ terms["loss"] *= self.num_timesteps
857
+ elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
858
+ output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0]
859
+
860
+ if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
861
+ return self._extracted_from_training_losses_diffusers(x_t, output, t)
862
+
863
+ if self.model_var_type in [
864
+ ModelVarType.LEARNED,
865
+ ModelVarType.LEARNED_RANGE,
866
+ ]:
867
+ B, C = x_t.shape[:2]
868
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
869
+ output, model_var_values = th.split(output, C, dim=1)
870
+ # Learn the variance using the variational bound, but don't let it affect our mean prediction.
871
+ frozen_out = th.cat([output.detach(), model_var_values], dim=1)
872
+ terms["vb"] = self._vb_terms_bpd(
873
+ model=lambda *args, r=frozen_out, **kwargs: r,
874
+ x_start=x_start,
875
+ x_t=x_t,
876
+ t=t,
877
+ clip_denoised=False,
878
+ )["output"]
879
+ if self.loss_type == LossType.RESCALED_MSE:
880
+ # Divide by 1000 for equivalence with initial implementation.
881
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
882
+ terms["vb"] *= self.num_timesteps / 1000.0
883
+
884
+ target = {
885
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
886
+ x_start=x_start, x_t=x_t, t=t
887
+ )[0],
888
+ ModelMeanType.START_X: x_start,
889
+ ModelMeanType.EPSILON: noise,
890
+ }[self.model_mean_type]
891
+ assert output.shape == target.shape == x_start.shape
892
+ if self.snr:
893
+ if self.model_mean_type == ModelMeanType.START_X:
894
+ pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
895
+ pred_startx = output
896
+ elif self.model_mean_type == ModelMeanType.EPSILON:
897
+ pred_noise = output
898
+ pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
899
+ # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
900
+ # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
901
+
902
+ t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
903
+ # best
904
+ target = th.where(t > 249, noise, x_start)
905
+ output = th.where(t > 249, pred_noise, pred_startx)
906
+ loss = (target - output) ** 2
907
+ terms["mse"] = mean_flat(loss)
908
+ terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
909
+ if "mae" in terms:
910
+ terms["loss"] = terms["loss"] + terms["mae"]
911
+ else:
912
+ raise NotImplementedError(self.loss_type)
913
+
914
+ return terms
915
+
916
+ def _extracted_from_training_losses_diffusers(self, x_t, output, t):
917
+ B, C = x_t.shape[:2]
918
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
919
+ output = th.split(output, C, dim=1)[0]
920
+ return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
921
+
922
+ def _prior_bpd(self, x_start):
923
+ """
924
+ Get the prior KL term for the variational lower-bound, measured in
925
+ bits-per-dim.
926
+ This term can't be optimized, as it only depends on the encoder.
927
+ :param x_start: the [N x C x ...] tensor of inputs.
928
+ :return: a batch of [N] KL values (in bits), one per batch element.
929
+ """
930
+ batch_size = x_start.shape[0]
931
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
932
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
933
+ kl_prior = normal_kl(
934
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
935
+ )
936
+ return mean_flat(kl_prior) / np.log(2.0)
937
+
938
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
939
+ """
940
+ Compute the entire variational lower-bound, measured in bits-per-dim,
941
+ as well as other related quantities.
942
+ :param model: the model to evaluate loss on.
943
+ :param x_start: the [N x C x ...] tensor of inputs.
944
+ :param clip_denoised: if True, clip denoised samples.
945
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
946
+ pass to the model. This can be used for conditioning.
947
+ :return: a dict containing the following keys:
948
+ - total_bpd: the total variational lower-bound, per batch element.
949
+ - prior_bpd: the prior term in the lower-bound.
950
+ - vb: an [N x T] tensor of terms in the lower-bound.
951
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
952
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
953
+ """
954
+ device = x_start.device
955
+ batch_size = x_start.shape[0]
956
+
957
+ vb = []
958
+ xstart_mse = []
959
+ mse = []
960
+ for t in list(range(self.num_timesteps))[::-1]:
961
+ t_batch = th.tensor([t] * batch_size, device=device)
962
+ noise = th.randn_like(x_start)
963
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
964
+ # Calculate VLB term at the current timestep
965
+ with th.no_grad():
966
+ out = self._vb_terms_bpd(
967
+ model,
968
+ x_start=x_start,
969
+ x_t=x_t,
970
+ t=t_batch,
971
+ clip_denoised=clip_denoised,
972
+ model_kwargs=model_kwargs,
973
+ )
974
+ vb.append(out["output"])
975
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
976
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
977
+ mse.append(mean_flat((eps - noise) ** 2))
978
+
979
+ vb = th.stack(vb, dim=1)
980
+ xstart_mse = th.stack(xstart_mse, dim=1)
981
+ mse = th.stack(mse, dim=1)
982
+
983
+ prior_bpd = self._prior_bpd(x_start)
984
+ total_bpd = vb.sum(dim=1) + prior_bpd
985
+ return {
986
+ "total_bpd": total_bpd,
987
+ "prior_bpd": prior_bpd,
988
+ "vb": vb,
989
+ "xstart_mse": xstart_mse,
990
+ "mse": mse,
991
+ }
992
+
993
+
994
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
995
+ """
996
+ Extract values from a 1-D numpy array for a batch of indices.
997
+ :param arr: the 1-D numpy array.
998
+ :param timesteps: a tensor of indices into the array to extract.
999
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1000
+ dimension equal to the length of timesteps.
1001
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1002
+ """
1003
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1004
+ while len(res.shape) < len(broadcast_shape):
1005
+ res = res[..., None]
1006
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
DiT_VAE/diffusion/model/hed.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official vae and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+ import sys
8
+ from pathlib import Path
9
+ current_file_path = Path(__file__).resolve()
10
+ sys.path.insert(0, str(current_file_path.parent.parent.parent))
11
+ from torch import nn
12
+ import torch
13
+ import numpy as np
14
+ from torchvision import transforms as T
15
+ from tqdm import tqdm
16
+ from torch.utils.data import Dataset, DataLoader
17
+ import json
18
+ from PIL import Image
19
+ import torchvision.transforms.functional as TF
20
+ from accelerate import Accelerator
21
+ from diffusers.models import AutoencoderKL
22
+ import os
23
+
24
+ image_resize = 1024
25
+
26
+
27
+ class DoubleConvBlock(nn.Module):
28
+ def __init__(self, input_channel, output_channel, layer_number):
29
+ super().__init__()
30
+ self.convs = torch.nn.Sequential()
31
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
32
+ for i in range(1, layer_number):
33
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
34
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
35
+
36
+ def forward(self, x, down_sampling=False):
37
+ h = x
38
+ if down_sampling:
39
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
40
+ for conv in self.convs:
41
+ h = conv(h)
42
+ h = torch.nn.functional.relu(h)
43
+ return h, self.projection(h)
44
+
45
+
46
+ class ControlNetHED_Apache2(nn.Module):
47
+ def __init__(self):
48
+ super().__init__()
49
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
50
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
51
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
52
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
53
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
54
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
55
+
56
+ def forward(self, x):
57
+ h = x - self.norm
58
+ h, projection1 = self.block1(h)
59
+ h, projection2 = self.block2(h, down_sampling=True)
60
+ h, projection3 = self.block3(h, down_sampling=True)
61
+ h, projection4 = self.block4(h, down_sampling=True)
62
+ h, projection5 = self.block5(h, down_sampling=True)
63
+ return projection1, projection2, projection3, projection4, projection5
64
+
65
+
66
+ class InternData(Dataset):
67
+ def __init__(self):
68
+ ####
69
+ with open('data/InternData/partition/data_info.json', 'r') as f:
70
+ self.j = json.load(f)
71
+ self.transform = T.Compose([
72
+ T.Lambda(lambda img: img.convert('RGB')),
73
+ T.Resize(image_resize), # Image.BICUBIC
74
+ T.CenterCrop(image_resize),
75
+ T.ToTensor(),
76
+ ])
77
+
78
+ def __len__(self):
79
+ return len(self.j)
80
+
81
+ def getdata(self, idx):
82
+
83
+ path = self.j[idx]['path']
84
+ image = Image.open("data/InternImgs/" + path)
85
+ image = self.transform(image)
86
+ return image, path
87
+
88
+ def __getitem__(self, idx):
89
+ for i in range(20):
90
+ try:
91
+ data = self.getdata(idx)
92
+ return data
93
+ except Exception as e:
94
+ print(f"Error details: {str(e)}")
95
+ idx = np.random.randint(len(self))
96
+ raise RuntimeError('Too many bad data.')
97
+
98
+ class HEDdetector(nn.Module):
99
+ def __init__(self, feature=True, vae=None):
100
+ super().__init__()
101
+ self.model = ControlNetHED_Apache2()
102
+ self.model.load_state_dict(torch.load('output/pretrained_models/ControlNetHED.pth', map_location='cpu'))
103
+ self.model.eval()
104
+ self.model.requires_grad_(False)
105
+ if feature:
106
+ if vae is None:
107
+ self.vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema")
108
+ else:
109
+ self.vae = vae
110
+ self.vae.eval()
111
+ self.vae.requires_grad_(False)
112
+ else:
113
+ self.vae = None
114
+
115
+ def forward(self, input_image):
116
+ B, C, H, W = input_image.shape
117
+ with torch.inference_mode():
118
+ edges = self.model(input_image * 255.)
119
+ edges = torch.cat([TF.resize(e, [H, W]) for e in edges], dim=1)
120
+ edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1, keepdim=True)))
121
+ edge.clip_(0, 1)
122
+ if self.vae:
123
+ edge = TF.normalize(edge, [.5], [.5])
124
+ edge = edge.repeat(1, 3, 1, 1)
125
+ posterior = self.vae.encode(edge).latent_dist
126
+ edge = torch.cat([posterior.mean, posterior.std], dim=1).cpu().numpy()
127
+ return edge
128
+
129
+
130
+ def main():
131
+ dataset = InternData()
132
+ dataloader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=8, pin_memory=True)
133
+ hed = HEDdetector()
134
+
135
+ accelerator = Accelerator()
136
+ hed, dataloader = accelerator.prepare(hed, dataloader)
137
+
138
+
139
+ for img, path in tqdm(dataloader):
140
+ out = hed(img.cuda())
141
+ for p, o in zip(path, out):
142
+ save = f'data/InternalData/hed_feature_{image_resize}/' + p.replace('.png', '.npz')
143
+ if os.path.exists(save):
144
+ continue
145
+ os.makedirs(os.path.dirname(save), exist_ok=True)
146
+ np.savez_compressed(save, o)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()
DiT_VAE/diffusion/model/image_embedding.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, Dinov2Model
3
+ from PIL import Image
4
+ import requests
5
+
6
+ # url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
7
+ # image = Image.open(requests.get(url, stream=True).raw)
8
+ #
9
+ # processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
10
+ # model = AutoModel.from_pretrained('facebook/dinov2-base')
11
+ #
12
+ # inputs = processor(images=image, return_tensors="pt")
13
+ # outputs = model(**inputs)
14
+ # last_hidden_states = outputs[0]
15
+
DiT_VAE/diffusion/model/nets/PixArt_blocks.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ from timm.models.vision_transformer import Mlp, Attention as Attention_
15
+ from einops import rearrange
16
+ import xformers.ops
17
+
18
+
19
+ def modulate(x, shift, scale):
20
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21
+
22
+
23
+ def t2i_modulate(x, shift, scale):
24
+ return x * (1 + scale) + shift
25
+
26
+
27
+ class MultiHeadCrossAttention(nn.Module):
28
+ def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., **block_kwargs):
29
+ super(MultiHeadCrossAttention, self).__init__()
30
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
31
+
32
+ self.d_model = d_model
33
+ self.num_heads = num_heads
34
+ self.head_dim = d_model // num_heads
35
+
36
+ self.q_linear = nn.Linear(d_model, d_model)
37
+ self.kv_linear = nn.Linear(d_model, d_model * 2)
38
+ self.attn_drop = nn.Dropout(attn_drop)
39
+ self.proj = nn.Linear(d_model, d_model)
40
+ self.proj_drop = nn.Dropout(proj_drop)
41
+
42
+ def forward(self, x, cond, mask=None):
43
+ # query: img tokens; key/value: condition; mask: if padding tokens
44
+ B, N, C = x.shape
45
+
46
+ q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
47
+ kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
48
+
49
+ k, v = kv.unbind(2)
50
+ attn_bias = None
51
+ if mask is not None:
52
+ attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
53
+ x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
54
+ x = x.view(B, -1, C)
55
+ x = self.proj(x)
56
+ x = self.proj_drop(x)
57
+
58
+ # q = self.q_linear(x).reshape(B, -1, self.num_heads, self.head_dim)
59
+ # kv = self.kv_linear(cond).reshape(B, -1, 2, self.num_heads, self.head_dim)
60
+ # k, v = kv.unbind(2)
61
+ # attn_bias = None
62
+ # if mask is not None:
63
+ # attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
64
+ # attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
65
+ # x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
66
+ # x = x.contiguous().reshape(B, -1, C)
67
+ # x = self.proj(x)
68
+ # x = self.proj_drop(x)
69
+
70
+ return x
71
+
72
+
73
+ class WindowAttention(Attention_):
74
+ """Multi-head Attention block with relative position embeddings."""
75
+
76
+ def __init__(
77
+ self,
78
+ dim,
79
+ num_heads=8,
80
+ qkv_bias=True,
81
+ use_rel_pos=False,
82
+ rel_pos_zero_init=True,
83
+ input_size=None,
84
+ **block_kwargs,
85
+ ):
86
+ """
87
+ Args:
88
+ dim (int): Number of input channels.
89
+ num_heads (int): Number of attention heads.
90
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
91
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
92
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
93
+ input_size (int or None): Input resolution for calculating the relative positional
94
+ parameter size.
95
+ """
96
+ super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs)
97
+
98
+ self.use_rel_pos = use_rel_pos
99
+ if self.use_rel_pos:
100
+ # initialize relative positional embeddings
101
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
102
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
103
+
104
+ if not rel_pos_zero_init:
105
+ nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
106
+ nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
107
+
108
+ def forward(self, x, mask=None):
109
+ B, N, C = x.shape
110
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
111
+ q, k, v = qkv.unbind(2)
112
+ if use_fp32_attention := getattr(self, 'fp32_attention', False):
113
+ q, k, v = q.float(), k.float(), v.float()
114
+
115
+ attn_bias = None
116
+ if mask is not None:
117
+ attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
118
+ attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
119
+ x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
120
+
121
+ x = x.view(B, N, C)
122
+ x = self.proj(x)
123
+ x = self.proj_drop(x)
124
+ return x
125
+
126
+
127
+ #################################################################################
128
+ # AMP attention with fp32 softmax to fix loss NaN problem during training #
129
+ #################################################################################
130
+ class Attention(Attention_):
131
+ def forward(self, x):
132
+ B, N, C = x.shape
133
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
134
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
135
+ use_fp32_attention = getattr(self, 'fp32_attention', False)
136
+ if use_fp32_attention:
137
+ q, k = q.float(), k.float()
138
+ with torch.cuda.amp.autocast(enabled=not use_fp32_attention):
139
+ attn = (q @ k.transpose(-2, -1)) * self.scale
140
+ attn = attn.softmax(dim=-1)
141
+
142
+ attn = self.attn_drop(attn)
143
+
144
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
145
+ x = self.proj(x)
146
+ x = self.proj_drop(x)
147
+ return x
148
+
149
+ class AttentionTest(Attention_):
150
+ def forward(self, x, mask=None):
151
+ B, N, C = x.shape
152
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
153
+ q, k, v = qkv.unbind(2)
154
+
155
+
156
+ attn_bias = None
157
+ if mask is not None:
158
+ attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
159
+ attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
160
+ x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
161
+
162
+ x = x.view(B, N, C)
163
+ x = self.proj(x)
164
+ x = self.proj_drop(x)
165
+ return x
166
+
167
+ class FinalLayer(nn.Module):
168
+ """
169
+ The final layer of PixArt.
170
+ """
171
+
172
+ def __init__(self, hidden_size, patch_size, out_channels):
173
+ super().__init__()
174
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
175
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
176
+ self.adaLN_modulation = nn.Sequential(
177
+ nn.SiLU(),
178
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
179
+ )
180
+
181
+ def forward(self, x, c):
182
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
183
+ x = modulate(self.norm_final(x), shift, scale)
184
+ x = self.linear(x)
185
+ return x
186
+
187
+
188
+ class T2IFinalLayer(nn.Module):
189
+ """
190
+ The final layer of PixArt.
191
+ """
192
+
193
+ def __init__(self, hidden_size, patch_size, out_channels):
194
+ super().__init__()
195
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
196
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
197
+ self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
198
+ self.out_channels = out_channels
199
+
200
+ def forward(self, x, t):
201
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
202
+ x = t2i_modulate(self.norm_final(x), shift, scale)
203
+ x = self.linear(x)
204
+ return x
205
+
206
+
207
+ class MaskFinalLayer(nn.Module):
208
+ """
209
+ The final layer of PixArt.
210
+ """
211
+
212
+ def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
213
+ super().__init__()
214
+ self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
215
+ self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
216
+ self.adaLN_modulation = nn.Sequential(
217
+ nn.SiLU(),
218
+ nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
219
+ )
220
+
221
+ def forward(self, x, t):
222
+ shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
223
+ x = modulate(self.norm_final(x), shift, scale)
224
+ x = self.linear(x)
225
+ return x
226
+
227
+
228
+ class DecoderLayer(nn.Module):
229
+ """
230
+ The final layer of PixArt.
231
+ """
232
+
233
+ def __init__(self, hidden_size, decoder_hidden_size):
234
+ super().__init__()
235
+ self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
236
+ self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
237
+ self.adaLN_modulation = nn.Sequential(
238
+ nn.SiLU(),
239
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
240
+ )
241
+
242
+ def forward(self, x, t):
243
+ shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
244
+ x = modulate(self.norm_decoder(x), shift, scale)
245
+ x = self.linear(x)
246
+ return x
247
+
248
+
249
+ #################################################################################
250
+ # Embedding Layers for Timesteps and Class Labels #
251
+ #################################################################################
252
+ class TimestepEmbedder(nn.Module):
253
+ """
254
+ Embeds scalar timesteps into vector representations.
255
+ """
256
+
257
+ def __init__(self, hidden_size, frequency_embedding_size=256):
258
+ super().__init__()
259
+ self.mlp = nn.Sequential(
260
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
261
+ nn.SiLU(),
262
+ nn.Linear(hidden_size, hidden_size, bias=True),
263
+ )
264
+ self.frequency_embedding_size = frequency_embedding_size
265
+
266
+ @staticmethod
267
+ def timestep_embedding(t, dim, max_period=10000):
268
+ """
269
+ Create sinusoidal timestep embeddings.
270
+ :param t: a 1-D Tensor of N indices, one per batch element.
271
+ These may be fractional.
272
+ :param dim: the dimension of the output.
273
+ :param max_period: controls the minimum frequency of the embeddings.
274
+ :return: an (N, D) Tensor of positional embeddings.
275
+ """
276
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
277
+ half = dim // 2
278
+ freqs = torch.exp(
279
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
280
+ args = t[:, None].float() * freqs[None]
281
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
282
+ if dim % 2:
283
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
284
+ return embedding
285
+
286
+ def forward(self, t):
287
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(self.dtype)
288
+ return self.mlp(t_freq)
289
+
290
+ @property
291
+ def dtype(self):
292
+ # 返回模型参数的数据类型
293
+ return next(self.parameters()).dtype
294
+
295
+
296
+ class SizeEmbedder(TimestepEmbedder):
297
+ """
298
+ Embeds scalar timesteps into vector representations.
299
+ """
300
+
301
+ def __init__(self, hidden_size, frequency_embedding_size=256):
302
+ super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
303
+ self.mlp = nn.Sequential(
304
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
305
+ nn.SiLU(),
306
+ nn.Linear(hidden_size, hidden_size, bias=True),
307
+ )
308
+ self.frequency_embedding_size = frequency_embedding_size
309
+ self.outdim = hidden_size
310
+
311
+ def forward(self, s, bs):
312
+ if s.ndim == 1:
313
+ s = s[:, None]
314
+ assert s.ndim == 2
315
+ if s.shape[0] != bs:
316
+ s = s.repeat(bs // s.shape[0], 1)
317
+ assert s.shape[0] == bs
318
+ b, dims = s.shape[0], s.shape[1]
319
+ s = rearrange(s, "b d -> (b d)")
320
+ s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
321
+ s_emb = self.mlp(s_freq)
322
+ s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
323
+ return s_emb
324
+
325
+ @property
326
+ def dtype(self):
327
+ # 返回模型参数的数据类型
328
+ return next(self.parameters()).dtype
329
+
330
+
331
+ class LabelEmbedder(nn.Module):
332
+ """
333
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
334
+ """
335
+
336
+ def __init__(self, num_classes, hidden_size, dropout_prob):
337
+ super().__init__()
338
+ use_cfg_embedding = dropout_prob > 0
339
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
340
+ self.num_classes = num_classes
341
+ self.dropout_prob = dropout_prob
342
+
343
+ def token_drop(self, labels, force_drop_ids=None):
344
+ """
345
+ Drops labels to enable classifier-free guidance.
346
+ """
347
+ if force_drop_ids is None:
348
+ drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
349
+ else:
350
+ drop_ids = force_drop_ids == 1
351
+ labels = torch.where(drop_ids, self.num_classes, labels)
352
+ return labels
353
+
354
+ def forward(self, labels, train, force_drop_ids=None):
355
+ use_dropout = self.dropout_prob > 0
356
+ if (train and use_dropout) or (force_drop_ids is not None):
357
+ labels = self.token_drop(labels, force_drop_ids)
358
+ return self.embedding_table(labels)
359
+
360
+
361
+ def FeedForward(dim, mult=4):
362
+ inner_dim = int(dim * mult)
363
+ return nn.Sequential(
364
+ nn.LayerNorm(dim),
365
+ nn.Linear(dim, inner_dim, bias=False),
366
+ nn.GELU(),
367
+ nn.Linear(inner_dim, dim, bias=False),
368
+ )
369
+
370
+
371
+ def reshape_tensor(x, heads):
372
+ bs, length, width = x.shape
373
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
374
+ x = x.view(bs, length, heads, -1)
375
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
376
+ x = x.transpose(1, 2)
377
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
378
+ x = x.reshape(bs, heads, length, -1)
379
+ return x
380
+
381
+
382
+ class PerceiverAttention(nn.Module):
383
+ def __init__(self, *, dim, dim_head=64, heads=8):
384
+ super().__init__()
385
+ self.scale = dim_head ** -0.5
386
+ self.dim_head = dim_head
387
+ self.heads = heads
388
+ inner_dim = dim_head * heads
389
+
390
+ self.norm1 = nn.LayerNorm(dim)
391
+ self.norm2 = nn.LayerNorm(dim)
392
+
393
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
394
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
395
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
396
+
397
+ def forward(self, x, latents):
398
+ """
399
+ Args:
400
+ x (torch.Tensor): image features
401
+ shape (b, n1, D)
402
+ latent (torch.Tensor): latent features
403
+ shape (b, n2, D)
404
+ """
405
+ x = self.norm1(x)
406
+ latents = self.norm2(latents)
407
+
408
+ b, l, _ = latents.shape
409
+
410
+ q = self.to_q(latents)
411
+ kv_input = torch.cat((x, latents), dim=-2)
412
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
413
+
414
+ q = reshape_tensor(q, self.heads)
415
+ k = reshape_tensor(k, self.heads)
416
+ v = reshape_tensor(v, self.heads)
417
+
418
+ # attention
419
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
420
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
421
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
422
+ out = weight @ v
423
+
424
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
425
+
426
+ return self.to_out(out)
427
+
428
+
429
+ class ImageCaptionEmbedder(nn.Module):
430
+ """
431
+ Embeds image feature into vector representations. Also handles label dropout for classifier-free guidance.
432
+ """
433
+
434
+ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), depth=4,
435
+ dim_head=64, heads=12, ff_mult=4, token_num=4):
436
+ super().__init__()
437
+ self.latents = nn.Parameter(torch.randn(1, token_num, hidden_size) / hidden_size ** 0.5)
438
+
439
+ self.proj_in = nn.Linear(in_channels, hidden_size)
440
+
441
+ self.proj_out = Mlp(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size,
442
+ act_layer=act_layer, drop=0)
443
+ self.norm_out = nn.LayerNorm(hidden_size)
444
+
445
+ self.layers = nn.ModuleList([])
446
+ for _ in range(depth):
447
+ self.layers.append(
448
+ nn.ModuleList(
449
+ [
450
+ PerceiverAttention(dim=hidden_size, dim_head=dim_head, heads=heads),
451
+ FeedForward(dim=hidden_size, mult=ff_mult),
452
+ ]
453
+ )
454
+ )
455
+ self.uncond_prob = uncond_prob
456
+
457
+ def forward(self, x, train, force_drop_ids=None):
458
+ latents = self.latents.repeat(x.size(0), 1, 1)
459
+ x = self.proj_in(x)
460
+
461
+ for attn, ff in self.layers:
462
+ latents = attn(x, latents) + latents
463
+ latents = ff(latents) + latents
464
+
465
+ latents = self.proj_out(latents)
466
+ latents = self.norm_out(latents)
467
+ image_caption = latents.unsqueeze(1) # # (N, 1, L, D)
468
+ return image_caption
469
+
470
+
471
+
472
+ class DinoFeatureEmbedderQFormer(nn.Module):
473
+ """
474
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
475
+ """
476
+
477
+ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=257, depth=4,
478
+ dim_head=64, heads=12, ff_mult=4 ):
479
+ super().__init__()
480
+ self.latents = nn.Parameter(torch.randn(1, token_num, hidden_size) / hidden_size ** 0.5)
481
+
482
+ self.proj_in = nn.Linear(in_channels, hidden_size)
483
+
484
+ self.proj_out = Mlp(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size,
485
+ act_layer=act_layer, drop=0)
486
+ self.norm_out = nn.LayerNorm(hidden_size)
487
+
488
+ self.layers = nn.ModuleList([])
489
+ for _ in range(depth):
490
+ self.layers.append(
491
+ nn.ModuleList(
492
+ [
493
+ PerceiverAttention(dim=hidden_size, dim_head=dim_head, heads=heads),
494
+ FeedForward(dim=hidden_size, mult=ff_mult),
495
+ ]
496
+ )
497
+ )
498
+ def forward(self, x, train, force_drop_ids=None):
499
+ latents = self.latents.repeat(x.size(0), 1, 1)
500
+ x = self.proj_in(x)
501
+
502
+ for attn, ff in self.layers:
503
+ latents = attn(x, latents) + latents
504
+ latents = ff(latents) + latents
505
+
506
+ latents = self.proj_out(latents)
507
+ latents = self.norm_out(latents)
508
+ image_caption = latents.unsqueeze(1) # # (N, 1, L, D)
509
+ return image_caption
510
+
511
+
512
+
513
+ class DinoFeatureEmbedderV2(nn.Module):
514
+ """
515
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
516
+ """
517
+
518
+ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=257, use_drop=True, dino_norm=False):
519
+ super().__init__()
520
+ self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
521
+ act_layer=act_layer, drop=0)
522
+ self.dino_norm = dino_norm
523
+ if self.dino_norm:
524
+ self.norm_out = nn.LayerNorm(hidden_size)
525
+
526
+ def forward(self, dino_feature):
527
+ dino_feature = dino_feature.unsqueeze(1)
528
+ dino_feature = self.y_proj(dino_feature)
529
+ if self.dino_norm:
530
+ dino_feature = self.norm_out(dino_feature)
531
+
532
+ return dino_feature
533
+
534
+
535
+ class DinoFeatureEmbedder(nn.Module):
536
+ """
537
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
538
+ """
539
+
540
+ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=257 ):
541
+ super().__init__()
542
+ self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
543
+ act_layer=act_layer, drop=0)
544
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
545
+ self.uncond_prob = uncond_prob
546
+
547
+ def token_drop(self, dino_feature, force_drop_ids=None):
548
+ """
549
+ Drops labels to enable classifier-free guidance.
550
+ """
551
+ if force_drop_ids is None:
552
+ drop_ids = torch.rand(dino_feature.shape[0]).cuda() < self.uncond_prob
553
+ else:
554
+ force_drop_ids = torch.tensor(force_drop_ids).cuda()
555
+ drop_ids = force_drop_ids == 1
556
+ dino_feature = torch.where(drop_ids[:, None, None, None], self.y_embedding, dino_feature)
557
+ return dino_feature
558
+
559
+ def forward(self, dino_feature, train, force_drop_ids=None):
560
+ # print("dino_2", dino_feature)
561
+ dino_feature = dino_feature.unsqueeze(1)
562
+
563
+ if train:
564
+ assert dino_feature.shape[2:] == self.y_embedding.shape
565
+ use_dropout = self.uncond_prob > 0
566
+
567
+ if (train and use_dropout) or (force_drop_ids is not None and force_drop_ids != {} and len(force_drop_ids) != 0):
568
+ dino_feature = self.token_drop(dino_feature, force_drop_ids)
569
+ dino_feature = self.y_proj(dino_feature)
570
+ # print("dino_3", dino_feature)
571
+
572
+ return dino_feature
573
+
574
+
575
+ class FusionEmbedder(nn.Module):
576
+ """
577
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
578
+ """
579
+
580
+ def __init__(self, in_channels, hidden_size, act_layer=nn.GELU(approximate='tanh')):
581
+ super().__init__()
582
+ self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
583
+ act_layer=act_layer, drop=0)
584
+
585
+ def forward(self, fusion_feature):
586
+ dino_feature = self.y_proj(fusion_feature)
587
+ return dino_feature
588
+
589
+
590
+ class CaptionEmbedder(nn.Module):
591
+ """
592
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
593
+ """
594
+
595
+ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
596
+ super().__init__()
597
+ self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
598
+ act_layer=act_layer, drop=0)
599
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
600
+ self.uncond_prob = uncond_prob
601
+
602
+ def token_drop(self, caption, force_drop_ids=None):
603
+ """
604
+ Drops labels to enable classifier-free guidance.
605
+ """
606
+ if force_drop_ids is None:
607
+ drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
608
+ else:
609
+ drop_ids = force_drop_ids == 1
610
+ caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
611
+ return caption
612
+
613
+ def forward(self, caption, train, force_drop_ids=None):
614
+ if train:
615
+ assert caption.shape[2:] == self.y_embedding.shape
616
+ use_dropout = self.uncond_prob > 0
617
+ if (train and use_dropout) or (force_drop_ids is not None):
618
+ caption = self.token_drop(caption, force_drop_ids)
619
+ caption = self.y_proj(caption)
620
+ return caption
621
+
622
+
623
+ class CaptionEmbedderDoubleBr(nn.Module):
624
+ """
625
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
626
+ """
627
+
628
+ def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
629
+ super().__init__()
630
+ self.proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size,
631
+ act_layer=act_layer, drop=0)
632
+ self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
633
+ self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
634
+ self.uncond_prob = uncond_prob
635
+
636
+ def token_drop(self, global_caption, caption, force_drop_ids=None):
637
+ """
638
+ Drops labels to enable classifier-free guidance.
639
+ """
640
+ if force_drop_ids is None:
641
+ drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
642
+ else:
643
+ drop_ids = force_drop_ids == 1
644
+ global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
645
+ caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
646
+ return global_caption, caption
647
+
648
+ def forward(self, caption, train, force_drop_ids=None):
649
+ assert caption.shape[2:] == self.y_embedding.shape
650
+ global_caption = caption.mean(dim=2).squeeze()
651
+ use_dropout = self.uncond_prob > 0
652
+ if (train and use_dropout) or (force_drop_ids is not None):
653
+ global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
654
+ y_embed = self.proj(global_caption)
655
+ return y_embed, caption
DiT_VAE/diffusion/model/nets/TriDitCLIPDINO.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+ import torch
12
+ import torch.nn as nn
13
+ import os
14
+ import numpy as np
15
+ from timm.models.layers import DropPath
16
+ from timm.models.vision_transformer import PatchEmbed, Mlp
17
+ from DiT_VAE.diffusion.model.builder import MODELS
18
+ from DiT_VAE.diffusion.model.utils import auto_grad_checkpoint, to_2tuple
19
+ from DiT_VAE.diffusion.model.nets.PixArt_blocks import t2i_modulate, WindowAttention, MultiHeadCrossAttention, \
20
+ T2IFinalLayer, TimestepEmbedder, ImageCaptionEmbedder, DinoFeatureEmbedderQFormer
21
+ from DiT_VAE.diffusion.utils.logger import get_root_logger
22
+
23
+
24
+ class PixArtBlock(nn.Module):
25
+ """
26
+ A PixArt block with adaptive layer norm (adaLN-single) conditioning.
27
+ """
28
+
29
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None,
30
+ use_rel_pos=False, **block_kwargs):
31
+ super().__init__()
32
+ self.hidden_size = hidden_size
33
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
34
+ self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
35
+ input_size=input_size if window_size == 0 else (window_size, window_size),
36
+ use_rel_pos=use_rel_pos, **block_kwargs)
37
+ self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
38
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
39
+ # to be compatible with lower version pytorch
40
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
41
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
42
+ drop=0)
43
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
44
+ self.window_size = window_size
45
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
46
+
47
+ def forward(self, x, y, t, mask=None, img_feature=None, **kwargs):
48
+ B, N, C = x.shape
49
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
50
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
51
+ if img_feature is None:
52
+ x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
53
+ else:
54
+ x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
55
+ img_feature = img_feature.squeeze(1)
56
+ N_new = N + img_feature.shape[1]
57
+ x_m = self.attn(torch.cat([x_m, img_feature], dim=1)).reshape(B, N_new, C)
58
+ x_m = x_m[:,:N, :]
59
+ x = x + self.drop_path(gate_msa * x_m)
60
+ x = x + self.cross_attn(x, y, mask)
61
+ x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
62
+
63
+ return x
64
+
65
+
66
+ #############################################################################
67
+ # Core PixArt Model #
68
+ #################################################################################
69
+ @MODELS.register_module()
70
+ class TriDitCLIPDINO(nn.Module):
71
+ """
72
+ Diffusion model with a Transformer backbone.
73
+ """
74
+
75
+ def __init__(self, input_size, patch_size=2, in_channels=8, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0,
76
+ class_dropout_prob=0.1, pred_sigma=True, drop_path: float = 0., window_size=0,
77
+ window_block_indexes=None, use_rel_pos=False, caption_channels=1280, lewei_scale=1.0, config=None, dino_channels=768, img_feature_self_attention=False, dino_norm=False,
78
+ model_max_length=257, **kwargs):
79
+ if window_block_indexes is None:
80
+ window_block_indexes = []
81
+ super().__init__()
82
+ self.img_feature_self_attention= img_feature_self_attention
83
+ self.pred_sigma = pred_sigma
84
+ self.in_channels = in_channels
85
+ self.out_channels = in_channels * 2 if pred_sigma else in_channels
86
+ self.patch_size = patch_size
87
+ self.num_heads = num_heads
88
+ self.lewei_scale = lewei_scale,
89
+ assert isinstance(input_size, tuple)
90
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
91
+ self.t_embedder = TimestepEmbedder(hidden_size)
92
+ num_patches = self.x_embedder.num_patches
93
+ self.base_size_h = input_size[0] // self.patch_size
94
+ self.base_size_w = input_size[1] // self.patch_size
95
+ self.h = self.base_size_h
96
+ self.w= self.base_size_w
97
+ # Will use fixed sin-cos embedding:
98
+ self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
99
+
100
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
101
+ self.t_block = nn.Sequential(
102
+ nn.SiLU(),
103
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
104
+ )
105
+ self.dino_embedder = DinoFeatureEmbedderQFormer(in_channels=dino_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=256)
106
+ self.y_embedder = ImageCaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size,
107
+ uncond_prob=class_dropout_prob, act_layer=approx_gelu,
108
+ token_num=16)
109
+ drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
110
+ self.blocks = nn.ModuleList([
111
+ PixArtBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
112
+ input_size=(input_size[0] // patch_size, input_size[1] // patch_size),
113
+ window_size=window_size if i in window_block_indexes else 0,
114
+ use_rel_pos=use_rel_pos if i in window_block_indexes else False)
115
+ for i in range(depth)
116
+ ])
117
+ self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
118
+
119
+ self.initialize_weights()
120
+
121
+ if config:
122
+ logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
123
+ logger.warning(
124
+ f"lewei scale: {self.lewei_scale}, base size h: {self.base_size_h} base size w: {self.base_size_w}")
125
+ else:
126
+ print(
127
+ f'Warning: lewei scale: {self.lewei_scale}, base size h: {self.base_size_h} base size w: {self.base_size_w}')
128
+
129
+ def forward(self, x, timestep, y, img_feature, drop_img_mask=None, mask=None, data_info=None, **kwargs):
130
+ """
131
+ Forward pass of PixArt.
132
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
133
+ t: (N,) tensor of diffusion timesteps
134
+ y: (N, 1, 120, C) tensor of class labels
135
+ """
136
+ x = x.to(self.dtype)
137
+ timestep = timestep.to(self.dtype)
138
+ y = y.to(self.dtype)
139
+ img_feature = img_feature.to(self.dtype)
140
+ pos_embed = self.pos_embed.to(self.dtype)
141
+ self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
142
+ x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
143
+ t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
144
+ t0 = self.t_block(t)
145
+ y = self.y_embedder(y, self.training) # (N, 1, L, D)
146
+ img_embedding = self.dino_embedder(img_feature, self.training)
147
+ # y_fusion = y
148
+ if mask is not None:
149
+ if mask.shape[0] != y.shape[0]:
150
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
151
+ mask = mask.squeeze(1).squeeze(1)
152
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
153
+ y_lens = mask.sum(dim=1).tolist()
154
+ else:
155
+ y_lens = [y.shape[2]] * y.shape[0]
156
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
157
+ for block in self.blocks:
158
+ x = auto_grad_checkpoint(block, x, y, t0, y_lens, img_embedding) # (N, T, D) #support grad checkpoint
159
+ x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
160
+ x = self.unpatchify(x) # (N, out_channels, H, W)
161
+ return x
162
+
163
+ def forward_with_dpmsolver(self, x, timestep, y, img_feature, mask=None, **kwargs):
164
+ """
165
+ dpm solver donnot need variance prediction
166
+ """
167
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
168
+ model_out = self.forward(x, timestep, y, img_feature)
169
+ return model_out.chunk(2, dim=1)[0]
170
+
171
+ def forward_with_cfg(self, x, timestep, y, img_feature, cfg_scale, mask=None, **kwargs):
172
+ """
173
+ Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
174
+ """
175
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
176
+ half = x[: len(x) // 2]
177
+ combined = torch.cat([half, half], dim=0)
178
+ model_out = self.forward(combined, timestep, y, img_feature, kwargs)
179
+ model_out = model_out['x'] if isinstance(model_out, dict) else model_out
180
+ eps, rest = torch.split(model_out, self.in_channels, dim=1)
181
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
182
+
183
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
184
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
185
+ eps = torch.cat([half_eps, half_eps], dim=0)
186
+ return torch.cat([eps, rest], dim=1)
187
+
188
+ def unpatchify(self, x):
189
+ """
190
+ x: (N, T, patch_size**2 * C)
191
+ imgs: (N, H, W, C)
192
+ """
193
+ c = self.out_channels
194
+ p = self.x_embedder.patch_size[0]
195
+ h = int(x.shape[1] ** 0.5 * 2)
196
+ w = int(x.shape[1] ** 0.5 / 2)
197
+ assert h * w == x.shape[1]
198
+
199
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
200
+ x = torch.einsum('nhwpqc->nchpwq', x)
201
+ return x.reshape(shape=(x.shape[0], c, h * p, w * p))
202
+
203
+ def initialize_weights(self):
204
+ # Initialize transformer layers:
205
+ def _basic_init(module):
206
+ if isinstance(module, nn.Linear):
207
+ torch.nn.init.xavier_uniform_(module.weight)
208
+ if module.bias is not None:
209
+ nn.init.constant_(module.bias, 0)
210
+
211
+ self.apply(_basic_init)
212
+
213
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
214
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5),
215
+ lewei_scale=self.lewei_scale, base_size_h=self.base_size_h,
216
+ base_size_w=self.base_size_w)
217
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
218
+
219
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
220
+ w = self.x_embedder.proj.weight.data
221
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
222
+
223
+ # Initialize timestep embedding MLP:
224
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
225
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
226
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
227
+
228
+ # Initialize caption embedding MLP:
229
+ nn.init.normal_(self.y_embedder.proj_out.fc1.weight, std=0.02)
230
+ nn.init.normal_(self.y_embedder.proj_out.fc2.weight, std=0.02)
231
+ nn.init.normal_(self.y_embedder.proj_in.weight, std=0.02)
232
+
233
+
234
+ # Initialize dino embedding MLP:
235
+ # nn.init.normal_(self.dino_embedder.y_proj.fc1.weight, std=0.02)
236
+ # nn.init.normal_(self.dino_embedder.y_proj.fc2.weight, std=0.02)
237
+ nn.init.normal_(self.dino_embedder.proj_out.fc1.weight, std=0.02)
238
+ nn.init.normal_(self.dino_embedder.proj_out.fc2.weight, std=0.02)
239
+ nn.init.normal_(self.dino_embedder.proj_in.weight, std=0.02)
240
+ # if not self.img_feature_self_attention:
241
+ # # Initialize fusion embedding MLP:
242
+ # nn.init.normal_(self.fusion_embedder.y_proj.fc1.weight, std=0.02)
243
+ # nn.init.normal_(self.fusion_embedder.y_proj.fc2.weight, std=0.02)
244
+
245
+ # Zero-out adaLN modulation layers in PixArt blocks:
246
+ for block in self.blocks:
247
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
248
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
249
+
250
+ # Zero-out output layers:
251
+ nn.init.constant_(self.final_layer.linear.weight, 0)
252
+ nn.init.constant_(self.final_layer.linear.bias, 0)
253
+
254
+ @property
255
+ def dtype(self):
256
+ return next(self.parameters()).dtype
257
+
258
+
259
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, lewei_scale=1.0, base_size_h=16,
260
+ base_size_w=16):
261
+ """
262
+ grid_size: int of the grid height and width
263
+ return:
264
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
265
+ """
266
+ if isinstance(grid_size, int):
267
+ grid_size = to_2tuple(grid_size)
268
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size_h) / lewei_scale
269
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size_w) / lewei_scale
270
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
271
+ grid = np.stack(grid, axis=0)
272
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
273
+
274
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
275
+ if cls_token and extra_tokens > 0:
276
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
277
+ return pos_embed
278
+
279
+
280
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
281
+ assert embed_dim % 2 == 0
282
+
283
+ # use half of dimensions to encode grid_h
284
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
285
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
286
+
287
+ return np.concatenate([emb_h, emb_w], axis=1)
288
+
289
+
290
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
291
+ """
292
+ embed_dim: output dimension for each position
293
+ pos: a list of positions to be encoded: size (M,)
294
+ out: (M, D)
295
+ """
296
+ assert embed_dim % 2 == 0
297
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
298
+ omega /= embed_dim / 2.
299
+ omega = 1. / 10000 ** omega # (D/2,)
300
+
301
+ pos = pos.reshape(-1) # (M,)
302
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
303
+
304
+ emb_sin = np.sin(out) # (M, D/2)
305
+ emb_cos = np.cos(out) # (M, D/2)
306
+
307
+ return np.concatenate([emb_sin, emb_cos], axis=1)
308
+
309
+
310
+ #################################################################################
311
+ # PixArt Configs #
312
+ #################################################################################
313
+ @MODELS.register_module()
314
+ def TriDitCLIPDINO_XL_2(**kwargs):
315
+ return TriDitCLIPDINO(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
DiT_VAE/diffusion/model/nets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .TriDitCLIPDINO import TriDitCLIPDINO_XL_2, TriDitCLIPDINO
DiT_VAE/diffusion/model/respace.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
52
+ cur_idx = 0.0
53
+ taken_steps = []
54
+ for _ in range(section_count):
55
+ taken_steps.append(start_idx + round(cur_idx))
56
+ cur_idx += frac_stride
57
+ all_steps += taken_steps
58
+ start_idx += size
59
+ return set(all_steps)
60
+
61
+
62
+ class SpacedDiffusion(GaussianDiffusion):
63
+ """
64
+ A diffusion process which can skip steps in a base diffusion process.
65
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
66
+ original diffusion process to retain.
67
+ :param kwargs: the kwargs to create the base diffusion process.
68
+ """
69
+
70
+ def __init__(self, use_timesteps, **kwargs):
71
+ self.use_timesteps = set(use_timesteps)
72
+ self.timestep_map = []
73
+ self.original_num_steps = len(kwargs["betas"])
74
+
75
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
76
+ last_alpha_cumprod = 1.0
77
+ new_betas = []
78
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
79
+ if i in self.use_timesteps:
80
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
81
+ last_alpha_cumprod = alpha_cumprod
82
+ self.timestep_map.append(i)
83
+ kwargs["betas"] = np.array(new_betas)
84
+ super().__init__(**kwargs)
85
+
86
+ def p_mean_variance(
87
+ self, model, *args, **kwargs
88
+ ): # pylint: disable=signature-differs
89
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
90
+
91
+ def training_losses(
92
+ self, model, *args, **kwargs
93
+ ): # pylint: disable=signature-differs
94
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
95
+
96
+ def training_losses_diffusers(
97
+ self, model, *args, **kwargs
98
+ ): # pylint: disable=signature-differs
99
+ return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs)
100
+
101
+ def condition_mean(self, cond_fn, *args, **kwargs):
102
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
103
+
104
+ def condition_score(self, cond_fn, *args, **kwargs):
105
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
106
+
107
+ def _wrap_model(self, model):
108
+ if isinstance(model, _WrappedModel):
109
+ return model
110
+ return _WrappedModel(
111
+ model, self.timestep_map, self.original_num_steps
112
+ )
113
+
114
+ def _scale_timesteps(self, t):
115
+ # Scaling is done by the wrapped model.
116
+ return t
117
+
118
+
119
+ class _WrappedModel:
120
+ def __init__(self, model, timestep_map, original_num_steps):
121
+ self.model = model
122
+ self.timestep_map = timestep_map
123
+ # self.rescale_timesteps = rescale_timesteps
124
+ self.original_num_steps = original_num_steps
125
+
126
+ def __call__(self, x, timestep, **kwargs):
127
+ map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype)
128
+ new_ts = map_tensor[timestep]
129
+ # if self.rescale_timesteps:
130
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
131
+ return self.model(x, timestep=new_ts, **kwargs)
DiT_VAE/diffusion/model/sa_solver.py ADDED
@@ -0,0 +1,1129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from tqdm import tqdm
5
+
6
+
7
+ class NoiseScheduleVP:
8
+ def __init__(
9
+ self,
10
+ schedule='discrete',
11
+ betas=None,
12
+ alphas_cumprod=None,
13
+ continuous_beta_0=0.1,
14
+ continuous_beta_1=20.,
15
+ dtype=torch.float32,
16
+ ):
17
+ """Thanks to DPM-Solver for their code base"""
18
+ """Create a wrapper class for the forward SDE (VP type).
19
+ ***
20
+ Update: We support discrete-time diffusion vae by implementing a picewise linear interpolation for log_alpha_t.
21
+ We recommend to use schedule='discrete' for the discrete-time diffusion vae, especially for high-resolution images.
22
+ ***
23
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
24
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
25
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
26
+ log_alpha_t = self.marginal_log_mean_coeff(t)
27
+ sigma_t = self.marginal_std(t)
28
+ lambda_t = self.marginal_lambda(t)
29
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
30
+ t = self.inverse_lambda(lambda_t)
31
+ ===============================================================
32
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
33
+ 1. For discrete-time DPMs:
34
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
35
+ t_i = (i + 1) / N
36
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
37
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
38
+ Args:
39
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
40
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
41
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
42
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
43
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
44
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
45
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
46
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
47
+ and
48
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
49
+ 2. For continuous-time DPMs:
50
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
51
+ schedule are the default settings in DDPM and improved-DDPM:
52
+ Args:
53
+ beta_min: A `float` number. The smallest beta for the linear schedule.
54
+ beta_max: A `float` number. The largest beta for the linear schedule.
55
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
56
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
57
+ T: A `float` number. The ending time of the forward process.
58
+ ===============================================================
59
+ Args:
60
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
61
+ 'linear' or 'cosine' for continuous-time DPMs.
62
+ Returns:
63
+ A wrapper object of the forward SDE (VP type).
64
+
65
+ ===============================================================
66
+ Example:
67
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
68
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
69
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
70
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
71
+ # For continuous-time DPMs (VPSDE), linear schedule:
72
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
73
+ """
74
+
75
+ if schedule not in ['discrete', 'linear', 'cosine']:
76
+ raise ValueError(
77
+ f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'"
78
+ )
79
+
80
+ self.schedule = schedule
81
+ if schedule == 'discrete':
82
+ if betas is not None:
83
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
84
+ else:
85
+ assert alphas_cumprod is not None
86
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
87
+ self.total_N = len(log_alphas)
88
+ self.T = 1.
89
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
90
+ self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype)
91
+ else:
92
+ self.total_N = 1000
93
+ self.beta_0 = continuous_beta_0
94
+ self.beta_1 = continuous_beta_1
95
+ self.cosine_s = 0.008
96
+ self.cosine_beta_max = 999.
97
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
98
+ 1. + self.cosine_s) / math.pi - self.cosine_s
99
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
100
+ self.schedule = schedule
101
+ self.T = 0.9946 if schedule == 'cosine' else 1.
102
+
103
+ def marginal_log_mean_coeff(self, t):
104
+ """
105
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
106
+ """
107
+ if self.schedule == 'discrete':
108
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
109
+ self.log_alpha_array.to(t.device)).reshape((-1))
110
+ elif self.schedule == 'linear':
111
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
112
+ elif self.schedule == 'cosine':
113
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
114
+ return log_alpha_fn(t) - self.cosine_log_alpha_0
115
+
116
+ def marginal_alpha(self, t):
117
+ """
118
+ Compute alpha_t of a given continuous-time label t in [0, T].
119
+ """
120
+ return torch.exp(self.marginal_log_mean_coeff(t))
121
+
122
+ def marginal_std(self, t):
123
+ """
124
+ Compute sigma_t of a given continuous-time label t in [0, T].
125
+ """
126
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
127
+
128
+ def marginal_lambda(self, t):
129
+ """
130
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
131
+ """
132
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
133
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
134
+ return log_mean_coeff - log_std
135
+
136
+ def inverse_lambda(self, lamb):
137
+ """
138
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
139
+ """
140
+ if self.schedule == 'linear':
141
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
142
+ Delta = self.beta_0 ** 2 + tmp
143
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
144
+ elif self.schedule == 'discrete':
145
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
146
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
147
+ torch.flip(self.t_array.to(lamb.device), [1]))
148
+ return t.reshape((-1,))
149
+ else:
150
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
151
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
152
+ 1. + self.cosine_s) / math.pi - self.cosine_s
153
+ return t_fn(log_alpha)
154
+
155
+ def edm_sigma(self, t):
156
+ return self.marginal_std(t) / self.marginal_alpha(t)
157
+
158
+ def edm_inverse_sigma(self, edmsigma):
159
+ alpha = 1 / (edmsigma ** 2 + 1).sqrt()
160
+ sigma = alpha * edmsigma
161
+ lambda_t = torch.log(alpha / sigma)
162
+ return self.inverse_lambda(lambda_t)
163
+
164
+
165
+ def model_wrapper(
166
+ model,
167
+ noise_schedule,
168
+ model_type="noise",
169
+ model_kwargs={},
170
+ guidance_type="uncond",
171
+ condition=None,
172
+ unconditional_condition=None,
173
+ guidance_scale=1.,
174
+ classifier_fn=None,
175
+ classifier_kwargs={},
176
+ ):
177
+ """Thanks to DPM-Solver for their code base"""
178
+ """Create a wrapper function for the noise prediction model.
179
+ SA-Solver needs to solve the continuous-time diffusion SDEs. For DPMs trained on discrete-time labels, we need to
180
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
181
+ We support four types of the diffusion model by setting `model_type`:
182
+ 1. "noise": noise prediction model. (Trained by predicting noise).
183
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
184
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
185
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
186
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion vae."
187
+ arXiv preprint arXiv:2202.00512 (2022).
188
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
189
+ arXiv preprint arXiv:2210.02303 (2022).
190
+
191
+ 4. "score": marginal score function. (Trained by denoising score matching).
192
+ Note that the score function and the noise prediction model follows a simple relationship:
193
+ ```
194
+ noise(x_t, t) = -sigma_t * score(x_t, t)
195
+ ```
196
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
197
+ 1. "uncond": unconditional sampling by DPMs.
198
+ The input `model` has the following format:
199
+ ``
200
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
201
+ ``
202
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
203
+ The input `model` has the following format:
204
+ ``
205
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
206
+ ``
207
+ The input `classifier_fn` has the following format:
208
+ ``
209
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
210
+ ``
211
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion vae beat GANs on image synthesis,"
212
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
213
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
214
+ The input `model` has the following format:
215
+ ``
216
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
217
+ ``
218
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
219
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
220
+ arXiv preprint arXiv:2207.12598 (2022).
221
+
222
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
223
+ or continuous-time labels (i.e. epsilon to T).
224
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
225
+ ``
226
+ def model_fn(x, t_continuous) -> noise:
227
+ t_input = get_model_input_time(t_continuous)
228
+ return noise_pred(model, x, t_input, **model_kwargs)
229
+ ``
230
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for SA-Solver.
231
+ ===============================================================
232
+ Args:
233
+ model: A diffusion model with the corresponding format described above.
234
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
235
+ model_type: A `str`. The parameterization type of the diffusion model.
236
+ "noise" or "x_start" or "v" or "score".
237
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
238
+ guidance_type: A `str`. The type of the guidance for sampling.
239
+ "uncond" or "classifier" or "classifier-free".
240
+ condition: A pytorch tensor. The condition for the guided sampling.
241
+ Only used for "classifier" or "classifier-free" guidance type.
242
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
243
+ Only used for "classifier-free" guidance type.
244
+ guidance_scale: A `float`. The scale for the guided sampling.
245
+ classifier_fn: A classifier function. Only used for the classifier guidance.
246
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
247
+ Returns:
248
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
249
+ """
250
+
251
+ def get_model_input_time(t_continuous):
252
+ """
253
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
254
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
255
+ For continuous-time DPMs, we just use `t_continuous`.
256
+ """
257
+ if noise_schedule.schedule == 'discrete':
258
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
259
+ else:
260
+ return t_continuous
261
+
262
+ def noise_pred_fn(x, t_continuous, cond=None, cond_2=None):
263
+ t_input = get_model_input_time(t_continuous)
264
+ if cond is None:
265
+ output = model(x, t_input, **model_kwargs)
266
+ else:
267
+ output = model(x, t_input, cond, cond_2, **model_kwargs)
268
+ if model_type == "noise":
269
+ return output
270
+ elif model_type == "x_start":
271
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
272
+ return (x - alpha_t[0] * output) / sigma_t[0]
273
+ elif model_type == "v":
274
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
275
+ return alpha_t[0] * output + sigma_t[0] * x
276
+ elif model_type == "score":
277
+ sigma_t = noise_schedule.marginal_std(t_continuous)
278
+ return -sigma_t[0] * output
279
+
280
+ def cond_grad_fn(x, t_input):
281
+ """
282
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
283
+ """
284
+ with torch.enable_grad():
285
+ x_in = x.detach().requires_grad_(True)
286
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
287
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
288
+
289
+ def model_fn(x, t_continuous):
290
+ """
291
+ The noise predicition model function that is used for DPM-Solver.
292
+ """
293
+ if guidance_type == "uncond":
294
+ return noise_pred_fn(x, t_continuous)
295
+ elif guidance_type == "classifier":
296
+ assert classifier_fn is not None
297
+ t_input = get_model_input_time(t_continuous)
298
+ cond_grad = cond_grad_fn(x, t_input)
299
+ sigma_t = noise_schedule.marginal_std(t_continuous)
300
+ noise = noise_pred_fn(x, t_continuous)
301
+ return noise - guidance_scale * sigma_t * cond_grad
302
+ elif guidance_type == "classifier-free":
303
+ if guidance_scale == 1. or unconditional_condition is None:
304
+ return noise_pred_fn(x, t_continuous, cond=condition)
305
+ x_in = torch.cat([x] * 2)
306
+ t_in = torch.cat([t_continuous] * 2)
307
+ # c_in = torch.cat([unconditional_condition, condition])
308
+ c_in_y = torch.cat([unconditional_condition[0], condition[0]])
309
+ c_in_dino = torch.cat([unconditional_condition[1], condition[1]])
310
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in_y, cond_2=c_in_dino).chunk(2)
311
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
312
+
313
+ assert model_type in ["noise", "x_start", "v", "score"]
314
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
315
+ return model_fn
316
+
317
+
318
+ class SASolver:
319
+ def __init__(
320
+ self,
321
+ model_fn,
322
+ noise_schedule,
323
+ algorithm_type="data_prediction",
324
+ correcting_x0_fn=None,
325
+ correcting_xt_fn=None,
326
+ thresholding_max_val=1.,
327
+ dynamic_thresholding_ratio=0.995
328
+ ):
329
+ """
330
+ Construct a SA-Solver
331
+ The default value for algorithm_type is "data_prediction" and we recommend not to change it to
332
+ "noise_prediction". For details, please see Appendix A.2.4 in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
333
+ """
334
+
335
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
336
+ self.noise_schedule = noise_schedule
337
+ assert algorithm_type in ["data_prediction", "noise_prediction"]
338
+
339
+ if correcting_x0_fn == "dynamic_thresholding":
340
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
341
+ else:
342
+ self.correcting_x0_fn = correcting_x0_fn
343
+
344
+ self.correcting_xt_fn = correcting_xt_fn
345
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
346
+ self.thresholding_max_val = thresholding_max_val
347
+
348
+ self.predict_x0 = algorithm_type == "data_prediction"
349
+
350
+ self.sigma_min = float(self.noise_schedule.edm_sigma(torch.tensor([1e-3])))
351
+ self.sigma_max = float(self.noise_schedule.edm_sigma(torch.tensor([1])))
352
+
353
+ def dynamic_thresholding_fn(self, x0, t=None):
354
+ """
355
+ The dynamic thresholding method.
356
+ """
357
+ dims = x0.dim()
358
+ p = self.dynamic_thresholding_ratio
359
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
360
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
361
+ x0 = torch.clamp(x0, -s, s) / s
362
+ return x0
363
+
364
+ def noise_prediction_fn(self, x, t):
365
+ """
366
+ Return the noise prediction model.
367
+ """
368
+ return self.model(x, t)
369
+
370
+ def data_prediction_fn(self, x, t):
371
+ """
372
+ Return the data prediction model (with corrector).
373
+ """
374
+ noise = self.noise_prediction_fn(x, t)
375
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
376
+ x0 = (x - sigma_t * noise) / alpha_t
377
+ if self.correcting_x0_fn is not None:
378
+ x0 = self.correcting_x0_fn(x0)
379
+ return x0
380
+
381
+ def model_fn(self, x, t):
382
+ """
383
+ Convert the model to the noise prediction model or the data prediction model.
384
+ """
385
+
386
+ if self.predict_x0:
387
+ return self.data_prediction_fn(x, t)
388
+ else:
389
+ return self.noise_prediction_fn(x, t)
390
+
391
+ def get_time_steps(self, skip_type, t_T, t_0, N, order, device):
392
+ """Compute the intermediate time steps for sampling.
393
+ """
394
+ if skip_type == 'logSNR':
395
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
396
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
397
+ logSNR_steps = lambda_T + torch.linspace(torch.tensor(0.).cpu().item(),
398
+ (lambda_0 - lambda_T).cpu().item() ** (1. / order), N + 1).pow(
399
+ order).to(device)
400
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
401
+ elif skip_type == 'time':
402
+ t = torch.linspace(t_T ** (1. / order), t_0 ** (1. / order), N + 1).pow(order).to(device)
403
+ return t
404
+ elif skip_type == 'karras':
405
+ sigma_min = max(0.002, self.sigma_min)
406
+ sigma_max = min(80, self.sigma_max)
407
+ sigma_steps = torch.linspace(sigma_max ** (1. / 7), sigma_min ** (1. / 7), N + 1).pow(7).to(device)
408
+ return self.noise_schedule.edm_inverse_sigma(sigma_steps)
409
+ else:
410
+ raise ValueError(
411
+ f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time' or 'karras'"
412
+ )
413
+
414
+ def denoise_to_zero_fn(self, x, s):
415
+ """
416
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
417
+ """
418
+ return self.data_prediction_fn(x, s)
419
+
420
+ def get_coefficients_exponential_negative(self, order, interval_start, interval_end):
421
+ """
422
+ Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end
423
+ For calculating the coefficient of gradient terms after the lagrange interpolation,
424
+ see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
425
+ For noise_prediction formula.
426
+ """
427
+ assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
428
+
429
+ if order == 0:
430
+ return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1)
431
+ elif order == 1:
432
+ return torch.exp(-interval_end) * (
433
+ (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1))
434
+ elif order == 2:
435
+ return torch.exp(-interval_end) * (
436
+ (interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - (
437
+ interval_end ** 2 + 2 * interval_end + 2))
438
+ elif order == 3:
439
+ return torch.exp(-interval_end) * (
440
+ (interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp(
441
+ interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6))
442
+
443
+ def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau):
444
+ """
445
+ Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
446
+ For calculating the coefficient of gradient terms after the lagrange interpolation,
447
+ see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
448
+ For data_prediction formula.
449
+ """
450
+ assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
451
+
452
+ # after change of variable(cov)
453
+ interval_end_cov = (1 + tau ** 2) * interval_end
454
+ interval_start_cov = (1 + tau ** 2) * interval_start
455
+
456
+ if order == 0:
457
+ return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / (
458
+ (1 + tau ** 2))
459
+ elif order == 1:
460
+ return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(
461
+ -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2)
462
+ elif order == 2:
463
+ return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - (
464
+ interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp(
465
+ -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3)
466
+ elif order == 3:
467
+ return torch.exp(interval_end_cov) * (
468
+ (interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - (
469
+ interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp(
470
+ -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4)
471
+
472
+ def lagrange_polynomial_coefficient(self, order, lambda_list):
473
+ """
474
+ Calculate the coefficient of lagrange polynomial
475
+ For lagrange interpolation
476
+ """
477
+ assert order in [0, 1, 2, 3]
478
+ assert order == len(lambda_list) - 1
479
+ if order == 0:
480
+ return [[1]]
481
+ elif order == 1:
482
+ return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
483
+ [1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
484
+ elif order == 2:
485
+ denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
486
+ denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2])
487
+ denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1])
488
+ return [[1 / denominator1,
489
+ (-lambda_list[1] - lambda_list[2]) / denominator1,
490
+ lambda_list[1] * lambda_list[2] / denominator1],
491
+
492
+ [1 / denominator2,
493
+ (-lambda_list[0] - lambda_list[2]) / denominator2,
494
+ lambda_list[0] * lambda_list[2] / denominator2],
495
+
496
+ [1 / denominator3,
497
+ (-lambda_list[0] - lambda_list[1]) / denominator3,
498
+ lambda_list[0] * lambda_list[1] / denominator3]
499
+ ]
500
+ elif order == 3:
501
+ denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (
502
+ lambda_list[0] - lambda_list[3])
503
+ denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (
504
+ lambda_list[1] - lambda_list[3])
505
+ denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (
506
+ lambda_list[2] - lambda_list[3])
507
+ denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (
508
+ lambda_list[3] - lambda_list[2])
509
+ return [[1 / denominator1,
510
+ (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1,
511
+ (lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[
512
+ 3]) / denominator1,
513
+ (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1],
514
+
515
+ [1 / denominator2,
516
+ (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2,
517
+ (lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[
518
+ 3]) / denominator2,
519
+ (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],
520
+
521
+ [1 / denominator3,
522
+ (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
523
+ (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[
524
+ 3]) / denominator3,
525
+ (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],
526
+
527
+ [1 / denominator4,
528
+ (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
529
+ (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[
530
+ 2]) / denominator4,
531
+ (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
532
+
533
+ ]
534
+
535
+ def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau):
536
+ """
537
+ Calculate the coefficient of gradients.
538
+ """
539
+ assert order in [1, 2, 3, 4]
540
+ assert order == len(lambda_list), 'the length of lambda list must be equal to the order'
541
+ coefficients = []
542
+ lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list)
543
+ for i in range(order):
544
+ coefficient = sum(
545
+ lagrange_coefficient[i][j]
546
+ * self.get_coefficients_exponential_positive(
547
+ order - 1 - j, interval_start, interval_end, tau
548
+ )
549
+ if self.predict_x0
550
+ else lagrange_coefficient[i][j]
551
+ * self.get_coefficients_exponential_negative(
552
+ order - 1 - j, interval_start, interval_end
553
+ )
554
+ for j in range(order)
555
+ )
556
+ coefficients.append(coefficient)
557
+ assert len(coefficients) == order, 'the length of coefficients does not match the order'
558
+ return coefficients
559
+
560
+ def adams_bashforth_update(self, order, x, tau, model_prev_list, t_prev_list, noise, t):
561
+ """
562
+ SA-Predictor, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
563
+ """
564
+ assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
565
+
566
+ # get noise schedule
567
+ ns = self.noise_schedule
568
+ alpha_t = ns.marginal_alpha(t)
569
+ sigma_t = ns.marginal_std(t)
570
+ lambda_t = ns.marginal_lambda(t)
571
+ alpha_prev = ns.marginal_alpha(t_prev_list[-1])
572
+ sigma_prev = ns.marginal_std(t_prev_list[-1])
573
+ gradient_part = torch.zeros_like(x)
574
+ h = lambda_t - ns.marginal_lambda(t_prev_list[-1])
575
+ lambda_list = [ns.marginal_lambda(t_prev_list[-(i + 1)]) for i in range(order)]
576
+ gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t,
577
+ lambda_list, tau)
578
+
579
+ for i in range(order):
580
+ if self.predict_x0:
581
+ gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
582
+ i] * model_prev_list[-(i + 1)]
583
+ else:
584
+ gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
585
+
586
+ if self.predict_x0:
587
+ noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
588
+ else:
589
+ noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
590
+
591
+ if self.predict_x0:
592
+ x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part
593
+ else:
594
+ x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part
595
+
596
+ return x_t
597
+
598
+ def adams_moulton_update(self, order, x, tau, model_prev_list, t_prev_list, noise, t):
599
+ """
600
+ SA-Corrector, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
601
+ """
602
+
603
+ assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
604
+
605
+ # get noise schedule
606
+ ns = self.noise_schedule
607
+ alpha_t = ns.marginal_alpha(t)
608
+ sigma_t = ns.marginal_std(t)
609
+ lambda_t = ns.marginal_lambda(t)
610
+ alpha_prev = ns.marginal_alpha(t_prev_list[-1])
611
+ sigma_prev = ns.marginal_std(t_prev_list[-1])
612
+ gradient_part = torch.zeros_like(x)
613
+ h = lambda_t - ns.marginal_lambda(t_prev_list[-1])
614
+ t_list = t_prev_list + [t]
615
+ lambda_list = [ns.marginal_lambda(t_list[-(i + 1)]) for i in range(order)]
616
+ gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t,
617
+ lambda_list, tau)
618
+
619
+ for i in range(order):
620
+ if self.predict_x0:
621
+ gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
622
+ i] * model_prev_list[-(i + 1)]
623
+ else:
624
+ gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
625
+
626
+ if self.predict_x0:
627
+ noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
628
+ else:
629
+ noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
630
+
631
+ if self.predict_x0:
632
+ x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part
633
+ else:
634
+ x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part
635
+
636
+ return x_t
637
+
638
+ def adams_bashforth_update_few_steps(self, order, x, tau, model_prev_list, t_prev_list, noise, t):
639
+ """
640
+ SA-Predictor, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
641
+ """
642
+
643
+ assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
644
+
645
+ # get noise schedule
646
+ ns = self.noise_schedule
647
+ alpha_t = ns.marginal_alpha(t)
648
+ sigma_t = ns.marginal_std(t)
649
+ lambda_t = ns.marginal_lambda(t)
650
+ alpha_prev = ns.marginal_alpha(t_prev_list[-1])
651
+ sigma_prev = ns.marginal_std(t_prev_list[-1])
652
+ gradient_part = torch.zeros_like(x)
653
+ h = lambda_t - ns.marginal_lambda(t_prev_list[-1])
654
+ lambda_list = [ns.marginal_lambda(t_prev_list[-(i + 1)]) for i in range(order)]
655
+ gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t,
656
+ lambda_list, tau)
657
+
658
+ if self.predict_x0:
659
+ if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling.
660
+ # The added term is O(h^3). Empirically we find it will slightly improve the image quality.
661
+ # ODE case
662
+ # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
663
+ # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
664
+ gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
665
+ h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
666
+ (1 + tau ** 2) ** 2)) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(
667
+ t_prev_list[-2]))
668
+ gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
669
+ h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
670
+ (1 + tau ** 2) ** 2)) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(
671
+ t_prev_list[-2]))
672
+
673
+ for i in range(order):
674
+ if self.predict_x0:
675
+ gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
676
+ i] * model_prev_list[-(i + 1)]
677
+ else:
678
+ gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
679
+
680
+ if self.predict_x0:
681
+ noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
682
+ else:
683
+ noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
684
+
685
+ if self.predict_x0:
686
+ x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part
687
+ else:
688
+ x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part
689
+
690
+ return x_t
691
+
692
+ def adams_moulton_update_few_steps(self, order, x, tau, model_prev_list, t_prev_list, noise, t):
693
+ """
694
+ SA-Corrector, without the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
695
+ """
696
+
697
+ assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
698
+
699
+ # get noise schedule
700
+ ns = self.noise_schedule
701
+ alpha_t = ns.marginal_alpha(t)
702
+ sigma_t = ns.marginal_std(t)
703
+ lambda_t = ns.marginal_lambda(t)
704
+ alpha_prev = ns.marginal_alpha(t_prev_list[-1])
705
+ sigma_prev = ns.marginal_std(t_prev_list[-1])
706
+ gradient_part = torch.zeros_like(x)
707
+ h = lambda_t - ns.marginal_lambda(t_prev_list[-1])
708
+ t_list = t_prev_list + [t]
709
+ lambda_list = [ns.marginal_lambda(t_list[-(i + 1)]) for i in range(order)]
710
+ gradient_coefficients = self.get_coefficients_fn(order, ns.marginal_lambda(t_prev_list[-1]), lambda_t,
711
+ lambda_list, tau)
712
+
713
+ if self.predict_x0:
714
+ if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling.
715
+ # The added term is O(h^3). Empirically we find it will slightly improve the image quality.
716
+ # ODE case
717
+ # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
718
+ # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
719
+ gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
720
+ h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
721
+ (1 + tau ** 2) ** 2 * h))
722
+ gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
723
+ h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
724
+ (1 + tau ** 2) ** 2 * h))
725
+
726
+ for i in range(order):
727
+ if self.predict_x0:
728
+ gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
729
+ i] * model_prev_list[-(i + 1)]
730
+ else:
731
+ gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
732
+
733
+ if self.predict_x0:
734
+ noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
735
+ else:
736
+ noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
737
+
738
+ if self.predict_x0:
739
+ x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_prev) * x + gradient_part + noise_part
740
+ else:
741
+ x_t = (alpha_t / alpha_prev) * x + gradient_part + noise_part
742
+
743
+ return x_t
744
+
745
+ def sample_few_steps(self, x, tau, steps=5, t_start=None, t_end=None, skip_type='time', skip_order=1,
746
+ predictor_order=3, corrector_order=4, pc_mode='PEC', return_intermediate=False
747
+ ):
748
+ """
749
+ For the PC-mode, please refer to the wiki page
750
+ https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode
751
+ 'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations
752
+ We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs.
753
+ """
754
+
755
+ skip_first_step = False
756
+ skip_final_step = True
757
+ lower_order_final = True
758
+ denoise_to_zero = False
759
+
760
+ assert pc_mode in ['PEC', 'PECE'], 'Predictor-corrector mode only supports PEC and PECE'
761
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
762
+ t_T = self.noise_schedule.T if t_start is None else t_start
763
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
764
+
765
+ device = x.device
766
+ intermediates = []
767
+ with torch.no_grad():
768
+ assert steps >= max(predictor_order, corrector_order - 1)
769
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, order=skip_order,
770
+ device=device)
771
+ assert timesteps.shape[0] - 1 == steps
772
+ # Init the initial values.
773
+ step = 0
774
+ t = timesteps[step]
775
+ noise = torch.randn_like(x)
776
+ t_prev_list = [t]
777
+ # do not evaluate if skip_first_step
778
+ if skip_first_step:
779
+ if self.predict_x0:
780
+ alpha_t = self.noise_schedule.marginal_alpha(t)
781
+ sigma_t = self.noise_schedule.marginal_std(t)
782
+ model_prev_list = [(1 - sigma_t) / alpha_t * x]
783
+ else:
784
+ model_prev_list = [x]
785
+ else:
786
+ model_prev_list = [self.model_fn(x, t)]
787
+
788
+ if self.correcting_xt_fn is not None:
789
+ x = self.correcting_xt_fn(x, t, step)
790
+ if return_intermediate:
791
+ intermediates.append(x)
792
+
793
+ # determine the first several values
794
+ for step in tqdm(range(1, max(predictor_order, corrector_order - 1))):
795
+
796
+ t = timesteps[step]
797
+ predictor_order_used = min(predictor_order, step)
798
+ corrector_order_used = min(corrector_order, step + 1)
799
+ noise = torch.randn_like(x)
800
+ # predictor step
801
+ x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau(t),
802
+ model_prev_list=model_prev_list, t_prev_list=t_prev_list,
803
+ noise=noise, t=t)
804
+ # evaluation step
805
+ model_x = self.model_fn(x_p, t)
806
+
807
+ # update model_list
808
+ model_prev_list.append(model_x)
809
+ # corrector step
810
+ if corrector_order > 0:
811
+ x = self.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau(t),
812
+ model_prev_list=model_prev_list, t_prev_list=t_prev_list,
813
+ noise=noise, t=t)
814
+ else:
815
+ x = x_p
816
+
817
+ # evaluation step if correction and mode = pece
818
+ if corrector_order > 0 and pc_mode == 'PECE':
819
+ model_x = self.model_fn(x, t)
820
+ del model_prev_list[-1]
821
+ model_prev_list.append(model_x)
822
+
823
+ if self.correcting_xt_fn is not None:
824
+ x = self.correcting_xt_fn(x, t, step)
825
+ if return_intermediate:
826
+ intermediates.append(x)
827
+
828
+ t_prev_list.append(t)
829
+
830
+ for step in tqdm(range(max(predictor_order, corrector_order - 1), steps + 1)):
831
+ if lower_order_final:
832
+ predictor_order_used = min(predictor_order, steps - step + 1)
833
+ corrector_order_used = min(corrector_order, steps - step + 2)
834
+
835
+ else:
836
+ predictor_order_used = predictor_order
837
+ corrector_order_used = corrector_order
838
+ t = timesteps[step]
839
+ noise = torch.randn_like(x)
840
+
841
+ # predictor step
842
+ if skip_final_step and step == steps and not denoise_to_zero:
843
+ x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=0,
844
+ model_prev_list=model_prev_list,
845
+ t_prev_list=t_prev_list, noise=noise, t=t)
846
+ else:
847
+ x_p = self.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau(t),
848
+ model_prev_list=model_prev_list,
849
+ t_prev_list=t_prev_list, noise=noise, t=t)
850
+
851
+ # evaluation step
852
+ # do not evaluate if skip_final_step and step = steps
853
+ if not skip_final_step or step < steps:
854
+ model_x = self.model_fn(x_p, t)
855
+
856
+ # update model_list
857
+ # do not update if skip_final_step and step = steps
858
+ if not skip_final_step or step < steps:
859
+ model_prev_list.append(model_x)
860
+
861
+ # corrector step
862
+ # do not correct if skip_final_step and step = steps
863
+ if corrector_order > 0 and (not skip_final_step or step < steps):
864
+ x = self.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau(t),
865
+ model_prev_list=model_prev_list,
866
+ t_prev_list=t_prev_list, noise=noise, t=t)
867
+ else:
868
+ x = x_p
869
+
870
+ # evaluation step if mode = pece and step != steps
871
+ if corrector_order > 0 and (pc_mode == 'PECE' and step < steps):
872
+ model_x = self.model_fn(x, t)
873
+ del model_prev_list[-1]
874
+ model_prev_list.append(model_x)
875
+
876
+ if self.correcting_xt_fn is not None:
877
+ x = self.correcting_xt_fn(x, t, step)
878
+ if return_intermediate:
879
+ intermediates.append(x)
880
+
881
+ t_prev_list.append(t)
882
+ del model_prev_list[0]
883
+
884
+ if denoise_to_zero:
885
+ t = torch.ones((1,)).to(device) * t_0
886
+ x = self.denoise_to_zero_fn(x, t)
887
+ if self.correcting_xt_fn is not None:
888
+ x = self.correcting_xt_fn(x, t, step + 1)
889
+ if return_intermediate:
890
+ intermediates.append(x)
891
+ return (x, intermediates) if return_intermediate else x
892
+
893
+ def sample_more_steps(self, x, tau, steps=20, t_start=None, t_end=None, skip_type='time', skip_order=1,
894
+ predictor_order=3, corrector_order=4, pc_mode='PEC', return_intermediate=False
895
+ ):
896
+ """
897
+ For the PC-mode, please refer to the wiki page
898
+ https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode
899
+ 'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations
900
+ We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs.
901
+ """
902
+
903
+ skip_first_step = False
904
+ skip_final_step = False
905
+ lower_order_final = True
906
+ denoise_to_zero = True
907
+
908
+ assert pc_mode in ['PEC', 'PECE'], 'Predictor-corrector mode only supports PEC and PECE'
909
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
910
+ t_T = self.noise_schedule.T if t_start is None else t_start
911
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
912
+
913
+ device = x.device
914
+ intermediates = []
915
+ with torch.no_grad():
916
+ assert steps >= max(predictor_order, corrector_order - 1)
917
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, order=skip_order,
918
+ device=device)
919
+ assert timesteps.shape[0] - 1 == steps
920
+ # Init the initial values.
921
+ step = 0
922
+ t = timesteps[step]
923
+ noise = torch.randn_like(x)
924
+ t_prev_list = [t]
925
+ # do not evaluate if skip_first_step
926
+ if skip_first_step:
927
+ if self.predict_x0:
928
+ alpha_t = self.noise_schedule.marginal_alpha(t)
929
+ sigma_t = self.noise_schedule.marginal_std(t)
930
+ model_prev_list = [(1 - sigma_t) / alpha_t * x]
931
+ else:
932
+ model_prev_list = [x]
933
+ else:
934
+ model_prev_list = [self.model_fn(x, t)]
935
+
936
+ if self.correcting_xt_fn is not None:
937
+ x = self.correcting_xt_fn(x, t, step)
938
+ if return_intermediate:
939
+ intermediates.append(x)
940
+
941
+ # determine the first several values
942
+ for step in tqdm(range(1, max(predictor_order, corrector_order - 1))):
943
+
944
+ t = timesteps[step]
945
+ predictor_order_used = min(predictor_order, step)
946
+ corrector_order_used = min(corrector_order, step + 1)
947
+ noise = torch.randn_like(x)
948
+ # predictor step
949
+ x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=tau(t),
950
+ model_prev_list=model_prev_list, t_prev_list=t_prev_list, noise=noise,
951
+ t=t)
952
+ # evaluation step
953
+ model_x = self.model_fn(x_p, t)
954
+
955
+ # update model_list
956
+ model_prev_list.append(model_x)
957
+ # corrector step
958
+ if corrector_order > 0:
959
+ x = self.adams_moulton_update(order=corrector_order_used, x=x, tau=tau(t),
960
+ model_prev_list=model_prev_list, t_prev_list=t_prev_list, noise=noise,
961
+ t=t)
962
+ else:
963
+ x = x_p
964
+
965
+ # evaluation step if mode = pece
966
+ if corrector_order > 0 and pc_mode == 'PECE':
967
+ model_x = self.model_fn(x, t)
968
+ del model_prev_list[-1]
969
+ model_prev_list.append(model_x)
970
+ if self.correcting_xt_fn is not None:
971
+ x = self.correcting_xt_fn(x, t, step)
972
+ if return_intermediate:
973
+ intermediates.append(x)
974
+
975
+ t_prev_list.append(t)
976
+
977
+ for step in tqdm(range(max(predictor_order, corrector_order - 1), steps + 1)):
978
+ if lower_order_final:
979
+ predictor_order_used = min(predictor_order, steps - step + 1)
980
+ corrector_order_used = min(corrector_order, steps - step + 2)
981
+
982
+ else:
983
+ predictor_order_used = predictor_order
984
+ corrector_order_used = corrector_order
985
+ t = timesteps[step]
986
+ noise = torch.randn_like(x)
987
+
988
+ # predictor step
989
+ if skip_final_step and step == steps and not denoise_to_zero:
990
+ x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=0,
991
+ model_prev_list=model_prev_list, t_prev_list=t_prev_list,
992
+ noise=noise, t=t)
993
+ else:
994
+ x_p = self.adams_bashforth_update(order=predictor_order_used, x=x, tau=tau(t),
995
+ model_prev_list=model_prev_list, t_prev_list=t_prev_list,
996
+ noise=noise, t=t)
997
+
998
+ # evaluation step
999
+ # do not evaluate if skip_final_step and step = steps
1000
+ if not skip_final_step or step < steps:
1001
+ model_x = self.model_fn(x_p, t)
1002
+
1003
+ # update model_list
1004
+ # do not update if skip_final_step and step = steps
1005
+ if not skip_final_step or step < steps:
1006
+ model_prev_list.append(model_x)
1007
+
1008
+ # corrector step
1009
+ # do not correct if skip_final_step and step = steps
1010
+ if corrector_order > 0:
1011
+ if not skip_final_step or step < steps:
1012
+ x = self.adams_moulton_update(order=corrector_order_used, x=x, tau=tau(t),
1013
+ model_prev_list=model_prev_list, t_prev_list=t_prev_list,
1014
+ noise=noise, t=t)
1015
+ else:
1016
+ x = x_p
1017
+ else:
1018
+ x = x_p
1019
+
1020
+ # evaluation step if mode = pece and step != steps
1021
+ if corrector_order > 0 and (pc_mode == 'PECE' and step < steps):
1022
+ model_x = self.model_fn(x, t)
1023
+ del model_prev_list[-1]
1024
+ model_prev_list.append(model_x)
1025
+
1026
+ if self.correcting_xt_fn is not None:
1027
+ x = self.correcting_xt_fn(x, t, step)
1028
+ if return_intermediate:
1029
+ intermediates.append(x)
1030
+
1031
+ t_prev_list.append(t)
1032
+ del model_prev_list[0]
1033
+
1034
+ if denoise_to_zero:
1035
+ t = torch.ones((1,)).to(device) * t_0
1036
+ x = self.denoise_to_zero_fn(x, t)
1037
+ if self.correcting_xt_fn is not None:
1038
+ x = self.correcting_xt_fn(x, t, step + 1)
1039
+ if return_intermediate:
1040
+ intermediates.append(x)
1041
+ if return_intermediate:
1042
+ return x, intermediates
1043
+ else:
1044
+ return x
1045
+
1046
+ def sample(self, mode, x, tau, steps, t_start=None, t_end=None, skip_type='time', skip_order=1, predictor_order=3,
1047
+ corrector_order=4, pc_mode='PEC', return_intermediate=False
1048
+ ):
1049
+ """
1050
+ For the PC-mode, please refer to the wiki page
1051
+ https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method#PEC_mode_and_PECE_mode
1052
+ 'PEC' needs one model evaluation per step while 'PECE' needs two model evaluations
1053
+ We recommend use pc_mode='PEC' for NFEs is limited. 'PECE' mode is only for test with sufficient NFEs.
1054
+
1055
+ 'few_steps' mode is recommended. The differences between 'few_steps' and 'more_steps' are as below:
1056
+ 1) 'few_steps' do not correct at final step and do not denoise to zero, while 'more_steps' do these two.
1057
+ Thus the NFEs for 'few_steps' = steps, NFEs for 'more_steps' = steps + 2
1058
+ For most of the experiments and tasks, we find these two operations do not have much help to sample quality.
1059
+ 2) 'few_steps' use a rescaling trick as in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
1060
+ We find it will slightly improve the sample quality especially in few steps.
1061
+ """
1062
+ assert mode in ['few_steps', 'more_steps'], "mode must be either 'few_steps' or 'more_steps'"
1063
+ if mode == 'few_steps':
1064
+ return self.sample_few_steps(x=x, tau=tau, steps=steps, t_start=t_start, t_end=t_end, skip_type=skip_type,
1065
+ skip_order=skip_order, predictor_order=predictor_order,
1066
+ corrector_order=corrector_order, pc_mode=pc_mode,
1067
+ return_intermediate=return_intermediate)
1068
+ else:
1069
+ return self.sample_more_steps(x=x, tau=tau, steps=steps, t_start=t_start, t_end=t_end, skip_type=skip_type,
1070
+ skip_order=skip_order, predictor_order=predictor_order,
1071
+ corrector_order=corrector_order, pc_mode=pc_mode,
1072
+ return_intermediate=return_intermediate)
1073
+
1074
+
1075
+ #############################################################
1076
+ # other utility functions
1077
+ #############################################################
1078
+
1079
+ def interpolate_fn(x, xp, yp):
1080
+ """
1081
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1082
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1083
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1084
+ Args:
1085
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1086
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1087
+ yp: PyTorch tensor with shape [C, K].
1088
+ Returns:
1089
+ The function values f(x), with shape [N, C].
1090
+ """
1091
+ N, K = x.shape[0], xp.shape[1]
1092
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1093
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1094
+ x_idx = torch.argmin(x_indices, dim=2)
1095
+ cand_start_idx = x_idx - 1
1096
+ start_idx = torch.where(
1097
+ torch.eq(x_idx, 0),
1098
+ torch.tensor(1, device=x.device),
1099
+ torch.where(
1100
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1101
+ ),
1102
+ )
1103
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1104
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1105
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1106
+ start_idx2 = torch.where(
1107
+ torch.eq(x_idx, 0),
1108
+ torch.tensor(0, device=x.device),
1109
+ torch.where(
1110
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1111
+ ),
1112
+ )
1113
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1114
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1115
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1116
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1117
+ return cand
1118
+
1119
+
1120
+ def expand_dims(v, dims):
1121
+ """
1122
+ Expand the tensor `v` to the dim `dims`.
1123
+ Args:
1124
+ `v`: a PyTorch tensor with shape [N].
1125
+ `dim`: a `int`.
1126
+ Returns:
1127
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1128
+ """
1129
+ return v[(...,) + (None,) * (dims - 1)]
DiT_VAE/diffusion/model/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs, device=local_ts.device) for _ in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs, device=local_losses.device) for _ in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
DiT_VAE/diffusion/model/utils.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch.nn as nn
4
+ from torch.utils.checkpoint import checkpoint, checkpoint_sequential
5
+ import torch.nn.functional as F
6
+ import torch
7
+ import torch.distributed as dist
8
+ import re
9
+ import math
10
+ from collections.abc import Iterable
11
+ from itertools import repeat
12
+ from torchvision import transforms as T
13
+ import random
14
+ from PIL import Image
15
+
16
+
17
+ def _ntuple(n):
18
+ def parse(x):
19
+ if isinstance(x, Iterable) and not isinstance(x, str):
20
+ return x
21
+ return tuple(repeat(x, n))
22
+ return parse
23
+
24
+
25
+ to_1tuple = _ntuple(1)
26
+ to_2tuple = _ntuple(2)
27
+
28
+ def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
29
+ assert isinstance(model, nn.Module)
30
+
31
+ def set_attr(module):
32
+ module.grad_checkpointing = True
33
+ module.fp32_attention = use_fp32_attention
34
+ module.grad_checkpointing_step = gc_step
35
+ model.apply(set_attr)
36
+
37
+
38
+ def auto_grad_checkpoint(module, *args, **kwargs):
39
+ if getattr(module, 'grad_checkpointing', False):
40
+ if not isinstance(module, Iterable):
41
+ return checkpoint(module, *args, **kwargs)
42
+ gc_step = module[0].grad_checkpointing_step
43
+ return checkpoint_sequential(module, gc_step, *args, **kwargs)
44
+ return module(*args, **kwargs)
45
+
46
+
47
+ def checkpoint_sequential(functions, step, input, *args, **kwargs):
48
+
49
+ # Hack for keyword-only parameter in a python 2.7-compliant way
50
+ preserve = kwargs.pop('preserve_rng_state', True)
51
+ if kwargs:
52
+ raise ValueError("Unexpected keyword arguments: " + ",".join(kwargs))
53
+
54
+ def run_function(start, end, functions):
55
+ def forward(input):
56
+ for j in range(start, end + 1):
57
+ input = functions[j](input, *args)
58
+ return input
59
+ return forward
60
+
61
+ if isinstance(functions, torch.nn.Sequential):
62
+ functions = list(functions.children())
63
+
64
+ # the last chunk has to be non-volatile
65
+ end = -1
66
+ segment = len(functions) // step
67
+ for start in range(0, step * (segment - 1), step):
68
+ end = start + step - 1
69
+ input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
70
+ return run_function(end + 1, len(functions) - 1, functions)(input)
71
+
72
+
73
+ def window_partition(x, window_size):
74
+ """
75
+ Partition into non-overlapping windows with padding if needed.
76
+ Args:
77
+ x (tensor): input tokens with [B, H, W, C].
78
+ window_size (int): window size.
79
+
80
+ Returns:
81
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
82
+ (Hp, Wp): padded height and width before partition
83
+ """
84
+ B, H, W, C = x.shape
85
+
86
+ pad_h = (window_size - H % window_size) % window_size
87
+ pad_w = (window_size - W % window_size) % window_size
88
+ if pad_h > 0 or pad_w > 0:
89
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
90
+ Hp, Wp = H + pad_h, W + pad_w
91
+
92
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
93
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
94
+ return windows, (Hp, Wp)
95
+
96
+
97
+ def window_unpartition(windows, window_size, pad_hw, hw):
98
+ """
99
+ Window unpartition into original sequences and removing padding.
100
+ Args:
101
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
102
+ window_size (int): window size.
103
+ pad_hw (Tuple): padded height and width (Hp, Wp).
104
+ hw (Tuple): original height and width (H, W) before padding.
105
+
106
+ Returns:
107
+ x: unpartitioned sequences with [B, H, W, C].
108
+ """
109
+ Hp, Wp = pad_hw
110
+ H, W = hw
111
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
112
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
113
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
114
+
115
+ if Hp > H or Wp > W:
116
+ x = x[:, :H, :W, :].contiguous()
117
+ return x
118
+
119
+
120
+ def get_rel_pos(q_size, k_size, rel_pos):
121
+ """
122
+ Get relative positional embeddings according to the relative positions of
123
+ query and key sizes.
124
+ Args:
125
+ q_size (int): size of query q.
126
+ k_size (int): size of key k.
127
+ rel_pos (Tensor): relative position embeddings (L, C).
128
+
129
+ Returns:
130
+ Extracted positional embeddings according to relative positions.
131
+ """
132
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
133
+ # Interpolate rel pos if needed.
134
+ if rel_pos.shape[0] != max_rel_dist:
135
+ # Interpolate rel pos.
136
+ rel_pos_resized = F.interpolate(
137
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
138
+ size=max_rel_dist,
139
+ mode="linear",
140
+ )
141
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
142
+ else:
143
+ rel_pos_resized = rel_pos
144
+
145
+ # Scale the coords with short length if shapes for q and k are different.
146
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
147
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
148
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
149
+
150
+ return rel_pos_resized[relative_coords.long()]
151
+
152
+
153
+ def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
154
+ """
155
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
156
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
157
+ Args:
158
+ attn (Tensor): attention map.
159
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
160
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
161
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
162
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
163
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
164
+
165
+ Returns:
166
+ attn (Tensor): attention map with added relative positional embeddings.
167
+ """
168
+ q_h, q_w = q_size
169
+ k_h, k_w = k_size
170
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
171
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
172
+
173
+ B, _, dim = q.shape
174
+ r_q = q.reshape(B, q_h, q_w, dim)
175
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
176
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
177
+
178
+ attn = (
179
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
180
+ ).view(B, q_h * q_w, k_h * k_w)
181
+
182
+ return attn
183
+
184
+ def mean_flat(tensor):
185
+ return tensor.mean(dim=list(range(1, tensor.ndim)))
186
+
187
+
188
+ #################################################################################
189
+ # Token Masking and Unmasking #
190
+ #################################################################################
191
+ def get_mask(batch, length, mask_ratio, device, mask_type=None, data_info=None, extra_len=0):
192
+ """
193
+ Get the binary mask for the input sequence.
194
+ Args:
195
+ - batch: batch size
196
+ - length: sequence length
197
+ - mask_ratio: ratio of tokens to mask
198
+ - data_info: dictionary with info for reconstruction
199
+ return:
200
+ mask_dict with following keys:
201
+ - mask: binary mask, 0 is keep, 1 is remove
202
+ - ids_keep: indices of tokens to keep
203
+ - ids_restore: indices to restore the original order
204
+ """
205
+ assert mask_type in ['random', 'fft', 'laplacian', 'group']
206
+ mask = torch.ones([batch, length], device=device)
207
+ len_keep = int(length * (1 - mask_ratio)) - extra_len
208
+
209
+ if mask_type in ['random', 'group']:
210
+ noise = torch.rand(batch, length, device=device) # noise in [0, 1]
211
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
212
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
213
+ # keep the first subset
214
+ ids_keep = ids_shuffle[:, :len_keep]
215
+ ids_removed = ids_shuffle[:, len_keep:]
216
+
217
+ elif mask_type in ['fft', 'laplacian']:
218
+ if 'strength' in data_info:
219
+ strength = data_info['strength']
220
+
221
+ else:
222
+ N = data_info['N'][0]
223
+ img = data_info['ori_img']
224
+ # 获取原图的尺寸信息
225
+ _, C, H, W = img.shape
226
+ if mask_type == 'fft':
227
+ # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N)
228
+ reshaped_image = img.reshape((batch, -1, H // N, N, W // N, N))
229
+ fft_image = torch.fft.fftn(reshaped_image, dim=(3, 5))
230
+ # 取绝对值并求和获取频率强度
231
+ strength = torch.sum(torch.abs(fft_image), dim=(1, 3, 5)).reshape((batch, -1,))
232
+ elif type == 'laplacian':
233
+ laplacian_kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32).reshape(1, 1, 3, 3)
234
+ laplacian_kernel = laplacian_kernel.repeat(C, 1, 1, 1)
235
+ # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N)
236
+ reshaped_image = img.reshape(-1, C, H // N, N, W // N, N).permute(0, 2, 4, 1, 3, 5).reshape(-1, C, N, N)
237
+ laplacian_response = F.conv2d(reshaped_image, laplacian_kernel, padding=1, groups=C)
238
+ strength = laplacian_response.sum(dim=[1, 2, 3]).reshape((batch, -1,))
239
+
240
+ # 对频率强度进行归一化,然后使用torch.multinomial进行采样
241
+ probabilities = strength / (strength.max(dim=1)[0][:, None]+1e-5)
242
+ ids_shuffle = torch.multinomial(probabilities.clip(1e-5, 1), length, replacement=False)
243
+ ids_keep = ids_shuffle[:, :len_keep]
244
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
245
+ ids_removed = ids_shuffle[:, len_keep:]
246
+
247
+ mask[:, :len_keep] = 0
248
+ mask = torch.gather(mask, dim=1, index=ids_restore)
249
+
250
+ return {'mask': mask,
251
+ 'ids_keep': ids_keep,
252
+ 'ids_restore': ids_restore,
253
+ 'ids_removed': ids_removed}
254
+
255
+
256
+ def mask_out_token(x, ids_keep, ids_removed=None):
257
+ """
258
+ Mask out the tokens specified by ids_keep.
259
+ Args:
260
+ - x: input sequence, [N, L, D]
261
+ - ids_keep: indices of tokens to keep
262
+ return:
263
+ - x_masked: masked sequence
264
+ """
265
+ N, L, D = x.shape # batch, length, dim
266
+ x_remain = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
267
+ if ids_removed is not None:
268
+ x_masked = torch.gather(x, dim=1, index=ids_removed.unsqueeze(-1).repeat(1, 1, D))
269
+ return x_remain, x_masked
270
+ else:
271
+ return x_remain
272
+
273
+
274
+ def mask_tokens(x, mask_ratio):
275
+ """
276
+ Perform per-sample random masking by per-sample shuffling.
277
+ Per-sample shuffling is done by argsort random noise.
278
+ x: [N, L, D], sequence
279
+ """
280
+ N, L, D = x.shape # batch, length, dim
281
+ len_keep = int(L * (1 - mask_ratio))
282
+
283
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
284
+
285
+ # sort noise for each sample
286
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
287
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
288
+
289
+ # keep the first subset
290
+ ids_keep = ids_shuffle[:, :len_keep]
291
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
292
+
293
+ # generate the binary mask: 0 is keep, 1 is remove
294
+ mask = torch.ones([N, L], device=x.device)
295
+ mask[:, :len_keep] = 0
296
+ mask = torch.gather(mask, dim=1, index=ids_restore)
297
+
298
+ return x_masked, mask, ids_restore
299
+
300
+
301
+ def unmask_tokens(x, ids_restore, mask_token):
302
+ # x: [N, T, D] if extras == 0 (i.e., no cls token) else x: [N, T+1, D]
303
+ mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
304
+ x = torch.cat([x, mask_tokens], dim=1)
305
+ x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
306
+ return x
307
+
308
+
309
+ # Parse 'None' to None and others to float value
310
+ def parse_float_none(s):
311
+ assert isinstance(s, str)
312
+ return None if s == 'None' else float(s)
313
+
314
+
315
+ #----------------------------------------------------------------------------
316
+ # Parse a comma separated list of numbers or ranges and return a list of ints.
317
+ # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
318
+
319
+ def parse_int_list(s):
320
+ if isinstance(s, list): return s
321
+ ranges = []
322
+ range_re = re.compile(r'^(\d+)-(\d+)$')
323
+ for p in s.split(','):
324
+ if m := range_re.match(p):
325
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
326
+ else:
327
+ ranges.append(int(p))
328
+ return ranges
329
+
330
+
331
+ def init_processes(fn, args):
332
+ """ Initialize the distributed environment. """
333
+ os.environ['MASTER_ADDR'] = args.master_address
334
+ os.environ['MASTER_PORT'] = str(random.randint(2000, 6000))
335
+ print(f'MASTER_ADDR = {os.environ["MASTER_ADDR"]}')
336
+ print(f'MASTER_PORT = {os.environ["MASTER_PORT"]}')
337
+ torch.cuda.set_device(args.local_rank)
338
+ dist.init_process_group(backend='nccl', init_method='env://', rank=args.global_rank, world_size=args.global_size)
339
+ fn(args)
340
+ if args.global_size > 1:
341
+ cleanup()
342
+
343
+
344
+ def mprint(*args, **kwargs):
345
+ """
346
+ Print only from rank 0.
347
+ """
348
+ if dist.get_rank() == 0:
349
+ print(*args, **kwargs)
350
+
351
+
352
+ def cleanup():
353
+ """
354
+ End DDP training.
355
+ """
356
+ dist.barrier()
357
+ mprint("Done!")
358
+ dist.barrier()
359
+ dist.destroy_process_group()
360
+
361
+
362
+ #----------------------------------------------------------------------------
363
+ # logging info.
364
+ class Logger(object):
365
+ """
366
+ Redirect stderr to stdout, optionally print stdout to a file,
367
+ and optionally force flushing on both stdout and the file.
368
+ """
369
+
370
+ def __init__(self, file_name=None, file_mode="w", should_flush=True):
371
+ self.file = None
372
+
373
+ if file_name is not None:
374
+ self.file = open(file_name, file_mode)
375
+
376
+ self.should_flush = should_flush
377
+ self.stdout = sys.stdout
378
+ self.stderr = sys.stderr
379
+
380
+ sys.stdout = self
381
+ sys.stderr = self
382
+
383
+ def __enter__(self):
384
+ return self
385
+
386
+ def __exit__(self, exc_type, exc_value, traceback):
387
+ self.close()
388
+
389
+ def write(self, text):
390
+ """Write text to stdout (and a file) and optionally flush."""
391
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
392
+ return
393
+
394
+ if self.file is not None:
395
+ self.file.write(text)
396
+
397
+ self.stdout.write(text)
398
+
399
+ if self.should_flush:
400
+ self.flush()
401
+
402
+ def flush(self):
403
+ """Flush written text to both stdout and a file, if open."""
404
+ if self.file is not None:
405
+ self.file.flush()
406
+
407
+ self.stdout.flush()
408
+
409
+ def close(self):
410
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
411
+ self.flush()
412
+
413
+ # if using multiple loggers, prevent closing in wrong order
414
+ if sys.stdout is self:
415
+ sys.stdout = self.stdout
416
+ if sys.stderr is self:
417
+ sys.stderr = self.stderr
418
+
419
+ if self.file is not None:
420
+ self.file.close()
421
+
422
+
423
+ class StackedRandomGenerator:
424
+ def __init__(self, device, seeds):
425
+ super().__init__()
426
+ self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
427
+
428
+ def randn(self, size, **kwargs):
429
+ assert size[0] == len(self.generators)
430
+ return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
431
+
432
+ def randn_like(self, input):
433
+ return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
434
+
435
+ def randint(self, *args, size, **kwargs):
436
+ assert size[0] == len(self.generators)
437
+ return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
438
+
439
+
440
+ def prepare_prompt_ar(prompt, ratios, device='cpu', show=True):
441
+ # get aspect_ratio or ar
442
+ aspect_ratios = re.findall(r"--aspect_ratio\s+(\d+:\d+)", prompt)
443
+ ars = re.findall(r"--ar\s+(\d+:\d+)", prompt)
444
+ custom_hw = re.findall(r"--hw\s+(\d+:\d+)", prompt)
445
+ if show:
446
+ print("aspect_ratios:", aspect_ratios, "ars:", ars, "hws:", custom_hw)
447
+ prompt_clean = prompt.split("--aspect_ratio")[0].split("--ar")[0].split("--hw")[0]
448
+ if len(aspect_ratios) + len(ars) + len(custom_hw) == 0 and show:
449
+ print( "Wrong prompt format. Set to default ar: 1. change your prompt into format '--ar h:w or --hw h:w' for correct generating")
450
+ if len(aspect_ratios) != 0:
451
+ ar = float(aspect_ratios[0].split(':')[0]) / float(aspect_ratios[0].split(':')[1])
452
+ elif len(ars) != 0:
453
+ ar = float(ars[0].split(':')[0]) / float(ars[0].split(':')[1])
454
+ else:
455
+ ar = 1.
456
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
457
+ if len(custom_hw) != 0:
458
+ custom_hw = [float(custom_hw[0].split(':')[0]), float(custom_hw[0].split(':')[1])]
459
+ else:
460
+ custom_hw = ratios[closest_ratio]
461
+ default_hw = ratios[closest_ratio]
462
+ prompt_show = f'prompt: {prompt_clean.strip()}\nSize: --ar {closest_ratio}, --bin hw {ratios[closest_ratio]}, --custom hw {custom_hw}'
463
+ return prompt_clean, prompt_show, torch.tensor(default_hw, device=device)[None], torch.tensor([float(closest_ratio)], device=device)[None], torch.tensor(custom_hw, device=device)[None]
464
+
465
+
466
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int):
467
+ orig_hw = torch.tensor([samples.shape[2], samples.shape[3]], dtype=torch.int)
468
+ custom_hw = torch.tensor([int(new_height), int(new_width)], dtype=torch.int)
469
+
470
+ if (orig_hw != custom_hw).all():
471
+ ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1])
472
+ resized_width = int(orig_hw[1] * ratio)
473
+ resized_height = int(orig_hw[0] * ratio)
474
+
475
+ transform = T.Compose([
476
+ T.Resize((resized_height, resized_width)),
477
+ T.CenterCrop(custom_hw.tolist())
478
+ ])
479
+ return transform(samples)
480
+ else:
481
+ return samples
482
+
483
+
484
+ def resize_and_crop_img(img: Image, new_width, new_height):
485
+ orig_width, orig_height = img.size
486
+
487
+ ratio = max(new_width/orig_width, new_height/orig_height)
488
+ resized_width = int(orig_width * ratio)
489
+ resized_height = int(orig_height * ratio)
490
+
491
+ img = img.resize((resized_width, resized_height), Image.LANCZOS)
492
+
493
+ left = (resized_width - new_width)/2
494
+ top = (resized_height - new_height)/2
495
+ right = (resized_width + new_width)/2
496
+ bottom = (resized_height + new_height)/2
497
+
498
+ img = img.crop((left, top, right, bottom))
499
+
500
+ return img
501
+
502
+
503
+
504
+ def mask_feature(emb, mask):
505
+ if emb.shape[0] == 1:
506
+ keep_index = mask.sum().item()
507
+ return emb[:, :, :keep_index, :], keep_index
508
+ else:
509
+ masked_feature = emb * mask[:, None, :, None]
510
+ return masked_feature, emb.shape[2]
DiT_VAE/diffusion/sa_sampler.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from DiT_VAE.diffusion.model.sa_solver import NoiseScheduleVP, model_wrapper, SASolver
7
+ from .model import gaussian_diffusion as gd
8
+
9
+
10
+ class SASolverSampler(object):
11
+ def __init__(self, model,
12
+ noise_schedule="linear",
13
+ diffusion_steps=1000,
14
+ device='cpu',
15
+ ):
16
+ super().__init__()
17
+ self.model = model
18
+ self.device = device
19
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(device)
20
+ betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
21
+ alphas = 1.0 - betas
22
+ self.register_buffer('alphas_cumprod', to_torch(np.cumprod(alphas, axis=0)))
23
+
24
+ def register_buffer(self, name, attr):
25
+ if type(attr) == torch.Tensor and attr.device != torch.device("cuda"):
26
+ attr = attr.to(torch.device("cuda"))
27
+ setattr(self, name, attr)
28
+
29
+ @torch.no_grad()
30
+ def sample(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, model_kwargs=None, **kwargs):
31
+ if model_kwargs is None:
32
+ model_kwargs = {}
33
+ if conditioning is not None:
34
+ if isinstance(conditioning, dict):
35
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
36
+ if cbs != batch_size:
37
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
38
+ elif conditioning.shape[0] != batch_size:
39
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
40
+
41
+ # sampling
42
+ C, H, W = shape
43
+ size = (batch_size, C, H, W)
44
+
45
+ device = self.device
46
+ img = torch.randn(size, device=device) if x_T is None else x_T
47
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
48
+
49
+ model_fn = model_wrapper(
50
+ self.model,
51
+ ns,
52
+ model_type="noise",
53
+ guidance_type="classifier-free",
54
+ condition=conditioning,
55
+ unconditional_condition=unconditional_conditioning,
56
+ guidance_scale=unconditional_guidance_scale,
57
+ model_kwargs=model_kwargs,
58
+ )
59
+
60
+ sasolver = SASolver(model_fn, ns, algorithm_type="data_prediction")
61
+
62
+ tau_t = lambda t: eta if 0.2 <= t <= 0.8 else 0
63
+
64
+ x = sasolver.sample(mode='few_steps', x=img, tau=tau_t, steps=S, skip_type='time', skip_order=1, predictor_order=2, corrector_order=2, pc_mode='PEC', return_intermediate=False)
65
+
66
+ return x.to(device), None
DiT_VAE/diffusion/sa_solver_diffusers.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ # DISCLAIMER: check https://arxiv.org/abs/2309.05019
14
+ # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
15
+
16
+ import math
17
+ from typing import List, Optional, Tuple, Union, Callable
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils.torch_utils import randn_tensor
24
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25
+
26
+
27
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
28
+ def betas_for_alpha_bar(
29
+ num_diffusion_timesteps,
30
+ max_beta=0.999,
31
+ alpha_transform_type="cosine",
32
+ ):
33
+ """
34
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
35
+ (1-beta) over time from t = [0,1].
36
+
37
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
38
+ to that part of the diffusion process.
39
+
40
+
41
+ Args:
42
+ num_diffusion_timesteps (`int`): the number of betas to produce.
43
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
44
+ prevent singularities.
45
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
46
+ Choose from `cosine` or `exp`
47
+
48
+ Returns:
49
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
50
+ """
51
+ if alpha_transform_type == "cosine":
52
+
53
+ def alpha_bar_fn(t):
54
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
55
+
56
+ elif alpha_transform_type == "exp":
57
+
58
+ def alpha_bar_fn(t):
59
+ return math.exp(t * -12.0)
60
+
61
+ else:
62
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
63
+
64
+ betas = []
65
+ for i in range(num_diffusion_timesteps):
66
+ t1 = i / num_diffusion_timesteps
67
+ t2 = (i + 1) / num_diffusion_timesteps
68
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
69
+ return torch.tensor(betas, dtype=torch.float32)
70
+
71
+
72
+ class SASolverScheduler(SchedulerMixin, ConfigMixin):
73
+ """
74
+ `SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs.
75
+
76
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
77
+ methods the library implements for all schedulers such as loading and saving.
78
+
79
+ Args:
80
+ num_train_timesteps (`int`, defaults to 1000):
81
+ The number of diffusion steps to train the model.
82
+ beta_start (`float`, defaults to 0.0001):
83
+ The starting `beta` value of inference.
84
+ beta_end (`float`, defaults to 0.02):
85
+ The final `beta` value.
86
+ beta_schedule (`str`, defaults to `"linear"`):
87
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
88
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
89
+ trained_betas (`np.ndarray`, *optional*):
90
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
91
+ predictor_order (`int`, defaults to 2):
92
+ The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided
93
+ sampling, and `predictor_order=3` for unconditional sampling.
94
+ corrector_order (`int`, defaults to 2):
95
+ The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided
96
+ sampling, and `corrector_order=3` for unconditional sampling.
97
+ predictor_corrector_mode (`str`, defaults to `PEC`):
98
+ The predictor-corrector mode can be `PEC` or 'PECE'. It is recommended to use `PEC` mode for fast
99
+ sampling, and `PECE` for high-quality sampling (PECE needs around twice model evaluations as PEC).
100
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
101
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
102
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
103
+ Video](https://imagen.research.google/video/paper.pdf) paper).
104
+ thresholding (`bool`, defaults to `False`):
105
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion vae such
106
+ as Stable Diffusion.
107
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
108
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
109
+ sample_max_value (`float`, defaults to 1.0):
110
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
111
+ `algorithm_type="dpmsolver++"`.
112
+ algorithm_type (`str`, defaults to `data_prediction`):
113
+ Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction`
114
+ with `solver_order=2` for guided sampling like in Stable Diffusion.
115
+ lower_order_final (`bool`, defaults to `True`):
116
+ Whether to use lower-order solvers in the final steps. Default = True.
117
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
118
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
119
+ the sigmas are determined according to a sequence of noise levels {σi}.
120
+ lambda_min_clipped (`float`, defaults to `-inf`):
121
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
122
+ cosine (`squaredcos_cap_v2`) noise schedule.
123
+ variance_type (`str`, *optional*):
124
+ Set to "learned" or "learned_range" for diffusion vae that predict variance. If set, the model's output
125
+ contains the predicted Gaussian variance.
126
+ timestep_spacing (`str`, defaults to `"linspace"`):
127
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
128
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
129
+ steps_offset (`int`, defaults to 0):
130
+ An offset added to the inference steps. You can use a combination of `offset=1` and
131
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
132
+ Diffusion.
133
+ """
134
+
135
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
136
+ order = 1
137
+
138
+ @register_to_config
139
+ def __init__(
140
+ self,
141
+ num_train_timesteps: int = 1000,
142
+ beta_start: float = 0.0001,
143
+ beta_end: float = 0.02,
144
+ beta_schedule: str = "linear",
145
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
146
+ predictor_order: int = 2,
147
+ corrector_order: int = 2,
148
+ predictor_corrector_mode: str = 'PEC',
149
+ prediction_type: str = "epsilon",
150
+ tau_func: Callable = lambda t: 1 if t >= 200 and t <= 800 else 0,
151
+ thresholding: bool = False,
152
+ dynamic_thresholding_ratio: float = 0.995,
153
+ sample_max_value: float = 1.0,
154
+ algorithm_type: str = "data_prediction",
155
+ lower_order_final: bool = True,
156
+ use_karras_sigmas: Optional[bool] = False,
157
+ lambda_min_clipped: float = -float("inf"),
158
+ variance_type: Optional[str] = None,
159
+ timestep_spacing: str = "linspace",
160
+ steps_offset: int = 0,
161
+ ):
162
+ if trained_betas is not None:
163
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
164
+ elif beta_schedule == "linear":
165
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
166
+ elif beta_schedule == "scaled_linear":
167
+ # this schedule is very specific to the latent diffusion model.
168
+ self.betas = (
169
+ torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
170
+ )
171
+ elif beta_schedule == "squaredcos_cap_v2":
172
+ # Glide cosine schedule
173
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
174
+ else:
175
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
176
+
177
+ self.alphas = 1.0 - self.betas
178
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
179
+ # Currently we only support VP-type noise schedule
180
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
181
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
182
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
183
+
184
+ # standard deviation of the initial noise distribution
185
+ self.init_noise_sigma = 1.0
186
+
187
+ if algorithm_type not in ["data_prediction", "noise_prediction"]:
188
+ raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
189
+
190
+ # setable values
191
+ self.num_inference_steps = None
192
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
193
+ self.timesteps = torch.from_numpy(timesteps)
194
+ self.timestep_list = [None] * max(predictor_order, corrector_order - 1)
195
+ self.model_outputs = [None] * max(predictor_order, corrector_order - 1)
196
+
197
+ self.tau_func = tau_func
198
+ self.predict_x0 = algorithm_type == "data_prediction"
199
+ self.lower_order_nums = 0
200
+ self.last_sample = None
201
+
202
+ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
203
+ """
204
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
205
+
206
+ Args:
207
+ num_inference_steps (`int`):
208
+ The number of diffusion steps used when generating samples with a pre-trained model.
209
+ device (`str` or `torch.device`, *optional*):
210
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
211
+ """
212
+ # Clipping the minimum of all lambda(t) for numerical stability.
213
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
214
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
215
+ last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
216
+
217
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
218
+ if self.config.timestep_spacing == "linspace":
219
+ timesteps = (
220
+ np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
221
+ )
222
+
223
+ elif self.config.timestep_spacing == "leading":
224
+ step_ratio = last_timestep // (num_inference_steps + 1)
225
+ # creates integer timesteps by multiplying by ratio
226
+ # casting to int to avoid issues when num_inference_step is power of 3
227
+ timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
228
+ timesteps += self.config.steps_offset
229
+ elif self.config.timestep_spacing == "trailing":
230
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
231
+ # creates integer timesteps by multiplying by ratio
232
+ # casting to int to avoid issues when num_inference_step is power of 3
233
+ timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
234
+ timesteps -= 1
235
+ else:
236
+ raise ValueError(
237
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
238
+ )
239
+
240
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
241
+ if self.config.use_karras_sigmas:
242
+ log_sigmas = np.log(sigmas)
243
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
244
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
245
+ timesteps = np.flip(timesteps).copy().astype(np.int64)
246
+
247
+ self.sigmas = torch.from_numpy(sigmas)
248
+
249
+ # when num_inference_steps == num_train_timesteps, we can end up with
250
+ # duplicates in timesteps.
251
+ _, unique_indices = np.unique(timesteps, return_index=True)
252
+ timesteps = timesteps[np.sort(unique_indices)]
253
+
254
+ self.timesteps = torch.from_numpy(timesteps).to(device)
255
+
256
+ self.num_inference_steps = len(timesteps)
257
+
258
+ self.model_outputs = [
259
+ None,
260
+ ] * max(self.config.predictor_order, self.config.corrector_order - 1)
261
+ self.lower_order_nums = 0
262
+ self.last_sample = None
263
+
264
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
265
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
266
+ """
267
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
268
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
269
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
270
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
271
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
272
+
273
+ https://arxiv.org/abs/2205.11487
274
+ """
275
+ dtype = sample.dtype
276
+ batch_size, channels, height, width = sample.shape
277
+
278
+ if dtype not in (torch.float32, torch.float64):
279
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
280
+
281
+ # Flatten sample for doing quantile calculation along each image
282
+ sample = sample.reshape(batch_size, channels * height * width)
283
+
284
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
285
+
286
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
287
+ s = torch.clamp(
288
+ s, min=1, max=self.config.sample_max_value
289
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
290
+
291
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
292
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
293
+
294
+ sample = sample.reshape(batch_size, channels, height, width)
295
+ sample = sample.to(dtype)
296
+
297
+ return sample
298
+
299
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
300
+ def _sigma_to_t(self, sigma, log_sigmas):
301
+ # get log sigma
302
+ log_sigma = np.log(sigma)
303
+
304
+ # get distribution
305
+ dists = log_sigma - log_sigmas[:, np.newaxis]
306
+
307
+ # get sigmas range
308
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
309
+ high_idx = low_idx + 1
310
+
311
+ low = log_sigmas[low_idx]
312
+ high = log_sigmas[high_idx]
313
+
314
+ # interpolate sigmas
315
+ w = (low - log_sigma) / (low - high)
316
+ w = np.clip(w, 0, 1)
317
+
318
+ # transform interpolation to time range
319
+ t = (1 - w) * low_idx + w * high_idx
320
+ t = t.reshape(sigma.shape)
321
+ return t
322
+
323
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
324
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
325
+ """Constructs the noise schedule of Karras et al. (2022)."""
326
+
327
+ sigma_min: float = in_sigmas[-1].item()
328
+ sigma_max: float = in_sigmas[0].item()
329
+
330
+ rho = 7.0 # 7.0 is the value used in the paper
331
+ ramp = np.linspace(0, 1, num_inference_steps)
332
+ min_inv_rho = sigma_min ** (1 / rho)
333
+ max_inv_rho = sigma_max ** (1 / rho)
334
+ return (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
335
+
336
+ def convert_model_output(
337
+ self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
338
+ ) -> torch.FloatTensor:
339
+ """
340
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
341
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
342
+ integral of the data prediction model.
343
+
344
+ <Tip>
345
+
346
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
347
+ prediction and data prediction vae.
348
+
349
+ </Tip>
350
+
351
+ Args:
352
+ model_output (`torch.FloatTensor`):
353
+ The direct output from the learned diffusion model.
354
+ timestep (`int`):
355
+ The current discrete timestep in the diffusion chain.
356
+ sample (`torch.FloatTensor`):
357
+ A current instance of a sample created by the diffusion process.
358
+
359
+ Returns:
360
+ `torch.FloatTensor`:
361
+ The converted model output.
362
+ """
363
+
364
+ # SA-Solver_data_prediction needs to solve an integral of the data prediction model.
365
+ if self.config.algorithm_type in ["data_prediction"]:
366
+ if self.config.prediction_type == "epsilon":
367
+ # SA-Solver only needs the "mean" output.
368
+ if self.config.variance_type in ["learned", "learned_range"]:
369
+ model_output = model_output[:, :3]
370
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
371
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
372
+ elif self.config.prediction_type == "sample":
373
+ x0_pred = model_output
374
+ elif self.config.prediction_type == "v_prediction":
375
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
376
+ x0_pred = alpha_t * sample - sigma_t * model_output
377
+ else:
378
+ raise ValueError(
379
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
380
+ " `v_prediction` for the SASolverScheduler."
381
+ )
382
+
383
+ if self.config.thresholding:
384
+ x0_pred = self._threshold_sample(x0_pred)
385
+
386
+ return x0_pred
387
+
388
+ # SA-Solver_noise_prediction needs to solve an integral of the noise prediction model.
389
+ elif self.config.algorithm_type in ["noise_prediction"]:
390
+ if self.config.prediction_type == "epsilon":
391
+ # SA-Solver only needs the "mean" output.
392
+ if self.config.variance_type in ["learned", "learned_range"]:
393
+ epsilon = model_output[:, :3]
394
+ else:
395
+ epsilon = model_output
396
+ elif self.config.prediction_type == "sample":
397
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
398
+ epsilon = (sample - alpha_t * model_output) / sigma_t
399
+ elif self.config.prediction_type == "v_prediction":
400
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
401
+ epsilon = alpha_t * model_output + sigma_t * sample
402
+ else:
403
+ raise ValueError(
404
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
405
+ " `v_prediction` for the SASolverScheduler."
406
+ )
407
+
408
+ if self.config.thresholding:
409
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
410
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
411
+ x0_pred = self._threshold_sample(x0_pred)
412
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
413
+
414
+ return epsilon
415
+
416
+ def get_coefficients_exponential_negative(self, order, interval_start, interval_end):
417
+ """
418
+ Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end
419
+ """
420
+ assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
421
+
422
+ if order == 0:
423
+ return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1)
424
+ elif order == 1:
425
+ return torch.exp(-interval_end) * (
426
+ (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1))
427
+ elif order == 2:
428
+ return torch.exp(-interval_end) * (
429
+ (interval_start ** 2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - (
430
+ interval_end ** 2 + 2 * interval_end + 2))
431
+ elif order == 3:
432
+ return torch.exp(-interval_end) * (
433
+ (interval_start ** 3 + 3 * interval_start ** 2 + 6 * interval_start + 6) * torch.exp(
434
+ interval_end - interval_start) - (interval_end ** 3 + 3 * interval_end ** 2 + 6 * interval_end + 6))
435
+
436
+ def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau):
437
+ """
438
+ Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
439
+ """
440
+ assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"
441
+
442
+ # after change of variable(cov)
443
+ interval_end_cov = (1 + tau ** 2) * interval_end
444
+ interval_start_cov = (1 + tau ** 2) * interval_start
445
+
446
+ if order == 0:
447
+ return torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / (
448
+ (1 + tau ** 2))
449
+ elif order == 1:
450
+ return torch.exp(interval_end_cov) * ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(
451
+ -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 2)
452
+ elif order == 2:
453
+ return torch.exp(interval_end_cov) * ((interval_end_cov ** 2 - 2 * interval_end_cov + 2) - (
454
+ interval_start_cov ** 2 - 2 * interval_start_cov + 2) * torch.exp(
455
+ -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 3)
456
+ elif order == 3:
457
+ return torch.exp(interval_end_cov) * (
458
+ (interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6) - (
459
+ interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6) * torch.exp(
460
+ -(interval_end_cov - interval_start_cov))) / ((1 + tau ** 2) ** 4)
461
+
462
+ def lagrange_polynomial_coefficient(self, order, lambda_list):
463
+ """
464
+ Calculate the coefficient of lagrange polynomial
465
+ """
466
+
467
+ assert order in [0, 1, 2, 3]
468
+ assert order == len(lambda_list) - 1
469
+ if order == 0:
470
+ return [[1]]
471
+ elif order == 1:
472
+ return [[1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
473
+ [1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
474
+ elif order == 2:
475
+ denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
476
+ denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2])
477
+ denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1])
478
+ return [[1 / denominator1,
479
+ (-lambda_list[1] - lambda_list[2]) / denominator1,
480
+ lambda_list[1] * lambda_list[2] / denominator1],
481
+
482
+ [1 / denominator2,
483
+ (-lambda_list[0] - lambda_list[2]) / denominator2,
484
+ lambda_list[0] * lambda_list[2] / denominator2],
485
+
486
+ [1 / denominator3,
487
+ (-lambda_list[0] - lambda_list[1]) / denominator3,
488
+ lambda_list[0] * lambda_list[1] / denominator3]
489
+ ]
490
+ elif order == 3:
491
+ denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (
492
+ lambda_list[0] - lambda_list[3])
493
+ denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (
494
+ lambda_list[1] - lambda_list[3])
495
+ denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (
496
+ lambda_list[2] - lambda_list[3])
497
+ denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (
498
+ lambda_list[3] - lambda_list[2])
499
+ return [[1 / denominator1,
500
+ (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1,
501
+ (lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[
502
+ 3]) / denominator1,
503
+ (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1],
504
+
505
+ [1 / denominator2,
506
+ (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2,
507
+ (lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[
508
+ 3]) / denominator2,
509
+ (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],
510
+
511
+ [1 / denominator3,
512
+ (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
513
+ (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[
514
+ 3]) / denominator3,
515
+ (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],
516
+
517
+ [1 / denominator4,
518
+ (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
519
+ (lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[
520
+ 2]) / denominator4,
521
+ (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
522
+
523
+ ]
524
+
525
+ def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau):
526
+ assert order in [1, 2, 3, 4]
527
+ assert order == len(lambda_list), 'the length of lambda list must be equal to the order'
528
+ coefficients = []
529
+ lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list)
530
+ for i in range(order):
531
+ coefficient = sum(
532
+ lagrange_coefficient[i][j]
533
+ * self.get_coefficients_exponential_positive(
534
+ order - 1 - j, interval_start, interval_end, tau
535
+ )
536
+ if self.predict_x0
537
+ else lagrange_coefficient[i][j]
538
+ * self.get_coefficients_exponential_negative(
539
+ order - 1 - j, interval_start, interval_end
540
+ )
541
+ for j in range(order)
542
+ )
543
+ coefficients.append(coefficient)
544
+ assert len(coefficients) == order, 'the length of coefficients does not match the order'
545
+ return coefficients
546
+
547
+ def stochastic_adams_bashforth_update(
548
+ self,
549
+ model_output: torch.FloatTensor,
550
+ prev_timestep: int,
551
+ sample: torch.FloatTensor,
552
+ noise: torch.FloatTensor,
553
+ order: int,
554
+ tau: torch.FloatTensor,
555
+ ) -> torch.FloatTensor:
556
+ """
557
+ One step for the SA-Predictor.
558
+
559
+ Args:
560
+ model_output (`torch.FloatTensor`):
561
+ The direct output from the learned diffusion model at the current timestep.
562
+ prev_timestep (`int`):
563
+ The previous discrete timestep in the diffusion chain.
564
+ sample (`torch.FloatTensor`):
565
+ A current instance of a sample created by the diffusion process.
566
+ order (`int`):
567
+ The order of SA-Predictor at this timestep.
568
+
569
+ Returns:
570
+ `torch.FloatTensor`:
571
+ The sample tensor at the previous timestep.
572
+ """
573
+
574
+ assert noise is not None
575
+ timestep_list = self.timestep_list
576
+ model_output_list = self.model_outputs
577
+ s0, t = self.timestep_list[-1], prev_timestep
578
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
579
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
580
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
581
+ gradient_part = torch.zeros_like(sample)
582
+ h = lambda_t - lambda_s0
583
+ lambda_list = [self.lambda_t[timestep_list[-(i + 1)]] for i in range(order)]
584
+ gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau)
585
+
586
+ x = sample
587
+
588
+ if self.predict_x0 and order == 2:
589
+ gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
590
+ h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
591
+ (1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[
592
+ timestep_list[-2]])
593
+ gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
594
+ h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
595
+ (1 + tau ** 2) ** 2)) / (self.lambda_t[timestep_list[-1]] - self.lambda_t[
596
+ timestep_list[-2]])
597
+
598
+ for i in range(order):
599
+ if self.predict_x0:
600
+
601
+ gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
602
+ i] * model_output_list[-(i + 1)]
603
+ else:
604
+ gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_output_list[-(i + 1)]
605
+
606
+ if self.predict_x0:
607
+ noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * noise
608
+ else:
609
+ noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise
610
+
611
+ if self.predict_x0:
612
+ x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part
613
+ else:
614
+ x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part
615
+
616
+ x_t = x_t.to(x.dtype)
617
+ return x_t
618
+
619
+ def stochastic_adams_moulton_update(
620
+ self,
621
+ this_model_output: torch.FloatTensor,
622
+ this_timestep: int,
623
+ last_sample: torch.FloatTensor,
624
+ last_noise: torch.FloatTensor,
625
+ this_sample: torch.FloatTensor,
626
+ order: int,
627
+ tau: torch.FloatTensor,
628
+ ) -> torch.FloatTensor:
629
+ """
630
+ One step for the SA-Corrector.
631
+
632
+ Args:
633
+ this_model_output (`torch.FloatTensor`):
634
+ The model outputs at `x_t`.
635
+ this_timestep (`int`):
636
+ The current timestep `t`.
637
+ last_sample (`torch.FloatTensor`):
638
+ The generated sample before the last predictor `x_{t-1}`.
639
+ this_sample (`torch.FloatTensor`):
640
+ The generated sample after the last predictor `x_{t}`.
641
+ order (`int`):
642
+ The order of SA-Corrector at this step.
643
+
644
+ Returns:
645
+ `torch.FloatTensor`:
646
+ The corrected sample tensor at the current timestep.
647
+ """
648
+
649
+ assert last_noise is not None
650
+ timestep_list = self.timestep_list
651
+ model_output_list = self.model_outputs
652
+ s0, t = self.timestep_list[-1], this_timestep
653
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
654
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
655
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
656
+ gradient_part = torch.zeros_like(this_sample)
657
+ h = lambda_t - lambda_s0
658
+ t_list = timestep_list + [this_timestep]
659
+ lambda_list = [self.lambda_t[t_list[-(i + 1)]] for i in range(order)]
660
+ model_prev_list = model_output_list + [this_model_output]
661
+
662
+ gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau)
663
+
664
+ x = last_sample
665
+
666
+ if self.predict_x0 and order == 2:
667
+ gradient_coefficients[0] += 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
668
+ h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
669
+ (1 + tau ** 2) ** 2 * h))
670
+ gradient_coefficients[1] -= 1.0 * torch.exp((1 + tau ** 2) * lambda_t) * (
671
+ h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / (
672
+ (1 + tau ** 2) ** 2 * h))
673
+
674
+ for i in range(order):
675
+ if self.predict_x0:
676
+ gradient_part += (1 + tau ** 2) * sigma_t * torch.exp(- tau ** 2 * lambda_t) * gradient_coefficients[
677
+ i] * model_prev_list[-(i + 1)]
678
+ else:
679
+ gradient_part += -(1 + tau ** 2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)]
680
+
681
+ if self.predict_x0:
682
+ noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau ** 2 * h)) * last_noise
683
+ else:
684
+ noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise
685
+
686
+ if self.predict_x0:
687
+ x_t = torch.exp(-tau ** 2 * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part
688
+ else:
689
+ x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part
690
+
691
+ x_t = x_t.to(x.dtype)
692
+ return x_t
693
+
694
+ def step(
695
+ self,
696
+ model_output: torch.FloatTensor,
697
+ timestep: int,
698
+ sample: torch.FloatTensor,
699
+ generator=None,
700
+ return_dict: bool = True,
701
+ ) -> Union[SchedulerOutput, Tuple]:
702
+ """
703
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
704
+ the SA-Solver.
705
+
706
+ Args:
707
+ model_output (`torch.FloatTensor`):
708
+ The direct output from learned diffusion model.
709
+ timestep (`int`):
710
+ The current discrete timestep in the diffusion chain.
711
+ sample (`torch.FloatTensor`):
712
+ A current instance of a sample created by the diffusion process.
713
+ generator (`torch.Generator`, *optional*):
714
+ A random number generator.
715
+ return_dict (`bool`):
716
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
717
+
718
+ Returns:
719
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
720
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
721
+ tuple is returned where the first element is the sample tensor.
722
+
723
+ """
724
+ if self.num_inference_steps is None:
725
+ raise ValueError(
726
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
727
+ )
728
+
729
+ if isinstance(timestep, torch.Tensor):
730
+ timestep = timestep.to(self.timesteps.device)
731
+ step_index = (self.timesteps == timestep).nonzero()
732
+ if len(step_index) == 0:
733
+ step_index = len(self.timesteps) - 1
734
+ else:
735
+ step_index = step_index.item()
736
+
737
+ use_corrector = (
738
+ step_index > 0 and self.last_sample is not None
739
+ )
740
+
741
+ model_output_convert = self.convert_model_output(model_output, timestep, sample)
742
+
743
+ if use_corrector:
744
+ current_tau = self.tau_func(self.timestep_list[-1])
745
+ sample = self.stochastic_adams_moulton_update(
746
+ this_model_output=model_output_convert,
747
+ this_timestep=timestep,
748
+ last_sample=self.last_sample,
749
+ last_noise=self.last_noise,
750
+ this_sample=sample,
751
+ order=self.this_corrector_order,
752
+ tau=current_tau,
753
+ )
754
+
755
+ prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
756
+
757
+ for i in range(max(self.config.predictor_order, self.config.corrector_order - 1) - 1):
758
+ self.model_outputs[i] = self.model_outputs[i + 1]
759
+ self.timestep_list[i] = self.timestep_list[i + 1]
760
+
761
+ self.model_outputs[-1] = model_output_convert
762
+ self.timestep_list[-1] = timestep
763
+
764
+ noise = randn_tensor(
765
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
766
+ )
767
+
768
+ if self.config.lower_order_final:
769
+ this_predictor_order = min(self.config.predictor_order, len(self.timesteps) - step_index)
770
+ this_corrector_order = min(self.config.corrector_order, len(self.timesteps) - step_index + 1)
771
+ else:
772
+ this_predictor_order = self.config.predictor_order
773
+ this_corrector_order = self.config.corrector_order
774
+
775
+ self.this_predictor_order = min(this_predictor_order, self.lower_order_nums + 1) # warmup for multistep
776
+ self.this_corrector_order = min(this_corrector_order, self.lower_order_nums + 2) # warmup for multistep
777
+ assert self.this_predictor_order > 0
778
+ assert self.this_corrector_order > 0
779
+
780
+ self.last_sample = sample
781
+ self.last_noise = noise
782
+
783
+ current_tau = self.tau_func(self.timestep_list[-1])
784
+ prev_sample = self.stochastic_adams_bashforth_update(
785
+ model_output=model_output_convert,
786
+ prev_timestep=prev_timestep,
787
+ sample=sample,
788
+ noise=noise,
789
+ order=self.this_predictor_order,
790
+ tau=current_tau,
791
+ )
792
+
793
+ if self.lower_order_nums < max(self.config.predictor_order, self.config.corrector_order - 1):
794
+ self.lower_order_nums += 1
795
+
796
+ if not return_dict:
797
+ return (prev_sample,)
798
+
799
+ return SchedulerOutput(prev_sample=prev_sample)
800
+
801
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
802
+ """
803
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
804
+ current timestep.
805
+
806
+ Args:
807
+ sample (`torch.FloatTensor`):
808
+ The input sample.
809
+
810
+ Returns:
811
+ `torch.FloatTensor`:
812
+ A scaled input sample.
813
+ """
814
+ return sample
815
+
816
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
817
+ def add_noise(
818
+ self,
819
+ original_samples: torch.FloatTensor,
820
+ noise: torch.FloatTensor,
821
+ timesteps: torch.IntTensor,
822
+ ) -> torch.FloatTensor:
823
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
824
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
825
+ timesteps = timesteps.to(original_samples.device)
826
+
827
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
828
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
829
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
830
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
831
+
832
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
833
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
834
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
835
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
836
+
837
+ return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
838
+
839
+ def __len__(self):
840
+ return self.config.num_train_timesteps
DiT_VAE/diffusion/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from .logger import get_root_logger
DiT_VAE/diffusion/utils/checkpoint.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+
5
+ from DiT_VAE.diffusion.utils.logger import get_root_logger
6
+
7
+
8
+ def save_checkpoint(work_dir,
9
+ epoch,
10
+ model,
11
+ model_ema=None,
12
+ optimizer=None,
13
+ lr_scheduler=None,
14
+ keep_last=False,
15
+ step=None,
16
+ ):
17
+ os.makedirs(work_dir, exist_ok=True)
18
+ state_dict = dict(state_dict=model.state_dict())
19
+ if model_ema is not None:
20
+ state_dict['state_dict_ema'] = model_ema.state_dict()
21
+ if optimizer is not None:
22
+ state_dict['optimizer'] = optimizer.state_dict()
23
+ if lr_scheduler is not None:
24
+ state_dict['scheduler'] = lr_scheduler.state_dict()
25
+ if epoch is not None:
26
+ state_dict['epoch'] = epoch
27
+ file_path = os.path.join(work_dir, f"epoch_{epoch}.pth")
28
+ if step is not None:
29
+ file_path = file_path.split('.pth')[0] + f"_step_{step}.pth"
30
+ logger = get_root_logger()
31
+ torch.save(state_dict, file_path)
32
+ logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.')
33
+ if keep_last:
34
+ for i in range(epoch):
35
+ previous_ckgt = file_path.format(i)
36
+ if os.path.exists(previous_ckgt):
37
+ os.remove(previous_ckgt)
38
+
39
+
40
+ def load_checkpoint(checkpoint,
41
+ model,
42
+ model_ema=None,
43
+ optimizer=None,
44
+ lr_scheduler=None,
45
+ load_ema=False,
46
+ resume_optimizer=True,
47
+ resume_lr_scheduler=True
48
+ ):
49
+ assert isinstance(checkpoint, str)
50
+ ckpt_file = checkpoint
51
+ checkpoint = torch.load(ckpt_file, map_location="cpu")
52
+
53
+ state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']
54
+ for key in state_dict_keys:
55
+ if key in checkpoint['state_dict']:
56
+ del checkpoint['state_dict'][key]
57
+ if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']:
58
+ del checkpoint['state_dict_ema'][key]
59
+ break
60
+
61
+ if load_ema:
62
+ state_dict = checkpoint['state_dict_ema']
63
+ else:
64
+ state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint
65
+ # model.load_state_dict(state_dict)
66
+ missing, unexpect = model.load_state_dict(state_dict, strict=False)
67
+ if model_ema is not None:
68
+ model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
69
+ if optimizer is not None and resume_optimizer:
70
+ optimizer.load_state_dict(checkpoint['optimizer'])
71
+ if lr_scheduler is not None and resume_lr_scheduler:
72
+ lr_scheduler.load_state_dict(checkpoint['scheduler'])
73
+ logger = get_root_logger()
74
+ if optimizer is not None:
75
+ epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0])
76
+ logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, '
77
+ f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.')
78
+ return epoch, missing, unexpect
79
+ logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.')
80
+ return missing, unexpect
DiT_VAE/diffusion/utils/data_sampler.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import Sequence
4
+ from torch.utils.data import BatchSampler, Sampler, Dataset
5
+ from random import shuffle, choice
6
+ from copy import deepcopy
7
+ from DiT_VAE.diffusion.utils.logger import get_root_logger
8
+
9
+
10
+ class AspectRatioBatchSampler(BatchSampler):
11
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
12
+
13
+ Args:
14
+ sampler (Sampler): Base sampler.
15
+ dataset (Dataset): Dataset providing data information.
16
+ batch_size (int): Size of mini-batch.
17
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
18
+ its size would be less than ``batch_size``.
19
+ aspect_ratios (dict): The predefined aspect ratios.
20
+ """
21
+
22
+ def __init__(self,
23
+ sampler: Sampler,
24
+ dataset: Dataset,
25
+ batch_size: int,
26
+ aspect_ratios: dict,
27
+ drop_last: bool = False,
28
+ config=None,
29
+ valid_num=0, # take as valid aspect-ratio when sample number >= valid_num
30
+ **kwargs) -> None:
31
+ if not isinstance(sampler, Sampler):
32
+ raise TypeError('sampler should be an instance of ``Sampler``, '
33
+ f'but got {sampler}')
34
+ if not isinstance(batch_size, int) or batch_size <= 0:
35
+ raise ValueError('batch_size should be a positive integer value, '
36
+ f'but got batch_size={batch_size}')
37
+ self.sampler = sampler
38
+ self.dataset = dataset
39
+ self.batch_size = batch_size
40
+ self.aspect_ratios = aspect_ratios
41
+ self.drop_last = drop_last
42
+ self.ratio_nums_gt = kwargs.get('ratio_nums', None)
43
+ self.config = config
44
+ assert self.ratio_nums_gt
45
+ # buckets for each aspect ratio
46
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
47
+ self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num]
48
+ logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
49
+ logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
50
+
51
+ def __iter__(self) -> Sequence[int]:
52
+ for idx in self.sampler:
53
+ data_info = self.dataset.get_data_info(idx)
54
+ height, width = data_info['height'], data_info['width']
55
+ ratio = height / width
56
+ # find the closest aspect ratio
57
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
58
+ if closest_ratio not in self.current_available_bucket_keys:
59
+ continue
60
+ bucket = self._aspect_ratio_buckets[closest_ratio]
61
+ bucket.append(idx)
62
+ # yield a batch of indices in the same aspect ratio group
63
+ if len(bucket) == self.batch_size:
64
+ yield bucket[:]
65
+ del bucket[:]
66
+
67
+ # yield the rest data and reset the buckets
68
+ for bucket in self._aspect_ratio_buckets.values():
69
+ while len(bucket) > 0:
70
+ if len(bucket) <= self.batch_size:
71
+ if not self.drop_last:
72
+ yield bucket[:]
73
+ bucket = []
74
+ else:
75
+ yield bucket[:self.batch_size]
76
+ bucket = bucket[self.batch_size:]
77
+
78
+
79
+ class BalancedAspectRatioBatchSampler(AspectRatioBatchSampler):
80
+ def __init__(self, *args, **kwargs):
81
+ super().__init__(*args, **kwargs)
82
+ # Assign samples to each bucket
83
+ self.ratio_nums_gt = kwargs.get('ratio_nums', None)
84
+ assert self.ratio_nums_gt
85
+ self._aspect_ratio_buckets = {float(ratio): [] for ratio in self.aspect_ratios.keys()}
86
+ self.original_buckets = {}
87
+ self.current_available_bucket_keys = [k for k, v in self.ratio_nums_gt.items() if v >= 3000]
88
+ self.all_available_keys = deepcopy(self.current_available_bucket_keys)
89
+ self.exhausted_bucket_keys = []
90
+ self.total_batches = len(self.sampler) // self.batch_size
91
+ self._aspect_ratio_count = {}
92
+ for k in self.all_available_keys:
93
+ self._aspect_ratio_count[float(k)] = 0
94
+ self.original_buckets[float(k)] = []
95
+ logger = get_root_logger(os.path.join(self.config.work_dir, 'train_log.log'))
96
+ logger.warning(f"Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}")
97
+
98
+ def __iter__(self) -> Sequence[int]:
99
+ i = 0
100
+ for idx in self.sampler:
101
+ data_info = self.dataset.get_data_info(idx)
102
+ height, width = data_info['height'], data_info['width']
103
+ ratio = height / width
104
+ closest_ratio = float(min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)))
105
+ if closest_ratio not in self.all_available_keys:
106
+ continue
107
+ if self._aspect_ratio_count[closest_ratio] < self.ratio_nums_gt[closest_ratio]:
108
+ self._aspect_ratio_count[closest_ratio] += 1
109
+ self._aspect_ratio_buckets[closest_ratio].append(idx)
110
+ self.original_buckets[closest_ratio].append(idx) # Save the original samples for each bucket
111
+ if not self.current_available_bucket_keys:
112
+ self.current_available_bucket_keys, self.exhausted_bucket_keys = self.exhausted_bucket_keys, []
113
+
114
+ if closest_ratio not in self.current_available_bucket_keys:
115
+ continue
116
+ key = closest_ratio
117
+ bucket = self._aspect_ratio_buckets[key]
118
+ if len(bucket) == self.batch_size:
119
+ yield bucket[:self.batch_size]
120
+ del bucket[:self.batch_size]
121
+ i += 1
122
+ self.exhausted_bucket_keys.append(key)
123
+ self.current_available_bucket_keys.remove(key)
124
+
125
+ for _ in range(self.total_batches - i):
126
+ key = choice(self.all_available_keys)
127
+ bucket = self._aspect_ratio_buckets[key]
128
+ if len(bucket) >= self.batch_size:
129
+ yield bucket[:self.batch_size]
130
+ del bucket[:self.batch_size]
131
+
132
+ # If a bucket is exhausted
133
+ if not bucket:
134
+ self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
135
+ shuffle(self._aspect_ratio_buckets[key])
136
+ else:
137
+ self._aspect_ratio_buckets[key] = deepcopy(self.original_buckets[key][:])
138
+ shuffle(self._aspect_ratio_buckets[key])
DiT_VAE/diffusion/utils/dist_utils.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains primitives for multi-gpu communication.
3
+ This is useful when doing distributed training.
4
+ """
5
+ import os
6
+ import pickle
7
+ import shutil
8
+
9
+ import gc
10
+ import mmcv
11
+ import torch
12
+ import torch.distributed as dist
13
+ from mmcv.runner import get_dist_info
14
+
15
+
16
+ def is_distributed():
17
+ return get_world_size() > 1
18
+
19
+
20
+ def get_world_size():
21
+ if not dist.is_available():
22
+ return 1
23
+ return dist.get_world_size() if dist.is_initialized() else 1
24
+
25
+
26
+ def get_rank():
27
+ if not dist.is_available():
28
+ return 0
29
+ return dist.get_rank() if dist.is_initialized() else 0
30
+
31
+
32
+ def get_local_rank():
33
+ if not dist.is_available():
34
+ return 0
35
+ return int(os.getenv('LOCAL_RANK', 0)) if dist.is_initialized() else 0
36
+
37
+
38
+ def is_master():
39
+ return get_rank() == 0
40
+
41
+
42
+ def is_local_master():
43
+ return get_local_rank() == 0
44
+
45
+
46
+ def get_local_proc_group(group_size=8):
47
+ world_size = get_world_size()
48
+ if world_size <= group_size or group_size == 1:
49
+ return None
50
+ assert world_size % group_size == 0, f'world size ({world_size}) should be evenly divided by group size ({group_size}).'
51
+ process_groups = getattr(get_local_proc_group, 'process_groups', {})
52
+ if group_size not in process_groups:
53
+ num_groups = dist.get_world_size() // group_size
54
+ groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)]
55
+ process_groups.update({group_size: [torch.distributed.new_group(group) for group in groups]})
56
+ get_local_proc_group.process_groups = process_groups
57
+
58
+ group_idx = get_rank() // group_size
59
+ return get_local_proc_group.process_groups.get(group_size)[group_idx]
60
+
61
+
62
+ def synchronize():
63
+ """
64
+ Helper function to synchronize (barrier) among all processes when
65
+ using distributed training
66
+ """
67
+ if not dist.is_available():
68
+ return
69
+ if not dist.is_initialized():
70
+ return
71
+ world_size = dist.get_world_size()
72
+ if world_size == 1:
73
+ return
74
+ dist.barrier()
75
+
76
+
77
+ def all_gather(data):
78
+ """
79
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
80
+ Args:
81
+ data: any picklable object
82
+ Returns:
83
+ list[data]: list of data gathered from each rank
84
+ """
85
+ to_device = torch.device("cuda")
86
+ # to_device = torch.device("cpu")
87
+
88
+ world_size = get_world_size()
89
+ if world_size == 1:
90
+ return [data]
91
+
92
+ # serialized to a Tensor
93
+ buffer = pickle.dumps(data)
94
+ storage = torch.ByteStorage.from_buffer(buffer)
95
+ tensor = torch.ByteTensor(storage).to(to_device)
96
+
97
+ # obtain Tensor size of each rank
98
+ local_size = torch.LongTensor([tensor.numel()]).to(to_device)
99
+ size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)]
100
+ dist.all_gather(size_list, local_size)
101
+ size_list = [int(size.item()) for size in size_list]
102
+ max_size = max(size_list)
103
+
104
+ tensor_list = [
105
+ torch.ByteTensor(size=(max_size,)).to(to_device) for _ in size_list
106
+ ]
107
+ if local_size != max_size:
108
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device)
109
+ tensor = torch.cat((tensor, padding), dim=0)
110
+ dist.all_gather(tensor_list, tensor)
111
+
112
+ data_list = []
113
+ for size, tensor in zip(size_list, tensor_list):
114
+ buffer = tensor.cpu().numpy().tobytes()[:size]
115
+ data_list.append(pickle.loads(buffer))
116
+
117
+ return data_list
118
+
119
+
120
+ def reduce_dict(input_dict, average=True):
121
+ """
122
+ Args:
123
+ input_dict (dict): all the values will be reduced
124
+ average (bool): whether to do average or sum
125
+ Reduce the values in the dictionary from all processes so that process with rank
126
+ 0 has the averaged results. Returns a dict with the same fields as
127
+ input_dict, after reduction.
128
+ """
129
+ world_size = get_world_size()
130
+ if world_size < 2:
131
+ return input_dict
132
+ with torch.no_grad():
133
+ reduced_dict = _extracted_from_reduce_dict_14(input_dict, average, world_size)
134
+ return reduced_dict
135
+
136
+
137
+ # TODO Rename this here and in `reduce_dict`
138
+ def _extracted_from_reduce_dict_14(input_dict, average, world_size):
139
+ names = []
140
+ values = []
141
+ # sort the keys so that they are consistent across processes
142
+ for k in sorted(input_dict.keys()):
143
+ names.append(k)
144
+ values.append(input_dict[k])
145
+ values = torch.stack(values, dim=0)
146
+ dist.reduce(values, dst=0)
147
+ if dist.get_rank() == 0 and average:
148
+ # only main process gets accumulated, so only divide by
149
+ # world_size in this case
150
+ values /= world_size
151
+ return dict(zip(names, values))
152
+
153
+
154
+ def broadcast(data, **kwargs):
155
+ if get_world_size() == 1:
156
+ return data
157
+ data = [data]
158
+ dist.broadcast_object_list(data, **kwargs)
159
+ return data[0]
160
+
161
+
162
+ def all_gather_cpu(result_part, tmpdir=None, collect_by_master=True):
163
+ rank, world_size = get_dist_info()
164
+ if tmpdir is None:
165
+ tmpdir = './tmp'
166
+ if rank == 0:
167
+ mmcv.mkdir_or_exist(tmpdir)
168
+ synchronize()
169
+ # dump the part result to the dir
170
+ mmcv.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl'))
171
+ synchronize()
172
+ if collect_by_master and rank != 0:
173
+ return None
174
+ # load results of all parts from tmp dir
175
+ results = []
176
+ for i in range(world_size):
177
+ part_file = os.path.join(tmpdir, f'part_{i}.pkl')
178
+ results.append(mmcv.load(part_file))
179
+ if not collect_by_master:
180
+ synchronize()
181
+ # remove tmp dir
182
+ if rank == 0:
183
+ shutil.rmtree(tmpdir)
184
+ return results
185
+
186
+ def all_gather_tensor(tensor, group_size=None, group=None):
187
+ if group_size is None:
188
+ group_size = get_world_size()
189
+ if group_size == 1:
190
+ output = [tensor]
191
+ else:
192
+ output = [torch.zeros_like(tensor) for _ in range(group_size)]
193
+ dist.all_gather(output, tensor, group=group)
194
+ return output
195
+
196
+
197
+ def gather_difflen_tensor(feat, num_samples_list, concat=True, group=None, group_size=None):
198
+ world_size = get_world_size()
199
+ if world_size == 1:
200
+ return feat if concat else [feat]
201
+ num_samples, *feat_dim = feat.size()
202
+ # padding to max number of samples
203
+ feat_padding = feat.new_zeros((max(num_samples_list), *feat_dim))
204
+ feat_padding[:num_samples] = feat
205
+ # gather
206
+ feat_gather = all_gather_tensor(feat_padding, group=group, group_size=group_size)
207
+ for r, num in enumerate(num_samples_list):
208
+ feat_gather[r] = feat_gather[r][:num]
209
+ if concat:
210
+ feat_gather = torch.cat(feat_gather)
211
+ return feat_gather
212
+
213
+
214
+ class GatherLayer(torch.autograd.Function):
215
+ '''Gather tensors from all process, supporting backward propagation.
216
+ '''
217
+
218
+ @staticmethod
219
+ def forward(ctx, input):
220
+ ctx.save_for_backward(input)
221
+ num_samples = torch.tensor(input.size(0), dtype=torch.long, device=input.device)
222
+ ctx.num_samples_list = all_gather_tensor(num_samples)
223
+ output = gather_difflen_tensor(input, ctx.num_samples_list, concat=False)
224
+ return tuple(output)
225
+
226
+ @staticmethod
227
+ def backward(ctx, *grads): # tuple(output)'s grad
228
+ input, = ctx.saved_tensors
229
+ num_samples_list = ctx.num_samples_list
230
+ rank = get_rank()
231
+ start, end = sum(num_samples_list[:rank]), sum(num_samples_list[:rank + 1])
232
+ grads = torch.cat(grads)
233
+ if is_distributed():
234
+ dist.all_reduce(grads)
235
+ grad_out = torch.zeros_like(input)
236
+ grad_out[:] = grads[start:end]
237
+ return grad_out, None, None
238
+
239
+
240
+ class GatherLayerWithGroup(torch.autograd.Function):
241
+ '''Gather tensors from all process, supporting backward propagation.
242
+ '''
243
+
244
+ @staticmethod
245
+ def forward(ctx, input, group, group_size):
246
+ ctx.save_for_backward(input)
247
+ ctx.group_size = group_size
248
+ output = all_gather_tensor(input, group=group, group_size=group_size)
249
+ return tuple(output)
250
+
251
+ @staticmethod
252
+ def backward(ctx, *grads): # tuple(output)'s grad
253
+ input, = ctx.saved_tensors
254
+ grads = torch.stack(grads)
255
+ if is_distributed():
256
+ dist.all_reduce(grads)
257
+ grad_out = torch.zeros_like(input)
258
+ grad_out[:] = grads[get_rank() % ctx.group_size]
259
+ return grad_out, None, None
260
+
261
+
262
+ def gather_layer_with_group(data, group=None, group_size=None):
263
+ if group_size is None:
264
+ group_size = get_world_size()
265
+ return GatherLayer.apply(data, group, group_size)
266
+
267
+ from typing import Union
268
+ import math
269
+ # from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_, _calc_grad_norm
270
+
271
+ @torch.no_grad()
272
+ def clip_grad_norm_(
273
+ self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
274
+ ) -> None:
275
+ self._lazy_init()
276
+ self._wait_for_previous_optim_step()
277
+ assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
278
+ self._assert_state(TrainingState_.IDLE)
279
+
280
+ max_norm = float(max_norm)
281
+ norm_type = float(norm_type)
282
+ # Computes the max norm for this shard's gradients and sync's across workers
283
+ local_norm = _calc_grad_norm(self.params_with_grad, norm_type).cuda() # type: ignore[arg-type]
284
+ if norm_type == math.inf:
285
+ total_norm = local_norm
286
+ dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
287
+ else:
288
+ total_norm = local_norm ** norm_type
289
+ dist.all_reduce(total_norm, group=self.process_group)
290
+ total_norm = total_norm ** (1.0 / norm_type)
291
+
292
+ clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
293
+ if clip_coef < 1:
294
+ # multiply by clip_coef, aka, (max_norm/total_norm).
295
+ for p in self.params_with_grad:
296
+ assert p.grad is not None
297
+ p.grad.detach().mul_(clip_coef.to(p.grad.device))
298
+ return total_norm
299
+
300
+
301
+ def flush():
302
+ gc.collect()
303
+ torch.cuda.empty_cache()
DiT_VAE/diffusion/utils/logger.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch.distributed as dist
4
+ from datetime import datetime
5
+ from .dist_utils import is_local_master
6
+ from mmcv.utils.logging import logger_initialized
7
+
8
+
9
+ def get_root_logger(log_file=None, log_level=logging.INFO, name='PixArt'):
10
+ """Get root logger.
11
+
12
+ Args:
13
+ log_file (str, optional): File path of log. Defaults to None.
14
+ log_level (int, optional): The level of logger.
15
+ Defaults to logging.INFO.
16
+ name (str): logger name
17
+ Returns:
18
+ :obj:`logging.Logger`: The obtained logger
19
+ """
20
+ if log_file is None:
21
+ log_file = '/dev/null'
22
+ return get_logger(name=name, log_file=log_file, log_level=log_level)
23
+
24
+
25
+ def get_logger(name, log_file=None, log_level=logging.INFO):
26
+ """Initialize and get a logger by name.
27
+
28
+ If the logger has not been initialized, this method will initialize the
29
+ logger by adding one or two handlers, otherwise the initialized logger will
30
+ be directly returned. During initialization, a StreamHandler will always be
31
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
32
+ will also be added.
33
+
34
+ Args:
35
+ name (str): Logger name.
36
+ log_file (str | None): The log filename. If specified, a FileHandler
37
+ will be added to the logger.
38
+ log_level (int): The logger level. Note that only the process of
39
+ rank 0 is affected, and other processes will set the level to
40
+ "Error" thus be silent most of the time.
41
+
42
+ Returns:
43
+ logging.Logger: The expected logger.
44
+ """
45
+ logger = logging.getLogger(name)
46
+ logger.propagate = False # disable root logger to avoid duplicate logging
47
+
48
+ if name in logger_initialized:
49
+ return logger
50
+ # handle hierarchical names
51
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
52
+ # initialization since it is a child of "a".
53
+ for logger_name in logger_initialized:
54
+ if name.startswith(logger_name):
55
+ return logger
56
+
57
+ stream_handler = logging.StreamHandler()
58
+ handlers = [stream_handler]
59
+
60
+ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
61
+ # only rank 0 will add a FileHandler
62
+ if rank == 0 and log_file is not None:
63
+ file_handler = logging.FileHandler(log_file, 'w')
64
+ handlers.append(file_handler)
65
+
66
+ formatter = logging.Formatter(
67
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
68
+ for handler in handlers:
69
+ handler.setFormatter(formatter)
70
+ handler.setLevel(log_level)
71
+ logger.addHandler(handler)
72
+
73
+ # only rank0 for each node will print logs
74
+ log_level = log_level if is_local_master() else logging.ERROR
75
+ logger.setLevel(log_level)
76
+
77
+ logger_initialized[name] = True
78
+
79
+ return logger
80
+
81
+ def rename_file_with_creation_time(file_path):
82
+ # 获取文件的创建时间
83
+ creation_time = os.path.getctime(file_path)
84
+ creation_time_str = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d_%H-%M-%S')
85
+
86
+ # 构建新的文件名
87
+ dir_name, file_name = os.path.split(file_path)
88
+ name, ext = os.path.splitext(file_name)
89
+ new_file_name = f"{name}_{creation_time_str}{ext}"
90
+ new_file_path = os.path.join(dir_name, new_file_name)
91
+
92
+ # 重命名文件
93
+ os.rename(file_path, new_file_path)
94
+ print(f"File renamed to: {new_file_path}")
DiT_VAE/diffusion/utils/lr_scheduler.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
2
+ from torch.optim import Optimizer
3
+ from torch.optim.lr_scheduler import LambdaLR
4
+ import math
5
+
6
+ from DiT_VAE.diffusion.utils.logger import get_root_logger
7
+
8
+
9
+ def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio):
10
+ if not config.get('lr_schedule_args', None):
11
+ config.lr_schedule_args = {}
12
+ if config.get('lr_warmup_steps', None):
13
+ config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version
14
+
15
+ logger = get_root_logger()
16
+ logger.info(
17
+ f'Lr schedule: {config.lr_schedule}, ' + ",".join(
18
+ [f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.')
19
+ if config.lr_schedule == 'cosine':
20
+ lr_scheduler = get_cosine_schedule_with_warmup(
21
+ optimizer=optimizer,
22
+ **config.lr_schedule_args,
23
+ num_training_steps=(len(train_dataloader) * config.num_epochs),
24
+ )
25
+ elif config.lr_schedule == 'constant':
26
+ lr_scheduler = get_constant_schedule_with_warmup(
27
+ optimizer=optimizer,
28
+ **config.lr_schedule_args,
29
+ )
30
+ elif config.lr_schedule == 'cosine_decay_to_constant':
31
+ assert lr_scale_ratio >= 1
32
+ lr_scheduler = get_cosine_decay_to_constant_with_warmup(
33
+ optimizer=optimizer,
34
+ **config.lr_schedule_args,
35
+ final_lr=1 / lr_scale_ratio,
36
+ num_training_steps=(len(train_dataloader) * config.num_epochs),
37
+ )
38
+ else:
39
+ raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.')
40
+ return lr_scheduler
41
+
42
+
43
+ def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer,
44
+ num_warmup_steps: int,
45
+ num_training_steps: int,
46
+ final_lr: float = 0.0,
47
+ num_decay: float = 0.667,
48
+ num_cycles: float = 0.5,
49
+ last_epoch: int = -1
50
+ ):
51
+ """
52
+ Create a schedule with a cosine annealing lr followed by a constant lr.
53
+
54
+ Args:
55
+ optimizer ([`~torch.optim.Optimizer`]):
56
+ The optimizer for which to schedule the learning rate.
57
+ num_warmup_steps (`int`):
58
+ The number of steps for the warmup phase.
59
+ num_training_steps (`int`):
60
+ The number of total training steps.
61
+ final_lr (`int`):
62
+ The final constant lr after cosine decay.
63
+ num_decay (`int`):
64
+ The
65
+ last_epoch (`int`, *optional*, defaults to -1):
66
+ The index of the last epoch when resuming training.
67
+
68
+ Return:
69
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
70
+ """
71
+
72
+ def lr_lambda(current_step):
73
+ if current_step < num_warmup_steps:
74
+ return float(current_step) / float(max(1, num_warmup_steps))
75
+
76
+ num_decay_steps = int(num_training_steps * num_decay)
77
+ if current_step > num_decay_steps:
78
+ return final_lr
79
+
80
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps))
81
+ return (
82
+ max(
83
+ 0.0,
84
+ 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)),
85
+ )
86
+ * (1 - final_lr)
87
+ ) + final_lr
88
+
89
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
DiT_VAE/diffusion/utils/misc.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import datetime
3
+ import os
4
+ import random
5
+ import time
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.distributed as dist
10
+ from mmcv import Config
11
+ from mmcv.runner import get_dist_info
12
+
13
+ from .logger import get_root_logger
14
+
15
+ os.environ["MOX_SILENT_MODE"] = "1" # mute moxing log
16
+
17
+
18
+ def read_config(file):
19
+ # solve config loading conflict when multi-processes
20
+ import time
21
+ while True:
22
+ config = Config.fromfile(file)
23
+ if len(config) == 0:
24
+ time.sleep(0.1)
25
+ continue
26
+ break
27
+ return config
28
+
29
+
30
+ def init_random_seed(seed=None, device='cuda'):
31
+ """Initialize random seed.
32
+
33
+ If the seed is not set, the seed will be automatically randomized,
34
+ and then broadcast to all processes to prevent some potential bugs.
35
+
36
+ Args:
37
+ seed (int, Optional): The seed. Default to None.
38
+ device (str): The device where the seed will be put on.
39
+ Default to 'cuda'.
40
+
41
+ Returns:
42
+ int: Seed to be used.
43
+ """
44
+ if seed is not None:
45
+ return seed
46
+
47
+ # Make sure all ranks share the same random seed to prevent
48
+ # some potential bugs. Please refer to
49
+ # https://github.com/open-mmlab/mmdetection/issues/6339
50
+ rank, world_size = get_dist_info()
51
+ seed = np.random.randint(2 ** 31)
52
+ if world_size == 1:
53
+ return seed
54
+
55
+ if rank == 0:
56
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
57
+ else:
58
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
59
+ dist.broadcast(random_num, src=0)
60
+ return random_num.item()
61
+
62
+
63
+ def set_random_seed(seed, deterministic=False):
64
+ """Set random seed.
65
+
66
+ Args:
67
+ seed (int): Seed to be used.
68
+ deterministic (bool): Whether to set the deterministic option for
69
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
70
+ to True and `torch.backends.cudnn.benchmark` to False.
71
+ Default: False.
72
+ """
73
+ random.seed(seed)
74
+ np.random.seed(seed)
75
+ torch.manual_seed(seed)
76
+ torch.cuda.manual_seed_all(seed)
77
+ if deterministic:
78
+ torch.backends.cudnn.deterministic = True
79
+ torch.backends.cudnn.benchmark = False
80
+
81
+ class SimpleTimer:
82
+ def __init__(self, num_tasks, log_interval=1, desc="Process"):
83
+ self.num_tasks = num_tasks
84
+ self.desc = desc
85
+ self.count = 0
86
+ self.log_interval = log_interval
87
+ self.start_time = time.time()
88
+ self.logger = get_root_logger()
89
+
90
+ def log(self):
91
+ self.count += 1
92
+ if (self.count % self.log_interval) == 0 or self.count == self.num_tasks:
93
+ time_elapsed = time.time() - self.start_time
94
+ avg_time = time_elapsed / self.count
95
+ eta_sec = avg_time * (self.num_tasks - self.count)
96
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
97
+ elapsed_str = str(datetime.timedelta(seconds=int(time_elapsed)))
98
+ log_info = f"{self.desc} [{self.count}/{self.num_tasks}], elapsed_time:{elapsed_str}," \
99
+ f" avg_time: {avg_time}, eta: {eta_str}."
100
+ self.logger.info(log_info)
101
+
102
+
103
+ class DebugUnderflowOverflow:
104
+ """
105
+ This debug class helps detect and understand where the model starts getting very large or very small, and more
106
+ importantly `nan` or `inf` weight and activation elements.
107
+ There are 2 working modes:
108
+ 1. Underflow/overflow detection (default)
109
+ 2. Specific batch absolute min/max tracing without detection
110
+ Mode 1: Underflow/overflow detection
111
+ To activate the underflow/overflow detection, initialize the object with the model :
112
+ ```python
113
+ debug_overflow = DebugUnderflowOverflow(model)
114
+ ```
115
+ then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or
116
+ output elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this
117
+ event, each frame reporting
118
+ 1. the fully qualified module name plus the class name whose `forward` was run
119
+ 2. the absolute min and max value of all elements for each module weights, and the inputs and output
120
+ For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16 mixed precision :
121
+ ```
122
+ Detected inf/nan during batch_number=0
123
+ Last 21 forward frames:
124
+ abs min abs max metadata
125
+ [...]
126
+ encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
127
+ 2.17e-07 4.50e+00 weight
128
+ 1.79e-06 4.65e+00 input[0]
129
+ 2.68e-06 3.70e+01 output
130
+ encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
131
+ 8.08e-07 2.66e+01 weight
132
+ 1.79e-06 4.65e+00 input[0]
133
+ 1.27e-04 2.37e+02 output
134
+ encoder.block.2.layer.1.DenseReluDense.wo Linear
135
+ 1.01e-06 6.44e+00 weight
136
+ 0.00e+00 9.74e+03 input[0]
137
+ 3.18e-04 6.27e+04 output
138
+ encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
139
+ 1.79e-06 4.65e+00 input[0]
140
+ 3.18e-04 6.27e+04 output
141
+ encoder.block.2.layer.1.dropout Dropout
142
+ 3.18e-04 6.27e+04 input[0]
143
+ 0.00e+00 inf output
144
+ ```
145
+ You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value
146
+ was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
147
+ renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
148
+ 64K, and we get an overlow.
149
+ As you can see it's the previous frames that we need to look into when the numbers start going into very large for
150
+ fp16 numbers.
151
+ The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
152
+ By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
153
+ ```python
154
+ debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
155
+ ```
156
+ To validate that you have set up this debugging feature correctly, and you intend to use it in a training that may
157
+ take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in the next
158
+ section.
159
+ Mode 2. Specific batch absolute min/max tracing without detection
160
+ The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
161
+ Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
162
+ given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
163
+ ```python
164
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3])
165
+ ```
166
+ And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
167
+ This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
168
+ fast-forward right to that area.
169
+ Early stopping:
170
+ You can also specify the batch number after which to stop the training, with :
171
+ ```python
172
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3)
173
+ ```
174
+ This feature is mainly useful in the tracing mode, but you can use it for any mode.
175
+ **Performance**:
176
+ As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the
177
+ training down. Therefore remember to turn it off once the debugging needs have been met.
178
+ Args:
179
+ model (`nn.Module`):
180
+ The model to debug.
181
+ max_frames_to_save (`int`, *optional*, defaults to 21):
182
+ How many frames back to record
183
+ trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
184
+ Which batch numbers to trace (turns detection off)
185
+ abort_after_batch_num (`int``, *optional*):
186
+ Whether to abort after a certain batch number has finished
187
+ """
188
+
189
+ def __init__(self, model, max_frames_to_save=21, trace_batch_nums=None, abort_after_batch_num=None):
190
+ if trace_batch_nums is None:
191
+ trace_batch_nums = []
192
+ self.model = model
193
+ self.trace_batch_nums = trace_batch_nums
194
+ self.abort_after_batch_num = abort_after_batch_num
195
+
196
+ # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
197
+ self.frames = collections.deque([], max_frames_to_save)
198
+ self.frame = []
199
+ self.batch_number = 0
200
+ self.total_calls = 0
201
+ self.detected_overflow = False
202
+ self.prefix = " "
203
+
204
+ self.analyse_model()
205
+
206
+ self.register_forward_hook()
207
+
208
+ def save_frame(self, frame=None):
209
+ if frame is not None:
210
+ self.expand_frame(frame)
211
+ self.frames.append("\n".join(self.frame))
212
+ self.frame = [] # start a new frame
213
+
214
+ def expand_frame(self, line):
215
+ self.frame.append(line)
216
+
217
+ def trace_frames(self):
218
+ print("\n".join(self.frames))
219
+ self.frames = []
220
+
221
+ def reset_saved_frames(self):
222
+ self.frames = []
223
+
224
+ def dump_saved_frames(self):
225
+ print(f"\nDetected inf/nan during batch_number={self.batch_number} "
226
+ f"Last {len(self.frames)} forward frames:"
227
+ f"{'abs min':8} {'abs max':8} metadata"
228
+ f"'\n'.join(self.frames)"
229
+ f"\n\n")
230
+ self.frames = []
231
+
232
+ def analyse_model(self):
233
+ # extract the fully qualified module names, to be able to report at run time. e.g.:
234
+ # encoder.block.2.layer.0.SelfAttention.o
235
+ #
236
+ # for shared weights only the first shared module name will be registered
237
+ self.module_names = {m: name for name, m in self.model.named_modules()}
238
+ # self.longest_module_name = max(len(v) for v in self.module_names.values())
239
+
240
+ def analyse_variable(self, var, ctx):
241
+ if torch.is_tensor(var):
242
+ self.expand_frame(self.get_abs_min_max(var, ctx))
243
+ if self.detect_overflow(var, ctx):
244
+ self.detected_overflow = True
245
+ elif var is None:
246
+ self.expand_frame(f"{'None':>17} {ctx}")
247
+ else:
248
+ self.expand_frame(f"{'not a tensor':>17} {ctx}")
249
+
250
+ def batch_start_frame(self):
251
+ self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
252
+ self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
253
+
254
+ def batch_end_frame(self):
255
+ self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number - 1} ***\n\n")
256
+
257
+ def create_frame(self, module, input, output):
258
+ self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
259
+
260
+ # params
261
+ for name, p in module.named_parameters(recurse=False):
262
+ self.analyse_variable(p, name)
263
+
264
+ # inputs
265
+ if isinstance(input, tuple):
266
+ for i, x in enumerate(input):
267
+ self.analyse_variable(x, f"input[{i}]")
268
+ else:
269
+ self.analyse_variable(input, "input")
270
+
271
+ # outputs
272
+ if isinstance(output, tuple):
273
+ for i, x in enumerate(output):
274
+ # possibly a tuple of tuples
275
+ if isinstance(x, tuple):
276
+ for j, y in enumerate(x):
277
+ self.analyse_variable(y, f"output[{i}][{j}]")
278
+ else:
279
+ self.analyse_variable(x, f"output[{i}]")
280
+ else:
281
+ self.analyse_variable(output, "output")
282
+
283
+ self.save_frame()
284
+
285
+ def register_forward_hook(self):
286
+ self.model.apply(self._register_forward_hook)
287
+
288
+ def _register_forward_hook(self, module):
289
+ module.register_forward_hook(self.forward_hook)
290
+
291
+ def forward_hook(self, module, input, output):
292
+ # - input is a tuple of packed inputs (could be non-Tensors)
293
+ # - output could be a Tensor or a tuple of Tensors and non-Tensors
294
+
295
+ last_frame_of_batch = False
296
+
297
+ trace_mode = self.batch_number in self.trace_batch_nums
298
+ if trace_mode:
299
+ self.reset_saved_frames()
300
+
301
+ if self.total_calls == 0:
302
+ self.batch_start_frame()
303
+ self.total_calls += 1
304
+
305
+ # count batch numbers - the very first forward hook of the batch will be called when the
306
+ # batch completes - i.e. it gets called very last - we know this batch has finished
307
+ if module == self.model:
308
+ self.batch_number += 1
309
+ last_frame_of_batch = True
310
+
311
+ self.create_frame(module, input, output)
312
+
313
+ # if last_frame_of_batch:
314
+ # self.batch_end_frame()
315
+
316
+ if trace_mode:
317
+ self.trace_frames()
318
+
319
+ if last_frame_of_batch:
320
+ self.batch_start_frame()
321
+
322
+ if self.detected_overflow and not trace_mode:
323
+ self.dump_saved_frames()
324
+
325
+ # now we can abort, as it's pointless to continue running
326
+ raise ValueError(
327
+ "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
328
+ "Please scroll up above this traceback to see the activation values prior to this event."
329
+ )
330
+
331
+ # abort after certain batch if requested to do so
332
+ if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
333
+ raise ValueError(
334
+ f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to `abort_after_batch_num={self.abort_after_batch_num}` arg"
335
+ )
336
+
337
+ @staticmethod
338
+ def get_abs_min_max(var, ctx):
339
+ abs_var = var.abs()
340
+ return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
341
+
342
+ @staticmethod
343
+ def detect_overflow(var, ctx):
344
+ """
345
+ Report whether the tensor contains any `nan` or `inf` entries.
346
+ This is useful for detecting overflows/underflows and best to call right after the function that did some math that
347
+ modified the tensor in question.
348
+ This function contains a few other helper features that you can enable and tweak directly if you want to track
349
+ various other things.
350
+ Args:
351
+ var: the tensor variable to check
352
+ ctx: the message to print as a context
353
+ Return:
354
+ `True` if `inf` or `nan` was detected, `False` otherwise
355
+ """
356
+ detected = False
357
+ if torch.isnan(var).any().item():
358
+ detected = True
359
+ print(f"{ctx} has nans")
360
+ if torch.isinf(var).any().item():
361
+ detected = True
362
+ print(f"{ctx} has infs")
363
+ if var.dtype == torch.float32 and torch.ge(var.abs(), 65535).any().item():
364
+ detected = True
365
+ print(f"{ctx} has overflow values {var.abs().max().item()}.")
366
+ return detected
DiT_VAE/diffusion/utils/optimizer.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from mmcv import Config
4
+ from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \
5
+ OPTIMIZERS
6
+ from mmcv.utils import _BatchNorm, _InstanceNorm
7
+ from torch.nn import GroupNorm, LayerNorm
8
+
9
+ from .logger import get_root_logger
10
+
11
+ from typing import Tuple, Optional, Callable
12
+
13
+ import torch
14
+ from torch.optim.optimizer import Optimizer
15
+
16
+
17
+ def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256):
18
+ assert rule in ['linear', 'sqrt']
19
+ logger = get_root_logger()
20
+ # scale by world size
21
+ if rule == 'sqrt':
22
+ scale_ratio = math.sqrt(effective_bs / base_batch_size)
23
+ elif rule == 'linear':
24
+ scale_ratio = effective_bs / base_batch_size
25
+ optimizer_cfg['lr'] *= scale_ratio
26
+ logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.7f} (using {rule} scaling rule).')
27
+ return scale_ratio
28
+
29
+
30
+ @OPTIMIZER_BUILDERS.register_module()
31
+ class MyOptimizerConstructor(DefaultOptimizerConstructor):
32
+
33
+ def add_params(self, params, module, prefix='', is_dcn_module=None):
34
+ """Add all parameters of module to the params list.
35
+
36
+ The parameters of the given module will be added to the list of param
37
+ groups, with specific rules defined by paramwise_cfg.
38
+
39
+ Args:
40
+ params (list[dict]): A list of param groups, it will be modified
41
+ in place.
42
+ module (nn.Module): The module to be added.
43
+ prefix (str): The prefix of the module
44
+
45
+ """
46
+ # get param-wise options
47
+ custom_keys = self.paramwise_cfg.get('custom_keys', {})
48
+ # first sort with alphabet order and then sort with reversed len of str
49
+ # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
50
+
51
+ bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
52
+ bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
53
+ norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
54
+ bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
55
+
56
+ # special rules for norm layers and depth-wise conv layers
57
+ is_norm = isinstance(module,
58
+ (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
59
+
60
+ for name, param in module.named_parameters(recurse=False):
61
+ base_lr = self.base_lr
62
+ if name == 'bias' and not is_norm and not is_dcn_module:
63
+ base_lr *= bias_lr_mult
64
+
65
+ # apply weight decay policies
66
+ base_wd = self.base_wd
67
+ # norm decay
68
+ if is_norm:
69
+ if self.base_wd is not None:
70
+ base_wd *= norm_decay_mult
71
+ elif name == 'bias' and not is_dcn_module:
72
+ if self.base_wd is not None:
73
+ # TODO: current bias_decay_mult will have affect on DCN
74
+ base_wd *= bias_decay_mult
75
+
76
+ param_group = {'params': [param]}
77
+ if not param.requires_grad:
78
+ param_group['requires_grad'] = False
79
+ params.append(param_group)
80
+ continue
81
+ if bypass_duplicate and self._is_in(param_group, params):
82
+ logger = get_root_logger()
83
+ logger.warn(f'{prefix} is duplicate. It is skipped since '
84
+ f'bypass_duplicate={bypass_duplicate}')
85
+ continue
86
+ # if the parameter match one of the custom keys, ignore other rules
87
+ is_custom = False
88
+ for key in custom_keys:
89
+ scope, key_name = key if isinstance(key, tuple) else (None, key)
90
+ if scope is not None and scope not in f'{prefix}':
91
+ continue
92
+ if key_name in f'{prefix}.{name}':
93
+ is_custom = True
94
+ if 'lr_mult' in custom_keys[key]:
95
+ # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}':
96
+ # param_group['lr'] = self.base_lr
97
+ # else:
98
+ param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult']
99
+ elif 'lr' not in param_group:
100
+ param_group['lr'] = base_lr
101
+ if self.base_wd is not None:
102
+ if 'decay_mult' in custom_keys[key]:
103
+ param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult']
104
+ elif 'weight_decay' not in param_group:
105
+ param_group['weight_decay'] = base_wd
106
+
107
+ if not is_custom:
108
+ # bias_lr_mult affects all bias parameters
109
+ # except for norm.bias dcn.conv_offset.bias
110
+ if base_lr != self.base_lr:
111
+ param_group['lr'] = base_lr
112
+ if base_wd != self.base_wd:
113
+ param_group['weight_decay'] = base_wd
114
+ params.append(param_group)
115
+
116
+ for child_name, child_mod in module.named_children():
117
+ child_prefix = f'{prefix}.{child_name}' if prefix else child_name
118
+ self.add_params(
119
+ params,
120
+ child_mod,
121
+ prefix=child_prefix,
122
+ is_dcn_module=is_dcn_module)
123
+
124
+
125
+ def build_optimizer(model, optimizer_cfg):
126
+ # default parameter-wise config
127
+ logger = get_root_logger()
128
+
129
+ if hasattr(model, 'module'):
130
+ model = model.module
131
+ # set optimizer constructor
132
+ optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor')
133
+ # parameter-wise setting: cancel weight decay for some specific modules
134
+ custom_keys = dict()
135
+ for name, module in model.named_modules():
136
+ if hasattr(module, 'zero_weight_decay'):
137
+ custom_keys |= {
138
+ (name, key): dict(decay_mult=0)
139
+ for key in module.zero_weight_decay
140
+ }
141
+
142
+ paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys)))
143
+ if given_cfg := optimizer_cfg.get('paramwise_cfg'):
144
+ paramwise_cfg.merge_from_dict(dict(cfg=given_cfg))
145
+ optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg
146
+ # build optimizer
147
+ optimizer = mm_build_optimizer(model, optimizer_cfg)
148
+
149
+ weight_decay_groups = dict()
150
+ lr_groups = dict()
151
+ for group in optimizer.param_groups:
152
+ if not group.get('requires_grad', True): continue
153
+ lr_groups.setdefault(group['lr'], []).append(group)
154
+ weight_decay_groups.setdefault(group['weight_decay'], []).append(group)
155
+
156
+ learnable_count, fix_count = 0, 0
157
+ for p in model.parameters():
158
+ if p.requires_grad:
159
+ learnable_count += 1
160
+ else:
161
+ fix_count += 1
162
+ fix_info = f"{learnable_count} are learnable, {fix_count} are fix"
163
+ lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()])
164
+ wd_info = "Weight decay group: " + ", ".join(
165
+ [f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()])
166
+ opt_info = f"Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}."
167
+ logger.info(opt_info)
168
+
169
+ return optimizer
170
+
171
+
172
+ @OPTIMIZERS.register_module()
173
+ class Lion(Optimizer):
174
+ def __init__(
175
+ self,
176
+ params,
177
+ lr: float = 1e-4,
178
+ betas: Tuple[float, float] = (0.9, 0.99),
179
+ weight_decay: float = 0.0,
180
+ ):
181
+ assert lr > 0.
182
+ assert all(0. <= beta <= 1. for beta in betas)
183
+
184
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
185
+
186
+ super().__init__(params, defaults)
187
+
188
+ @staticmethod
189
+ def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
190
+ # stepweight decay
191
+ p.data.mul_(1 - lr * wd)
192
+
193
+ # weight update
194
+ update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_()
195
+ p.add_(update, alpha=-lr)
196
+
197
+ # decay the momentum running average coefficient
198
+ exp_avg.lerp_(grad, 1 - beta2)
199
+
200
+ @staticmethod
201
+ def exists(val):
202
+ return val is not None
203
+
204
+ @torch.no_grad()
205
+ def step(
206
+ self,
207
+ closure: Optional[Callable] = None
208
+ ):
209
+
210
+ loss = None
211
+ if self.exists(closure):
212
+ with torch.enable_grad():
213
+ loss = closure()
214
+
215
+ for group in self.param_groups:
216
+ for p in filter(lambda p: self.exists(p.grad), group['params']):
217
+
218
+ grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \
219
+ self.state[p]
220
+
221
+ # init state - exponential moving average of gradient values
222
+ if len(state) == 0:
223
+ state['exp_avg'] = torch.zeros_like(p)
224
+
225
+ exp_avg = state['exp_avg']
226
+
227
+ self.update_fn(
228
+ p,
229
+ grad,
230
+ exp_avg,
231
+ lr,
232
+ wd,
233
+ beta1,
234
+ beta2
235
+ )
236
+
237
+ return loss
DiT_VAE/train_diffusion.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # TODO: Implement model training and evaluation.
2
+ # This script will load data, train a deep learning model, and evaluate its performance.
3
+ # Future improvements may include hyperparameter tuning and multi-GPU training.
4
+
5
+
DiT_VAE/train_vae.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+
5
+ import sys
6
+
7
+ current_path = os.path.abspath(__file__)
8
+ father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".")
9
+ sys.path.append((os.path.join(father_path, 'Next3d')))
10
+
11
+ from typing import Dict, Optional, Tuple
12
+ from omegaconf import OmegaConf
13
+ import torch
14
+ import logging
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint
17
+ from torch.utils.data import Dataset
18
+ import inspect
19
+ from accelerate import Accelerator
20
+ from accelerate.logging import get_logger
21
+ from accelerate.utils import set_seed
22
+ import dnnlib
23
+ from diffusers.optimization import get_scheduler
24
+ from tqdm.auto import tqdm
25
+ from vae.triplane_vae import AutoencoderKL, AutoencoderKLRollOut
26
+ from vae.data.dataset_online_vae import TriplaneDataset
27
+ from einops import rearrange
28
+ from vae.utils.common_utils import instantiate_from_config
29
+ from Next3d.training_avatar_texture.triplane_generation import TriPlaneGenerator
30
+ import Next3d.legacy as legacy
31
+
32
+ from torch_utils import misc
33
+ import datetime
34
+
35
+ logger = get_logger(__name__, log_level="INFO")
36
+
37
+
38
+ def collate_fn(data):
39
+ model_names = [example["data_model_name"] for example in data]
40
+ zs = torch.cat([example["data_z"] for example in data], dim=0)
41
+ verts = torch.cat([example["data_vert"] for example in data], dim=0)
42
+
43
+ return {
44
+ 'model_names': model_names,
45
+ 'zs': zs,
46
+ 'verts': verts
47
+ }
48
+
49
+
50
+ def rollout_fn(triplane):
51
+ triplane = rearrange(triplane, "b c f h w -> b f c h w")
52
+ b, f, c, h, w = triplane.shape
53
+ triplane = triplane.permute(0, 2, 3, 1, 4).reshape(-1, c, h, f * w)
54
+ return triplane
55
+
56
+
57
+ def unrollout_fn(triplane):
58
+ res = triplane.shape[-2]
59
+ ch = triplane.shape[1]
60
+ triplane = triplane.reshape(-1, ch // 3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, 3, ch, res, res)
61
+ triplane = rearrange(triplane, "b f c h w -> b c f h w")
62
+ return triplane
63
+
64
+
65
+ def triplane_generate(G_model, z, conditioning_params, std, mean, truncation_psi=0.7, truncation_cutoff=14):
66
+ w = G_model.mapping(z, conditioning_params, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
67
+ triplane = G_model.synthesis(w, noise_mode='const')
68
+ triplane = (triplane - mean) / std
69
+ return triplane
70
+
71
+
72
+ def gan_model(gan_models, device, gan_model_base_dir):
73
+ gan_model_dict = gan_models
74
+ gan_model_load = {}
75
+ for model_name in gan_model_dict.keys():
76
+ model_pkl = os.path.join(gan_model_base_dir, model_name + '.pkl')
77
+ with dnnlib.util.open_url(model_pkl) as f:
78
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
79
+ G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device)
80
+ misc.copy_params_and_buffers(G, G_new, require_all=True)
81
+ G_new.neural_rendering_resolution = G.neural_rendering_resolution
82
+ G_new.rendering_kwargs = G.rendering_kwargs
83
+ gan_model_load[model_name] = G_new
84
+ return gan_model_load
85
+
86
+
87
+ def main(vae_config: str,
88
+ gan_model_config: str,
89
+ output_dir: str,
90
+ std_dir: str,
91
+ mean_dir: str,
92
+ conditioning_params_dir: str,
93
+ gan_model_base_dir: str,
94
+ train_data: Dict,
95
+ train_batch_size: int = 2,
96
+ max_train_steps: int = 500,
97
+ learning_rate: float = 3e-5,
98
+ scale_lr: bool = False,
99
+ lr_scheduler: str = "constant",
100
+ lr_warmup_steps: int = 0,
101
+ adam_beta1: float = 0.5,
102
+ adam_beta2: float = 0.9,
103
+ adam_weight_decay: float = 1e-2,
104
+ adam_epsilon: float = 1e-08,
105
+ max_grad_norm: float = 1.0,
106
+ gradient_accumulation_steps: int = 1,
107
+ gradient_checkpointing: bool = True,
108
+ checkpointing_steps: int = 500,
109
+ pretrained_model_path_zero123: str = None,
110
+ resume_from_checkpoint: Optional[str] = None,
111
+ mixed_precision: Optional[str] = "fp16",
112
+ use_8bit_adam: bool = False,
113
+ rollout: bool = False,
114
+ enable_xformers_memory_efficient_attention: bool = True,
115
+ seed: Optional[int] = None, ):
116
+ *_, config = inspect.getargvalues(inspect.currentframe())
117
+ base_dir = output_dir
118
+
119
+ accelerator = Accelerator(
120
+ gradient_accumulation_steps=gradient_accumulation_steps,
121
+ mixed_precision=mixed_precision,
122
+ )
123
+ logging.basicConfig(
124
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
125
+ datefmt="%m/%d/%Y %H:%M:%S",
126
+ level=logging.INFO,
127
+ )
128
+ logger.info(accelerator.state, main_process_only=False)
129
+ # If passed along, set the training seed now.
130
+ if seed is not None:
131
+ set_seed(seed)
132
+ if accelerator.is_main_process:
133
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
134
+ output_dir = os.path.join(output_dir, now)
135
+ os.makedirs(output_dir, exist_ok=True)
136
+ os.makedirs(f"{output_dir}/samples", exist_ok=True)
137
+ os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
138
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
139
+
140
+ config_vae = OmegaConf.load(vae_config)
141
+
142
+ if rollout:
143
+ vae = AutoencoderKLRollOut(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8)
144
+
145
+ else:
146
+ vae = AutoencoderKL(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8)
147
+ print(f"VAE total params = {len(list(vae.named_parameters()))} ")
148
+ if 'perceptual_weight' in config_vae['lossconfig']['params'].keys():
149
+ config_vae['lossconfig']['params']['device'] = str(accelerator.device)
150
+ loss_fn = instantiate_from_config(config_vae['lossconfig'])
151
+ conditioning_params = torch.load(conditioning_params_dir).to(str(accelerator.device))
152
+ data_std = torch.load(std_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1)
153
+
154
+ data_mean = torch.load(mean_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1)
155
+
156
+ # define the gan model
157
+ print("########## gan model load ##########")
158
+ config_gan_model = OmegaConf.load(gan_model_config)
159
+ gan_model_all = gan_model(config_gan_model['gan_models'], str(accelerator.device), gan_model_base_dir)
160
+ print("########## gan model loaded ##########")
161
+ if scale_lr:
162
+ learning_rate = (
163
+ learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
164
+ )
165
+
166
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
167
+ if use_8bit_adam:
168
+ try:
169
+ import bitsandbytes as bnb
170
+ except ImportError:
171
+ raise ImportError(
172
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
173
+ )
174
+
175
+ optimizer_cls = bnb.optim.AdamW8bit
176
+ else:
177
+ optimizer_cls = torch.optim.AdamW
178
+
179
+ optimizer = optimizer_cls(
180
+ vae.parameters(),
181
+ lr=learning_rate,
182
+ betas=(adam_beta1, adam_beta2),
183
+ weight_decay=adam_weight_decay,
184
+ eps=adam_epsilon,
185
+ )
186
+
187
+ train_dataset = TriplaneDataset(**train_data)
188
+
189
+ # Preprocessing the dataset
190
+
191
+ # DataLoaders creation:
192
+ train_dataloader = torch.utils.data.DataLoader(
193
+ train_dataset, batch_size=train_batch_size, collate_fn=collate_fn, shuffle=True, num_workers=2
194
+ )
195
+
196
+ lr_scheduler = get_scheduler(
197
+ lr_scheduler,
198
+ optimizer=optimizer,
199
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
200
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
201
+ )
202
+
203
+ vae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
204
+ vae, optimizer, train_dataloader, lr_scheduler
205
+ )
206
+
207
+ weight_dtype = torch.float32
208
+
209
+ # Move text_encode and vae to gpu and cast to weight_dtype
210
+
211
+ if accelerator.mixed_precision == "fp16":
212
+ weight_dtype = torch.float16
213
+ elif accelerator.mixed_precision == "bf16":
214
+ weight_dtype = torch.bfloat16
215
+
216
+ vae.to(accelerator.device, dtype=weight_dtype)
217
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
218
+ # Afterwards we recalculate our number of training epochs
219
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
220
+
221
+ # We need to initialize the trackers we use, and also store our configuration.
222
+ # The trackers initializes automatically on the main process.
223
+ if accelerator.is_main_process:
224
+ accelerator.init_trackers("trainvae", config=vars(args))
225
+
226
+ # Train!
227
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
228
+
229
+ logger.info("***** Running training *****")
230
+ logger.info(f" Num examples = {len(train_dataset)}")
231
+ logger.info(f" Num Epochs = {num_train_epochs}")
232
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
233
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
234
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
235
+ logger.info(f" Total optimization steps = {max_train_steps}")
236
+ global_step = 0
237
+ first_epoch = 0
238
+
239
+ # Potentially load in the weights and states from a previous save
240
+ if resume_from_checkpoint:
241
+ if resume_from_checkpoint != "latest":
242
+ path = os.path.basename(resume_from_checkpoint)
243
+ else:
244
+ # Get the most recent checkpoint
245
+ dirs = os.listdir(output_dir)
246
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
247
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
248
+ path = dirs[-1]
249
+ accelerator.print(f"Resuming from checkpoint {path}")
250
+ if resume_from_checkpoint != "latest":
251
+ accelerator.load_state(resume_from_checkpoint)
252
+ else:
253
+ accelerator.load_state(os.path.join(output_dir, path))
254
+
255
+ global_step = int(path.split("-")[1])
256
+
257
+ first_epoch = global_step // num_update_steps_per_epoch
258
+ resume_step = global_step % num_update_steps_per_epoch
259
+ else:
260
+ all_final_training_dirs = []
261
+ dirs = os.listdir(base_dir)
262
+ if len(dirs) != 0:
263
+ dirs = [d for d in dirs if d.startswith("2024")] # specific years
264
+ if len(dirs) != 0:
265
+ base_resume_paths = [os.path.join(base_dir, d) for d in dirs]
266
+ for base_resume_path in base_resume_paths:
267
+ checkpoint_file_names = os.listdir(base_resume_path)
268
+ checkpoint_file_names = [d for d in checkpoint_file_names if d.startswith("checkpoint")]
269
+ if len(checkpoint_file_names) != 0:
270
+ for checkpoint_file_name in checkpoint_file_names:
271
+ final_training_dir = os.path.join(base_resume_path, checkpoint_file_name)
272
+ all_final_training_dirs.append(final_training_dir)
273
+ if len(all_final_training_dirs) != 0:
274
+ sorted_all_final_training_dirs = sorted(all_final_training_dirs, key=lambda x: int(x.split("-")[1]))
275
+ latest_dir = sorted_all_final_training_dirs[-1]
276
+ path = os.path.basename( latest_dir)
277
+ accelerator.print(f"Resuming from checkpoint {path}")
278
+ accelerator.load_state(latest_dir)
279
+ global_step = int(path.split("-")[1])
280
+
281
+ first_epoch = global_step // num_update_steps_per_epoch
282
+ resume_step = global_step % num_update_steps_per_epoch
283
+ else:
284
+ accelerator.print(f"Training from start")
285
+ else:
286
+ accelerator.print(f"Training from start")
287
+ else:
288
+ accelerator.print(f"Training from start")
289
+
290
+ # Only show the progress bar once on each machine.
291
+ progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
292
+ progress_bar.set_description("Steps")
293
+
294
+ for epoch in range(first_epoch, num_train_epochs):
295
+ vae.train()
296
+ train_loss = 0.0
297
+ for step, batch in enumerate(train_dataloader):
298
+ # if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
299
+ # print(epoch)
300
+ # print(first_epoch)
301
+ # print(step)
302
+ # if step % gradient_accumulation_steps == 0:
303
+ # progress_bar.update(1)
304
+ # continue
305
+ with accelerator.accumulate(vae):
306
+ # Convert images to latent space
307
+ z_values = batch["zs"].to(weight_dtype)
308
+ model_names = batch["model_names"]
309
+
310
+ triplane_values = []
311
+ with torch.no_grad():
312
+ for z_id in range(z_values.shape[0]):
313
+ z_value = z_values[z_id].unsqueeze(0)
314
+ model_name = model_names[z_id]
315
+ triplane_value = triplane_generate(gan_model_all[model_name], z_value,
316
+ conditioning_params, data_std, data_mean)
317
+ triplane_values.append(triplane_value)
318
+ triplane_values = torch.cat(triplane_values, dim=0)
319
+ vert_values = batch["verts"].to(weight_dtype)
320
+ triplane_values = rearrange(triplane_values, "b f c h w -> b c f h w")
321
+ if rollout:
322
+ triplane_values_roll = rollout_fn(triplane_values.clone())
323
+ reconstructions, posterior = vae(triplane_values_roll)
324
+ reconstructions_unroll = unrollout_fn(reconstructions)
325
+ loss, log_dict_ae = loss_fn(triplane_values, reconstructions_unroll, posterior, vert_values,
326
+ split="train")
327
+ else:
328
+ reconstructions, posterior = vae(triplane_values)
329
+ loss, log_dict_ae = loss_fn(triplane_values, reconstructions, posterior, vert_values,
330
+ split="train")
331
+ accelerator.backward(loss)
332
+ if accelerator.sync_gradients:
333
+ accelerator.clip_grad_norm_(vae.parameters(), max_grad_norm)
334
+ optimizer.step()
335
+ lr_scheduler.step()
336
+ optimizer.zero_grad()
337
+
338
+ # Checks if the accelerator has performed an optimization step behind the scenes
339
+ if accelerator.sync_gradients:
340
+ progress_bar.update(1)
341
+ global_step += 1
342
+ accelerator.log({"train_loss": train_loss}, step=global_step)
343
+ train_loss = 0.0
344
+
345
+ if global_step % checkpointing_steps == 0:
346
+ if accelerator.is_main_process:
347
+ save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
348
+ accelerator.save_state(save_path)
349
+ logger.info(f"Saved state to {save_path}")
350
+
351
+ # logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
352
+
353
+ logs = log_dict_ae
354
+ progress_bar.set_postfix(**logs)
355
+ accelerator.log(logs, step=global_step)
356
+
357
+ if global_step >= max_train_steps:
358
+ break
359
+
360
+ accelerator.wait_for_everyone()
361
+
362
+ accelerator.end_training()
363
+
364
+
365
+ if __name__ == "__main__":
366
+ parser = argparse.ArgumentParser()
367
+ parser.add_argument("--config", type=str, default="./configs/triplane_vae.yaml")
368
+ args = parser.parse_args()
369
+ main(**OmegaConf.load(args.config))
DiT_VAE/util.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torchvision
8
+
9
+ from tqdm import tqdm
10
+ from einops import rearrange
11
+
12
+
13
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
14
+ videos = rearrange(videos, "b c t h w -> t b c h w")
15
+ outputs = []
16
+ for x in videos:
17
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
18
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
19
+ if rescale:
20
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
21
+ x = (x * 255).numpy().astype(np.uint8)
22
+ outputs.append(x)
23
+
24
+ os.makedirs(os.path.dirname(path), exist_ok=True)
25
+ imageio.mimsave(path, outputs, fps=fps)
26
+
27
+
28
+ # DDIM Inversion
29
+ @torch.no_grad()
30
+ def init_prompt(prompt, pipeline):
31
+ uncond_input = pipeline.tokenizer(
32
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
33
+ return_tensors="pt"
34
+ )
35
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
36
+ text_input = pipeline.tokenizer(
37
+ [prompt],
38
+ padding="max_length",
39
+ max_length=pipeline.tokenizer.model_max_length,
40
+ truncation=True,
41
+ return_tensors="pt",
42
+ )
43
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
44
+ context = torch.cat([uncond_embeddings, text_embeddings])
45
+
46
+ return context
47
+
48
+
49
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
50
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
51
+ timestep, next_timestep = min(
52
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
53
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
54
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
55
+ beta_prod_t = 1 - alpha_prod_t
56
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
57
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
58
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
59
+ return next_sample
60
+
61
+
62
+ def get_noise_pred_single(latents, t, context, unet):
63
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
64
+ return noise_pred
65
+
66
+
67
+ @torch.no_grad()
68
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
69
+ context = init_prompt(prompt, pipeline)
70
+ uncond_embeddings, cond_embeddings = context.chunk(2)
71
+ all_latent = [latent]
72
+ latent = latent.clone().detach()
73
+ for i in tqdm(range(num_inv_steps)):
74
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
75
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
76
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
77
+ all_latent.append(latent)
78
+ return all_latent
79
+
80
+
81
+ @torch.no_grad()
82
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
83
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
84
+ return ddim_latents
85
+
86
+ def rendering():
87
+ pass
88
+
89
+
90
+
91
+ def force_zero_snr(betas):
92
+ alphas = 1 - betas
93
+ alphas_bar = torch.cumprod(alphas, dim=0)
94
+ alphas_bar_sqrt = alphas_bar ** (1/2)
95
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
96
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - 1e-6
97
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
98
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
99
+ alphas_bar = alphas_bar_sqrt ** 2
100
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
101
+ alphas = torch.cat([alphas_bar[0:1], alphas], 0)
102
+ betas = 1 - alphas
103
+ return betas
104
+
105
+ def make_beta_schedule(schedule="scaled_linear", n_timestep=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3, shift_scale=None):
106
+ if schedule == "scaled_linear":
107
+ betas = (
108
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
109
+ )
110
+ elif schedule == 'linear':
111
+ betas = (
112
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
113
+ )
114
+ elif schedule == "cosine":
115
+ timesteps = (
116
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
117
+ )
118
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
119
+ alphas = torch.cos(alphas).pow(2)
120
+ alphas = alphas / alphas[0]
121
+ betas = 1 - alphas[1:] / alphas[:-1]
122
+ betas = np.clip(betas, a_min=0, a_max=0.999)
123
+
124
+ elif schedule == "sqrt_linear":
125
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
126
+ elif schedule == "sqrt":
127
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
128
+ elif schedule == 'linear_force_zero_snr':
129
+ betas = (
130
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
131
+ )
132
+ betas = force_zero_snr(betas)
133
+ elif schedule == 'linear_100':
134
+ betas = (
135
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
136
+ )
137
+ betas = betas[:100]
138
+ else:
139
+ raise ValueError(f"schedule '{schedule}' unknown.")
140
+
141
+ if shift_scale is not None:
142
+ print("shift_scale")
143
+ betas = shift_schedule(betas, shift_scale)
144
+
145
+ return betas.numpy()
146
+
147
+ def shift_schedule(base_betas, shift_scale):
148
+ alphas = 1 - base_betas
149
+ alphas_bar = torch.cumprod(alphas, dim=0)
150
+ snr = alphas_bar / (1 - alphas_bar) # snr(1-ab)=ab; snr-snr*ab=ab; snr=(1+snr)ab; ab=snr/(1+snr)
151
+ shifted_snr = snr * ((1 / shift_scale) ** 2)
152
+ shifted_alphas_bar = shifted_snr / (1 + shifted_snr)
153
+ shifted_alphas = shifted_alphas_bar[1:] / shifted_alphas_bar[:-1]
154
+ shifted_alphas = torch.cat([shifted_alphas_bar[0:1], shifted_alphas], 0)
155
+ shifted_betas = 1 - shifted_alphas
156
+ return shifted_betas
157
+
158
+
159
+ def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
160
+ n_dims = len(x.shape)
161
+ if src_dim < 0:
162
+ src_dim = n_dims + src_dim
163
+ if dest_dim < 0:
164
+ dest_dim = n_dims + dest_dim
165
+
166
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
167
+
168
+ dims = list(range(n_dims))
169
+ del dims[src_dim]
170
+
171
+ permutation = []
172
+ ctr = 0
173
+ for i in range(n_dims):
174
+ if i == dest_dim:
175
+ permutation.append(src_dim)
176
+ else:
177
+ permutation.append(dims[ctr])
178
+ ctr += 1
179
+ x = x.permute(permutation)
180
+ if make_contiguous:
181
+ x = x.contiguous()
182
+ return x
183
+
184
+
185
+ # reshapes tensor start from dim i (inclusive)
186
+ # to dim j (exclusive) to the desired shape
187
+ # e.g. if x.shape = (b, thw, c) then
188
+ # view_range(x, 1, 2, (t, h, w)) returns
189
+ # x of shape (b, t, h, w, c)
190
+ def view_range(x, i, j, shape):
191
+ shape = tuple(shape)
192
+
193
+ n_dims = len(x.shape)
194
+ if i < 0:
195
+ i = n_dims + i
196
+
197
+ if j is None:
198
+ j = n_dims
199
+ elif j < 0:
200
+ j = n_dims + j
201
+
202
+ assert 0 <= i < j <= n_dims
203
+
204
+ x_shape = x.shape
205
+ target_shape = x_shape[:i] + shape + x_shape[j:]
206
+ return x.view(target_shape)
207
+
208
+
209
+ def tensor_slice(x, begin, size):
210
+ assert all([b >= 0 for b in begin])
211
+ size = [l - b if s == -1 else s
212
+ for s, b, l in zip(size, begin, x.shape)]
213
+ assert all([s >= 0 for s in size])
214
+
215
+ slices = [slice(b, b + s) for b, s in zip(begin, size)]
216
+ return x[slices]
217
+
DiT_VAE/vae/__init__.py ADDED
File without changes
DiT_VAE/vae/aemodules3d.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TATS
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import math
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from .attention_vae import MultiHeadAttention
11
+
12
+
13
+ def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
14
+ n_dims = len(x.shape)
15
+ if src_dim < 0:
16
+ src_dim = n_dims + src_dim
17
+ if dest_dim < 0:
18
+ dest_dim = n_dims + dest_dim
19
+
20
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
21
+
22
+ dims = list(range(n_dims))
23
+ del dims[src_dim]
24
+
25
+ permutation = []
26
+ ctr = 0
27
+ for i in range(n_dims):
28
+ if i == dest_dim:
29
+ permutation.append(src_dim)
30
+ else:
31
+ permutation.append(dims[ctr])
32
+ ctr += 1
33
+ x = x.permute(permutation)
34
+ if make_contiguous:
35
+ x = x.contiguous()
36
+ return x
37
+
38
+ def silu(x):
39
+ return x * torch.sigmoid(x)
40
+
41
+
42
+ class SiLU(nn.Module):
43
+ def __init__(self):
44
+ super(SiLU, self).__init__()
45
+
46
+ def forward(self, x):
47
+ return silu(x)
48
+
49
+
50
+ def hinge_d_loss(logits_real, logits_fake):
51
+ loss_real = torch.mean(F.relu(1. - logits_real))
52
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
53
+ d_loss = 0.5 * (loss_real + loss_fake)
54
+ return d_loss
55
+
56
+
57
+ def vanilla_d_loss(logits_real, logits_fake):
58
+ d_loss = 0.5 * (
59
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
60
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
61
+ return d_loss
62
+
63
+
64
+ def Normalize(in_channels, norm_type='group'):
65
+ assert norm_type in ['group', 'batch']
66
+ if norm_type == 'group':
67
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
68
+ elif norm_type == 'batch':
69
+ return torch.nn.SyncBatchNorm(in_channels)
70
+
71
+
72
+ class ResBlock(nn.Module):
73
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group',
74
+ padding_type='replicate'):
75
+ super().__init__()
76
+ self.in_channels = in_channels
77
+ out_channels = in_channels if out_channels is None else out_channels
78
+ self.out_channels = out_channels
79
+ self.use_conv_shortcut = conv_shortcut
80
+
81
+ self.norm1 = Normalize(in_channels, norm_type)
82
+ self.conv1 = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type)
83
+ self.dropout = torch.nn.Dropout(dropout)
84
+ self.norm2 = Normalize(in_channels, norm_type)
85
+ self.conv2 = SamePadConv3d(out_channels, out_channels, kernel_size=3, padding_type=padding_type)
86
+ if self.in_channels != self.out_channels:
87
+ self.conv_shortcut = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type)
88
+
89
+ def forward(self, x):
90
+ h = x
91
+ h = self.norm1(h)
92
+ h = silu(h)
93
+ h = self.conv1(h)
94
+ h = self.norm2(h)
95
+ h = silu(h)
96
+ h = self.conv2(h)
97
+
98
+ if self.in_channels != self.out_channels:
99
+ x = self.conv_shortcut(x)
100
+
101
+ return x + h
102
+
103
+
104
+ # Does not support dilation
105
+ class SamePadConv3d(nn.Module):
106
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
107
+ super().__init__()
108
+ if isinstance(kernel_size, int):
109
+ kernel_size = (kernel_size,) * 3
110
+ if isinstance(stride, int):
111
+ stride = (stride,) * 3
112
+
113
+ # assumes that the input shape is divisible by stride
114
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
115
+ pad_input = []
116
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
117
+ pad_input.append((p // 2 + p % 2, p // 2))
118
+ pad_input = sum(pad_input, tuple())
119
+
120
+ self.pad_input = pad_input
121
+ self.padding_type = padding_type
122
+
123
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
124
+ stride=stride, padding=0, bias=bias)
125
+ self.weight = self.conv.weight
126
+
127
+ def forward(self, x):
128
+ return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
129
+
130
+
131
+ class SamePadConvTranspose3d(nn.Module):
132
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
133
+ super().__init__()
134
+ if isinstance(kernel_size, int):
135
+ kernel_size = (kernel_size,) * 3
136
+ if isinstance(stride, int):
137
+ stride = (stride,) * 3
138
+
139
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
140
+ pad_input = []
141
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
142
+ pad_input.append((p // 2 + p % 2, p // 2))
143
+ pad_input = sum(pad_input, tuple())
144
+ self.pad_input = pad_input
145
+ self.padding_type = padding_type
146
+ self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
147
+ stride=stride, bias=bias,
148
+ padding=tuple([k - 1 for k in kernel_size]))
149
+
150
+ def forward(self, x):
151
+ return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
152
+
153
+
154
+ class AxialBlock(nn.Module):
155
+ def __init__(self, n_hiddens, n_head):
156
+ super().__init__()
157
+ kwargs = dict(shape=(0,) * 3, dim_q=n_hiddens,
158
+ dim_kv=n_hiddens, n_head=n_head,
159
+ n_layer=1, causal=False, attn_type='axial')
160
+ self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2),
161
+ **kwargs)
162
+ self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3),
163
+ **kwargs)
164
+ self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4),
165
+ **kwargs)
166
+
167
+ def forward(self, x):
168
+ x = shift_dim(x, 1, -1)
169
+ x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x)
170
+ x = shift_dim(x, -1, 1)
171
+ return x
172
+ class AttentionResidualBlock(nn.Module):
173
+ def __init__(self, n_hiddens):
174
+ super().__init__()
175
+ self.block = nn.Sequential(
176
+ Normalize(n_hiddens),
177
+ SiLU(),
178
+ SamePadConv3d(n_hiddens, n_hiddens // 2, 3, bias=False),
179
+ Normalize(n_hiddens // 2),
180
+ SiLU(),
181
+ SamePadConv3d(n_hiddens // 2, n_hiddens, 1, bias=False),
182
+ Normalize(n_hiddens),
183
+ SiLU(),
184
+ AxialBlock(n_hiddens, 2)
185
+ )
186
+
187
+ def forward(self, x):
188
+ return x + self.block(x)
189
+
190
+ class Encoder(nn.Module):
191
+ def __init__(self, n_hiddens, downsample, z_channels, double_z, image_channel=3, norm_type='group',
192
+ padding_type='replicate', res_num=1):
193
+ super().__init__()
194
+ n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
195
+ self.conv_blocks = nn.ModuleList()
196
+ max_ds = n_times_downsample.max()
197
+ self.conv_first = SamePadConv3d(image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
198
+
199
+ for i in range(max_ds):
200
+ block = nn.Module()
201
+ in_channels = n_hiddens * 2 ** i
202
+ out_channels = n_hiddens * 2 ** (i + 1)
203
+ stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
204
+ stride = list(stride)
205
+ stride[0] = 1
206
+ stride = tuple(stride)
207
+ block.down = SamePadConv3d(in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
208
+
209
+ block.res = ResBlock(out_channels, out_channels, norm_type=norm_type)
210
+ self.conv_blocks.append(block)
211
+ n_times_downsample -= 1
212
+
213
+ self.final_block = nn.Sequential(
214
+ Normalize(out_channels, norm_type),
215
+ SiLU(),
216
+ SamePadConv3d(out_channels, 2 * z_channels if double_z else z_channels,
217
+ kernel_size=3,
218
+ stride=1,
219
+ padding_type=padding_type)
220
+ )
221
+ self.out_channels = out_channels
222
+
223
+
224
+ def forward(self, x):
225
+ h = self.conv_first(x)
226
+ for block in self.conv_blocks:
227
+ h = block.down(h)
228
+ h = block.res(h)
229
+ h = self.final_block(h)
230
+ return h
231
+
232
+
233
+ class Decoder(nn.Module):
234
+ def __init__(self, n_hiddens, upsample, z_channels, image_channel, norm_type='group'):
235
+ super().__init__()
236
+
237
+ n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
238
+ max_us = n_times_upsample.max()
239
+ in_channels = z_channels
240
+ self.conv_blocks = nn.ModuleList()
241
+ for i in range(max_us):
242
+ block = nn.Module()
243
+ in_channels = in_channels if i == 0 else n_hiddens * 2 ** (max_us - i + 1)
244
+ out_channels = n_hiddens * 2 ** (max_us - i)
245
+ us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
246
+ us = list(us)
247
+ us[0] = 1
248
+ us = tuple(us)
249
+ block.up = SamePadConvTranspose3d(in_channels, out_channels, 4, stride=us)
250
+ block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
251
+ block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
252
+ self.conv_blocks.append(block)
253
+ n_times_upsample -= 1
254
+
255
+ self.conv_out = SamePadConv3d(out_channels, image_channel, kernel_size=3)
256
+
257
+ def forward(self, x):
258
+ h = x
259
+ for i, block in enumerate(self.conv_blocks):
260
+ h = block.up(h)
261
+ h = block.res1(h)
262
+ h = block.res2(h)
263
+ h = self.conv_out(h)
264
+ return h
265
+
266
+
267
+ class EncoderRe(nn.Module):
268
+ def __init__(self, n_hiddens, downsample, z_channels, double_z, image_channel=3, norm_type='group',
269
+ padding_type='replicate', n_res_layers=2):
270
+ super().__init__()
271
+ # n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
272
+ self.conv_blocks = nn.ModuleList()
273
+ # max_ds = n_times_downsample.max()
274
+ self.conv_first = SamePadConv3d(image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
275
+
276
+ for i, step in enumerate(downsample):
277
+ block = nn.Module()
278
+ in_channels = n_hiddens
279
+ out_channels = n_hiddens
280
+ stride = [1, downsample[i], downsample[i]]
281
+ stride = tuple(stride)
282
+ block.down = SamePadConv3d(in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
283
+ block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
284
+ block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
285
+ self.conv_blocks.append(block)
286
+
287
+
288
+ self.res_stack = nn.Sequential(
289
+ *[AttentionResidualBlock(out_channels)
290
+ for _ in range(n_res_layers)]
291
+ )
292
+ self.final_block = nn.Sequential(
293
+ Normalize(out_channels, norm_type),
294
+ SiLU(),
295
+ SamePadConv3d(out_channels, 2 * z_channels if double_z else z_channels,
296
+ kernel_size=3,
297
+ stride=1,
298
+ padding_type=padding_type)
299
+ )
300
+ self.out_channels = out_channels
301
+
302
+ def forward(self, x):
303
+ h = self.conv_first(x)
304
+ for block in self.conv_blocks:
305
+ h = block.down(h)
306
+ h = block.res1(h)
307
+ h = block.res2(h)
308
+ h = self.res_stack(h)
309
+ h = self.final_block(h)
310
+ return h
311
+
312
+
313
+ class DecoderRe(nn.Module):
314
+ def __init__(self, n_hiddens, upsample, z_channels, image_channel, norm_type='group', padding_type='replicate', n_res_layers=2):
315
+ super().__init__()
316
+ self.conv_first = SamePadConv3d(z_channels, n_hiddens, kernel_size=3, padding_type=padding_type)
317
+ self.res_stack = nn.Sequential(
318
+ *[AttentionResidualBlock(n_hiddens)
319
+ for _ in range(n_res_layers)]
320
+ )
321
+ # n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
322
+ # max_us = n_times_upsample.max()
323
+ # in_channels = n_hiddens
324
+ self.conv_blocks = nn.ModuleList()
325
+ for i, step in enumerate(upsample):
326
+ block = nn.Module()
327
+ in_channels = n_hiddens
328
+ out_channels = n_hiddens
329
+ stride = [1, upsample[i], upsample[i]]
330
+ stride = tuple(stride)
331
+
332
+ block.up = SamePadConvTranspose3d(in_channels, out_channels, 4, stride=stride)
333
+ block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
334
+ block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
335
+ self.conv_blocks.append(block)
336
+
337
+ self.conv_out = SamePadConv3d(out_channels, image_channel, kernel_size=3)
338
+
339
+ def forward(self, x):
340
+ h = x
341
+ h = self.conv_first(h)
342
+ h = self.res_stack(h)
343
+ for i, block in enumerate(self.conv_blocks):
344
+ h = block.up(h)
345
+ h = block.res1(h)
346
+ h = block.res2(h)
347
+ h = self.conv_out(h)
348
+ return h
349
+
350
+
351
+ # unit test
352
+ if __name__ == '__main__':
353
+ encoder = EncoderRe(n_hiddens=320, downsample=[1, 2, 2, 2], z_channels=8, double_z=True, image_channel=96,
354
+ norm_type='group', padding_type='replicate')
355
+ encoder = encoder.cuda()
356
+ en_input = torch.rand(1, 96, 3, 256, 256).cuda()
357
+ out = encoder(en_input)
358
+ print(out.shape)
359
+ mean, logvar = torch.chunk(out, 2, dim=1)
360
+ # print(mean.shape)
361
+ decoder = DecoderRe(n_hiddens=320, upsample=[2, 2, 2, 1], z_channels=8, image_channel=96,
362
+ norm_type='group' )
363
+
364
+ decoder = decoder.cuda()
365
+ out = decoder(mean)
366
+ print(out.shape)
367
+ # logvar = nn.Parameter(torch.ones(size=()) * 0.0)
368
+ # print(logvar)
DiT_VAE/vae/attention_vae.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.checkpoint import checkpoint
7
+ def tensor_slice(x, begin, size):
8
+ assert all([b >= 0 for b in begin])
9
+ size = [l - b if s == -1 else s
10
+ for s, b, l in zip(size, begin, x.shape)]
11
+ assert all([s >= 0 for s in size])
12
+
13
+ slices = [slice(b, b + s) for b, s in zip(begin, size)]
14
+ return x[slices]
15
+
16
+
17
+ # reshapes tensor start from dim i (inclusive)
18
+ # to dim j (exclusive) to the desired shape
19
+ # e.g. if x.shape = (b, thw, c) then
20
+ # view_range(x, 1, 2, (t, h, w)) returns
21
+ # x of shape (b, t, h, w, c)
22
+ def view_range(x, i, j, shape):
23
+ shape = tuple(shape)
24
+
25
+ n_dims = len(x.shape)
26
+ if i < 0:
27
+ i = n_dims + i
28
+
29
+ if j is None:
30
+ j = n_dims
31
+ elif j < 0:
32
+ j = n_dims + j
33
+
34
+ assert 0 <= i < j <= n_dims
35
+
36
+ x_shape = x.shape
37
+ target_shape = x_shape[:i] + shape + x_shape[j:]
38
+ return x.view(target_shape)
39
+
40
+ def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
41
+ n_dims = len(x.shape)
42
+ if src_dim < 0:
43
+ src_dim = n_dims + src_dim
44
+ if dest_dim < 0:
45
+ dest_dim = n_dims + dest_dim
46
+
47
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
48
+
49
+ dims = list(range(n_dims))
50
+ del dims[src_dim]
51
+
52
+ permutation = []
53
+ ctr = 0
54
+ for i in range(n_dims):
55
+ if i == dest_dim:
56
+ permutation.append(src_dim)
57
+ else:
58
+ permutation.append(dims[ctr])
59
+ ctr += 1
60
+ x = x.permute(permutation)
61
+ if make_contiguous:
62
+ x = x.contiguous()
63
+ return x
64
+ class AttentionStack(nn.Module):
65
+ def __init__(
66
+ self, shape, embd_dim, n_head, n_layer, dropout,
67
+ attn_type, attn_dropout, class_cond_dim, frame_cond_shape,
68
+ ):
69
+ super().__init__()
70
+ self.shape = shape
71
+ self.embd_dim = embd_dim
72
+ self.use_frame_cond = frame_cond_shape is not None
73
+
74
+ self.right_shift = RightShift(embd_dim)
75
+ self.pos_embd = AddBroadcastPosEmbed(
76
+ shape=shape, embd_dim=embd_dim
77
+ )
78
+
79
+ self.attn_nets = nn.ModuleList(
80
+ [
81
+ AttentionBlock(
82
+ shape=shape,
83
+ embd_dim=embd_dim,
84
+ n_head=n_head,
85
+ n_layer=n_layer,
86
+ dropout=dropout,
87
+ attn_type=attn_type,
88
+ attn_dropout=attn_dropout,
89
+ class_cond_dim=class_cond_dim,
90
+ frame_cond_shape=frame_cond_shape
91
+ )
92
+ for i in range(n_layer)
93
+ ]
94
+ )
95
+
96
+ def forward(self, x, cond, decode_step, decode_idx):
97
+ """
98
+ Args
99
+ ------
100
+ x: (b, d1, d2, ..., dn, embd_dim)
101
+ cond: a dictionary of conditioning tensors
102
+
103
+ (below is used only when sampling for fast decoding)
104
+ decode: the enumerated rasterscan order of the current idx being sampled
105
+ decode_step: a tuple representing the current idx being sampled
106
+ """
107
+ x = self.right_shift(x, decode_step)
108
+ x = self.pos_embd(x, decode_step, decode_idx)
109
+ for net in self.attn_nets:
110
+ x = net(x, cond, decode_step, decode_idx)
111
+
112
+ return x
113
+
114
+
115
+ class AttentionBlock(nn.Module):
116
+ def __init__(self, shape, embd_dim, n_head, n_layer, dropout,
117
+ attn_type, attn_dropout, class_cond_dim, frame_cond_shape):
118
+ super().__init__()
119
+ self.use_frame_cond = frame_cond_shape is not None
120
+
121
+ self.pre_attn_norm = LayerNorm(embd_dim, class_cond_dim)
122
+ self.post_attn_dp = nn.Dropout(dropout)
123
+ self.attn = MultiHeadAttention(shape, embd_dim, embd_dim, n_head,
124
+ n_layer, causal=True, attn_type=attn_type,
125
+ attn_kwargs=dict(attn_dropout=attn_dropout))
126
+
127
+ if frame_cond_shape is not None:
128
+ enc_len = np.prod(frame_cond_shape[:-1])
129
+ self.pre_enc_norm = LayerNorm(embd_dim, class_cond_dim)
130
+ self.post_enc_dp = nn.Dropout(dropout)
131
+ self.enc_attn = MultiHeadAttention(shape, embd_dim, frame_cond_shape[-1],
132
+ n_head, n_layer, attn_type='full',
133
+ attn_kwargs=dict(attn_dropout=0.), causal=False)
134
+
135
+ self.pre_fc_norm = LayerNorm(embd_dim, class_cond_dim)
136
+ self.post_fc_dp = nn.Dropout(dropout)
137
+ self.fc_block = nn.Sequential(
138
+ nn.Linear(in_features=embd_dim, out_features=embd_dim * 4),
139
+ GeLU2(),
140
+ nn.Linear(in_features=embd_dim * 4, out_features=embd_dim),
141
+ )
142
+
143
+ def forward(self, x, cond, decode_step, decode_idx):
144
+ h = self.pre_attn_norm(x, cond)
145
+ if self.training:
146
+ h = checkpoint(self.attn, h, h, h, decode_step, decode_idx)
147
+ else:
148
+ h = self.attn(h, h, h, decode_step, decode_idx)
149
+ h = self.post_attn_dp(h)
150
+ x = x + h
151
+
152
+ if self.use_frame_cond:
153
+ h = self.pre_enc_norm(x, cond)
154
+ if self.training:
155
+ h = checkpoint(self.enc_attn, h, cond['frame_cond'], cond['frame_cond'],
156
+ decode_step, decode_idx)
157
+ else:
158
+ h = self.enc_attn(h, cond['frame_cond'], cond['frame_cond'],
159
+ decode_step, decode_idx)
160
+ h = self.post_enc_dp(h)
161
+ x = x + h
162
+
163
+ h = self.pre_fc_norm(x, cond)
164
+ if self.training:
165
+ h = checkpoint(self.fc_block, h)
166
+ else:
167
+ h = self.fc_block(h)
168
+ h = self.post_fc_dp(h)
169
+ x = x + h
170
+
171
+ return x
172
+
173
+
174
+ class MultiHeadAttention(nn.Module):
175
+ def __init__(self, shape, dim_q, dim_kv, n_head, n_layer,
176
+ causal, attn_type, attn_kwargs):
177
+ super().__init__()
178
+ self.causal = causal
179
+ self.shape = shape
180
+
181
+ self.d_k = dim_q // n_head
182
+ self.d_v = dim_kv // n_head
183
+ self.n_head = n_head
184
+
185
+ self.w_qs = nn.Linear(dim_q, n_head * self.d_k, bias=False) # q
186
+ self.w_qs.weight.data.normal_(std=1.0 / np.sqrt(dim_q))
187
+
188
+ self.w_ks = nn.Linear(dim_kv, n_head * self.d_k, bias=False) # k
189
+ self.w_ks.weight.data.normal_(std=1.0 / np.sqrt(dim_kv))
190
+
191
+ self.w_vs = nn.Linear(dim_kv, n_head * self.d_v, bias=False) # v
192
+ self.w_vs.weight.data.normal_(std=1.0 / np.sqrt(dim_kv))
193
+
194
+ self.fc = nn.Linear(n_head * self.d_v, dim_q, bias=True) # c
195
+ self.fc.weight.data.normal_(std=1.0 / np.sqrt(dim_q * n_layer))
196
+
197
+ if attn_type == 'full':
198
+ self.attn = FullAttention(shape, causal, **attn_kwargs)
199
+ elif attn_type == 'axial':
200
+ assert not causal, 'causal axial attention is not supported'
201
+ self.attn = AxialAttention(len(shape), **attn_kwargs)
202
+ elif attn_type == 'sparse':
203
+ self.attn = SparseAttention(shape, n_head, causal, **attn_kwargs)
204
+
205
+ self.cache = None
206
+
207
+ def forward(self, q, k, v, decode_step=None, decode_idx=None):
208
+ """ Compute multi-head attention
209
+ Args
210
+ q, k, v: a [b, d1, ..., dn, c] tensor or
211
+ a [b, 1, ..., 1, c] tensor if decode_step is not None
212
+
213
+ Returns
214
+ The output after performing attention
215
+ """
216
+
217
+ # compute k, q, v
218
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
219
+ q = view_range(self.w_qs(q), -1, None, (n_head, d_k))
220
+ k = view_range(self.w_ks(k), -1, None, (n_head, d_k))
221
+ v = view_range(self.w_vs(v), -1, None, (n_head, d_v))
222
+
223
+ # b x n_head x seq_len x d
224
+ # (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d)
225
+ q = shift_dim(q, -2, 1)
226
+ k = shift_dim(k, -2, 1)
227
+ v = shift_dim(v, -2, 1)
228
+
229
+ # fast decoding
230
+ if decode_step is not None:
231
+ if decode_step == 0:
232
+ if self.causal:
233
+ k_shape = (q.shape[0], n_head, *self.shape, self.d_k)
234
+ v_shape = (q.shape[0], n_head, *self.shape, self.d_v)
235
+ self.cache = dict(k=torch.zeros(k_shape, dtype=k.dtype, device=q.device),
236
+ v=torch.zeros(v_shape, dtype=v.dtype, device=q.device))
237
+ else:
238
+ # cache only once in the non-causal case
239
+ self.cache = dict(k=k.clone(), v=v.clone())
240
+ if self.causal:
241
+ idx = (slice(None, None), slice(None, None), *[slice(i, i+ 1) for i in decode_idx])
242
+ self.cache['k'][idx] = k
243
+ self.cache['v'][idx] = v
244
+ k, v = self.cache['k'], self.cache['v']
245
+
246
+ a = self.attn(q, k, v, decode_step, decode_idx)
247
+
248
+ # (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d)
249
+ a = shift_dim(a, 1, -2).flatten(start_dim=-2)
250
+ a = self.fc(a) # (b x seq_len x embd_dim)
251
+
252
+ return a
253
+
254
+ ############## Attention #######################
255
+ class FullAttention(nn.Module):
256
+ def __init__(self, shape, causal, attn_dropout):
257
+ super().__init__()
258
+ self.causal = causal
259
+ self.attn_dropout = attn_dropout
260
+
261
+ seq_len = np.prod(shape)
262
+ if self.causal:
263
+ self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)))
264
+
265
+ def forward(self, q, k, v, decode_step, decode_idx):
266
+ mask = self.mask if self.causal else None
267
+ if decode_step is not None and mask is not None:
268
+ mask = mask[[decode_step]]
269
+
270
+ old_shape = q.shape[2:-1]
271
+ q = q.flatten(start_dim=2, end_dim=-2)
272
+ k = k.flatten(start_dim=2, end_dim=-2)
273
+ v = v.flatten(start_dim=2, end_dim=-2)
274
+
275
+ out = scaled_dot_product_attention(q, k, v, mask=mask,
276
+ attn_dropout=self.attn_dropout,
277
+ training=self.training)
278
+
279
+ return view_range(out, 2, 3, old_shape)
280
+
281
+ class AxialAttention(nn.Module):
282
+ def __init__(self, n_dim, axial_dim):
283
+ super().__init__()
284
+ if axial_dim < 0:
285
+ axial_dim = 2 + n_dim + 1 + axial_dim
286
+ else:
287
+ axial_dim += 2 # account for batch, head, dim
288
+ self.axial_dim = axial_dim
289
+
290
+ def forward(self, q, k, v, decode_step, decode_idx):
291
+ q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3)
292
+ k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3)
293
+ v = shift_dim(v, self.axial_dim, -2)
294
+ old_shape = list(v.shape)
295
+ v = v.flatten(end_dim=-3)
296
+
297
+ out = scaled_dot_product_attention(q, k, v, training=self.training)
298
+ out = out.view(*old_shape)
299
+ out = shift_dim(out, -2, self.axial_dim)
300
+ return out
301
+
302
+
303
+ class SparseAttention(nn.Module):
304
+ ops = dict()
305
+ attn_mask = dict()
306
+ block_layout = dict()
307
+
308
+ def __init__(self, shape, n_head, causal, num_local_blocks=4, block=32,
309
+ attn_dropout=0.): # does not use attn_dropout
310
+ super().__init__()
311
+ self.causal = causal
312
+ self.shape = shape
313
+
314
+ self.sparsity_config = StridedSparsityConfig(shape=shape, n_head=n_head,
315
+ causal=causal, block=block,
316
+ num_local_blocks=num_local_blocks)
317
+
318
+ if self.shape not in SparseAttention.block_layout:
319
+ SparseAttention.block_layout[self.shape] = self.sparsity_config.make_layout()
320
+ if causal and self.shape not in SparseAttention.attn_mask:
321
+ SparseAttention.attn_mask[self.shape] = self.sparsity_config.make_sparse_attn_mask()
322
+
323
+ def get_ops(self):
324
+ try:
325
+ from deepspeed.ops.sparse_attention import MatMul, Softmax
326
+ except:
327
+ raise Exception('Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`')
328
+ if self.shape not in SparseAttention.ops:
329
+ sparsity_layout = self.sparsity_config.make_layout()
330
+ sparse_dot_sdd_nt = MatMul(sparsity_layout,
331
+ self.sparsity_config.block,
332
+ 'sdd',
333
+ trans_a=False,
334
+ trans_b=True)
335
+
336
+ sparse_dot_dsd_nn = MatMul(sparsity_layout,
337
+ self.sparsity_config.block,
338
+ 'dsd',
339
+ trans_a=False,
340
+ trans_b=False)
341
+
342
+ sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)
343
+
344
+ SparseAttention.ops[self.shape] = (sparse_dot_sdd_nt,
345
+ sparse_dot_dsd_nn,
346
+ sparse_softmax)
347
+ return SparseAttention.ops[self.shape]
348
+
349
+ def forward(self, q, k, v, decode_step, decode_idx):
350
+ if self.training and self.shape not in SparseAttention.ops:
351
+ self.get_ops()
352
+
353
+ SparseAttention.block_layout[self.shape] = SparseAttention.block_layout[self.shape].to(q)
354
+ if self.causal:
355
+ SparseAttention.attn_mask[self.shape] = SparseAttention.attn_mask[self.shape].to(q).type_as(q)
356
+ attn_mask = SparseAttention.attn_mask[self.shape] if self.causal else None
357
+
358
+ old_shape = q.shape[2:-1]
359
+ q = q.flatten(start_dim=2, end_dim=-2)
360
+ k = k.flatten(start_dim=2, end_dim=-2)
361
+ v = v.flatten(start_dim=2, end_dim=-2)
362
+
363
+ if decode_step is not None:
364
+ mask = self.sparsity_config.get_non_block_layout_row(SparseAttention.block_layout[self.shape], decode_step)
365
+ out = scaled_dot_product_attention(q, k, v, mask=mask, training=self.training)
366
+ else:
367
+ if q.shape != k.shape or k.shape != v.shape:
368
+ raise Exception('SparseAttention only support self-attention')
369
+ sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops()
370
+ scaling = float(q.shape[-1]) ** -0.5
371
+
372
+ attn_output_weights = sparse_dot_sdd_nt(q, k)
373
+ if attn_mask is not None:
374
+ attn_output_weights = attn_output_weights.masked_fill(attn_mask == 0,
375
+ float('-inf'))
376
+ attn_output_weights = sparse_softmax(
377
+ attn_output_weights,
378
+ scale=scaling
379
+ )
380
+
381
+ out = sparse_dot_dsd_nn(attn_output_weights, v)
382
+
383
+ return view_range(out, 2, 3, old_shape)
384
+
385
+
386
+ class StridedSparsityConfig(object):
387
+ """
388
+ Strided Sparse configuration specified in https://arxiv.org/abs/1904.10509 that
389
+ generalizes to arbitrary dimensions
390
+ """
391
+ def __init__(self, shape, n_head, causal, block, num_local_blocks):
392
+ self.n_head = n_head
393
+ self.shape = shape
394
+ self.causal = causal
395
+ self.block = block
396
+ self.num_local_blocks = num_local_blocks
397
+
398
+ assert self.num_local_blocks >= 1, 'Must have at least 1 local block'
399
+ assert self.seq_len % self.block == 0, 'seq len must be divisible by block size'
400
+
401
+ self._block_shape = self._compute_block_shape()
402
+ self._block_shape_cum = self._block_shape_cum_sizes()
403
+
404
+ @property
405
+ def seq_len(self):
406
+ return np.prod(self.shape)
407
+
408
+ @property
409
+ def num_blocks(self):
410
+ return self.seq_len // self.block
411
+
412
+ def set_local_layout(self, layout):
413
+ num_blocks = self.num_blocks
414
+ for row in range(0, num_blocks):
415
+ end = min(row + self.num_local_blocks, num_blocks)
416
+ for col in range(
417
+ max(0, row - self.num_local_blocks),
418
+ (row + 1 if self.causal else end)):
419
+ layout[:, row, col] = 1
420
+ return layout
421
+
422
+ def set_global_layout(self, layout):
423
+ num_blocks = self.num_blocks
424
+ n_dim = len(self._block_shape)
425
+ for row in range(num_blocks):
426
+ assert self._to_flattened_idx(self._to_unflattened_idx(row)) == row
427
+ cur_idx = self._to_unflattened_idx(row)
428
+ # no strided attention over last dim
429
+ for d in range(n_dim - 1):
430
+ end = self._block_shape[d]
431
+ for i in range(0, (cur_idx[d] + 1 if self.causal else end)):
432
+ new_idx = list(cur_idx)
433
+ new_idx[d] = i
434
+ new_idx = tuple(new_idx)
435
+
436
+ col = self._to_flattened_idx(new_idx)
437
+ layout[:, row, col] = 1
438
+
439
+ return layout
440
+
441
+ def make_layout(self):
442
+ layout = torch.zeros((self.n_head, self.num_blocks, self.num_blocks), dtype=torch.int64)
443
+ layout = self.set_local_layout(layout)
444
+ layout = self.set_global_layout(layout)
445
+ return layout
446
+
447
+ def make_sparse_attn_mask(self):
448
+ block_layout = self.make_layout()
449
+ assert block_layout.shape[1] == block_layout.shape[2] == self.num_blocks
450
+
451
+ num_dense_blocks = block_layout.sum().item()
452
+ attn_mask = torch.ones(num_dense_blocks, self.block, self.block)
453
+ counter = 0
454
+ for h in range(self.n_head):
455
+ for i in range(self.num_blocks):
456
+ for j in range(self.num_blocks):
457
+ elem = block_layout[h, i, j].item()
458
+ if elem == 1:
459
+ assert i >= j
460
+ if i == j: # need to mask within block on diagonals
461
+ attn_mask[counter] = torch.tril(attn_mask[counter])
462
+ counter += 1
463
+ assert counter == num_dense_blocks
464
+
465
+ return attn_mask.unsqueeze(0)
466
+
467
+ def get_non_block_layout_row(self, block_layout, row):
468
+ block_row = row // self.block
469
+ block_row = block_layout[:, [block_row]] # n_head x 1 x n_blocks
470
+ block_row = block_row.repeat_interleave(self.block, dim=-1)
471
+ block_row[:, :, row + 1:] = 0.
472
+ return block_row
473
+
474
+ ############# Helper functions ##########################
475
+
476
+ def _compute_block_shape(self):
477
+ n_dim = len(self.shape)
478
+ cum_prod = 1
479
+ for i in range(n_dim - 1, -1, -1):
480
+ cum_prod *= self.shape[i]
481
+ if cum_prod > self.block:
482
+ break
483
+ assert cum_prod % self.block == 0
484
+ new_shape = (*self.shape[:i], cum_prod // self.block)
485
+
486
+ assert np.prod(new_shape) == np.prod(self.shape) // self.block
487
+
488
+ return new_shape
489
+
490
+ def _block_shape_cum_sizes(self):
491
+ bs = np.flip(np.array(self._block_shape))
492
+ return tuple(np.flip(np.cumprod(bs)[:-1])) + (1,)
493
+
494
+ def _to_flattened_idx(self, idx):
495
+ assert len(idx) == len(self._block_shape), f"{len(idx)} != {len(self._block_shape)}"
496
+ flat_idx = 0
497
+ for i in range(len(self._block_shape)):
498
+ flat_idx += idx[i] * self._block_shape_cum[i]
499
+ return flat_idx
500
+
501
+ def _to_unflattened_idx(self, flat_idx):
502
+ assert flat_idx < np.prod(self._block_shape)
503
+ idx = []
504
+ for i in range(len(self._block_shape)):
505
+ idx.append(flat_idx // self._block_shape_cum[i])
506
+ flat_idx %= self._block_shape_cum[i]
507
+ return tuple(idx)
508
+
509
+
510
+ ################ Spatiotemporal broadcasted positional embeddings ###############
511
+ class AddBroadcastPosEmbed(nn.Module):
512
+ def __init__(self, shape, embd_dim, dim=-1):
513
+ super().__init__()
514
+ assert dim in [-1, 1] # only first or last dim supported
515
+ self.shape = shape
516
+ self.n_dim = n_dim = len(shape)
517
+ self.embd_dim = embd_dim
518
+ self.dim = dim
519
+
520
+ assert embd_dim % n_dim == 0, f"{embd_dim} % {n_dim} != 0"
521
+ self.emb = nn.ParameterDict({
522
+ f'd_{i}': nn.Parameter(torch.randn(shape[i], embd_dim // n_dim) * 0.01
523
+ if dim == -1 else
524
+ torch.randn(embd_dim // n_dim, shape[i]) * 0.01)
525
+ for i in range(n_dim)
526
+ })
527
+
528
+ def forward(self, x, decode_step=None, decode_idx=None):
529
+ embs = []
530
+ for i in range(self.n_dim):
531
+ e = self.emb[f'd_{i}']
532
+ if self.dim == -1:
533
+ # (1, 1, ..., 1, self.shape[i], 1, ..., -1)
534
+ e = e.view(1, *((1,) * i), self.shape[i], *((1,) * (self.n_dim - i - 1)), -1)
535
+ e = e.expand(1, *self.shape, -1)
536
+ else:
537
+ e = e.view(1, -1, *((1,) * i), self.shape[i], *((1,) * (self.n_dim - i - 1)))
538
+ e = e.expand(1, -1, *self.shape)
539
+ embs.append(e)
540
+
541
+ embs = torch.cat(embs, dim=self.dim)
542
+ if decode_step is not None:
543
+ embs = tensor_slice(embs, [0, *decode_idx, 0],
544
+ [x.shape[0], *(1,) * self.n_dim, x.shape[-1]])
545
+
546
+ return x + embs
547
+
548
+ ################# Helper Functions ###################################
549
+ def scaled_dot_product_attention(q, k, v, mask=None, attn_dropout=0., training=True):
550
+ # Performs scaled dot-product attention over the second to last dimension dn
551
+
552
+ # (b, n_head, d1, ..., dn, d)
553
+ attn = torch.matmul(q, k.transpose(-1, -2))
554
+ attn = attn / np.sqrt(q.shape[-1])
555
+ if mask is not None:
556
+ attn = attn.masked_fill(mask == 0, float('-inf'))
557
+ attn_float = F.softmax(attn, dim=-1)
558
+ attn = attn_float.type_as(attn) # b x n_head x d1 x ... x dn x d
559
+ attn = F.dropout(attn, p=attn_dropout, training=training)
560
+
561
+ a = torch.matmul(attn, v) # b x n_head x d1 x ... x dn x d
562
+
563
+ return a
564
+
565
+
566
+ class RightShift(nn.Module):
567
+ def __init__(self, embd_dim):
568
+ super().__init__()
569
+ self.embd_dim = embd_dim
570
+ self.sos = nn.Parameter(torch.FloatTensor(embd_dim).normal_(std=0.02), requires_grad=True)
571
+
572
+ def forward(self, x, decode_step):
573
+ if decode_step is not None and decode_step > 0:
574
+ return x
575
+
576
+ x_shape = list(x.shape)
577
+ x = x.flatten(start_dim=1, end_dim=-2) # (b, seq_len, embd_dim)
578
+ sos = torch.ones(x_shape[0], 1, self.embd_dim, dtype=torch.float32).to(self.sos) * self.sos
579
+ sos = sos.type_as(x)
580
+ x = torch.cat([sos, x[:, :-1, :]], axis=1)
581
+ x = x.view(*x_shape)
582
+
583
+ return x
584
+
585
+
586
+ class GeLU2(nn.Module):
587
+ def forward(self, x):
588
+ return (1.702 * x).sigmoid() * x
589
+
590
+
591
+ class LayerNorm(nn.Module):
592
+ def __init__(self, embd_dim, class_cond_dim):
593
+ super().__init__()
594
+ self.conditional = class_cond_dim is not None
595
+
596
+ if self.conditional:
597
+ self.w = nn.Linear(class_cond_dim, embd_dim, bias=False)
598
+ nn.init.constant_(self.w.weight.data, 1. / np.sqrt(class_cond_dim))
599
+ self.wb = nn.Linear(class_cond_dim, embd_dim, bias=False)
600
+ else:
601
+ self.g = nn.Parameter(torch.ones(embd_dim, dtype=torch.float32), requires_grad=True)
602
+ self.b = nn.Parameter(torch.zeros(embd_dim, dtype=torch.float32), requires_grad=True)
603
+
604
+ def forward(self, x, cond):
605
+ if self.conditional: # (b, cond_dim)
606
+ g = 1 + self.w(cond['class_cond']).view(x.shape[0], *(1,)*(len(x.shape)-2), x.shape[-1]) # (b, ..., embd_dim)
607
+ b = self.wb(cond['class_cond']).view(x.shape[0], *(1,)*(len(x.shape)-2), x.shape[-1])
608
+ else:
609
+ g = self.g # (embd_dim,)
610
+ b = self.b
611
+
612
+ x_float = x.float()
613
+
614
+ mu = x_float.mean(dim=-1, keepdims=True)
615
+ s = (x_float - mu).square().mean(dim=-1, keepdims=True)
616
+ x_float = (x_float - mu) * (1e-5 + s.rsqrt()) # (b, ..., embd_dim)
617
+ x_float = x_float * g + b
618
+
619
+ x = x_float.type_as(x)
620
+ return x
DiT_VAE/vae/data/__init__.py ADDED
File without changes
DiT_VAE/vae/data/dataset_online_vae.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy
4
+ import json
5
+ import zipfile
6
+ import torch
7
+ from PIL import Image
8
+ # from transformers import CLIPImageProcessor
9
+ from torch.utils.data import Dataset
10
+ import io
11
+ from omegaconf import OmegaConf
12
+ import numpy as np
13
+ # from torchvision import transforms
14
+ # from einops import rearrange
15
+ # import random
16
+ # import os
17
+ # from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDIMScheduler
18
+ # import time
19
+ # import io
20
+ # import array
21
+ # import numpy as np
22
+ #
23
+ # from training.triplane import TriPlaneGenerator
24
+
25
+
26
+ def to_rgb_image(maybe_rgba: Image.Image):
27
+ if maybe_rgba.mode == 'RGB':
28
+ return maybe_rgba
29
+ elif maybe_rgba.mode == 'RGBA':
30
+ rgba = maybe_rgba
31
+ img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
32
+ img = Image.fromarray(img, 'RGB')
33
+ img.paste(rgba, mask=rgba.getchannel('A'))
34
+ return img
35
+ else:
36
+ raise ValueError("Unsupported image type.", maybe_rgba.mode)
37
+
38
+
39
+
40
+ # image(contain style),z,pose,text
41
+ class TriplaneDataset(Dataset):
42
+ # image, triplane, ref_feature
43
+ def __init__(self, json_file, data_base_dir, model_names):
44
+ super().__init__()
45
+ self.dict_data_image = json.load(open(json_file)) # {'image_name': pose}
46
+ self.data_base_dir = data_base_dir
47
+ self.data_list = list(self.dict_data_image.keys())
48
+ self.zip_file_dict = {}
49
+ config_gan_model = OmegaConf.load(model_names)
50
+ all_models = config_gan_model['gan_models'].keys()
51
+ for model_name in all_models:
52
+ zipfile_path = os.path.join(self.data_base_dir, model_name+'.zip')
53
+ zipfile_load = zipfile.ZipFile(zipfile_path)
54
+ self.zip_file_dict[model_name] = zipfile_load
55
+
56
+ def getdata(self, idx):
57
+ # need z and expression and model name
58
+ # image:"seed0035.png"
59
+ # data_each_dict = {
60
+ # 'vert_dir': vert_dir,
61
+ # 'z_dir': z_dir,
62
+ # 'pose_dir': pose_dir,
63
+ # 'img_dir': img_dir,
64
+ # 'model_name': model_name
65
+ # }
66
+ data_name = self.data_list[idx]
67
+ data_model_name = self.dict_data_image[data_name]['model_name']
68
+ zipfile_loaded = self.zip_file_dict[data_model_name]
69
+ # zipfile_path = os.path.join(self.data_base_dir, data_model_name)
70
+ # zipfile_loaded = zipfile.ZipFile(zipfile_path)
71
+ with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f:
72
+ buffer = io.BytesIO(f.read())
73
+ data_z = torch.load(buffer)
74
+ buffer.close()
75
+ f.close()
76
+ with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as ff:
77
+ buffer_v = io.BytesIO(ff.read())
78
+ data_vert = torch.load(buffer_v)
79
+ buffer_v.close()
80
+ ff.close()
81
+ # raw_image = to_rgb_image(Image.open(f))
82
+ #
83
+ # data_model_name = self.dict_data_image[data_name]['model_name']
84
+ # data_z_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['z_dir'])
85
+ # data_vert_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['vert_dir'])
86
+ # data_z = torch.load(data_z_dir)
87
+ # data_vert = torch.load(data_vert_dir)
88
+
89
+ return {
90
+ "data_z": data_z,
91
+ "data_vert": data_vert,
92
+ "data_model_name": data_model_name
93
+ }
94
+
95
+ def __getitem__(self, idx):
96
+ for _ in range(20):
97
+ try:
98
+ return self.getdata(idx)
99
+ except Exception as e:
100
+ print(f"Error details: {str(e)}")
101
+ idx = np.random.randint(len(self))
102
+ raise RuntimeError('Too many bad data.')
103
+
104
+
105
+ def __len__(self):
106
+ return len(self.data_list)
107
+
108
+ # for zip files
DiT_VAE/vae/distributions.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self, noise=None):
36
+ if noise is None:
37
+ noise = torch.randn(self.mean.shape)
38
+ x = self.mean + self.std * noise.to(device=self.parameters.device, dtype=self.parameters.dtype)
39
+ return x
40
+
41
+ def kl(self, other=None):
42
+ if self.deterministic:
43
+ return torch.Tensor([0.])
44
+ else:
45
+ if other is None:
46
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
47
+ + self.var - 1.0 - self.logvar,
48
+ dim=[1, 2, 3])
49
+ else:
50
+ return 0.5 * torch.sum(
51
+ torch.pow(self.mean - other.mean, 2) / other.var
52
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
53
+ dim=[1, 2, 3])
54
+
55
+ def nll(self, sample, dims=[1,2,3]):
56
+ if self.deterministic:
57
+ return torch.Tensor([0.])
58
+ logtwopi = np.log(2.0 * np.pi)
59
+ return 0.5 * torch.sum(
60
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
61
+ dim=dims)
62
+
63
+ def mode(self):
64
+ return self.mean
65
+
66
+
67
+ def normal_kl(mean1, logvar1, mean2, logvar2):
68
+ """
69
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
70
+ Compute the KL divergence between two gaussians.
71
+ Shapes are automatically broadcasted, so batches can be compared to
72
+ scalars, among other use cases.
73
+ """
74
+ tensor = None
75
+ for obj in (mean1, logvar1, mean2, logvar2):
76
+ if isinstance(obj, torch.Tensor):
77
+ tensor = obj
78
+ break
79
+ assert tensor is not None, "at least one argument must be a Tensor"
80
+
81
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
82
+ # Tensors, but it does not work for torch.exp().
83
+ logvar1, logvar2 = [
84
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
85
+ for x in (logvar1, logvar2)
86
+ ]
87
+
88
+ return 0.5 * (
89
+ -1.0
90
+ + logvar2
91
+ - logvar1
92
+ + torch.exp(logvar1 - logvar2)
93
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
94
+ )
DiT_VAE/vae/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .contperceptual import LPIPSithTVLoss