Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		Ji4chenLi
		
	commited on
		
		
					Commit 
							
							·
						
						5bec700
	
1
								Parent(s):
							
							4906e77
								
initialize demo
Browse files- app.py +226 -0
- checkpoints/VideoCrafter2_model.ckpt +3 -0
- checkpoints/unet_mg.pt +3 -0
- configs/inference_t2v_512_v2.0.yaml +78 -0
- lvdm/__pycache__/basics.cpython-312.pyc +0 -0
- lvdm/__pycache__/common.cpython-312.pyc +0 -0
- lvdm/__pycache__/distributions.cpython-312.pyc +0 -0
- lvdm/__pycache__/ema.cpython-312.pyc +0 -0
- lvdm/basics.py +102 -0
- lvdm/common.py +112 -0
- lvdm/distributions.py +103 -0
- lvdm/ema.py +84 -0
- lvdm/models/__pycache__/autoencoder.cpython-312.pyc +0 -0
- lvdm/models/__pycache__/ddpm3d.cpython-312.pyc +0 -0
- lvdm/models/__pycache__/utils_diffusion.cpython-312.pyc +0 -0
- lvdm/models/autoencoder.py +276 -0
- lvdm/models/ddpm3d.py +967 -0
- lvdm/models/samplers/ddim.py +493 -0
- lvdm/models/utils_diffusion.py +130 -0
- lvdm/modules/__pycache__/attention.cpython-312.pyc +0 -0
- lvdm/modules/attention.py +612 -0
- lvdm/modules/encoders/__pycache__/condition.cpython-312.pyc +0 -0
- lvdm/modules/encoders/__pycache__/ip_resampler.cpython-312.pyc +0 -0
- lvdm/modules/encoders/condition.py +512 -0
- lvdm/modules/encoders/ip_resampler.py +148 -0
- lvdm/modules/networks/__pycache__/ae_modules.cpython-312.pyc +0 -0
- lvdm/modules/networks/__pycache__/openaimodel3d.cpython-312.pyc +0 -0
- lvdm/modules/networks/ae_modules.py +1025 -0
- lvdm/modules/networks/openaimodel3d.py +740 -0
- lvdm/modules/x_transformer.py +704 -0
- pipeline/__init__.py +0 -0
- pipeline/__pycache__/__init__.cpython-312.pyc +0 -0
- pipeline/__pycache__/t2v_turbo_vc2_pipeline.cpython-312.pyc +0 -0
- pipeline/t2v_turbo_vc2_pipeline.py +221 -0
- scheduler/__pycache__/t2v_turbo_scheduler.cpython-312.pyc +0 -0
- scheduler/t2v_turbo_scheduler.py +524 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-312.pyc +0 -0
- utils/__pycache__/common_utils.cpython-312.pyc +0 -0
- utils/__pycache__/lora.cpython-312.pyc +0 -0
- utils/__pycache__/lora_handler.cpython-312.pyc +0 -0
- utils/__pycache__/utils.cpython-312.pyc +0 -0
- utils/common_utils.py +511 -0
- utils/lora.py +1349 -0
- utils/lora_handler.py +153 -0
- utils/utils.py +99 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,226 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import uuid
         | 
| 3 | 
            +
            from omegaconf import OmegaConf
         | 
| 4 | 
            +
            import spaces
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import imageio
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torchvision
         | 
| 11 | 
            +
            import gradio as gr
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            from gradio.components import Textbox, Video
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from utils.common_utils import load_model_checkpoint
         | 
| 16 | 
            +
            from utils.utils import instantiate_from_config
         | 
| 17 | 
            +
            from scheduler.t2v_turbo_scheduler import T2VTurboScheduler
         | 
| 18 | 
            +
            from pipeline.t2v_turbo_vc2_pipeline import T2VTurboVC2Pipeline
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            DESCRIPTION = """# T2V-Turbo 🚀
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            Our model is distilled from [VideoCrafter2](https://ailab-cvc.github.io/videocrafter2/).
         | 
| 23 | 
            +
            T2V-Turbo learns a LoRA on top of the base model by aligning to the reward feedback from [HPSv2.1](https://github.com/tgxs002/HPSv2/tree/master) and [InternVid2 Stage 2 Model](https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4).
         | 
| 24 | 
            +
            T2V-Turbo-v2 optimizes the training techniques by finetuning the full base model and further aligns to [CLIPScore](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            T2V-Turbo trains on pure WebVid-10M data, whereas T2V-Turbo-v2 carufully optimizes different learning objectives with a mixutre of VidGen-1M and WebVid-10M data.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            Moreover, T2V-Turbo-v2 supports to distill motion priors from the training videos. 
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            [Project page for T2V-Turbo](https://t2v-turbo.github.io) 😄
         | 
| 31 | 
            +
            [Project page for T2V-Turbo-v2](https://t2v-turbo-v2.github.io) 🛫
         | 
| 32 | 
            +
            """
         | 
| 33 | 
            +
            if torch.cuda.is_available():
         | 
| 34 | 
            +
                DESCRIPTION += "\n<p>Running on CUDA 😀</p>"
         | 
| 35 | 
            +
            elif hasattr(torch, "xpu") and torch.xpu.is_available():
         | 
| 36 | 
            +
                DESCRIPTION += "\n<p>Running on XPU 🤓</p>"
         | 
| 37 | 
            +
            else:
         | 
| 38 | 
            +
                DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
         | 
| 44 | 
            +
                if randomize_seed:
         | 
| 45 | 
            +
                    seed = random.randint(0, MAX_SEED)
         | 
| 46 | 
            +
                return seed
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def save_video(video_array, video_save_path, fps: int = 16):
         | 
| 50 | 
            +
                video = video_array.detach().cpu()
         | 
| 51 | 
            +
                video = torch.clamp(video.float(), -1.0, 1.0)
         | 
| 52 | 
            +
                video = video.permute(1, 0, 2, 3)  # t,c,h,w
         | 
| 53 | 
            +
                video = (video + 1.0) / 2.0
         | 
| 54 | 
            +
                video = (video * 255).to(torch.uint8).permute(0, 2, 3, 1)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                torchvision.io.write_video(
         | 
| 57 | 
            +
                    video_save_path, video, fps=fps, video_codec="h264", options={"crf": "10"}
         | 
| 58 | 
            +
                )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            example_txt = [
         | 
| 61 | 
            +
                "An astronaut riding a horse.",
         | 
| 62 | 
            +
                "Darth vader surfing in waves.",
         | 
| 63 | 
            +
                "light wind, feathers moving, she moves her gaze, 4k",
         | 
| 64 | 
            +
                "a girl floating underwater.",
         | 
| 65 | 
            +
                "Pikachu snowboarding.",
         | 
| 66 | 
            +
                "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
         | 
| 67 | 
            +
                "A musician strums his guitar, serenading the moonlit night.",
         | 
| 68 | 
            +
            ]
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            examples = [[i, 7.5, 0.5, 16, 16, 0, True, "bf16"] for i in example_txt]
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            @spaces.GPU(duration=120)
         | 
| 73 | 
            +
            @torch.inference_mode()
         | 
| 74 | 
            +
            def generate(
         | 
| 75 | 
            +
                prompt: str,
         | 
| 76 | 
            +
                guidance_scale: float = 7.5,
         | 
| 77 | 
            +
                percentage: float = 0.5,
         | 
| 78 | 
            +
                num_inference_steps: int = 4,
         | 
| 79 | 
            +
                num_frames: int = 16,
         | 
| 80 | 
            +
                seed: int = 0,
         | 
| 81 | 
            +
                randomize_seed: bool = False,
         | 
| 82 | 
            +
                param_dtype="bf16",
         | 
| 83 | 
            +
                motion_gs: float = 0.05,
         | 
| 84 | 
            +
                fps: int = 8,
         | 
| 85 | 
            +
            ):
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                seed = randomize_seed_fn(seed, randomize_seed)
         | 
| 88 | 
            +
                torch.manual_seed(seed)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                if param_dtype == "bf16":
         | 
| 91 | 
            +
                    dtype = torch.bfloat16
         | 
| 92 | 
            +
                    unet.dtype = torch.bfloat16
         | 
| 93 | 
            +
                elif param_dtype == "fp16":
         | 
| 94 | 
            +
                    dtype = torch.float16
         | 
| 95 | 
            +
                    unet.dtype = torch.float16
         | 
| 96 | 
            +
                elif param_dtype == "fp32":
         | 
| 97 | 
            +
                    dtype = torch.float32
         | 
| 98 | 
            +
                    unet.dtype = torch.float32
         | 
| 99 | 
            +
                else:
         | 
| 100 | 
            +
                    raise ValueError(f"Unknown dtype: {param_dtype}")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                pipeline.unet.to(device, dtype)
         | 
| 103 | 
            +
                pipeline.text_encoder.to(device, dtype)
         | 
| 104 | 
            +
                pipeline.vae.to(device, dtype)
         | 
| 105 | 
            +
                pipeline.to(device, dtype)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                result = pipeline(
         | 
| 108 | 
            +
                    prompt=prompt,
         | 
| 109 | 
            +
                    frames=num_frames,
         | 
| 110 | 
            +
                    fps=fps,
         | 
| 111 | 
            +
                    guidance_scale=guidance_scale,
         | 
| 112 | 
            +
                    motion_gs=motion_gs,
         | 
| 113 | 
            +
                    use_motion_cond=True,
         | 
| 114 | 
            +
                    percentage=percentage,
         | 
| 115 | 
            +
                    num_inference_steps=num_inference_steps,
         | 
| 116 | 
            +
                    lcm_origin_steps=200,
         | 
| 117 | 
            +
                    num_videos_per_prompt=1,
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                torch.cuda.empty_cache()
         | 
| 121 | 
            +
                tmp_save_path = "tmp.mp4"
         | 
| 122 | 
            +
                root_path = "./videos/"
         | 
| 123 | 
            +
                os.makedirs(root_path, exist_ok=True)
         | 
| 124 | 
            +
                video_save_path = os.path.join(root_path, tmp_save_path)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                save_video(result[0], video_save_path, fps=fps)
         | 
| 127 | 
            +
                display_model_info = f"Video size: {num_frames}x320x512, Sampling Step: {num_inference_steps}, Guidance Scale: {guidance_scale}"
         | 
| 128 | 
            +
                return video_save_path, prompt, display_model_info, seed
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            block_css = """
         | 
| 132 | 
            +
            #buttons button {
         | 
| 133 | 
            +
                min-width: min(120px,100%);
         | 
| 134 | 
            +
            }
         | 
| 135 | 
            +
            """
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            if __name__ == "__main__":
         | 
| 139 | 
            +
                device = torch.device("cuda:0")
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                config = OmegaConf.load("configs/inference_t2v_512_v2.0.yaml")
         | 
| 142 | 
            +
                model_config = config.pop("model", OmegaConf.create())
         | 
| 143 | 
            +
                pretrained_t2v = instantiate_from_config(model_config)
         | 
| 144 | 
            +
                pretrained_t2v = load_model_checkpoint(pretrained_t2v, "checkpoints/VideoCrafter2_model.ckpt")
         | 
| 145 | 
            +
                
         | 
| 146 | 
            +
                unet_config = model_config["params"]["unet_config"]
         | 
| 147 | 
            +
                unet_config["params"]["use_checkpoint"] = False
         | 
| 148 | 
            +
                unet_config["params"]["time_cond_proj_dim"] = 256
         | 
| 149 | 
            +
                unet_config["params"]["motion_cond_proj_dim"] = 256
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                unet = instantiate_from_config(unet_config)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                unet.load_state_dict(torch.load("checkpoints/unet_mg.pt", map_location=device))
         | 
| 154 | 
            +
                unet.eval()
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                pretrained_t2v.model.diffusion_model = unet
         | 
| 157 | 
            +
                scheduler = T2VTurboScheduler(
         | 
| 158 | 
            +
                    linear_start=model_config["params"]["linear_start"],
         | 
| 159 | 
            +
                    linear_end=model_config["params"]["linear_end"],
         | 
| 160 | 
            +
                )
         | 
| 161 | 
            +
                pipeline = T2VTurboVC2Pipeline(pretrained_t2v, scheduler, model_config)
         | 
| 162 | 
            +
                pipeline.to(device)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                demo = gr.Interface(
         | 
| 165 | 
            +
                    fn=generate,
         | 
| 166 | 
            +
                    inputs=[
         | 
| 167 | 
            +
                        Textbox(label="", placeholder="Please enter your prompt. \n"),
         | 
| 168 | 
            +
                        gr.Slider(
         | 
| 169 | 
            +
                            label="Guidance scale",
         | 
| 170 | 
            +
                            minimum=2,
         | 
| 171 | 
            +
                            maximum=14,
         | 
| 172 | 
            +
                            step=0.1,
         | 
| 173 | 
            +
                            value=7.5,
         | 
| 174 | 
            +
                        ),
         | 
| 175 | 
            +
                        gr.Slider(
         | 
| 176 | 
            +
                            label="Percentage of steps to apply motion guidance (v2 w/ MG only)",
         | 
| 177 | 
            +
                            minimum=0.0,
         | 
| 178 | 
            +
                            maximum=0.5,
         | 
| 179 | 
            +
                            step=0.05,
         | 
| 180 | 
            +
                            value=0.5,
         | 
| 181 | 
            +
                        ),
         | 
| 182 | 
            +
                        gr.Slider(
         | 
| 183 | 
            +
                            label="Number of inference steps",
         | 
| 184 | 
            +
                            minimum=4,
         | 
| 185 | 
            +
                            maximum=50,
         | 
| 186 | 
            +
                            step=1,
         | 
| 187 | 
            +
                            value=16,
         | 
| 188 | 
            +
                        ),
         | 
| 189 | 
            +
                        gr.Slider(
         | 
| 190 | 
            +
                            label="Number of Video Frames",
         | 
| 191 | 
            +
                            minimum=16,
         | 
| 192 | 
            +
                            maximum=48,
         | 
| 193 | 
            +
                            step=8,
         | 
| 194 | 
            +
                            value=16,
         | 
| 195 | 
            +
                        ),
         | 
| 196 | 
            +
                        gr.Slider(
         | 
| 197 | 
            +
                            label="Seed",
         | 
| 198 | 
            +
                            minimum=0,
         | 
| 199 | 
            +
                            maximum=MAX_SEED,
         | 
| 200 | 
            +
                            step=1,
         | 
| 201 | 
            +
                            value=0,
         | 
| 202 | 
            +
                            randomize=True,
         | 
| 203 | 
            +
                        ),
         | 
| 204 | 
            +
                        gr.Checkbox(label="Randomize seed", value=True),
         | 
| 205 | 
            +
                        gr.Radio(
         | 
| 206 | 
            +
                            ["bf16", "fp16", "fp32"],
         | 
| 207 | 
            +
                            label="torch.dtype",
         | 
| 208 | 
            +
                            value="bf16",
         | 
| 209 | 
            +
                            interactive=True,
         | 
| 210 | 
            +
                            info="Dtype for inference. Default is bf16.",
         | 
| 211 | 
            +
                        )
         | 
| 212 | 
            +
                    ],
         | 
| 213 | 
            +
                    outputs=[
         | 
| 214 | 
            +
                        gr.Video(label="Generated Video", width=512, height=320, interactive=False, autoplay=True),
         | 
| 215 | 
            +
                        Textbox(label="input prompt"),
         | 
| 216 | 
            +
                        Textbox(label="model info"),
         | 
| 217 | 
            +
                        gr.Slider(label="seed"),
         | 
| 218 | 
            +
                    ],
         | 
| 219 | 
            +
                    description=DESCRIPTION,
         | 
| 220 | 
            +
                    theme=gr.themes.Default(),
         | 
| 221 | 
            +
                    css=block_css,
         | 
| 222 | 
            +
                    examples=examples,
         | 
| 223 | 
            +
                    cache_examples=False,
         | 
| 224 | 
            +
                    concurrency_limit=10,
         | 
| 225 | 
            +
                )
         | 
| 226 | 
            +
                demo.launch()
         | 
    	
        checkpoints/VideoCrafter2_model.ckpt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:1edf769ece3308e977228943eeeed39286806aba9da17350449a3fbf4324ccfb
         | 
| 3 | 
            +
            size 7404653244
         | 
    	
        checkpoints/unet_mg.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:92c8767b40a5b2737dd3c69f5f13dae222ead5bd4befbbf894ca870231db13bc
         | 
| 3 | 
            +
            size 5655143958
         | 
    	
        configs/inference_t2v_512_v2.0.yaml
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            model:
         | 
| 2 | 
            +
              target: lvdm.models.ddpm3d.LatentDiffusion
         | 
| 3 | 
            +
              params:
         | 
| 4 | 
            +
                linear_start: 0.00085
         | 
| 5 | 
            +
                linear_end: 0.012
         | 
| 6 | 
            +
                num_timesteps_cond: 1
         | 
| 7 | 
            +
                timesteps: 1000
         | 
| 8 | 
            +
                first_stage_key: video
         | 
| 9 | 
            +
                cond_stage_key: caption
         | 
| 10 | 
            +
                cond_stage_trainable: false
         | 
| 11 | 
            +
                conditioning_key: crossattn
         | 
| 12 | 
            +
                image_size:
         | 
| 13 | 
            +
                - 40
         | 
| 14 | 
            +
                - 64
         | 
| 15 | 
            +
                channels: 4
         | 
| 16 | 
            +
                scale_by_std: false
         | 
| 17 | 
            +
                scale_factor: 0.18215
         | 
| 18 | 
            +
                use_ema: false
         | 
| 19 | 
            +
                uncond_type: empty_seq
         | 
| 20 | 
            +
                use_scale: true
         | 
| 21 | 
            +
                scale_b: 0.7
         | 
| 22 | 
            +
                unet_config:
         | 
| 23 | 
            +
                  target: lvdm.modules.networks.openaimodel3d.UNetModel
         | 
| 24 | 
            +
                  params:
         | 
| 25 | 
            +
                    in_channels: 4
         | 
| 26 | 
            +
                    out_channels: 4
         | 
| 27 | 
            +
                    model_channels: 320
         | 
| 28 | 
            +
                    attention_resolutions:
         | 
| 29 | 
            +
                    - 4
         | 
| 30 | 
            +
                    - 2
         | 
| 31 | 
            +
                    - 1
         | 
| 32 | 
            +
                    num_res_blocks: 2
         | 
| 33 | 
            +
                    channel_mult:
         | 
| 34 | 
            +
                    - 1
         | 
| 35 | 
            +
                    - 2
         | 
| 36 | 
            +
                    - 4
         | 
| 37 | 
            +
                    - 4
         | 
| 38 | 
            +
                    num_head_channels: 64
         | 
| 39 | 
            +
                    transformer_depth: 1
         | 
| 40 | 
            +
                    context_dim: 1024
         | 
| 41 | 
            +
                    use_linear: true
         | 
| 42 | 
            +
                    use_checkpoint: false
         | 
| 43 | 
            +
                    temporal_conv: true
         | 
| 44 | 
            +
                    temporal_attention: true
         | 
| 45 | 
            +
                    temporal_selfatt_only: true
         | 
| 46 | 
            +
                    use_relative_position: false
         | 
| 47 | 
            +
                    use_causal_attention: false
         | 
| 48 | 
            +
                    temporal_length: 16
         | 
| 49 | 
            +
                    addition_attention: true
         | 
| 50 | 
            +
                    fps_cond: true
         | 
| 51 | 
            +
                first_stage_config:
         | 
| 52 | 
            +
                  target: lvdm.models.autoencoder.AutoencoderKL
         | 
| 53 | 
            +
                  params:
         | 
| 54 | 
            +
                    embed_dim: 4
         | 
| 55 | 
            +
                    monitor: val/rec_loss
         | 
| 56 | 
            +
                    ddconfig:
         | 
| 57 | 
            +
                      double_z: true
         | 
| 58 | 
            +
                      z_channels: 4
         | 
| 59 | 
            +
                      resolution: 512
         | 
| 60 | 
            +
                      in_channels: 3
         | 
| 61 | 
            +
                      out_ch: 3
         | 
| 62 | 
            +
                      ch: 128
         | 
| 63 | 
            +
                      ch_mult:
         | 
| 64 | 
            +
                      - 1
         | 
| 65 | 
            +
                      - 2
         | 
| 66 | 
            +
                      - 4
         | 
| 67 | 
            +
                      - 4
         | 
| 68 | 
            +
                      num_res_blocks: 2
         | 
| 69 | 
            +
                      attn_resolutions: []
         | 
| 70 | 
            +
                      dropout: 0.0
         | 
| 71 | 
            +
                    lossconfig:
         | 
| 72 | 
            +
                      target: torch.nn.Identity
         | 
| 73 | 
            +
                cond_stage_config:
         | 
| 74 | 
            +
                  target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
         | 
| 75 | 
            +
                  params:
         | 
| 76 | 
            +
                    freeze: true
         | 
| 77 | 
            +
                    layer: penultimate
         | 
| 78 | 
            +
                    max_length: 200
         | 
    	
        lvdm/__pycache__/basics.cpython-312.pyc
    ADDED
    
    | Binary file (4.41 kB). View file | 
|  | 
    	
        lvdm/__pycache__/common.cpython-312.pyc
    ADDED
    
    | Binary file (6.16 kB). View file | 
|  | 
    	
        lvdm/__pycache__/distributions.cpython-312.pyc
    ADDED
    
    | Binary file (5.98 kB). View file | 
|  | 
    	
        lvdm/__pycache__/ema.cpython-312.pyc
    ADDED
    
    | Binary file (4.86 kB). View file | 
|  | 
    	
        lvdm/basics.py
    ADDED
    
    | @@ -0,0 +1,102 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # adopted from
         | 
| 2 | 
            +
            # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
         | 
| 3 | 
            +
            # and
         | 
| 4 | 
            +
            # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
         | 
| 5 | 
            +
            # and
         | 
| 6 | 
            +
            # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            # thanks!
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            from utils.utils import instantiate_from_config
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def disabled_train(self, mode=True):
         | 
| 15 | 
            +
                """Overwrite model.train with this function to make sure train/eval mode
         | 
| 16 | 
            +
                does not change anymore."""
         | 
| 17 | 
            +
                return self
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def zero_module(module):
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                Zero out the parameters of a module and return it.
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                for p in module.parameters():
         | 
| 25 | 
            +
                    p.detach().zero_()
         | 
| 26 | 
            +
                return module
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def scale_module(module, scale):
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
                Scale the parameters of a module and return it.
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                for p in module.parameters():
         | 
| 34 | 
            +
                    p.detach().mul_(scale)
         | 
| 35 | 
            +
                return module
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def conv_nd(dims, *args, **kwargs):
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                Create a 1D, 2D, or 3D convolution module.
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
                if dims == 1:
         | 
| 43 | 
            +
                    return nn.Conv1d(*args, **kwargs)
         | 
| 44 | 
            +
                elif dims == 2:
         | 
| 45 | 
            +
                    return nn.Conv2d(*args, **kwargs)
         | 
| 46 | 
            +
                elif dims == 3:
         | 
| 47 | 
            +
                    return nn.Conv3d(*args, **kwargs)
         | 
| 48 | 
            +
                raise ValueError(f"unsupported dimensions: {dims}")
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def linear(*args, **kwargs):
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                Create a linear module.
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
                return nn.Linear(*args, **kwargs)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def avg_pool_nd(dims, *args, **kwargs):
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
                Create a 1D, 2D, or 3D average pooling module.
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                if dims == 1:
         | 
| 63 | 
            +
                    return nn.AvgPool1d(*args, **kwargs)
         | 
| 64 | 
            +
                elif dims == 2:
         | 
| 65 | 
            +
                    return nn.AvgPool2d(*args, **kwargs)
         | 
| 66 | 
            +
                elif dims == 3:
         | 
| 67 | 
            +
                    return nn.AvgPool3d(*args, **kwargs)
         | 
| 68 | 
            +
                raise ValueError(f"unsupported dimensions: {dims}")
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            def nonlinearity(type="silu"):
         | 
| 72 | 
            +
                if type == "silu":
         | 
| 73 | 
            +
                    return nn.SiLU()
         | 
| 74 | 
            +
                elif type == "leaky_relu":
         | 
| 75 | 
            +
                    return nn.LeakyReLU()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            class GroupNormSpecific(nn.GroupNorm):
         | 
| 79 | 
            +
                def forward(self, x):
         | 
| 80 | 
            +
                    return super().forward(x).type(x.dtype)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def normalization(channels, num_groups=32):
         | 
| 84 | 
            +
                """
         | 
| 85 | 
            +
                Make a standard normalization layer.
         | 
| 86 | 
            +
                :param channels: number of input channels.
         | 
| 87 | 
            +
                :return: an nn.Module for normalization.
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                return GroupNormSpecific(num_groups, channels)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            class HybridConditioner(nn.Module):
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __init__(self, c_concat_config, c_crossattn_config):
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
                    self.concat_conditioner = instantiate_from_config(c_concat_config)
         | 
| 97 | 
            +
                    self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def forward(self, c_concat, c_crossattn):
         | 
| 100 | 
            +
                    c_concat = self.concat_conditioner(c_concat)
         | 
| 101 | 
            +
                    c_crossattn = self.crossattn_conditioner(c_crossattn)
         | 
| 102 | 
            +
                    return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
         | 
    	
        lvdm/common.py
    ADDED
    
    | @@ -0,0 +1,112 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from inspect import isfunction
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import nn
         | 
| 5 | 
            +
            import torch.distributed as dist
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def gather_data(data, return_np=True):
         | 
| 9 | 
            +
                """gather data from multiple processes to one list"""
         | 
| 10 | 
            +
                data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
         | 
| 11 | 
            +
                dist.all_gather(data_list, data)  # gather not supported with NCCL
         | 
| 12 | 
            +
                if return_np:
         | 
| 13 | 
            +
                    data_list = [data.cpu().numpy() for data in data_list]
         | 
| 14 | 
            +
                return data_list
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def autocast(f):
         | 
| 18 | 
            +
                def do_autocast(*args, **kwargs):
         | 
| 19 | 
            +
                    with torch.cuda.amp.autocast(
         | 
| 20 | 
            +
                        enabled=True,
         | 
| 21 | 
            +
                        dtype=torch.get_autocast_gpu_dtype(),
         | 
| 22 | 
            +
                        cache_enabled=torch.is_autocast_cache_enabled(),
         | 
| 23 | 
            +
                    ):
         | 
| 24 | 
            +
                        return f(*args, **kwargs)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                return do_autocast
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def extract_into_tensor(a, t, x_shape):
         | 
| 30 | 
            +
                b, *_ = t.shape
         | 
| 31 | 
            +
                out = a.gather(-1, t)
         | 
| 32 | 
            +
                return out.reshape(b, *((1,) * (len(x_shape) - 1)))
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def noise_like(shape, device, repeat=False):
         | 
| 36 | 
            +
                repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
         | 
| 37 | 
            +
                    shape[0], *((1,) * (len(shape) - 1))
         | 
| 38 | 
            +
                )
         | 
| 39 | 
            +
                noise = lambda: torch.randn(shape, device=device)
         | 
| 40 | 
            +
                return repeat_noise() if repeat else noise()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def default(val, d):
         | 
| 44 | 
            +
                if exists(val):
         | 
| 45 | 
            +
                    return val
         | 
| 46 | 
            +
                return d() if isfunction(d) else d
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def exists(val):
         | 
| 50 | 
            +
                return val is not None
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def identity(*args, **kwargs):
         | 
| 54 | 
            +
                return nn.Identity()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def uniq(arr):
         | 
| 58 | 
            +
                return {el: True for el in arr}.keys()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def mean_flat(tensor):
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                Take the mean over all non-batch dimensions.
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                return tensor.mean(dim=list(range(1, len(tensor.shape))))
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def ismap(x):
         | 
| 69 | 
            +
                if not isinstance(x, torch.Tensor):
         | 
| 70 | 
            +
                    return False
         | 
| 71 | 
            +
                return (len(x.shape) == 4) and (x.shape[1] > 3)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def isimage(x):
         | 
| 75 | 
            +
                if not isinstance(x, torch.Tensor):
         | 
| 76 | 
            +
                    return False
         | 
| 77 | 
            +
                return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def max_neg_value(t):
         | 
| 81 | 
            +
                return -torch.finfo(t.dtype).max
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            def shape_to_str(x):
         | 
| 85 | 
            +
                shape_str = "x".join([str(x) for x in x.shape])
         | 
| 86 | 
            +
                return shape_str
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def init_(tensor):
         | 
| 90 | 
            +
                dim = tensor.shape[-1]
         | 
| 91 | 
            +
                std = 1 / math.sqrt(dim)
         | 
| 92 | 
            +
                tensor.uniform_(-std, std)
         | 
| 93 | 
            +
                return tensor
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            ckpt = torch.utils.checkpoint.checkpoint
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def checkpoint(func, inputs, params, flag):
         | 
| 100 | 
            +
                """
         | 
| 101 | 
            +
                Evaluate a function without caching intermediate activations, allowing for
         | 
| 102 | 
            +
                reduced memory at the expense of extra compute in the backward pass.
         | 
| 103 | 
            +
                :param func: the function to evaluate.
         | 
| 104 | 
            +
                :param inputs: the argument sequence to pass to `func`.
         | 
| 105 | 
            +
                :param params: a sequence of parameters `func` depends on but does not
         | 
| 106 | 
            +
                               explicitly take as arguments.
         | 
| 107 | 
            +
                :param flag: if False, disable gradient checkpointing.
         | 
| 108 | 
            +
                """
         | 
| 109 | 
            +
                if flag:
         | 
| 110 | 
            +
                    return ckpt(func, *inputs)
         | 
| 111 | 
            +
                else:
         | 
| 112 | 
            +
                    return func(*inputs)
         | 
    	
        lvdm/distributions.py
    ADDED
    
    | @@ -0,0 +1,103 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class AbstractDistribution:
         | 
| 6 | 
            +
                def sample(self):
         | 
| 7 | 
            +
                    raise NotImplementedError()
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                def mode(self):
         | 
| 10 | 
            +
                    raise NotImplementedError()
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class DiracDistribution(AbstractDistribution):
         | 
| 14 | 
            +
                def __init__(self, value):
         | 
| 15 | 
            +
                    self.value = value
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def sample(self):
         | 
| 18 | 
            +
                    return self.value
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def mode(self):
         | 
| 21 | 
            +
                    return self.value
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class DiagonalGaussianDistribution(object):
         | 
| 25 | 
            +
                def __init__(self, parameters, deterministic=False):
         | 
| 26 | 
            +
                    self.parameters = parameters
         | 
| 27 | 
            +
                    self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
         | 
| 28 | 
            +
                    self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
         | 
| 29 | 
            +
                    self.deterministic = deterministic
         | 
| 30 | 
            +
                    self.std = torch.exp(0.5 * self.logvar)
         | 
| 31 | 
            +
                    self.var = torch.exp(self.logvar)
         | 
| 32 | 
            +
                    if self.deterministic:
         | 
| 33 | 
            +
                        self.var = self.std = torch.zeros_like(self.mean).to(
         | 
| 34 | 
            +
                            device=self.parameters.device
         | 
| 35 | 
            +
                        )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def sample(self, noise=None):
         | 
| 38 | 
            +
                    if noise is None:
         | 
| 39 | 
            +
                        noise = torch.randn(self.mean.shape)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    x = self.mean + self.std * noise.to(device=self.parameters.device)
         | 
| 42 | 
            +
                    return x
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def kl(self, other=None):
         | 
| 45 | 
            +
                    if self.deterministic:
         | 
| 46 | 
            +
                        return torch.Tensor([0.0])
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        if other is None:
         | 
| 49 | 
            +
                            return 0.5 * torch.sum(
         | 
| 50 | 
            +
                                torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
         | 
| 51 | 
            +
                                dim=[1, 2, 3],
         | 
| 52 | 
            +
                            )
         | 
| 53 | 
            +
                        else:
         | 
| 54 | 
            +
                            return 0.5 * torch.sum(
         | 
| 55 | 
            +
                                torch.pow(self.mean - other.mean, 2) / other.var
         | 
| 56 | 
            +
                                + self.var / other.var
         | 
| 57 | 
            +
                                - 1.0
         | 
| 58 | 
            +
                                - self.logvar
         | 
| 59 | 
            +
                                + other.logvar,
         | 
| 60 | 
            +
                                dim=[1, 2, 3],
         | 
| 61 | 
            +
                            )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def nll(self, sample, dims=[1, 2, 3]):
         | 
| 64 | 
            +
                    if self.deterministic:
         | 
| 65 | 
            +
                        return torch.Tensor([0.0])
         | 
| 66 | 
            +
                    logtwopi = np.log(2.0 * np.pi)
         | 
| 67 | 
            +
                    return 0.5 * torch.sum(
         | 
| 68 | 
            +
                        logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
         | 
| 69 | 
            +
                        dim=dims,
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def mode(self):
         | 
| 73 | 
            +
                    return self.mean
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def normal_kl(mean1, logvar1, mean2, logvar2):
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
                source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
         | 
| 79 | 
            +
                Compute the KL divergence between two gaussians.
         | 
| 80 | 
            +
                Shapes are automatically broadcasted, so batches can be compared to
         | 
| 81 | 
            +
                scalars, among other use cases.
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                tensor = None
         | 
| 84 | 
            +
                for obj in (mean1, logvar1, mean2, logvar2):
         | 
| 85 | 
            +
                    if isinstance(obj, torch.Tensor):
         | 
| 86 | 
            +
                        tensor = obj
         | 
| 87 | 
            +
                        break
         | 
| 88 | 
            +
                assert tensor is not None, "at least one argument must be a Tensor"
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                # Force variances to be Tensors. Broadcasting helps convert scalars to
         | 
| 91 | 
            +
                # Tensors, but it does not work for torch.exp().
         | 
| 92 | 
            +
                logvar1, logvar2 = [
         | 
| 93 | 
            +
                    x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
         | 
| 94 | 
            +
                    for x in (logvar1, logvar2)
         | 
| 95 | 
            +
                ]
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                return 0.5 * (
         | 
| 98 | 
            +
                    -1.0
         | 
| 99 | 
            +
                    + logvar2
         | 
| 100 | 
            +
                    - logvar1
         | 
| 101 | 
            +
                    + torch.exp(logvar1 - logvar2)
         | 
| 102 | 
            +
                    + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
         | 
| 103 | 
            +
                )
         | 
    	
        lvdm/ema.py
    ADDED
    
    | @@ -0,0 +1,84 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class LitEma(nn.Module):
         | 
| 6 | 
            +
                def __init__(self, model, decay=0.9999, use_num_upates=True):
         | 
| 7 | 
            +
                    super().__init__()
         | 
| 8 | 
            +
                    if decay < 0.0 or decay > 1.0:
         | 
| 9 | 
            +
                        raise ValueError("Decay must be between 0 and 1")
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                    self.m_name2s_name = {}
         | 
| 12 | 
            +
                    self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
         | 
| 13 | 
            +
                    self.register_buffer(
         | 
| 14 | 
            +
                        "num_updates",
         | 
| 15 | 
            +
                        (
         | 
| 16 | 
            +
                            torch.tensor(0, dtype=torch.int)
         | 
| 17 | 
            +
                            if use_num_upates
         | 
| 18 | 
            +
                            else torch.tensor(-1, dtype=torch.int)
         | 
| 19 | 
            +
                        ),
         | 
| 20 | 
            +
                    )
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    for name, p in model.named_parameters():
         | 
| 23 | 
            +
                        if p.requires_grad:
         | 
| 24 | 
            +
                            # remove as '.'-character is not allowed in buffers
         | 
| 25 | 
            +
                            s_name = name.replace(".", "")
         | 
| 26 | 
            +
                            self.m_name2s_name.update({name: s_name})
         | 
| 27 | 
            +
                            self.register_buffer(s_name, p.clone().detach().data)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    self.collected_params = []
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def forward(self, model):
         | 
| 32 | 
            +
                    decay = self.decay
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if self.num_updates >= 0:
         | 
| 35 | 
            +
                        self.num_updates += 1
         | 
| 36 | 
            +
                        decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    one_minus_decay = 1.0 - decay
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    with torch.no_grad():
         | 
| 41 | 
            +
                        m_param = dict(model.named_parameters())
         | 
| 42 | 
            +
                        shadow_params = dict(self.named_buffers())
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                        for key in m_param:
         | 
| 45 | 
            +
                            if m_param[key].requires_grad:
         | 
| 46 | 
            +
                                sname = self.m_name2s_name[key]
         | 
| 47 | 
            +
                                shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
         | 
| 48 | 
            +
                                shadow_params[sname].sub_(
         | 
| 49 | 
            +
                                    one_minus_decay * (shadow_params[sname] - m_param[key])
         | 
| 50 | 
            +
                                )
         | 
| 51 | 
            +
                            else:
         | 
| 52 | 
            +
                                assert not key in self.m_name2s_name
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def copy_to(self, model):
         | 
| 55 | 
            +
                    m_param = dict(model.named_parameters())
         | 
| 56 | 
            +
                    shadow_params = dict(self.named_buffers())
         | 
| 57 | 
            +
                    for key in m_param:
         | 
| 58 | 
            +
                        if m_param[key].requires_grad:
         | 
| 59 | 
            +
                            m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
         | 
| 60 | 
            +
                        else:
         | 
| 61 | 
            +
                            assert not key in self.m_name2s_name
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def store(self, parameters):
         | 
| 64 | 
            +
                    """
         | 
| 65 | 
            +
                    Save the current parameters for restoring later.
         | 
| 66 | 
            +
                    Args:
         | 
| 67 | 
            +
                      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
         | 
| 68 | 
            +
                        temporarily stored.
         | 
| 69 | 
            +
                    """
         | 
| 70 | 
            +
                    self.collected_params = [param.clone() for param in parameters]
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def restore(self, parameters):
         | 
| 73 | 
            +
                    """
         | 
| 74 | 
            +
                    Restore the parameters stored with the `store` method.
         | 
| 75 | 
            +
                    Useful to validate the model with EMA parameters without affecting the
         | 
| 76 | 
            +
                    original optimization process. Store the parameters before the
         | 
| 77 | 
            +
                    `copy_to` method. After validation (or model saving), use this to
         | 
| 78 | 
            +
                    restore the former parameters.
         | 
| 79 | 
            +
                    Args:
         | 
| 80 | 
            +
                      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
         | 
| 81 | 
            +
                        updated with the stored parameters.
         | 
| 82 | 
            +
                    """
         | 
| 83 | 
            +
                    for c_param, param in zip(self.collected_params, parameters):
         | 
| 84 | 
            +
                        param.data.copy_(c_param.data)
         | 
    	
        lvdm/models/__pycache__/autoencoder.cpython-312.pyc
    ADDED
    
    | Binary file (13.1 kB). View file | 
|  | 
    	
        lvdm/models/__pycache__/ddpm3d.cpython-312.pyc
    ADDED
    
    | Binary file (42.3 kB). View file | 
|  | 
    	
        lvdm/models/__pycache__/utils_diffusion.cpython-312.pyc
    ADDED
    
    | Binary file (6.25 kB). View file | 
|  | 
    	
        lvdm/models/autoencoder.py
    ADDED
    
    | @@ -0,0 +1,276 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from contextlib import contextmanager
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from einops import rearrange
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            import pytorch_lightning as pl
         | 
| 8 | 
            +
            from lvdm.modules.networks.ae_modules import Encoder, Decoder
         | 
| 9 | 
            +
            from lvdm.distributions import DiagonalGaussianDistribution
         | 
| 10 | 
            +
            from utils.utils import instantiate_from_config
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class AutoencoderKL(pl.LightningModule):
         | 
| 14 | 
            +
                def __init__(
         | 
| 15 | 
            +
                    self,
         | 
| 16 | 
            +
                    ddconfig,
         | 
| 17 | 
            +
                    lossconfig,
         | 
| 18 | 
            +
                    embed_dim,
         | 
| 19 | 
            +
                    ckpt_path=None,
         | 
| 20 | 
            +
                    ignore_keys=[],
         | 
| 21 | 
            +
                    image_key="image",
         | 
| 22 | 
            +
                    colorize_nlabels=None,
         | 
| 23 | 
            +
                    monitor=None,
         | 
| 24 | 
            +
                    test=False,
         | 
| 25 | 
            +
                    logdir=None,
         | 
| 26 | 
            +
                    input_dim=4,
         | 
| 27 | 
            +
                    test_args=None,
         | 
| 28 | 
            +
                ):
         | 
| 29 | 
            +
                    super().__init__()
         | 
| 30 | 
            +
                    self.image_key = image_key
         | 
| 31 | 
            +
                    self.encoder = Encoder(**ddconfig)
         | 
| 32 | 
            +
                    self.decoder = Decoder(**ddconfig)
         | 
| 33 | 
            +
                    self.loss = instantiate_from_config(lossconfig)
         | 
| 34 | 
            +
                    assert ddconfig["double_z"]
         | 
| 35 | 
            +
                    self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
         | 
| 36 | 
            +
                    self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
         | 
| 37 | 
            +
                    self.embed_dim = embed_dim
         | 
| 38 | 
            +
                    self.input_dim = input_dim
         | 
| 39 | 
            +
                    self.test = test
         | 
| 40 | 
            +
                    self.test_args = test_args
         | 
| 41 | 
            +
                    self.logdir = logdir
         | 
| 42 | 
            +
                    if colorize_nlabels is not None:
         | 
| 43 | 
            +
                        assert type(colorize_nlabels) == int
         | 
| 44 | 
            +
                        self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
         | 
| 45 | 
            +
                    if monitor is not None:
         | 
| 46 | 
            +
                        self.monitor = monitor
         | 
| 47 | 
            +
                    if ckpt_path is not None:
         | 
| 48 | 
            +
                        self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
         | 
| 49 | 
            +
                    if self.test:
         | 
| 50 | 
            +
                        self.init_test()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def init_test(
         | 
| 53 | 
            +
                    self,
         | 
| 54 | 
            +
                ):
         | 
| 55 | 
            +
                    self.test = True
         | 
| 56 | 
            +
                    save_dir = os.path.join(self.logdir, "test")
         | 
| 57 | 
            +
                    if "ckpt" in self.test_args:
         | 
| 58 | 
            +
                        ckpt_name = (
         | 
| 59 | 
            +
                            os.path.basename(self.test_args.ckpt).split(".ckpt")[0]
         | 
| 60 | 
            +
                            + f"_epoch{self._cur_epoch}"
         | 
| 61 | 
            +
                        )
         | 
| 62 | 
            +
                        self.root = os.path.join(save_dir, ckpt_name)
         | 
| 63 | 
            +
                    else:
         | 
| 64 | 
            +
                        self.root = save_dir
         | 
| 65 | 
            +
                    if "test_subdir" in self.test_args:
         | 
| 66 | 
            +
                        self.root = os.path.join(save_dir, self.test_args.test_subdir)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    self.root_zs = os.path.join(self.root, "zs")
         | 
| 69 | 
            +
                    self.root_dec = os.path.join(self.root, "reconstructions")
         | 
| 70 | 
            +
                    self.root_inputs = os.path.join(self.root, "inputs")
         | 
| 71 | 
            +
                    os.makedirs(self.root, exist_ok=True)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if self.test_args.save_z:
         | 
| 74 | 
            +
                        os.makedirs(self.root_zs, exist_ok=True)
         | 
| 75 | 
            +
                    if self.test_args.save_reconstruction:
         | 
| 76 | 
            +
                        os.makedirs(self.root_dec, exist_ok=True)
         | 
| 77 | 
            +
                    if self.test_args.save_input:
         | 
| 78 | 
            +
                        os.makedirs(self.root_inputs, exist_ok=True)
         | 
| 79 | 
            +
                    assert self.test_args is not None
         | 
| 80 | 
            +
                    self.test_maximum = getattr(self.test_args, "test_maximum", None)
         | 
| 81 | 
            +
                    self.count = 0
         | 
| 82 | 
            +
                    self.eval_metrics = {}
         | 
| 83 | 
            +
                    self.decodes = []
         | 
| 84 | 
            +
                    self.save_decode_samples = 2048
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def init_from_ckpt(self, path, ignore_keys=list()):
         | 
| 87 | 
            +
                    sd = torch.load(path, map_location="cpu")
         | 
| 88 | 
            +
                    try:
         | 
| 89 | 
            +
                        self._cur_epoch = sd["epoch"]
         | 
| 90 | 
            +
                        sd = sd["state_dict"]
         | 
| 91 | 
            +
                    except:
         | 
| 92 | 
            +
                        self._cur_epoch = "null"
         | 
| 93 | 
            +
                    keys = list(sd.keys())
         | 
| 94 | 
            +
                    for k in keys:
         | 
| 95 | 
            +
                        for ik in ignore_keys:
         | 
| 96 | 
            +
                            if k.startswith(ik):
         | 
| 97 | 
            +
                                print("Deleting key {} from state_dict.".format(k))
         | 
| 98 | 
            +
                                del sd[k]
         | 
| 99 | 
            +
                    self.load_state_dict(sd, strict=False)
         | 
| 100 | 
            +
                    # self.load_state_dict(sd, strict=True)
         | 
| 101 | 
            +
                    print(f"Restored from {path}")
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def encode(self, x, **kwargs):
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    h = self.encoder(x)
         | 
| 106 | 
            +
                    moments = self.quant_conv(h)
         | 
| 107 | 
            +
                    posterior = DiagonalGaussianDistribution(moments)
         | 
| 108 | 
            +
                    return posterior
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def decode(self, z, **kwargs):
         | 
| 111 | 
            +
                    z = self.post_quant_conv(z)
         | 
| 112 | 
            +
                    dec = self.decoder(z)
         | 
| 113 | 
            +
                    return dec
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def forward(self, input, sample_posterior=True):
         | 
| 116 | 
            +
                    posterior = self.encode(input)
         | 
| 117 | 
            +
                    if sample_posterior:
         | 
| 118 | 
            +
                        z = posterior.sample()
         | 
| 119 | 
            +
                    else:
         | 
| 120 | 
            +
                        z = posterior.mode()
         | 
| 121 | 
            +
                    dec = self.decode(z)
         | 
| 122 | 
            +
                    return dec, posterior
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def get_input(self, batch, k):
         | 
| 125 | 
            +
                    x = batch[k]
         | 
| 126 | 
            +
                    if x.dim() == 5 and self.input_dim == 4:
         | 
| 127 | 
            +
                        b, c, t, h, w = x.shape
         | 
| 128 | 
            +
                        self.b = b
         | 
| 129 | 
            +
                        self.t = t
         | 
| 130 | 
            +
                        x = rearrange(x, "b c t h w -> (b t) c h w")
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    return x
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def training_step(self, batch, batch_idx, optimizer_idx):
         | 
| 135 | 
            +
                    inputs = self.get_input(batch, self.image_key)
         | 
| 136 | 
            +
                    reconstructions, posterior = self(inputs)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    if optimizer_idx == 0:
         | 
| 139 | 
            +
                        # train encoder+decoder+logvar
         | 
| 140 | 
            +
                        aeloss, log_dict_ae = self.loss(
         | 
| 141 | 
            +
                            inputs,
         | 
| 142 | 
            +
                            reconstructions,
         | 
| 143 | 
            +
                            posterior,
         | 
| 144 | 
            +
                            optimizer_idx,
         | 
| 145 | 
            +
                            self.global_step,
         | 
| 146 | 
            +
                            last_layer=self.get_last_layer(),
         | 
| 147 | 
            +
                            split="train",
         | 
| 148 | 
            +
                        )
         | 
| 149 | 
            +
                        self.log(
         | 
| 150 | 
            +
                            "aeloss",
         | 
| 151 | 
            +
                            aeloss,
         | 
| 152 | 
            +
                            prog_bar=True,
         | 
| 153 | 
            +
                            logger=True,
         | 
| 154 | 
            +
                            on_step=True,
         | 
| 155 | 
            +
                            on_epoch=True,
         | 
| 156 | 
            +
                        )
         | 
| 157 | 
            +
                        self.log_dict(
         | 
| 158 | 
            +
                            log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
                        return aeloss
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    if optimizer_idx == 1:
         | 
| 163 | 
            +
                        # train the discriminator
         | 
| 164 | 
            +
                        discloss, log_dict_disc = self.loss(
         | 
| 165 | 
            +
                            inputs,
         | 
| 166 | 
            +
                            reconstructions,
         | 
| 167 | 
            +
                            posterior,
         | 
| 168 | 
            +
                            optimizer_idx,
         | 
| 169 | 
            +
                            self.global_step,
         | 
| 170 | 
            +
                            last_layer=self.get_last_layer(),
         | 
| 171 | 
            +
                            split="train",
         | 
| 172 | 
            +
                        )
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                        self.log(
         | 
| 175 | 
            +
                            "discloss",
         | 
| 176 | 
            +
                            discloss,
         | 
| 177 | 
            +
                            prog_bar=True,
         | 
| 178 | 
            +
                            logger=True,
         | 
| 179 | 
            +
                            on_step=True,
         | 
| 180 | 
            +
                            on_epoch=True,
         | 
| 181 | 
            +
                        )
         | 
| 182 | 
            +
                        self.log_dict(
         | 
| 183 | 
            +
                            log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
         | 
| 184 | 
            +
                        )
         | 
| 185 | 
            +
                        return discloss
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def validation_step(self, batch, batch_idx):
         | 
| 188 | 
            +
                    inputs = self.get_input(batch, self.image_key)
         | 
| 189 | 
            +
                    reconstructions, posterior = self(inputs)
         | 
| 190 | 
            +
                    aeloss, log_dict_ae = self.loss(
         | 
| 191 | 
            +
                        inputs,
         | 
| 192 | 
            +
                        reconstructions,
         | 
| 193 | 
            +
                        posterior,
         | 
| 194 | 
            +
                        0,
         | 
| 195 | 
            +
                        self.global_step,
         | 
| 196 | 
            +
                        last_layer=self.get_last_layer(),
         | 
| 197 | 
            +
                        split="val",
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    discloss, log_dict_disc = self.loss(
         | 
| 201 | 
            +
                        inputs,
         | 
| 202 | 
            +
                        reconstructions,
         | 
| 203 | 
            +
                        posterior,
         | 
| 204 | 
            +
                        1,
         | 
| 205 | 
            +
                        self.global_step,
         | 
| 206 | 
            +
                        last_layer=self.get_last_layer(),
         | 
| 207 | 
            +
                        split="val",
         | 
| 208 | 
            +
                    )
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
         | 
| 211 | 
            +
                    self.log_dict(log_dict_ae)
         | 
| 212 | 
            +
                    self.log_dict(log_dict_disc)
         | 
| 213 | 
            +
                    return self.log_dict
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def configure_optimizers(self):
         | 
| 216 | 
            +
                    lr = self.learning_rate
         | 
| 217 | 
            +
                    opt_ae = torch.optim.Adam(
         | 
| 218 | 
            +
                        list(self.encoder.parameters())
         | 
| 219 | 
            +
                        + list(self.decoder.parameters())
         | 
| 220 | 
            +
                        + list(self.quant_conv.parameters())
         | 
| 221 | 
            +
                        + list(self.post_quant_conv.parameters()),
         | 
| 222 | 
            +
                        lr=lr,
         | 
| 223 | 
            +
                        betas=(0.5, 0.9),
         | 
| 224 | 
            +
                    )
         | 
| 225 | 
            +
                    opt_disc = torch.optim.Adam(
         | 
| 226 | 
            +
                        self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
         | 
| 227 | 
            +
                    )
         | 
| 228 | 
            +
                    return [opt_ae, opt_disc], []
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                def get_last_layer(self):
         | 
| 231 | 
            +
                    return self.decoder.conv_out.weight
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                @torch.no_grad()
         | 
| 234 | 
            +
                def log_images(self, batch, only_inputs=False, **kwargs):
         | 
| 235 | 
            +
                    log = dict()
         | 
| 236 | 
            +
                    x = self.get_input(batch, self.image_key)
         | 
| 237 | 
            +
                    x = x.to(self.device)
         | 
| 238 | 
            +
                    if not only_inputs:
         | 
| 239 | 
            +
                        xrec, posterior = self(x)
         | 
| 240 | 
            +
                        if x.shape[1] > 3:
         | 
| 241 | 
            +
                            # colorize with random projection
         | 
| 242 | 
            +
                            assert xrec.shape[1] > 3
         | 
| 243 | 
            +
                            x = self.to_rgb(x)
         | 
| 244 | 
            +
                            xrec = self.to_rgb(xrec)
         | 
| 245 | 
            +
                        log["samples"] = self.decode(torch.randn_like(posterior.sample()))
         | 
| 246 | 
            +
                        log["reconstructions"] = xrec
         | 
| 247 | 
            +
                    log["inputs"] = x
         | 
| 248 | 
            +
                    return log
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                def to_rgb(self, x):
         | 
| 251 | 
            +
                    assert self.image_key == "segmentation"
         | 
| 252 | 
            +
                    if not hasattr(self, "colorize"):
         | 
| 253 | 
            +
                        self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
         | 
| 254 | 
            +
                    x = F.conv2d(x, weight=self.colorize)
         | 
| 255 | 
            +
                    x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
         | 
| 256 | 
            +
                    return x
         | 
| 257 | 
            +
             | 
| 258 | 
            +
             | 
| 259 | 
            +
            class IdentityFirstStage(torch.nn.Module):
         | 
| 260 | 
            +
                def __init__(self, *args, vq_interface=False, **kwargs):
         | 
| 261 | 
            +
                    self.vq_interface = vq_interface  # TODO: Should be true by default but check to not break older stuff
         | 
| 262 | 
            +
                    super().__init__()
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                def encode(self, x, *args, **kwargs):
         | 
| 265 | 
            +
                    return x
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def decode(self, x, *args, **kwargs):
         | 
| 268 | 
            +
                    return x
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                def quantize(self, x, *args, **kwargs):
         | 
| 271 | 
            +
                    if self.vq_interface:
         | 
| 272 | 
            +
                        return x, None, [None, None, None]
         | 
| 273 | 
            +
                    return x
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                def forward(self, x, *args, **kwargs):
         | 
| 276 | 
            +
                    return x
         | 
    	
        lvdm/models/ddpm3d.py
    ADDED
    
    | @@ -0,0 +1,967 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            wild mixture of
         | 
| 3 | 
            +
            https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
         | 
| 4 | 
            +
            https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
         | 
| 5 | 
            +
            https://github.com/CompVis/taming-transformers
         | 
| 6 | 
            +
            -- merci
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from functools import partial
         | 
| 10 | 
            +
            from contextlib import contextmanager
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
            from einops import rearrange, repeat
         | 
| 14 | 
            +
            import logging
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            mainlogger = logging.getLogger("mainlogger")
         | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            import torch.nn as nn
         | 
| 19 | 
            +
            from torchvision.utils import make_grid
         | 
| 20 | 
            +
            import pytorch_lightning as pl
         | 
| 21 | 
            +
            from utils.utils import instantiate_from_config
         | 
| 22 | 
            +
            from lvdm.ema import LitEma
         | 
| 23 | 
            +
            from lvdm.distributions import DiagonalGaussianDistribution
         | 
| 24 | 
            +
            from lvdm.models.utils_diffusion import make_beta_schedule
         | 
| 25 | 
            +
            from lvdm.modules.encoders.ip_resampler import ImageProjModel, Resampler
         | 
| 26 | 
            +
            from lvdm.basics import disabled_train
         | 
| 27 | 
            +
            from lvdm.common import extract_into_tensor, noise_like, exists, default
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class DDPM(pl.LightningModule):
         | 
| 34 | 
            +
                # classic DDPM with Gaussian diffusion, in image space
         | 
| 35 | 
            +
                def __init__(
         | 
| 36 | 
            +
                    self,
         | 
| 37 | 
            +
                    unet_config,
         | 
| 38 | 
            +
                    timesteps=1000,
         | 
| 39 | 
            +
                    beta_schedule="linear",
         | 
| 40 | 
            +
                    loss_type="l2",
         | 
| 41 | 
            +
                    ckpt_path=None,
         | 
| 42 | 
            +
                    ignore_keys=[],
         | 
| 43 | 
            +
                    load_only_unet=False,
         | 
| 44 | 
            +
                    monitor=None,
         | 
| 45 | 
            +
                    use_ema=True,
         | 
| 46 | 
            +
                    first_stage_key="image",
         | 
| 47 | 
            +
                    image_size=256,
         | 
| 48 | 
            +
                    channels=3,
         | 
| 49 | 
            +
                    log_every_t=100,
         | 
| 50 | 
            +
                    clip_denoised=True,
         | 
| 51 | 
            +
                    linear_start=1e-4,
         | 
| 52 | 
            +
                    linear_end=2e-2,
         | 
| 53 | 
            +
                    cosine_s=8e-3,
         | 
| 54 | 
            +
                    given_betas=None,
         | 
| 55 | 
            +
                    original_elbo_weight=0.0,
         | 
| 56 | 
            +
                    v_posterior=0.0,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
         | 
| 57 | 
            +
                    l_simple_weight=1.0,
         | 
| 58 | 
            +
                    conditioning_key=None,
         | 
| 59 | 
            +
                    parameterization="eps",  # all assuming fixed variance schedules
         | 
| 60 | 
            +
                    scheduler_config=None,
         | 
| 61 | 
            +
                    use_positional_encodings=False,
         | 
| 62 | 
            +
                    learn_logvar=False,
         | 
| 63 | 
            +
                    logvar_init=0.0,
         | 
| 64 | 
            +
                ):
         | 
| 65 | 
            +
                    super().__init__()
         | 
| 66 | 
            +
                    assert parameterization in [
         | 
| 67 | 
            +
                        "eps",
         | 
| 68 | 
            +
                        "x0",
         | 
| 69 | 
            +
                    ], 'currently only supporting "eps" and "x0"'
         | 
| 70 | 
            +
                    self.parameterization = parameterization
         | 
| 71 | 
            +
                    mainlogger.info(
         | 
| 72 | 
            +
                        f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
         | 
| 73 | 
            +
                    )
         | 
| 74 | 
            +
                    self.cond_stage_model = None
         | 
| 75 | 
            +
                    self.clip_denoised = clip_denoised
         | 
| 76 | 
            +
                    self.log_every_t = log_every_t
         | 
| 77 | 
            +
                    self.first_stage_key = first_stage_key
         | 
| 78 | 
            +
                    self.channels = channels
         | 
| 79 | 
            +
                    self.temporal_length = unet_config.params.temporal_length
         | 
| 80 | 
            +
                    self.image_size = image_size
         | 
| 81 | 
            +
                    if isinstance(self.image_size, int):
         | 
| 82 | 
            +
                        self.image_size = [self.image_size, self.image_size]
         | 
| 83 | 
            +
                    self.use_positional_encodings = use_positional_encodings
         | 
| 84 | 
            +
                    self.model = DiffusionWrapper(unet_config, conditioning_key)
         | 
| 85 | 
            +
                    self.use_ema = use_ema
         | 
| 86 | 
            +
                    if self.use_ema:
         | 
| 87 | 
            +
                        self.model_ema = LitEma(self.model)
         | 
| 88 | 
            +
                        mainlogger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    self.use_scheduler = scheduler_config is not None
         | 
| 91 | 
            +
                    if self.use_scheduler:
         | 
| 92 | 
            +
                        self.scheduler_config = scheduler_config
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    self.v_posterior = v_posterior
         | 
| 95 | 
            +
                    self.original_elbo_weight = original_elbo_weight
         | 
| 96 | 
            +
                    self.l_simple_weight = l_simple_weight
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    if monitor is not None:
         | 
| 99 | 
            +
                        self.monitor = monitor
         | 
| 100 | 
            +
                    if ckpt_path is not None:
         | 
| 101 | 
            +
                        self.init_from_ckpt(
         | 
| 102 | 
            +
                            ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
         | 
| 103 | 
            +
                        )
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    self.register_schedule(
         | 
| 106 | 
            +
                        given_betas=given_betas,
         | 
| 107 | 
            +
                        beta_schedule=beta_schedule,
         | 
| 108 | 
            +
                        timesteps=timesteps,
         | 
| 109 | 
            +
                        linear_start=linear_start,
         | 
| 110 | 
            +
                        linear_end=linear_end,
         | 
| 111 | 
            +
                        cosine_s=cosine_s,
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    self.loss_type = loss_type
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    self.learn_logvar = learn_logvar
         | 
| 117 | 
            +
                    self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
         | 
| 118 | 
            +
                    if self.learn_logvar:
         | 
| 119 | 
            +
                        self.logvar = nn.Parameter(self.logvar, requires_grad=True)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def register_schedule(
         | 
| 122 | 
            +
                    self,
         | 
| 123 | 
            +
                    given_betas=None,
         | 
| 124 | 
            +
                    beta_schedule="linear",
         | 
| 125 | 
            +
                    timesteps=1000,
         | 
| 126 | 
            +
                    linear_start=1e-4,
         | 
| 127 | 
            +
                    linear_end=2e-2,
         | 
| 128 | 
            +
                    cosine_s=8e-3,
         | 
| 129 | 
            +
                ):
         | 
| 130 | 
            +
                    if exists(given_betas):
         | 
| 131 | 
            +
                        betas = given_betas
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        betas = make_beta_schedule(
         | 
| 134 | 
            +
                            beta_schedule,
         | 
| 135 | 
            +
                            timesteps,
         | 
| 136 | 
            +
                            linear_start=linear_start,
         | 
| 137 | 
            +
                            linear_end=linear_end,
         | 
| 138 | 
            +
                            cosine_s=cosine_s,
         | 
| 139 | 
            +
                        )
         | 
| 140 | 
            +
                    alphas = 1.0 - betas
         | 
| 141 | 
            +
                    alphas_cumprod = np.cumprod(alphas, axis=0)
         | 
| 142 | 
            +
                    alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    (timesteps,) = betas.shape
         | 
| 145 | 
            +
                    self.num_timesteps = int(timesteps)
         | 
| 146 | 
            +
                    self.linear_start = linear_start
         | 
| 147 | 
            +
                    self.linear_end = linear_end
         | 
| 148 | 
            +
                    assert (
         | 
| 149 | 
            +
                        alphas_cumprod.shape[0] == self.num_timesteps
         | 
| 150 | 
            +
                    ), "alphas have to be defined for each timestep"
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    to_torch = partial(torch.tensor, dtype=torch.float32)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    self.register_buffer("betas", to_torch(betas))
         | 
| 155 | 
            +
                    self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
         | 
| 156 | 
            +
                    self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    # calculations for diffusion q(x_t | x_{t-1}) and others
         | 
| 159 | 
            +
                    self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
         | 
| 160 | 
            +
                    self.register_buffer(
         | 
| 161 | 
            +
                        "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
         | 
| 162 | 
            +
                    )
         | 
| 163 | 
            +
                    self.register_buffer(
         | 
| 164 | 
            +
                        "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
         | 
| 165 | 
            +
                    )
         | 
| 166 | 
            +
                    self.register_buffer(
         | 
| 167 | 
            +
                        "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
         | 
| 168 | 
            +
                    )
         | 
| 169 | 
            +
                    self.register_buffer(
         | 
| 170 | 
            +
                        "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
         | 
| 171 | 
            +
                    )
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    # calculations for posterior q(x_{t-1} | x_t, x_0)
         | 
| 174 | 
            +
                    posterior_variance = (1 - self.v_posterior) * betas * (
         | 
| 175 | 
            +
                        1.0 - alphas_cumprod_prev
         | 
| 176 | 
            +
                    ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
         | 
| 177 | 
            +
                    # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
         | 
| 178 | 
            +
                    self.register_buffer("posterior_variance", to_torch(posterior_variance))
         | 
| 179 | 
            +
                    # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
         | 
| 180 | 
            +
                    self.register_buffer(
         | 
| 181 | 
            +
                        "posterior_log_variance_clipped",
         | 
| 182 | 
            +
                        to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
         | 
| 183 | 
            +
                    )
         | 
| 184 | 
            +
                    self.register_buffer(
         | 
| 185 | 
            +
                        "posterior_mean_coef1",
         | 
| 186 | 
            +
                        to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
         | 
| 187 | 
            +
                    )
         | 
| 188 | 
            +
                    self.register_buffer(
         | 
| 189 | 
            +
                        "posterior_mean_coef2",
         | 
| 190 | 
            +
                        to_torch(
         | 
| 191 | 
            +
                            (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
         | 
| 192 | 
            +
                        ),
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if self.parameterization == "eps":
         | 
| 196 | 
            +
                        lvlb_weights = self.betas**2 / (
         | 
| 197 | 
            +
                            2
         | 
| 198 | 
            +
                            * self.posterior_variance
         | 
| 199 | 
            +
                            * to_torch(alphas)
         | 
| 200 | 
            +
                            * (1 - self.alphas_cumprod)
         | 
| 201 | 
            +
                        )
         | 
| 202 | 
            +
                    elif self.parameterization == "x0":
         | 
| 203 | 
            +
                        lvlb_weights = (
         | 
| 204 | 
            +
                            0.5
         | 
| 205 | 
            +
                            * np.sqrt(torch.Tensor(alphas_cumprod))
         | 
| 206 | 
            +
                            / (2.0 * 1 - torch.Tensor(alphas_cumprod))
         | 
| 207 | 
            +
                        )
         | 
| 208 | 
            +
                    else:
         | 
| 209 | 
            +
                        raise NotImplementedError("mu not supported")
         | 
| 210 | 
            +
                    # TODO how to choose this term
         | 
| 211 | 
            +
                    lvlb_weights[0] = lvlb_weights[1]
         | 
| 212 | 
            +
                    self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
         | 
| 213 | 
            +
                    assert not torch.isnan(self.lvlb_weights).all()
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                @contextmanager
         | 
| 216 | 
            +
                def ema_scope(self, context=None):
         | 
| 217 | 
            +
                    if self.use_ema:
         | 
| 218 | 
            +
                        self.model_ema.store(self.model.parameters())
         | 
| 219 | 
            +
                        self.model_ema.copy_to(self.model)
         | 
| 220 | 
            +
                        if context is not None:
         | 
| 221 | 
            +
                            mainlogger.info(f"{context}: Switched to EMA weights")
         | 
| 222 | 
            +
                    try:
         | 
| 223 | 
            +
                        yield None
         | 
| 224 | 
            +
                    finally:
         | 
| 225 | 
            +
                        if self.use_ema:
         | 
| 226 | 
            +
                            self.model_ema.restore(self.model.parameters())
         | 
| 227 | 
            +
                            if context is not None:
         | 
| 228 | 
            +
                                mainlogger.info(f"{context}: Restored training weights")
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
         | 
| 231 | 
            +
                    sd = torch.load(path, map_location="cpu")
         | 
| 232 | 
            +
                    if "state_dict" in list(sd.keys()):
         | 
| 233 | 
            +
                        sd = sd["state_dict"]
         | 
| 234 | 
            +
                    keys = list(sd.keys())
         | 
| 235 | 
            +
                    for k in keys:
         | 
| 236 | 
            +
                        for ik in ignore_keys:
         | 
| 237 | 
            +
                            if k.startswith(ik):
         | 
| 238 | 
            +
                                mainlogger.info("Deleting key {} from state_dict.".format(k))
         | 
| 239 | 
            +
                                del sd[k]
         | 
| 240 | 
            +
                    missing, unexpected = (
         | 
| 241 | 
            +
                        self.load_state_dict(sd, strict=False)
         | 
| 242 | 
            +
                        if not only_model
         | 
| 243 | 
            +
                        else self.model.load_state_dict(sd, strict=False)
         | 
| 244 | 
            +
                    )
         | 
| 245 | 
            +
                    mainlogger.info(
         | 
| 246 | 
            +
                        f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
         | 
| 247 | 
            +
                    )
         | 
| 248 | 
            +
                    if len(missing) > 0:
         | 
| 249 | 
            +
                        mainlogger.info(f"Missing Keys: {missing}")
         | 
| 250 | 
            +
                    if len(unexpected) > 0:
         | 
| 251 | 
            +
                        mainlogger.info(f"Unexpected Keys: {unexpected}")
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def q_mean_variance(self, x_start, t):
         | 
| 254 | 
            +
                    """
         | 
| 255 | 
            +
                    Get the distribution q(x_t | x_0).
         | 
| 256 | 
            +
                    :param x_start: the [N x C x ...] tensor of noiseless inputs.
         | 
| 257 | 
            +
                    :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
         | 
| 258 | 
            +
                    :return: A tuple (mean, variance, log_variance), all of x_start's shape.
         | 
| 259 | 
            +
                    """
         | 
| 260 | 
            +
                    mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
         | 
| 261 | 
            +
                    variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
         | 
| 262 | 
            +
                    log_variance = extract_into_tensor(
         | 
| 263 | 
            +
                        self.log_one_minus_alphas_cumprod, t, x_start.shape
         | 
| 264 | 
            +
                    )
         | 
| 265 | 
            +
                    return mean, variance, log_variance
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def predict_start_from_noise(self, x_t, t, noise):
         | 
| 268 | 
            +
                    return (
         | 
| 269 | 
            +
                        extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
         | 
| 270 | 
            +
                        - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
         | 
| 271 | 
            +
                        * noise
         | 
| 272 | 
            +
                    )
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                def q_posterior(self, x_start, x_t, t):
         | 
| 275 | 
            +
                    posterior_mean = (
         | 
| 276 | 
            +
                        extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
         | 
| 277 | 
            +
                        + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
         | 
| 278 | 
            +
                    )
         | 
| 279 | 
            +
                    posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
         | 
| 280 | 
            +
                    posterior_log_variance_clipped = extract_into_tensor(
         | 
| 281 | 
            +
                        self.posterior_log_variance_clipped, t, x_t.shape
         | 
| 282 | 
            +
                    )
         | 
| 283 | 
            +
                    return posterior_mean, posterior_variance, posterior_log_variance_clipped
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                def p_mean_variance(self, x, t, clip_denoised: bool):
         | 
| 286 | 
            +
                    model_out = self.model(x, t)
         | 
| 287 | 
            +
                    if self.parameterization == "eps":
         | 
| 288 | 
            +
                        x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
         | 
| 289 | 
            +
                    elif self.parameterization == "x0":
         | 
| 290 | 
            +
                        x_recon = model_out
         | 
| 291 | 
            +
                    if clip_denoised:
         | 
| 292 | 
            +
                        x_recon.clamp_(-1.0, 1.0)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
         | 
| 295 | 
            +
                        x_start=x_recon, x_t=x, t=t
         | 
| 296 | 
            +
                    )
         | 
| 297 | 
            +
                    return model_mean, posterior_variance, posterior_log_variance
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                @torch.no_grad()
         | 
| 300 | 
            +
                def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
         | 
| 301 | 
            +
                    b, *_, device = *x.shape, x.device
         | 
| 302 | 
            +
                    model_mean, _, model_log_variance = self.p_mean_variance(
         | 
| 303 | 
            +
                        x=x, t=t, clip_denoised=clip_denoised
         | 
| 304 | 
            +
                    )
         | 
| 305 | 
            +
                    noise = noise_like(x.shape, device, repeat_noise)
         | 
| 306 | 
            +
                    # no noise when t == 0
         | 
| 307 | 
            +
                    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
         | 
| 308 | 
            +
                    return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                @torch.no_grad()
         | 
| 311 | 
            +
                def p_sample_loop(self, shape, return_intermediates=False):
         | 
| 312 | 
            +
                    device = self.betas.device
         | 
| 313 | 
            +
                    b = shape[0]
         | 
| 314 | 
            +
                    img = torch.randn(shape, device=device)
         | 
| 315 | 
            +
                    intermediates = [img]
         | 
| 316 | 
            +
                    for i in tqdm(
         | 
| 317 | 
            +
                        reversed(range(0, self.num_timesteps)),
         | 
| 318 | 
            +
                        desc="Sampling t",
         | 
| 319 | 
            +
                        total=self.num_timesteps,
         | 
| 320 | 
            +
                    ):
         | 
| 321 | 
            +
                        img = self.p_sample(
         | 
| 322 | 
            +
                            img,
         | 
| 323 | 
            +
                            torch.full((b,), i, device=device, dtype=torch.long),
         | 
| 324 | 
            +
                            clip_denoised=self.clip_denoised,
         | 
| 325 | 
            +
                        )
         | 
| 326 | 
            +
                        if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
         | 
| 327 | 
            +
                            intermediates.append(img)
         | 
| 328 | 
            +
                    if return_intermediates:
         | 
| 329 | 
            +
                        return img, intermediates
         | 
| 330 | 
            +
                    return img
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                @torch.no_grad()
         | 
| 333 | 
            +
                def sample(self, batch_size=16, return_intermediates=False):
         | 
| 334 | 
            +
                    image_size = self.image_size
         | 
| 335 | 
            +
                    channels = self.channels
         | 
| 336 | 
            +
                    return self.p_sample_loop(
         | 
| 337 | 
            +
                        (batch_size, channels, image_size, image_size),
         | 
| 338 | 
            +
                        return_intermediates=return_intermediates,
         | 
| 339 | 
            +
                    )
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                def q_sample(self, x_start, t, noise=None):
         | 
| 342 | 
            +
                    noise = default(noise, lambda: torch.randn_like(x_start))
         | 
| 343 | 
            +
                    return (
         | 
| 344 | 
            +
                        extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
         | 
| 345 | 
            +
                        * x_start
         | 
| 346 | 
            +
                        * extract_into_tensor(self.scale_arr, t, x_start.shape)
         | 
| 347 | 
            +
                        + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
         | 
| 348 | 
            +
                        * noise
         | 
| 349 | 
            +
                    )
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                def get_input(self, batch, k):
         | 
| 352 | 
            +
                    x = batch[k]
         | 
| 353 | 
            +
                    x = x.to(memory_format=torch.contiguous_format).float()
         | 
| 354 | 
            +
                    return x
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                def _get_rows_from_list(self, samples):
         | 
| 357 | 
            +
                    n_imgs_per_row = len(samples)
         | 
| 358 | 
            +
                    denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
         | 
| 359 | 
            +
                    denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
         | 
| 360 | 
            +
                    denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
         | 
| 361 | 
            +
                    return denoise_grid
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                @torch.no_grad()
         | 
| 364 | 
            +
                def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
         | 
| 365 | 
            +
                    log = dict()
         | 
| 366 | 
            +
                    x = self.get_input(batch, self.first_stage_key)
         | 
| 367 | 
            +
                    N = min(x.shape[0], N)
         | 
| 368 | 
            +
                    n_row = min(x.shape[0], n_row)
         | 
| 369 | 
            +
                    x = x.to(self.device)[:N]
         | 
| 370 | 
            +
                    log["inputs"] = x
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    # get diffusion row
         | 
| 373 | 
            +
                    diffusion_row = list()
         | 
| 374 | 
            +
                    x_start = x[:n_row]
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    for t in range(self.num_timesteps):
         | 
| 377 | 
            +
                        if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
         | 
| 378 | 
            +
                            t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
         | 
| 379 | 
            +
                            t = t.to(self.device).long()
         | 
| 380 | 
            +
                            noise = torch.randn_like(x_start)
         | 
| 381 | 
            +
                            x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
         | 
| 382 | 
            +
                            diffusion_row.append(x_noisy)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    if sample:
         | 
| 387 | 
            +
                        # get denoise row
         | 
| 388 | 
            +
                        with self.ema_scope("Plotting"):
         | 
| 389 | 
            +
                            samples, denoise_row = self.sample(
         | 
| 390 | 
            +
                                batch_size=N, return_intermediates=True
         | 
| 391 | 
            +
                            )
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                        log["samples"] = samples
         | 
| 394 | 
            +
                        log["denoise_row"] = self._get_rows_from_list(denoise_row)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    if return_keys:
         | 
| 397 | 
            +
                        if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
         | 
| 398 | 
            +
                            return log
         | 
| 399 | 
            +
                        else:
         | 
| 400 | 
            +
                            return {key: log[key] for key in return_keys}
         | 
| 401 | 
            +
                    return log
         | 
| 402 | 
            +
             | 
| 403 | 
            +
             | 
| 404 | 
            +
            class LatentDiffusion(DDPM):
         | 
| 405 | 
            +
                """main class"""
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                def __init__(
         | 
| 408 | 
            +
                    self,
         | 
| 409 | 
            +
                    first_stage_config,
         | 
| 410 | 
            +
                    cond_stage_config,
         | 
| 411 | 
            +
                    num_timesteps_cond=None,
         | 
| 412 | 
            +
                    cond_stage_key="caption",
         | 
| 413 | 
            +
                    cond_stage_trainable=False,
         | 
| 414 | 
            +
                    cond_stage_forward=None,
         | 
| 415 | 
            +
                    conditioning_key=None,
         | 
| 416 | 
            +
                    uncond_prob=0.2,
         | 
| 417 | 
            +
                    uncond_type="empty_seq",
         | 
| 418 | 
            +
                    scale_factor=1.0,
         | 
| 419 | 
            +
                    scale_by_std=False,
         | 
| 420 | 
            +
                    encoder_type="2d",
         | 
| 421 | 
            +
                    only_model=False,
         | 
| 422 | 
            +
                    use_scale=False,
         | 
| 423 | 
            +
                    scale_a=1,
         | 
| 424 | 
            +
                    scale_b=0.3,
         | 
| 425 | 
            +
                    mid_step=400,
         | 
| 426 | 
            +
                    fix_scale_bug=False,
         | 
| 427 | 
            +
                    *args,
         | 
| 428 | 
            +
                    **kwargs,
         | 
| 429 | 
            +
                ):
         | 
| 430 | 
            +
                    self.num_timesteps_cond = default(num_timesteps_cond, 1)
         | 
| 431 | 
            +
                    self.scale_by_std = scale_by_std
         | 
| 432 | 
            +
                    assert self.num_timesteps_cond <= kwargs["timesteps"]
         | 
| 433 | 
            +
                    # for backwards compatibility after implementation of DiffusionWrapper
         | 
| 434 | 
            +
                    ckpt_path = kwargs.pop("ckpt_path", None)
         | 
| 435 | 
            +
                    ignore_keys = kwargs.pop("ignore_keys", [])
         | 
| 436 | 
            +
                    conditioning_key = default(conditioning_key, "crossattn")
         | 
| 437 | 
            +
                    super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    self.cond_stage_trainable = cond_stage_trainable
         | 
| 440 | 
            +
                    self.cond_stage_key = cond_stage_key
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    # scale factor
         | 
| 443 | 
            +
                    self.use_scale = use_scale
         | 
| 444 | 
            +
                    if self.use_scale:
         | 
| 445 | 
            +
                        self.scale_a = scale_a
         | 
| 446 | 
            +
                        self.scale_b = scale_b
         | 
| 447 | 
            +
                        if fix_scale_bug:
         | 
| 448 | 
            +
                            scale_step = self.num_timesteps - mid_step
         | 
| 449 | 
            +
                        else:  # bug
         | 
| 450 | 
            +
                            scale_step = self.num_timesteps
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                        scale_arr1 = np.linspace(scale_a, scale_b, mid_step)
         | 
| 453 | 
            +
                        scale_arr2 = np.full(scale_step, scale_b)
         | 
| 454 | 
            +
                        scale_arr = np.concatenate((scale_arr1, scale_arr2))
         | 
| 455 | 
            +
                        scale_arr_prev = np.append(scale_a, scale_arr[:-1])
         | 
| 456 | 
            +
                        to_torch = partial(torch.tensor, dtype=torch.float32)
         | 
| 457 | 
            +
                        self.register_buffer("scale_arr", to_torch(scale_arr))
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    try:
         | 
| 460 | 
            +
                        self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
         | 
| 461 | 
            +
                    except:
         | 
| 462 | 
            +
                        self.num_downs = 0
         | 
| 463 | 
            +
                    if not scale_by_std:
         | 
| 464 | 
            +
                        self.scale_factor = scale_factor
         | 
| 465 | 
            +
                    else:
         | 
| 466 | 
            +
                        self.register_buffer("scale_factor", torch.tensor(scale_factor))
         | 
| 467 | 
            +
                    self.instantiate_first_stage(first_stage_config)
         | 
| 468 | 
            +
                    self.instantiate_cond_stage(cond_stage_config)
         | 
| 469 | 
            +
                    self.first_stage_config = first_stage_config
         | 
| 470 | 
            +
                    self.cond_stage_config = cond_stage_config
         | 
| 471 | 
            +
                    self.clip_denoised = False
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    self.cond_stage_forward = cond_stage_forward
         | 
| 474 | 
            +
                    self.encoder_type = encoder_type
         | 
| 475 | 
            +
                    assert encoder_type in ["2d", "3d"]
         | 
| 476 | 
            +
                    self.uncond_prob = uncond_prob
         | 
| 477 | 
            +
                    self.classifier_free_guidance = True if uncond_prob > 0 else False
         | 
| 478 | 
            +
                    assert uncond_type in ["zero_embed", "empty_seq"]
         | 
| 479 | 
            +
                    self.uncond_type = uncond_type
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                    self.restarted_from_ckpt = False
         | 
| 482 | 
            +
                    if ckpt_path is not None:
         | 
| 483 | 
            +
                        self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model)
         | 
| 484 | 
            +
                        self.restarted_from_ckpt = True
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                def make_cond_schedule(
         | 
| 487 | 
            +
                    self,
         | 
| 488 | 
            +
                ):
         | 
| 489 | 
            +
                    self.cond_ids = torch.full(
         | 
| 490 | 
            +
                        size=(self.num_timesteps,),
         | 
| 491 | 
            +
                        fill_value=self.num_timesteps - 1,
         | 
| 492 | 
            +
                        dtype=torch.long,
         | 
| 493 | 
            +
                    )
         | 
| 494 | 
            +
                    ids = torch.round(
         | 
| 495 | 
            +
                        torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
         | 
| 496 | 
            +
                    ).long()
         | 
| 497 | 
            +
                    self.cond_ids[: self.num_timesteps_cond] = ids
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                def q_sample(self, x_start, t, noise=None):
         | 
| 500 | 
            +
                    noise = default(noise, lambda: torch.randn_like(x_start))
         | 
| 501 | 
            +
                    if self.use_scale:
         | 
| 502 | 
            +
                        return (
         | 
| 503 | 
            +
                            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
         | 
| 504 | 
            +
                            * x_start
         | 
| 505 | 
            +
                            * extract_into_tensor(self.scale_arr, t, x_start.shape)
         | 
| 506 | 
            +
                            + extract_into_tensor(
         | 
| 507 | 
            +
                                self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
         | 
| 508 | 
            +
                            )
         | 
| 509 | 
            +
                            * noise
         | 
| 510 | 
            +
                        )
         | 
| 511 | 
            +
                    else:
         | 
| 512 | 
            +
                        return (
         | 
| 513 | 
            +
                            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
         | 
| 514 | 
            +
                            * x_start
         | 
| 515 | 
            +
                            + extract_into_tensor(
         | 
| 516 | 
            +
                                self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
         | 
| 517 | 
            +
                            )
         | 
| 518 | 
            +
                            * noise
         | 
| 519 | 
            +
                        )
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                def _freeze_model(self):
         | 
| 522 | 
            +
                    for name, para in self.model.diffusion_model.named_parameters():
         | 
| 523 | 
            +
                        para.requires_grad = False
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                def instantiate_first_stage(self, config):
         | 
| 526 | 
            +
                    model = instantiate_from_config(config)
         | 
| 527 | 
            +
                    self.first_stage_model = model.eval()
         | 
| 528 | 
            +
                    self.first_stage_model.train = disabled_train
         | 
| 529 | 
            +
                    for param in self.first_stage_model.parameters():
         | 
| 530 | 
            +
                        param.requires_grad = False
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                def instantiate_cond_stage(self, config):
         | 
| 533 | 
            +
                    if not self.cond_stage_trainable:
         | 
| 534 | 
            +
                        model = instantiate_from_config(config)
         | 
| 535 | 
            +
                        self.cond_stage_model = model.eval()
         | 
| 536 | 
            +
                        self.cond_stage_model.train = disabled_train
         | 
| 537 | 
            +
                        for param in self.cond_stage_model.parameters():
         | 
| 538 | 
            +
                            param.requires_grad = False
         | 
| 539 | 
            +
                    else:
         | 
| 540 | 
            +
                        model = instantiate_from_config(config)
         | 
| 541 | 
            +
                        self.cond_stage_model = model
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                def get_learned_conditioning(self, c):
         | 
| 544 | 
            +
                    if self.cond_stage_forward is None:
         | 
| 545 | 
            +
                        if hasattr(self.cond_stage_model, "encode") and callable(
         | 
| 546 | 
            +
                            self.cond_stage_model.encode
         | 
| 547 | 
            +
                        ):
         | 
| 548 | 
            +
                            c = self.cond_stage_model.encode(c)
         | 
| 549 | 
            +
                            if isinstance(c, DiagonalGaussianDistribution):
         | 
| 550 | 
            +
                                c = c.mode()
         | 
| 551 | 
            +
                        else:
         | 
| 552 | 
            +
                            c = self.cond_stage_model(c)
         | 
| 553 | 
            +
                    else:
         | 
| 554 | 
            +
                        assert hasattr(self.cond_stage_model, self.cond_stage_forward)
         | 
| 555 | 
            +
                        c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
         | 
| 556 | 
            +
                    return c
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                def get_first_stage_encoding(self, encoder_posterior, noise=None):
         | 
| 559 | 
            +
                    if isinstance(encoder_posterior, DiagonalGaussianDistribution):
         | 
| 560 | 
            +
                        z = encoder_posterior.sample(noise=noise)
         | 
| 561 | 
            +
                    elif isinstance(encoder_posterior, torch.Tensor):
         | 
| 562 | 
            +
                        z = encoder_posterior
         | 
| 563 | 
            +
                    else:
         | 
| 564 | 
            +
                        raise NotImplementedError(
         | 
| 565 | 
            +
                            f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
         | 
| 566 | 
            +
                        )
         | 
| 567 | 
            +
                    return self.scale_factor * z
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                @torch.no_grad()
         | 
| 570 | 
            +
                def encode_first_stage(self, x):
         | 
| 571 | 
            +
                    if self.encoder_type == "2d" and x.dim() == 5:
         | 
| 572 | 
            +
                        b, _, t, _, _ = x.shape
         | 
| 573 | 
            +
                        x = rearrange(x, "b c t h w -> (b t) c h w")
         | 
| 574 | 
            +
                        reshape_back = True
         | 
| 575 | 
            +
                    else:
         | 
| 576 | 
            +
                        reshape_back = False
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                    encoder_posterior = self.first_stage_model.encode(x)
         | 
| 579 | 
            +
                    results = self.get_first_stage_encoding(encoder_posterior).detach()
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                    if reshape_back:
         | 
| 582 | 
            +
                        results = rearrange(results, "(b t) c h w -> b c t h w", b=b, t=t)
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                    return results
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                @torch.no_grad()
         | 
| 587 | 
            +
                def encode_first_stage_2DAE(self, x):
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    b, _, t, _, _ = x.shape
         | 
| 590 | 
            +
                    results = torch.cat(
         | 
| 591 | 
            +
                        [
         | 
| 592 | 
            +
                            self.get_first_stage_encoding(self.first_stage_model.encode(x[:, :, i]))
         | 
| 593 | 
            +
                            .detach()
         | 
| 594 | 
            +
                            .unsqueeze(2)
         | 
| 595 | 
            +
                            for i in range(t)
         | 
| 596 | 
            +
                        ],
         | 
| 597 | 
            +
                        dim=2,
         | 
| 598 | 
            +
                    )
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                    return results
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                def decode_core(self, z, **kwargs):
         | 
| 603 | 
            +
                    if self.encoder_type == "2d" and z.dim() == 5:
         | 
| 604 | 
            +
                        b, _, t, _, _ = z.shape
         | 
| 605 | 
            +
                        z = rearrange(z, "b c t h w -> (b t) c h w")
         | 
| 606 | 
            +
                        reshape_back = True
         | 
| 607 | 
            +
                    else:
         | 
| 608 | 
            +
                        reshape_back = False
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    z = 1.0 / self.scale_factor * z
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                    results = self.first_stage_model.decode(z, **kwargs)
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                    if reshape_back:
         | 
| 615 | 
            +
                        results = rearrange(results, "(b t) c h w -> b c t h w", b=b, t=t)
         | 
| 616 | 
            +
                    return results
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                @torch.no_grad()
         | 
| 619 | 
            +
                def decode_first_stage(self, z, **kwargs):
         | 
| 620 | 
            +
                    return self.decode_core(z, **kwargs)
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                def apply_model(self, x_noisy, t, cond, **kwargs):
         | 
| 623 | 
            +
                    if isinstance(cond, dict):
         | 
| 624 | 
            +
                        # hybrid case, cond is exptected to be a dict
         | 
| 625 | 
            +
                        pass
         | 
| 626 | 
            +
                    else:
         | 
| 627 | 
            +
                        if not isinstance(cond, list):
         | 
| 628 | 
            +
                            cond = [cond]
         | 
| 629 | 
            +
                        key = (
         | 
| 630 | 
            +
                            "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
         | 
| 631 | 
            +
                        )
         | 
| 632 | 
            +
                        cond = {key: cond}
         | 
| 633 | 
            +
             | 
| 634 | 
            +
                    x_recon = self.model(x_noisy, t, **cond, **kwargs)
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                    if isinstance(x_recon, tuple):
         | 
| 637 | 
            +
                        return x_recon[0]
         | 
| 638 | 
            +
                    else:
         | 
| 639 | 
            +
                        return x_recon
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                def _get_denoise_row_from_list(self, samples, desc=""):
         | 
| 642 | 
            +
                    denoise_row = []
         | 
| 643 | 
            +
                    for zd in tqdm(samples, desc=desc):
         | 
| 644 | 
            +
                        denoise_row.append(self.decode_first_stage(zd.to(self.device)))
         | 
| 645 | 
            +
                    n_log_timesteps = len(denoise_row)
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                    denoise_row = torch.stack(denoise_row)  # n_log_timesteps, b, C, H, W
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                    if denoise_row.dim() == 5:
         | 
| 650 | 
            +
                        # img, num_imgs= n_log_timesteps * bs, grid_size=[bs,n_log_timesteps]
         | 
| 651 | 
            +
                        denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
         | 
| 652 | 
            +
                        denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
         | 
| 653 | 
            +
                        denoise_grid = make_grid(denoise_grid, nrow=n_log_timesteps)
         | 
| 654 | 
            +
                    elif denoise_row.dim() == 6:
         | 
| 655 | 
            +
                        # video, grid_size=[n_log_timesteps*bs, t]
         | 
| 656 | 
            +
                        video_length = denoise_row.shape[3]
         | 
| 657 | 
            +
                        denoise_grid = rearrange(denoise_row, "n b c t h w -> b n c t h w")
         | 
| 658 | 
            +
                        denoise_grid = rearrange(denoise_grid, "b n c t h w -> (b n) c t h w")
         | 
| 659 | 
            +
                        denoise_grid = rearrange(denoise_grid, "n c t h w -> (n t) c h w")
         | 
| 660 | 
            +
                        denoise_grid = make_grid(denoise_grid, nrow=video_length)
         | 
| 661 | 
            +
                    else:
         | 
| 662 | 
            +
                        raise ValueError
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                    return denoise_grid
         | 
| 665 | 
            +
             | 
| 666 | 
            +
                @torch.no_grad()
         | 
| 667 | 
            +
                def decode_first_stage_2DAE(self, z, **kwargs):
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    b, _, t, _, _ = z.shape
         | 
| 670 | 
            +
                    z = 1.0 / self.scale_factor * z
         | 
| 671 | 
            +
                    results = torch.cat(
         | 
| 672 | 
            +
                        [
         | 
| 673 | 
            +
                            self.first_stage_model.decode(z[:, :, i], **kwargs).unsqueeze(2)
         | 
| 674 | 
            +
                            for i in range(t)
         | 
| 675 | 
            +
                        ],
         | 
| 676 | 
            +
                        dim=2,
         | 
| 677 | 
            +
                    )
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    return results
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                def p_mean_variance(
         | 
| 682 | 
            +
                    self,
         | 
| 683 | 
            +
                    x,
         | 
| 684 | 
            +
                    c,
         | 
| 685 | 
            +
                    t,
         | 
| 686 | 
            +
                    clip_denoised: bool,
         | 
| 687 | 
            +
                    return_x0=False,
         | 
| 688 | 
            +
                    score_corrector=None,
         | 
| 689 | 
            +
                    corrector_kwargs=None,
         | 
| 690 | 
            +
                    **kwargs,
         | 
| 691 | 
            +
                ):
         | 
| 692 | 
            +
                    t_in = t
         | 
| 693 | 
            +
                    model_out = self.apply_model(x, t_in, c, **kwargs)
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                    if score_corrector is not None:
         | 
| 696 | 
            +
                        assert self.parameterization == "eps"
         | 
| 697 | 
            +
                        model_out = score_corrector.modify_score(
         | 
| 698 | 
            +
                            self, model_out, x, t, c, **corrector_kwargs
         | 
| 699 | 
            +
                        )
         | 
| 700 | 
            +
             | 
| 701 | 
            +
                    if self.parameterization == "eps":
         | 
| 702 | 
            +
                        x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
         | 
| 703 | 
            +
                    elif self.parameterization == "x0":
         | 
| 704 | 
            +
                        x_recon = model_out
         | 
| 705 | 
            +
                    else:
         | 
| 706 | 
            +
                        raise NotImplementedError()
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                    if clip_denoised:
         | 
| 709 | 
            +
                        x_recon.clamp_(-1.0, 1.0)
         | 
| 710 | 
            +
             | 
| 711 | 
            +
                    model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
         | 
| 712 | 
            +
                        x_start=x_recon, x_t=x, t=t
         | 
| 713 | 
            +
                    )
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                    if return_x0:
         | 
| 716 | 
            +
                        return model_mean, posterior_variance, posterior_log_variance, x_recon
         | 
| 717 | 
            +
                    else:
         | 
| 718 | 
            +
                        return model_mean, posterior_variance, posterior_log_variance
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                @torch.no_grad()
         | 
| 721 | 
            +
                def p_sample(
         | 
| 722 | 
            +
                    self,
         | 
| 723 | 
            +
                    x,
         | 
| 724 | 
            +
                    c,
         | 
| 725 | 
            +
                    t,
         | 
| 726 | 
            +
                    clip_denoised=False,
         | 
| 727 | 
            +
                    repeat_noise=False,
         | 
| 728 | 
            +
                    return_x0=False,
         | 
| 729 | 
            +
                    temperature=1.0,
         | 
| 730 | 
            +
                    noise_dropout=0.0,
         | 
| 731 | 
            +
                    score_corrector=None,
         | 
| 732 | 
            +
                    corrector_kwargs=None,
         | 
| 733 | 
            +
                    **kwargs,
         | 
| 734 | 
            +
                ):
         | 
| 735 | 
            +
                    b, *_, device = *x.shape, x.device
         | 
| 736 | 
            +
                    outputs = self.p_mean_variance(
         | 
| 737 | 
            +
                        x=x,
         | 
| 738 | 
            +
                        c=c,
         | 
| 739 | 
            +
                        t=t,
         | 
| 740 | 
            +
                        clip_denoised=clip_denoised,
         | 
| 741 | 
            +
                        return_x0=return_x0,
         | 
| 742 | 
            +
                        score_corrector=score_corrector,
         | 
| 743 | 
            +
                        corrector_kwargs=corrector_kwargs,
         | 
| 744 | 
            +
                        **kwargs,
         | 
| 745 | 
            +
                    )
         | 
| 746 | 
            +
                    if return_x0:
         | 
| 747 | 
            +
                        model_mean, _, model_log_variance, x0 = outputs
         | 
| 748 | 
            +
                    else:
         | 
| 749 | 
            +
                        model_mean, _, model_log_variance = outputs
         | 
| 750 | 
            +
             | 
| 751 | 
            +
                    noise = noise_like(x.shape, device, repeat_noise) * temperature
         | 
| 752 | 
            +
                    if noise_dropout > 0.0:
         | 
| 753 | 
            +
                        noise = torch.nn.functional.dropout(noise, p=noise_dropout)
         | 
| 754 | 
            +
                    # no noise when t == 0
         | 
| 755 | 
            +
                    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                    if return_x0:
         | 
| 758 | 
            +
                        return (
         | 
| 759 | 
            +
                            model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
         | 
| 760 | 
            +
                            x0,
         | 
| 761 | 
            +
                        )
         | 
| 762 | 
            +
                    else:
         | 
| 763 | 
            +
                        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                @torch.no_grad()
         | 
| 766 | 
            +
                def p_sample_loop(
         | 
| 767 | 
            +
                    self,
         | 
| 768 | 
            +
                    cond,
         | 
| 769 | 
            +
                    shape,
         | 
| 770 | 
            +
                    return_intermediates=False,
         | 
| 771 | 
            +
                    x_T=None,
         | 
| 772 | 
            +
                    verbose=True,
         | 
| 773 | 
            +
                    callback=None,
         | 
| 774 | 
            +
                    timesteps=None,
         | 
| 775 | 
            +
                    mask=None,
         | 
| 776 | 
            +
                    x0=None,
         | 
| 777 | 
            +
                    img_callback=None,
         | 
| 778 | 
            +
                    start_T=None,
         | 
| 779 | 
            +
                    log_every_t=None,
         | 
| 780 | 
            +
                    **kwargs,
         | 
| 781 | 
            +
                ):
         | 
| 782 | 
            +
             | 
| 783 | 
            +
                    if not log_every_t:
         | 
| 784 | 
            +
                        log_every_t = self.log_every_t
         | 
| 785 | 
            +
                    device = self.betas.device
         | 
| 786 | 
            +
                    b = shape[0]
         | 
| 787 | 
            +
                    # sample an initial noise
         | 
| 788 | 
            +
                    if x_T is None:
         | 
| 789 | 
            +
                        img = torch.randn(shape, device=device)
         | 
| 790 | 
            +
                    else:
         | 
| 791 | 
            +
                        img = x_T
         | 
| 792 | 
            +
             | 
| 793 | 
            +
                    intermediates = [img]
         | 
| 794 | 
            +
                    if timesteps is None:
         | 
| 795 | 
            +
                        timesteps = self.num_timesteps
         | 
| 796 | 
            +
                    if start_T is not None:
         | 
| 797 | 
            +
                        timesteps = min(timesteps, start_T)
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                    iterator = (
         | 
| 800 | 
            +
                        tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
         | 
| 801 | 
            +
                        if verbose
         | 
| 802 | 
            +
                        else reversed(range(0, timesteps))
         | 
| 803 | 
            +
                    )
         | 
| 804 | 
            +
             | 
| 805 | 
            +
                    if mask is not None:
         | 
| 806 | 
            +
                        assert x0 is not None
         | 
| 807 | 
            +
                        assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match
         | 
| 808 | 
            +
             | 
| 809 | 
            +
                    for i in iterator:
         | 
| 810 | 
            +
                        ts = torch.full((b,), i, device=device, dtype=torch.long)
         | 
| 811 | 
            +
                        if self.shorten_cond_schedule:
         | 
| 812 | 
            +
                            assert self.model.conditioning_key != "hybrid"
         | 
| 813 | 
            +
                            tc = self.cond_ids[ts].to(cond.device)
         | 
| 814 | 
            +
                            cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
         | 
| 815 | 
            +
             | 
| 816 | 
            +
                        img = self.p_sample(
         | 
| 817 | 
            +
                            img, cond, ts, clip_denoised=self.clip_denoised, **kwargs
         | 
| 818 | 
            +
                        )
         | 
| 819 | 
            +
                        if mask is not None:
         | 
| 820 | 
            +
                            img_orig = self.q_sample(x0, ts)
         | 
| 821 | 
            +
                            img = img_orig * mask + (1.0 - mask) * img
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                        if i % log_every_t == 0 or i == timesteps - 1:
         | 
| 824 | 
            +
                            intermediates.append(img)
         | 
| 825 | 
            +
                        if callback:
         | 
| 826 | 
            +
                            callback(i)
         | 
| 827 | 
            +
                        if img_callback:
         | 
| 828 | 
            +
                            img_callback(img, i)
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                    if return_intermediates:
         | 
| 831 | 
            +
                        return img, intermediates
         | 
| 832 | 
            +
                    return img
         | 
| 833 | 
            +
             | 
| 834 | 
            +
             | 
| 835 | 
            +
            class LatentVisualDiffusion(LatentDiffusion):
         | 
| 836 | 
            +
                def __init__(
         | 
| 837 | 
            +
                    self, cond_img_config, finegrained=False, random_cond=False, *args, **kwargs
         | 
| 838 | 
            +
                ):
         | 
| 839 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 840 | 
            +
                    self.random_cond = random_cond
         | 
| 841 | 
            +
                    self.instantiate_img_embedder(cond_img_config, freeze=True)
         | 
| 842 | 
            +
                    num_tokens = 16 if finegrained else 4
         | 
| 843 | 
            +
                    self.image_proj_model = self.init_projector(
         | 
| 844 | 
            +
                        use_finegrained=finegrained,
         | 
| 845 | 
            +
                        num_tokens=num_tokens,
         | 
| 846 | 
            +
                        input_dim=1024,
         | 
| 847 | 
            +
                        cross_attention_dim=1024,
         | 
| 848 | 
            +
                        dim=1280,
         | 
| 849 | 
            +
                    )
         | 
| 850 | 
            +
             | 
| 851 | 
            +
                def instantiate_img_embedder(self, config, freeze=True):
         | 
| 852 | 
            +
                    embedder = instantiate_from_config(config)
         | 
| 853 | 
            +
                    if freeze:
         | 
| 854 | 
            +
                        self.embedder = embedder.eval()
         | 
| 855 | 
            +
                        self.embedder.train = disabled_train
         | 
| 856 | 
            +
                        for param in self.embedder.parameters():
         | 
| 857 | 
            +
                            param.requires_grad = False
         | 
| 858 | 
            +
             | 
| 859 | 
            +
                def init_projector(
         | 
| 860 | 
            +
                    self, use_finegrained, num_tokens, input_dim, cross_attention_dim, dim
         | 
| 861 | 
            +
                ):
         | 
| 862 | 
            +
                    if not use_finegrained:
         | 
| 863 | 
            +
                        image_proj_model = ImageProjModel(
         | 
| 864 | 
            +
                            clip_extra_context_tokens=num_tokens,
         | 
| 865 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 866 | 
            +
                            clip_embeddings_dim=input_dim,
         | 
| 867 | 
            +
                        )
         | 
| 868 | 
            +
                    else:
         | 
| 869 | 
            +
                        image_proj_model = Resampler(
         | 
| 870 | 
            +
                            dim=input_dim,
         | 
| 871 | 
            +
                            depth=4,
         | 
| 872 | 
            +
                            dim_head=64,
         | 
| 873 | 
            +
                            heads=12,
         | 
| 874 | 
            +
                            num_queries=num_tokens,
         | 
| 875 | 
            +
                            embedding_dim=dim,
         | 
| 876 | 
            +
                            output_dim=cross_attention_dim,
         | 
| 877 | 
            +
                            ff_mult=4,
         | 
| 878 | 
            +
                        )
         | 
| 879 | 
            +
                    return image_proj_model
         | 
| 880 | 
            +
             | 
| 881 | 
            +
                ## Never delete this func: it is used in log_images() and inference stage
         | 
| 882 | 
            +
                def get_image_embeds(self, batch_imgs):
         | 
| 883 | 
            +
                    ## img: b c h w
         | 
| 884 | 
            +
                    img_token = self.embedder(batch_imgs)
         | 
| 885 | 
            +
                    img_emb = self.image_proj_model(img_token)
         | 
| 886 | 
            +
                    return img_emb
         | 
| 887 | 
            +
             | 
| 888 | 
            +
             | 
| 889 | 
            +
            class DiffusionWrapper(pl.LightningModule):
         | 
| 890 | 
            +
                def __init__(self, diff_model_config, conditioning_key):
         | 
| 891 | 
            +
                    super().__init__()
         | 
| 892 | 
            +
                    self.diffusion_model = instantiate_from_config(diff_model_config)
         | 
| 893 | 
            +
                    self.conditioning_key = conditioning_key
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                def forward(
         | 
| 896 | 
            +
                    self,
         | 
| 897 | 
            +
                    x,
         | 
| 898 | 
            +
                    t,
         | 
| 899 | 
            +
                    c_concat: list = None,
         | 
| 900 | 
            +
                    c_crossattn: list = None,
         | 
| 901 | 
            +
                    c_adm=None,
         | 
| 902 | 
            +
                    s=None,
         | 
| 903 | 
            +
                    mask=None,
         | 
| 904 | 
            +
                    **kwargs,
         | 
| 905 | 
            +
                ):
         | 
| 906 | 
            +
                    # temporal_context = fps is foNone
         | 
| 907 | 
            +
                    if self.conditioning_key is None:
         | 
| 908 | 
            +
                        out = self.diffusion_model(x, t)
         | 
| 909 | 
            +
                    elif self.conditioning_key == "concat":
         | 
| 910 | 
            +
                        xc = torch.cat([x] + c_concat, dim=1)
         | 
| 911 | 
            +
                        out = self.diffusion_model(xc, t, **kwargs)
         | 
| 912 | 
            +
                    elif self.conditioning_key == "crossattn":
         | 
| 913 | 
            +
                        cc = torch.cat(c_crossattn, 1)
         | 
| 914 | 
            +
                        out = self.diffusion_model(x, t, context=cc, **kwargs)
         | 
| 915 | 
            +
                    elif self.conditioning_key == "hybrid":
         | 
| 916 | 
            +
                        ## it is just right [b,c,t,h,w]: concatenate in channel dim
         | 
| 917 | 
            +
                        xc = torch.cat([x] + c_concat, dim=1)
         | 
| 918 | 
            +
                        cc = torch.cat(c_crossattn, 1)
         | 
| 919 | 
            +
                        out = self.diffusion_model(xc, t, context=cc)
         | 
| 920 | 
            +
                    elif self.conditioning_key == "resblockcond":
         | 
| 921 | 
            +
                        cc = c_crossattn[0]
         | 
| 922 | 
            +
                        out = self.diffusion_model(x, t, context=cc)
         | 
| 923 | 
            +
                    elif self.conditioning_key == "adm":
         | 
| 924 | 
            +
                        cc = c_crossattn[0]
         | 
| 925 | 
            +
                        out = self.diffusion_model(x, t, y=cc)
         | 
| 926 | 
            +
                    elif self.conditioning_key == "hybrid-adm":
         | 
| 927 | 
            +
                        assert c_adm is not None
         | 
| 928 | 
            +
                        xc = torch.cat([x] + c_concat, dim=1)
         | 
| 929 | 
            +
                        cc = torch.cat(c_crossattn, 1)
         | 
| 930 | 
            +
                        out = self.diffusion_model(xc, t, context=cc, y=c_adm)
         | 
| 931 | 
            +
                    elif self.conditioning_key == "hybrid-time":
         | 
| 932 | 
            +
                        assert s is not None
         | 
| 933 | 
            +
                        xc = torch.cat([x] + c_concat, dim=1)
         | 
| 934 | 
            +
                        cc = torch.cat(c_crossattn, 1)
         | 
| 935 | 
            +
                        out = self.diffusion_model(xc, t, context=cc, s=s)
         | 
| 936 | 
            +
                    elif self.conditioning_key == "concat-time-mask":
         | 
| 937 | 
            +
                        # assert s is not None
         | 
| 938 | 
            +
                        # mainlogger.info('x & mask:',x.shape,c_concat[0].shape)
         | 
| 939 | 
            +
                        xc = torch.cat([x] + c_concat, dim=1)
         | 
| 940 | 
            +
                        out = self.diffusion_model(xc, t, context=None, s=s, mask=mask)
         | 
| 941 | 
            +
                    elif self.conditioning_key == "concat-adm-mask":
         | 
| 942 | 
            +
                        # assert s is not None
         | 
| 943 | 
            +
                        # mainlogger.info('x & mask:',x.shape,c_concat[0].shape)
         | 
| 944 | 
            +
                        if c_concat is not None:
         | 
| 945 | 
            +
                            xc = torch.cat([x] + c_concat, dim=1)
         | 
| 946 | 
            +
                        else:
         | 
| 947 | 
            +
                            xc = x
         | 
| 948 | 
            +
                        out = self.diffusion_model(xc, t, context=None, y=s, mask=mask)
         | 
| 949 | 
            +
                    elif self.conditioning_key == "hybrid-adm-mask":
         | 
| 950 | 
            +
                        cc = torch.cat(c_crossattn, 1)
         | 
| 951 | 
            +
                        if c_concat is not None:
         | 
| 952 | 
            +
                            xc = torch.cat([x] + c_concat, dim=1)
         | 
| 953 | 
            +
                        else:
         | 
| 954 | 
            +
                            xc = x
         | 
| 955 | 
            +
                        out = self.diffusion_model(xc, t, context=cc, y=s, mask=mask)
         | 
| 956 | 
            +
                    elif (
         | 
| 957 | 
            +
                        self.conditioning_key == "hybrid-time-adm"
         | 
| 958 | 
            +
                    ):  # adm means y, e.g., class index
         | 
| 959 | 
            +
                        # assert s is not None
         | 
| 960 | 
            +
                        assert c_adm is not None
         | 
| 961 | 
            +
                        xc = torch.cat([x] + c_concat, dim=1)
         | 
| 962 | 
            +
                        cc = torch.cat(c_crossattn, 1)
         | 
| 963 | 
            +
                        out = self.diffusion_model(xc, t, context=cc, s=s, y=c_adm)
         | 
| 964 | 
            +
                    else:
         | 
| 965 | 
            +
                        raise NotImplementedError()
         | 
| 966 | 
            +
             | 
| 967 | 
            +
                    return out
         | 
    	
        lvdm/models/samplers/ddim.py
    ADDED
    
    | @@ -0,0 +1,493 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            from tqdm import tqdm
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from lvdm.models.utils_diffusion import (
         | 
| 5 | 
            +
                make_ddim_sampling_parameters,
         | 
| 6 | 
            +
                make_ddim_timesteps,
         | 
| 7 | 
            +
            )
         | 
| 8 | 
            +
            from lvdm.common import noise_like
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class DDIMSampler(object):
         | 
| 12 | 
            +
                def __init__(self, model, schedule="linear", **kwargs):
         | 
| 13 | 
            +
                    super().__init__()
         | 
| 14 | 
            +
                    self.model = model
         | 
| 15 | 
            +
                    self.ddpm_num_timesteps = model.num_timesteps
         | 
| 16 | 
            +
                    self.schedule = schedule
         | 
| 17 | 
            +
                    self.counter = 0
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def register_buffer(self, name, attr):
         | 
| 20 | 
            +
                    if type(attr) == torch.Tensor:
         | 
| 21 | 
            +
                        if attr.device != torch.device("cuda"):
         | 
| 22 | 
            +
                            attr = attr.to(torch.device("cuda"))
         | 
| 23 | 
            +
                    setattr(self, name, attr)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def make_schedule(
         | 
| 26 | 
            +
                    self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
         | 
| 27 | 
            +
                ):
         | 
| 28 | 
            +
                    self.ddim_timesteps = make_ddim_timesteps(
         | 
| 29 | 
            +
                        ddim_discr_method=ddim_discretize,
         | 
| 30 | 
            +
                        num_ddim_timesteps=ddim_num_steps,
         | 
| 31 | 
            +
                        num_ddpm_timesteps=self.ddpm_num_timesteps,
         | 
| 32 | 
            +
                        verbose=verbose,
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
                    alphas_cumprod = self.model.alphas_cumprod
         | 
| 35 | 
            +
                    assert (
         | 
| 36 | 
            +
                        alphas_cumprod.shape[0] == self.ddpm_num_timesteps
         | 
| 37 | 
            +
                    ), "alphas have to be defined for each timestep"
         | 
| 38 | 
            +
                    to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    self.register_buffer("betas", to_torch(self.model.betas))
         | 
| 41 | 
            +
                    self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
         | 
| 42 | 
            +
                    self.register_buffer(
         | 
| 43 | 
            +
                        "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
                    self.use_scale = self.model.use_scale
         | 
| 46 | 
            +
                    print("DDIM scale", self.use_scale)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    if self.use_scale:
         | 
| 49 | 
            +
                        self.register_buffer("scale_arr", to_torch(self.model.scale_arr))
         | 
| 50 | 
            +
                        ddim_scale_arr = self.scale_arr.cpu()[self.ddim_timesteps]
         | 
| 51 | 
            +
                        self.register_buffer("ddim_scale_arr", ddim_scale_arr)
         | 
| 52 | 
            +
                        ddim_scale_arr = np.asarray(
         | 
| 53 | 
            +
                            [self.scale_arr.cpu()[0]]
         | 
| 54 | 
            +
                            + self.scale_arr.cpu()[self.ddim_timesteps[:-1]].tolist()
         | 
| 55 | 
            +
                        )
         | 
| 56 | 
            +
                        self.register_buffer("ddim_scale_arr_prev", ddim_scale_arr)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # calculations for diffusion q(x_t | x_{t-1}) and others
         | 
| 59 | 
            +
                    self.register_buffer(
         | 
| 60 | 
            +
                        "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
                    self.register_buffer(
         | 
| 63 | 
            +
                        "sqrt_one_minus_alphas_cumprod",
         | 
| 64 | 
            +
                        to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    self.register_buffer(
         | 
| 67 | 
            +
                        "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
                    self.register_buffer(
         | 
| 70 | 
            +
                        "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
         | 
| 71 | 
            +
                    )
         | 
| 72 | 
            +
                    self.register_buffer(
         | 
| 73 | 
            +
                        "sqrt_recipm1_alphas_cumprod",
         | 
| 74 | 
            +
                        to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    # ddim sampling parameters
         | 
| 78 | 
            +
                    ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
         | 
| 79 | 
            +
                        alphacums=alphas_cumprod.cpu(),
         | 
| 80 | 
            +
                        ddim_timesteps=self.ddim_timesteps,
         | 
| 81 | 
            +
                        eta=ddim_eta,
         | 
| 82 | 
            +
                        verbose=verbose,
         | 
| 83 | 
            +
                    )
         | 
| 84 | 
            +
                    self.register_buffer("ddim_sigmas", ddim_sigmas)
         | 
| 85 | 
            +
                    self.register_buffer("ddim_alphas", ddim_alphas)
         | 
| 86 | 
            +
                    self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
         | 
| 87 | 
            +
                    self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
         | 
| 88 | 
            +
                    sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
         | 
| 89 | 
            +
                        (1 - self.alphas_cumprod_prev)
         | 
| 90 | 
            +
                        / (1 - self.alphas_cumprod)
         | 
| 91 | 
            +
                        * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    self.register_buffer(
         | 
| 94 | 
            +
                        "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                @torch.no_grad()
         | 
| 98 | 
            +
                def sample(
         | 
| 99 | 
            +
                    self,
         | 
| 100 | 
            +
                    S,
         | 
| 101 | 
            +
                    batch_size,
         | 
| 102 | 
            +
                    shape,
         | 
| 103 | 
            +
                    conditioning=None,
         | 
| 104 | 
            +
                    callback=None,
         | 
| 105 | 
            +
                    normals_sequence=None,
         | 
| 106 | 
            +
                    img_callback=None,
         | 
| 107 | 
            +
                    quantize_x0=False,
         | 
| 108 | 
            +
                    eta=0.0,
         | 
| 109 | 
            +
                    mask=None,
         | 
| 110 | 
            +
                    x0=None,
         | 
| 111 | 
            +
                    temperature=1.0,
         | 
| 112 | 
            +
                    noise_dropout=0.0,
         | 
| 113 | 
            +
                    score_corrector=None,
         | 
| 114 | 
            +
                    corrector_kwargs=None,
         | 
| 115 | 
            +
                    verbose=True,
         | 
| 116 | 
            +
                    schedule_verbose=False,
         | 
| 117 | 
            +
                    x_T=None,
         | 
| 118 | 
            +
                    log_every_t=100,
         | 
| 119 | 
            +
                    unconditional_guidance_scale=1.0,
         | 
| 120 | 
            +
                    unconditional_conditioning=None,
         | 
| 121 | 
            +
                    # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
         | 
| 122 | 
            +
                    **kwargs,
         | 
| 123 | 
            +
                ):
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # check condition bs
         | 
| 126 | 
            +
                    if conditioning is not None:
         | 
| 127 | 
            +
                        if isinstance(conditioning, dict):
         | 
| 128 | 
            +
                            try:
         | 
| 129 | 
            +
                                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
         | 
| 130 | 
            +
                            except:
         | 
| 131 | 
            +
                                cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                            if cbs != batch_size:
         | 
| 134 | 
            +
                                print(
         | 
| 135 | 
            +
                                    f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
         | 
| 136 | 
            +
                                )
         | 
| 137 | 
            +
                        else:
         | 
| 138 | 
            +
                            if conditioning.shape[0] != batch_size:
         | 
| 139 | 
            +
                                print(
         | 
| 140 | 
            +
                                    f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
         | 
| 141 | 
            +
                                )
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=schedule_verbose)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # make shape
         | 
| 146 | 
            +
                    if len(shape) == 3:
         | 
| 147 | 
            +
                        C, H, W = shape
         | 
| 148 | 
            +
                        size = (batch_size, C, H, W)
         | 
| 149 | 
            +
                    elif len(shape) == 4:
         | 
| 150 | 
            +
                        C, T, H, W = shape
         | 
| 151 | 
            +
                        size = (batch_size, C, T, H, W)
         | 
| 152 | 
            +
                    # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    samples, intermediates = self.ddim_sampling(
         | 
| 155 | 
            +
                        conditioning,
         | 
| 156 | 
            +
                        size,
         | 
| 157 | 
            +
                        callback=callback,
         | 
| 158 | 
            +
                        img_callback=img_callback,
         | 
| 159 | 
            +
                        quantize_denoised=quantize_x0,
         | 
| 160 | 
            +
                        mask=mask,
         | 
| 161 | 
            +
                        x0=x0,
         | 
| 162 | 
            +
                        ddim_use_original_steps=False,
         | 
| 163 | 
            +
                        noise_dropout=noise_dropout,
         | 
| 164 | 
            +
                        temperature=temperature,
         | 
| 165 | 
            +
                        score_corrector=score_corrector,
         | 
| 166 | 
            +
                        corrector_kwargs=corrector_kwargs,
         | 
| 167 | 
            +
                        x_T=x_T,
         | 
| 168 | 
            +
                        log_every_t=log_every_t,
         | 
| 169 | 
            +
                        unconditional_guidance_scale=unconditional_guidance_scale,
         | 
| 170 | 
            +
                        unconditional_conditioning=unconditional_conditioning,
         | 
| 171 | 
            +
                        verbose=verbose,
         | 
| 172 | 
            +
                        **kwargs,
         | 
| 173 | 
            +
                    )
         | 
| 174 | 
            +
                    return samples, intermediates
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                @torch.no_grad()
         | 
| 177 | 
            +
                def ddim_sampling(
         | 
| 178 | 
            +
                    self,
         | 
| 179 | 
            +
                    cond,
         | 
| 180 | 
            +
                    shape,
         | 
| 181 | 
            +
                    x_T=None,
         | 
| 182 | 
            +
                    ddim_use_original_steps=False,
         | 
| 183 | 
            +
                    callback=None,
         | 
| 184 | 
            +
                    timesteps=None,
         | 
| 185 | 
            +
                    quantize_denoised=False,
         | 
| 186 | 
            +
                    mask=None,
         | 
| 187 | 
            +
                    x0=None,
         | 
| 188 | 
            +
                    img_callback=None,
         | 
| 189 | 
            +
                    log_every_t=100,
         | 
| 190 | 
            +
                    temperature=1.0,
         | 
| 191 | 
            +
                    noise_dropout=0.0,
         | 
| 192 | 
            +
                    score_corrector=None,
         | 
| 193 | 
            +
                    corrector_kwargs=None,
         | 
| 194 | 
            +
                    unconditional_guidance_scale=1.0,
         | 
| 195 | 
            +
                    unconditional_conditioning=None,
         | 
| 196 | 
            +
                    verbose=True,
         | 
| 197 | 
            +
                    cond_tau=1.0,
         | 
| 198 | 
            +
                    target_size=None,
         | 
| 199 | 
            +
                    start_timesteps=None,
         | 
| 200 | 
            +
                    **kwargs,
         | 
| 201 | 
            +
                ):
         | 
| 202 | 
            +
                    device = self.model.betas.device
         | 
| 203 | 
            +
                    print("ddim device", device)
         | 
| 204 | 
            +
                    b = shape[0]
         | 
| 205 | 
            +
                    if x_T is None:
         | 
| 206 | 
            +
                        img = torch.randn(shape, device=device)
         | 
| 207 | 
            +
                    else:
         | 
| 208 | 
            +
                        img = x_T
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    if timesteps is None:
         | 
| 211 | 
            +
                        timesteps = (
         | 
| 212 | 
            +
                            self.ddpm_num_timesteps
         | 
| 213 | 
            +
                            if ddim_use_original_steps
         | 
| 214 | 
            +
                            else self.ddim_timesteps
         | 
| 215 | 
            +
                        )
         | 
| 216 | 
            +
                    elif timesteps is not None and not ddim_use_original_steps:
         | 
| 217 | 
            +
                        subset_end = (
         | 
| 218 | 
            +
                            int(
         | 
| 219 | 
            +
                                min(timesteps / self.ddim_timesteps.shape[0], 1)
         | 
| 220 | 
            +
                                * self.ddim_timesteps.shape[0]
         | 
| 221 | 
            +
                            )
         | 
| 222 | 
            +
                            - 1
         | 
| 223 | 
            +
                        )
         | 
| 224 | 
            +
                        timesteps = self.ddim_timesteps[:subset_end]
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    intermediates = {"x_inter": [img], "pred_x0": [img]}
         | 
| 227 | 
            +
                    time_range = (
         | 
| 228 | 
            +
                        reversed(range(0, timesteps))
         | 
| 229 | 
            +
                        if ddim_use_original_steps
         | 
| 230 | 
            +
                        else np.flip(timesteps)
         | 
| 231 | 
            +
                    )
         | 
| 232 | 
            +
                    total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
         | 
| 233 | 
            +
                    if verbose:
         | 
| 234 | 
            +
                        iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
         | 
| 235 | 
            +
                    else:
         | 
| 236 | 
            +
                        iterator = time_range
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    init_x0 = False
         | 
| 239 | 
            +
                    clean_cond = kwargs.pop("clean_cond", False)
         | 
| 240 | 
            +
                    for i, step in enumerate(iterator):
         | 
| 241 | 
            +
                        index = total_steps - i - 1
         | 
| 242 | 
            +
                        ts = torch.full((b,), step, device=device, dtype=torch.long)
         | 
| 243 | 
            +
                        if start_timesteps is not None:
         | 
| 244 | 
            +
                            assert x0 is not None
         | 
| 245 | 
            +
                            if step > start_timesteps * time_range[0]:
         | 
| 246 | 
            +
                                continue
         | 
| 247 | 
            +
                            elif not init_x0:
         | 
| 248 | 
            +
                                img = self.model.q_sample(x0, ts)
         | 
| 249 | 
            +
                                init_x0 = True
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                        # use mask to blend noised original latent (img_orig) & new sampled latent (img)
         | 
| 252 | 
            +
                        if mask is not None:
         | 
| 253 | 
            +
                            assert x0 is not None
         | 
| 254 | 
            +
                            if clean_cond:
         | 
| 255 | 
            +
                                img_orig = x0
         | 
| 256 | 
            +
                            else:
         | 
| 257 | 
            +
                                img_orig = self.model.q_sample(
         | 
| 258 | 
            +
                                    x0, ts
         | 
| 259 | 
            +
                                )  # TODO: deterministic forward pass? <ddim inversion>
         | 
| 260 | 
            +
                            img = (
         | 
| 261 | 
            +
                                img_orig * mask + (1.0 - mask) * img
         | 
| 262 | 
            +
                            )  # keep original & modify use img
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                        index_clip = int((1 - cond_tau) * total_steps)
         | 
| 265 | 
            +
                        if index <= index_clip and target_size is not None:
         | 
| 266 | 
            +
                            target_size_ = [
         | 
| 267 | 
            +
                                target_size[0],
         | 
| 268 | 
            +
                                target_size[1] // 8,
         | 
| 269 | 
            +
                                target_size[2] // 8,
         | 
| 270 | 
            +
                            ]
         | 
| 271 | 
            +
                            img = torch.nn.functional.interpolate(
         | 
| 272 | 
            +
                                img,
         | 
| 273 | 
            +
                                size=target_size_,
         | 
| 274 | 
            +
                                mode="nearest",
         | 
| 275 | 
            +
                            )
         | 
| 276 | 
            +
                        outs = self.p_sample_ddim(
         | 
| 277 | 
            +
                            img,
         | 
| 278 | 
            +
                            cond,
         | 
| 279 | 
            +
                            ts,
         | 
| 280 | 
            +
                            index=index,
         | 
| 281 | 
            +
                            use_original_steps=ddim_use_original_steps,
         | 
| 282 | 
            +
                            quantize_denoised=quantize_denoised,
         | 
| 283 | 
            +
                            temperature=temperature,
         | 
| 284 | 
            +
                            noise_dropout=noise_dropout,
         | 
| 285 | 
            +
                            score_corrector=score_corrector,
         | 
| 286 | 
            +
                            corrector_kwargs=corrector_kwargs,
         | 
| 287 | 
            +
                            unconditional_guidance_scale=unconditional_guidance_scale,
         | 
| 288 | 
            +
                            unconditional_conditioning=unconditional_conditioning,
         | 
| 289 | 
            +
                            x0=x0,
         | 
| 290 | 
            +
                            **kwargs,
         | 
| 291 | 
            +
                        )
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                        img, pred_x0 = outs
         | 
| 294 | 
            +
                        if callback:
         | 
| 295 | 
            +
                            callback(i)
         | 
| 296 | 
            +
                        if img_callback:
         | 
| 297 | 
            +
                            img_callback(pred_x0, i)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                        if index % log_every_t == 0 or index == total_steps - 1:
         | 
| 300 | 
            +
                            intermediates["x_inter"].append(img)
         | 
| 301 | 
            +
                            intermediates["pred_x0"].append(pred_x0)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    return img, intermediates
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                @torch.no_grad()
         | 
| 306 | 
            +
                def p_sample_ddim(
         | 
| 307 | 
            +
                    self,
         | 
| 308 | 
            +
                    x,
         | 
| 309 | 
            +
                    c,
         | 
| 310 | 
            +
                    t,
         | 
| 311 | 
            +
                    index,
         | 
| 312 | 
            +
                    repeat_noise=False,
         | 
| 313 | 
            +
                    use_original_steps=False,
         | 
| 314 | 
            +
                    quantize_denoised=False,
         | 
| 315 | 
            +
                    temperature=1.0,
         | 
| 316 | 
            +
                    noise_dropout=0.0,
         | 
| 317 | 
            +
                    score_corrector=None,
         | 
| 318 | 
            +
                    corrector_kwargs=None,
         | 
| 319 | 
            +
                    unconditional_guidance_scale=1.0,
         | 
| 320 | 
            +
                    unconditional_conditioning=None,
         | 
| 321 | 
            +
                    uc_type=None,
         | 
| 322 | 
            +
                    conditional_guidance_scale_temporal=None,
         | 
| 323 | 
            +
                    **kwargs,
         | 
| 324 | 
            +
                ):
         | 
| 325 | 
            +
                    b, *_, device = *x.shape, x.device
         | 
| 326 | 
            +
                    if x.dim() == 5:
         | 
| 327 | 
            +
                        is_video = True
         | 
| 328 | 
            +
                    else:
         | 
| 329 | 
            +
                        is_video = False
         | 
| 330 | 
            +
                    if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
         | 
| 331 | 
            +
                        e_t = self.model.apply_model(x, t, c, **kwargs)  # unet denoiser
         | 
| 332 | 
            +
                    else:
         | 
| 333 | 
            +
                        # with unconditional condition
         | 
| 334 | 
            +
                        if isinstance(c, torch.Tensor):
         | 
| 335 | 
            +
                            e_t = self.model.apply_model(x, t, c, **kwargs)
         | 
| 336 | 
            +
                            e_t_uncond = self.model.apply_model(
         | 
| 337 | 
            +
                                x, t, unconditional_conditioning, **kwargs
         | 
| 338 | 
            +
                            )
         | 
| 339 | 
            +
                        elif isinstance(c, dict):
         | 
| 340 | 
            +
                            e_t = self.model.apply_model(x, t, c, **kwargs)
         | 
| 341 | 
            +
                            e_t_uncond = self.model.apply_model(
         | 
| 342 | 
            +
                                x, t, unconditional_conditioning, **kwargs
         | 
| 343 | 
            +
                            )
         | 
| 344 | 
            +
                        else:
         | 
| 345 | 
            +
                            raise NotImplementedError
         | 
| 346 | 
            +
                        # text cfg
         | 
| 347 | 
            +
                        if uc_type is None:
         | 
| 348 | 
            +
                            e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
         | 
| 349 | 
            +
                        else:
         | 
| 350 | 
            +
                            if uc_type == "cfg_original":
         | 
| 351 | 
            +
                                e_t = e_t + unconditional_guidance_scale * (e_t - e_t_uncond)
         | 
| 352 | 
            +
                            elif uc_type == "cfg_ours":
         | 
| 353 | 
            +
                                e_t = e_t + unconditional_guidance_scale * (e_t_uncond - e_t)
         | 
| 354 | 
            +
                            else:
         | 
| 355 | 
            +
                                raise NotImplementedError
         | 
| 356 | 
            +
                        # temporal guidance
         | 
| 357 | 
            +
                        if conditional_guidance_scale_temporal is not None:
         | 
| 358 | 
            +
                            e_t_temporal = self.model.apply_model(x, t, c, **kwargs)
         | 
| 359 | 
            +
                            e_t_image = self.model.apply_model(
         | 
| 360 | 
            +
                                x, t, c, no_temporal_attn=True, **kwargs
         | 
| 361 | 
            +
                            )
         | 
| 362 | 
            +
                            e_t = e_t + conditional_guidance_scale_temporal * (
         | 
| 363 | 
            +
                                e_t_temporal - e_t_image
         | 
| 364 | 
            +
                            )
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    if score_corrector is not None:
         | 
| 367 | 
            +
                        assert self.model.parameterization == "eps"
         | 
| 368 | 
            +
                        e_t = score_corrector.modify_score(
         | 
| 369 | 
            +
                            self.model, e_t, x, t, c, **corrector_kwargs
         | 
| 370 | 
            +
                        )
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
         | 
| 373 | 
            +
                    alphas_prev = (
         | 
| 374 | 
            +
                        self.model.alphas_cumprod_prev
         | 
| 375 | 
            +
                        if use_original_steps
         | 
| 376 | 
            +
                        else self.ddim_alphas_prev
         | 
| 377 | 
            +
                    )
         | 
| 378 | 
            +
                    sqrt_one_minus_alphas = (
         | 
| 379 | 
            +
                        self.model.sqrt_one_minus_alphas_cumprod
         | 
| 380 | 
            +
                        if use_original_steps
         | 
| 381 | 
            +
                        else self.ddim_sqrt_one_minus_alphas
         | 
| 382 | 
            +
                    )
         | 
| 383 | 
            +
                    sigmas = (
         | 
| 384 | 
            +
                        self.model.ddim_sigmas_for_original_num_steps
         | 
| 385 | 
            +
                        if use_original_steps
         | 
| 386 | 
            +
                        else self.ddim_sigmas
         | 
| 387 | 
            +
                    )
         | 
| 388 | 
            +
                    # select parameters corresponding to the currently considered timestep
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    if is_video:
         | 
| 391 | 
            +
                        size = (b, 1, 1, 1, 1)
         | 
| 392 | 
            +
                    else:
         | 
| 393 | 
            +
                        size = (b, 1, 1, 1)
         | 
| 394 | 
            +
                    a_t = torch.full(size, alphas[index], device=device)
         | 
| 395 | 
            +
                    a_prev = torch.full(size, alphas_prev[index], device=device)
         | 
| 396 | 
            +
                    sigma_t = torch.full(size, sigmas[index], device=device)
         | 
| 397 | 
            +
                    sqrt_one_minus_at = torch.full(
         | 
| 398 | 
            +
                        size, sqrt_one_minus_alphas[index], device=device
         | 
| 399 | 
            +
                    )
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    # current prediction for x_0
         | 
| 402 | 
            +
                    pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
         | 
| 403 | 
            +
                    if quantize_denoised:
         | 
| 404 | 
            +
                        pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
         | 
| 405 | 
            +
                    # direction pointing to x_t
         | 
| 406 | 
            +
                    dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
         | 
| 409 | 
            +
                    if noise_dropout > 0.0:
         | 
| 410 | 
            +
                        noise = torch.nn.functional.dropout(noise, p=noise_dropout)
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
         | 
| 413 | 
            +
                    if self.use_scale:
         | 
| 414 | 
            +
                        scale_arr = (
         | 
| 415 | 
            +
                            self.model.scale_arr if use_original_steps else self.ddim_scale_arr
         | 
| 416 | 
            +
                        )
         | 
| 417 | 
            +
                        scale_t = torch.full(size, scale_arr[index], device=device)
         | 
| 418 | 
            +
                        scale_arr_prev = (
         | 
| 419 | 
            +
                            self.model.scale_arr_prev
         | 
| 420 | 
            +
                            if use_original_steps
         | 
| 421 | 
            +
                            else self.ddim_scale_arr_prev
         | 
| 422 | 
            +
                        )
         | 
| 423 | 
            +
                        scale_t_prev = torch.full(size, scale_arr_prev[index], device=device)
         | 
| 424 | 
            +
                        pred_x0 /= scale_t
         | 
| 425 | 
            +
                        x_prev = a_prev.sqrt() * scale_t_prev * pred_x0 + dir_xt + noise
         | 
| 426 | 
            +
                    else:
         | 
| 427 | 
            +
                        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    return x_prev, pred_x0
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                @torch.no_grad()
         | 
| 432 | 
            +
                def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
         | 
| 433 | 
            +
                    # fast, but does not allow for exact reconstruction
         | 
| 434 | 
            +
                    # t serves as an index to gather the correct alphas
         | 
| 435 | 
            +
                    if use_original_steps:
         | 
| 436 | 
            +
                        sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
         | 
| 437 | 
            +
                        sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
         | 
| 438 | 
            +
                    else:
         | 
| 439 | 
            +
                        sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
         | 
| 440 | 
            +
                        sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    if noise is None:
         | 
| 443 | 
            +
                        noise = torch.randn_like(x0)
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    def extract_into_tensor(a, t, x_shape):
         | 
| 446 | 
            +
                        b, *_ = t.shape
         | 
| 447 | 
            +
                        out = a.gather(-1, t)
         | 
| 448 | 
            +
                        return out.reshape(b, *((1,) * (len(x_shape) - 1)))
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    return (
         | 
| 451 | 
            +
                        extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
         | 
| 452 | 
            +
                        + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
         | 
| 453 | 
            +
                    )
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                @torch.no_grad()
         | 
| 456 | 
            +
                def decode(
         | 
| 457 | 
            +
                    self,
         | 
| 458 | 
            +
                    x_latent,
         | 
| 459 | 
            +
                    cond,
         | 
| 460 | 
            +
                    t_start,
         | 
| 461 | 
            +
                    unconditional_guidance_scale=1.0,
         | 
| 462 | 
            +
                    unconditional_conditioning=None,
         | 
| 463 | 
            +
                    use_original_steps=False,
         | 
| 464 | 
            +
                ):
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    timesteps = (
         | 
| 467 | 
            +
                        np.arange(self.ddpm_num_timesteps)
         | 
| 468 | 
            +
                        if use_original_steps
         | 
| 469 | 
            +
                        else self.ddim_timesteps
         | 
| 470 | 
            +
                    )
         | 
| 471 | 
            +
                    timesteps = timesteps[:t_start]
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    time_range = np.flip(timesteps)
         | 
| 474 | 
            +
                    total_steps = timesteps.shape[0]
         | 
| 475 | 
            +
                    print(f"Running DDIM Sampling with {total_steps} timesteps")
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
         | 
| 478 | 
            +
                    x_dec = x_latent
         | 
| 479 | 
            +
                    for i, step in enumerate(iterator):
         | 
| 480 | 
            +
                        index = total_steps - i - 1
         | 
| 481 | 
            +
                        ts = torch.full(
         | 
| 482 | 
            +
                            (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
         | 
| 483 | 
            +
                        )
         | 
| 484 | 
            +
                        x_dec, _ = self.p_sample_ddim(
         | 
| 485 | 
            +
                            x_dec,
         | 
| 486 | 
            +
                            cond,
         | 
| 487 | 
            +
                            ts,
         | 
| 488 | 
            +
                            index=index,
         | 
| 489 | 
            +
                            use_original_steps=use_original_steps,
         | 
| 490 | 
            +
                            unconditional_guidance_scale=unconditional_guidance_scale,
         | 
| 491 | 
            +
                            unconditional_conditioning=unconditional_conditioning,
         | 
| 492 | 
            +
                        )
         | 
| 493 | 
            +
                    return x_dec
         | 
    	
        lvdm/models/utils_diffusion.py
    ADDED
    
    | @@ -0,0 +1,130 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from einops import repeat
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                Create sinusoidal timestep embeddings.
         | 
| 11 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 12 | 
            +
                                  These may be fractional.
         | 
| 13 | 
            +
                :param dim: the dimension of the output.
         | 
| 14 | 
            +
                :param max_period: controls the minimum frequency of the embeddings.
         | 
| 15 | 
            +
                :return: an [N x dim] Tensor of positional embeddings.
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                if not repeat_only:
         | 
| 18 | 
            +
                    half = dim // 2
         | 
| 19 | 
            +
                    freqs = torch.exp(
         | 
| 20 | 
            +
                        -math.log(max_period)
         | 
| 21 | 
            +
                        * torch.arange(start=0, end=half, dtype=torch.float32)
         | 
| 22 | 
            +
                        / half
         | 
| 23 | 
            +
                    ).to(device=timesteps.device)
         | 
| 24 | 
            +
                    args = timesteps[:, None].float() * freqs[None]
         | 
| 25 | 
            +
                    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         | 
| 26 | 
            +
                    if dim % 2:
         | 
| 27 | 
            +
                        embedding = torch.cat(
         | 
| 28 | 
            +
                            [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
         | 
| 29 | 
            +
                        )
         | 
| 30 | 
            +
                else:
         | 
| 31 | 
            +
                    embedding = repeat(timesteps, "b -> b d", d=dim)
         | 
| 32 | 
            +
                return embedding
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def make_beta_schedule(
         | 
| 36 | 
            +
                schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
         | 
| 37 | 
            +
            ):
         | 
| 38 | 
            +
                if schedule == "linear":
         | 
| 39 | 
            +
                    betas = (
         | 
| 40 | 
            +
                        torch.linspace(
         | 
| 41 | 
            +
                            linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
         | 
| 42 | 
            +
                        )
         | 
| 43 | 
            +
                        ** 2
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                elif schedule == "cosine":
         | 
| 47 | 
            +
                    timesteps = (
         | 
| 48 | 
            +
                        torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
                    alphas = timesteps / (1 + cosine_s) * np.pi / 2
         | 
| 51 | 
            +
                    alphas = torch.cos(alphas).pow(2)
         | 
| 52 | 
            +
                    alphas = alphas / alphas[0]
         | 
| 53 | 
            +
                    betas = 1 - alphas[1:] / alphas[:-1]
         | 
| 54 | 
            +
                    betas = np.clip(betas, a_min=0, a_max=0.999)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                elif schedule == "sqrt_linear":
         | 
| 57 | 
            +
                    betas = torch.linspace(
         | 
| 58 | 
            +
                        linear_start, linear_end, n_timestep, dtype=torch.float64
         | 
| 59 | 
            +
                    )
         | 
| 60 | 
            +
                elif schedule == "sqrt":
         | 
| 61 | 
            +
                    betas = (
         | 
| 62 | 
            +
                        torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
         | 
| 63 | 
            +
                        ** 0.5
         | 
| 64 | 
            +
                    )
         | 
| 65 | 
            +
                else:
         | 
| 66 | 
            +
                    raise ValueError(f"schedule '{schedule}' unknown.")
         | 
| 67 | 
            +
                return betas.numpy()
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def make_ddim_timesteps(
         | 
| 71 | 
            +
                ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
         | 
| 72 | 
            +
            ):
         | 
| 73 | 
            +
                if ddim_discr_method == "uniform":
         | 
| 74 | 
            +
                    c = num_ddpm_timesteps // num_ddim_timesteps
         | 
| 75 | 
            +
                    ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
         | 
| 76 | 
            +
                elif ddim_discr_method == "quad":
         | 
| 77 | 
            +
                    ddim_timesteps = (
         | 
| 78 | 
            +
                        (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
         | 
| 79 | 
            +
                    ).astype(int)
         | 
| 80 | 
            +
                else:
         | 
| 81 | 
            +
                    raise NotImplementedError(
         | 
| 82 | 
            +
                        f'There is no ddim discretization method called "{ddim_discr_method}"'
         | 
| 83 | 
            +
                    )
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                # assert ddim_timesteps.shape[0] == num_ddim_timesteps
         | 
| 86 | 
            +
                # add one to get the final alpha values right (the ones from first scale to data during sampling)
         | 
| 87 | 
            +
                steps_out = ddim_timesteps + 1
         | 
| 88 | 
            +
                if verbose:
         | 
| 89 | 
            +
                    print(f"Selected timesteps for ddim sampler: {steps_out}")
         | 
| 90 | 
            +
                return steps_out
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
         | 
| 94 | 
            +
                # select alphas for computing the variance schedule
         | 
| 95 | 
            +
                # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}')
         | 
| 96 | 
            +
                alphas = alphacums[ddim_timesteps]
         | 
| 97 | 
            +
                alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # according the the formula provided in https://arxiv.org/abs/2010.02502
         | 
| 100 | 
            +
                sigmas = eta * np.sqrt(
         | 
| 101 | 
            +
                    (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
         | 
| 102 | 
            +
                )
         | 
| 103 | 
            +
                if verbose:
         | 
| 104 | 
            +
                    print(
         | 
| 105 | 
            +
                        f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
                    print(
         | 
| 108 | 
            +
                        f"For the chosen value of eta, which is {eta}, "
         | 
| 109 | 
            +
                        f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
         | 
| 110 | 
            +
                    )
         | 
| 111 | 
            +
                return sigmas, alphas, alphas_prev
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
         | 
| 115 | 
            +
                """
         | 
| 116 | 
            +
                Create a beta schedule that discretizes the given alpha_t_bar function,
         | 
| 117 | 
            +
                which defines the cumulative product of (1-beta) over time from t = [0,1].
         | 
| 118 | 
            +
                :param num_diffusion_timesteps: the number of betas to produce.
         | 
| 119 | 
            +
                :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
         | 
| 120 | 
            +
                                  produces the cumulative product of (1-beta) up to that
         | 
| 121 | 
            +
                                  part of the diffusion process.
         | 
| 122 | 
            +
                :param max_beta: the maximum beta to use; use values lower than 1 to
         | 
| 123 | 
            +
                                 prevent singularities.
         | 
| 124 | 
            +
                """
         | 
| 125 | 
            +
                betas = []
         | 
| 126 | 
            +
                for i in range(num_diffusion_timesteps):
         | 
| 127 | 
            +
                    t1 = i / num_diffusion_timesteps
         | 
| 128 | 
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         | 
| 129 | 
            +
                    betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
         | 
| 130 | 
            +
                return np.array(betas)
         | 
    	
        lvdm/modules/__pycache__/attention.cpython-312.pyc
    ADDED
    
    | Binary file (27.6 kB). View file | 
|  | 
    	
        lvdm/modules/attention.py
    ADDED
    
    | @@ -0,0 +1,612 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from functools import partial
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from torch import nn, einsum
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from einops import rearrange, repeat
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            try:
         | 
| 8 | 
            +
                import xformers
         | 
| 9 | 
            +
                import xformers.ops
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                XFORMERS_IS_AVAILBLE = True
         | 
| 12 | 
            +
            except:
         | 
| 13 | 
            +
                XFORMERS_IS_AVAILBLE = False
         | 
| 14 | 
            +
            from lvdm.common import (
         | 
| 15 | 
            +
                checkpoint,
         | 
| 16 | 
            +
                exists,
         | 
| 17 | 
            +
                default,
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
            from lvdm.basics import (
         | 
| 20 | 
            +
                zero_module,
         | 
| 21 | 
            +
            )
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class RelativePosition(nn.Module):
         | 
| 25 | 
            +
                """https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py"""
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def __init__(self, num_units, max_relative_position):
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
                    self.num_units = num_units
         | 
| 30 | 
            +
                    self.max_relative_position = max_relative_position
         | 
| 31 | 
            +
                    self.embeddings_table = nn.Parameter(
         | 
| 32 | 
            +
                        torch.Tensor(max_relative_position * 2 + 1, num_units)
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
                    nn.init.xavier_uniform_(self.embeddings_table)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def forward(self, length_q, length_k):
         | 
| 37 | 
            +
                    device = self.embeddings_table.device
         | 
| 38 | 
            +
                    range_vec_q = torch.arange(length_q, device=device)
         | 
| 39 | 
            +
                    range_vec_k = torch.arange(length_k, device=device)
         | 
| 40 | 
            +
                    distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
         | 
| 41 | 
            +
                    distance_mat_clipped = torch.clamp(
         | 
| 42 | 
            +
                        distance_mat, -self.max_relative_position, self.max_relative_position
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    final_mat = distance_mat_clipped + self.max_relative_position
         | 
| 45 | 
            +
                    final_mat = final_mat.long()
         | 
| 46 | 
            +
                    embeddings = self.embeddings_table[final_mat]
         | 
| 47 | 
            +
                    return embeddings
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class CrossAttention(nn.Module):
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def __init__(
         | 
| 53 | 
            +
                    self,
         | 
| 54 | 
            +
                    query_dim,
         | 
| 55 | 
            +
                    context_dim=None,
         | 
| 56 | 
            +
                    heads=8,
         | 
| 57 | 
            +
                    dim_head=64,
         | 
| 58 | 
            +
                    dropout=0.0,
         | 
| 59 | 
            +
                    relative_position=False,
         | 
| 60 | 
            +
                    temporal_length=None,
         | 
| 61 | 
            +
                    img_cross_attention=False,
         | 
| 62 | 
            +
                    record_attn_probs=False,
         | 
| 63 | 
            +
                ):
         | 
| 64 | 
            +
                    super().__init__()
         | 
| 65 | 
            +
                    inner_dim = dim_head * heads
         | 
| 66 | 
            +
                    context_dim = default(context_dim, query_dim)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    self.scale = dim_head**-0.5
         | 
| 69 | 
            +
                    self.heads = heads
         | 
| 70 | 
            +
                    self.dim_head = dim_head
         | 
| 71 | 
            +
                    self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
         | 
| 72 | 
            +
                    self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 73 | 
            +
                    self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 74 | 
            +
                    self.to_out = nn.Sequential(
         | 
| 75 | 
            +
                        nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
         | 
| 76 | 
            +
                    )
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    self.image_cross_attention_scale = 1.0
         | 
| 79 | 
            +
                    self.text_context_len = 200
         | 
| 80 | 
            +
                    self.img_cross_attention = img_cross_attention
         | 
| 81 | 
            +
                    if self.img_cross_attention:
         | 
| 82 | 
            +
                        self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 83 | 
            +
                        self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    self.relative_position = relative_position
         | 
| 86 | 
            +
                    if self.relative_position:
         | 
| 87 | 
            +
                        assert temporal_length is not None
         | 
| 88 | 
            +
                        self.relative_position_k = RelativePosition(
         | 
| 89 | 
            +
                            num_units=dim_head, max_relative_position=temporal_length
         | 
| 90 | 
            +
                        )
         | 
| 91 | 
            +
                        self.relative_position_v = RelativePosition(
         | 
| 92 | 
            +
                            num_units=dim_head, max_relative_position=temporal_length
         | 
| 93 | 
            +
                        )
         | 
| 94 | 
            +
                    else:
         | 
| 95 | 
            +
                        ## only used for spatial attention, while NOT for temporal attention
         | 
| 96 | 
            +
                        if XFORMERS_IS_AVAILBLE and temporal_length is None:
         | 
| 97 | 
            +
                            self.forward = self.efficient_forward
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.record_attn_probs = record_attn_probs
         | 
| 100 | 
            +
                    self.attention_probs = None
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def forward(self, x, context=None, mask=None):
         | 
| 103 | 
            +
                    h = self.heads
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    q = self.to_q(x)
         | 
| 106 | 
            +
                    context = default(context, x)
         | 
| 107 | 
            +
                    ## considering image token additionally
         | 
| 108 | 
            +
                    if context is not None and self.img_cross_attention:
         | 
| 109 | 
            +
                        context, context_img = (
         | 
| 110 | 
            +
                            context[:, : self.text_context_len, :],
         | 
| 111 | 
            +
                            context[:, self.text_context_len :, :],
         | 
| 112 | 
            +
                        )
         | 
| 113 | 
            +
                        k = self.to_k(context)
         | 
| 114 | 
            +
                        v = self.to_v(context)
         | 
| 115 | 
            +
                        k_ip = self.to_k_ip(context_img)
         | 
| 116 | 
            +
                        v_ip = self.to_v_ip(context_img)
         | 
| 117 | 
            +
                    else:
         | 
| 118 | 
            +
                        k = self.to_k(context)
         | 
| 119 | 
            +
                        v = self.to_v(context)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # Record the attention probs
         | 
| 124 | 
            +
                    if self.record_attn_probs:
         | 
| 125 | 
            +
                        attention_score = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
         | 
| 126 | 
            +
                        self.attention_probs = attention_score.softmax(dim=-1)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
         | 
| 129 | 
            +
                    if self.relative_position:
         | 
| 130 | 
            +
                        len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
         | 
| 131 | 
            +
                        k2 = self.relative_position_k(len_q, len_k)
         | 
| 132 | 
            +
                        sim2 = einsum("b t d, t s d -> b t s", q, k2) * self.scale  # TODO check
         | 
| 133 | 
            +
                        sim += sim2
         | 
| 134 | 
            +
                    del k
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    if exists(mask):
         | 
| 137 | 
            +
                        ## feasible for causal attention mask only
         | 
| 138 | 
            +
                        max_neg_value = -torch.finfo(sim.dtype).max
         | 
| 139 | 
            +
                        mask = repeat(mask, "b i j -> (b h) i j", h=h)
         | 
| 140 | 
            +
                        sim.masked_fill_(~(mask > 0.5), max_neg_value)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    # attention, what we cannot get enough of
         | 
| 143 | 
            +
                    sim = sim.softmax(dim=-1)
         | 
| 144 | 
            +
                    out = torch.einsum("b i j, b j d -> b i d", sim, v)
         | 
| 145 | 
            +
                    if self.relative_position:
         | 
| 146 | 
            +
                        v2 = self.relative_position_v(len_q, len_v)
         | 
| 147 | 
            +
                        out2 = einsum("b t s, t s d -> b t d", sim, v2)  # TODO check
         | 
| 148 | 
            +
                        out += out2
         | 
| 149 | 
            +
                    out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    ## considering image token additionally
         | 
| 152 | 
            +
                    if context is not None and self.img_cross_attention:
         | 
| 153 | 
            +
                        k_ip, v_ip = map(
         | 
| 154 | 
            +
                            lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (k_ip, v_ip)
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
                        sim_ip = torch.einsum("b i d, b j d -> b i j", q, k_ip) * self.scale
         | 
| 157 | 
            +
                        del k_ip
         | 
| 158 | 
            +
                        sim_ip = sim_ip.softmax(dim=-1)
         | 
| 159 | 
            +
                        out_ip = torch.einsum("b i j, b j d -> b i d", sim_ip, v_ip)
         | 
| 160 | 
            +
                        out_ip = rearrange(out_ip, "(b h) n d -> b n (h d)", h=h)
         | 
| 161 | 
            +
                        out = out + self.image_cross_attention_scale * out_ip
         | 
| 162 | 
            +
                    del q
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    return self.to_out(out)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                def efficient_forward(self, x, context=None, mask=None):
         | 
| 167 | 
            +
                    q = self.to_q(x)
         | 
| 168 | 
            +
                    context = default(context, x)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    ## considering image token additionally
         | 
| 171 | 
            +
                    if context is not None and self.img_cross_attention:
         | 
| 172 | 
            +
                        context, context_img = (
         | 
| 173 | 
            +
                            context[:, : self.text_context_len, :],
         | 
| 174 | 
            +
                            context[:, self.text_context_len :, :],
         | 
| 175 | 
            +
                        )
         | 
| 176 | 
            +
                        k = self.to_k(context)
         | 
| 177 | 
            +
                        v = self.to_v(context)
         | 
| 178 | 
            +
                        k_ip = self.to_k_ip(context_img)
         | 
| 179 | 
            +
                        v_ip = self.to_v_ip(context_img)
         | 
| 180 | 
            +
                    else:
         | 
| 181 | 
            +
                        k = self.to_k(context)
         | 
| 182 | 
            +
                        v = self.to_v(context)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    b, _, _ = q.shape
         | 
| 185 | 
            +
                    # Record the attention probs
         | 
| 186 | 
            +
                    if self.record_attn_probs:
         | 
| 187 | 
            +
                        q, k, v = map(
         | 
| 188 | 
            +
                            lambda t: t.unsqueeze(3)
         | 
| 189 | 
            +
                            .reshape(b, t.shape[1], self.heads, self.dim_head)
         | 
| 190 | 
            +
                            .permute(0, 2, 1, 3)
         | 
| 191 | 
            +
                            .reshape(b * self.heads, t.shape[1], self.dim_head)
         | 
| 192 | 
            +
                            .contiguous(),
         | 
| 193 | 
            +
                            (q, k, v),
         | 
| 194 | 
            +
                        )
         | 
| 195 | 
            +
                        attention_score = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
         | 
| 196 | 
            +
                        self.attention_probs = attention_score.softmax(dim=-1)
         | 
| 197 | 
            +
                    else:
         | 
| 198 | 
            +
                        q, k, v = map(
         | 
| 199 | 
            +
                            lambda t: t.unsqueeze(3)
         | 
| 200 | 
            +
                            .reshape(b, t.shape[1], self.heads, self.dim_head)
         | 
| 201 | 
            +
                            .contiguous(),
         | 
| 202 | 
            +
                            (q, k, v),
         | 
| 203 | 
            +
                        )
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # actually compute the attention, what we cannot get enough of
         | 
| 206 | 
            +
                    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None)
         | 
| 207 | 
            +
                    if not self.record_attn_probs:
         | 
| 208 | 
            +
                        out = out.permute(0, 2, 1, 3).reshape(b * self.heads, out.shape[1], self.dim_head)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    ## considering image token additionally
         | 
| 211 | 
            +
                    if context is not None and self.img_cross_attention:
         | 
| 212 | 
            +
                        k_ip, v_ip = map(
         | 
| 213 | 
            +
                            lambda t: t.unsqueeze(3)
         | 
| 214 | 
            +
                            .reshape(b, t.shape[1], self.heads, self.dim_head)
         | 
| 215 | 
            +
                            .permute(0, 2, 1, 3)
         | 
| 216 | 
            +
                            .reshape(b * self.heads, t.shape[1], self.dim_head)
         | 
| 217 | 
            +
                            .contiguous(),
         | 
| 218 | 
            +
                            (k_ip, v_ip),
         | 
| 219 | 
            +
                        )
         | 
| 220 | 
            +
                        out_ip = xformers.ops.memory_efficient_attention(
         | 
| 221 | 
            +
                            q, k_ip, v_ip, attn_bias=None, op=None
         | 
| 222 | 
            +
                        )
         | 
| 223 | 
            +
                        out_ip = (
         | 
| 224 | 
            +
                            out_ip.unsqueeze(0)
         | 
| 225 | 
            +
                            .reshape(b, self.heads, out.shape[1], self.dim_head)
         | 
| 226 | 
            +
                            .permute(0, 2, 1, 3)
         | 
| 227 | 
            +
                            .reshape(b, out.shape[1], self.heads * self.dim_head)
         | 
| 228 | 
            +
                        )
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    if exists(mask):
         | 
| 231 | 
            +
                        raise NotImplementedError
         | 
| 232 | 
            +
                    out = (
         | 
| 233 | 
            +
                        out.unsqueeze(0)
         | 
| 234 | 
            +
                        .reshape(b, self.heads, out.shape[1], self.dim_head)
         | 
| 235 | 
            +
                        .permute(0, 2, 1, 3)
         | 
| 236 | 
            +
                        .reshape(b, out.shape[1], self.heads * self.dim_head)
         | 
| 237 | 
            +
                    )
         | 
| 238 | 
            +
                    if context is not None and self.img_cross_attention:
         | 
| 239 | 
            +
                        out = out + self.image_cross_attention_scale * out_ip
         | 
| 240 | 
            +
                    return self.to_out(out)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                def __init__(
         | 
| 246 | 
            +
                    self,
         | 
| 247 | 
            +
                    dim,
         | 
| 248 | 
            +
                    n_heads,
         | 
| 249 | 
            +
                    d_head,
         | 
| 250 | 
            +
                    dropout=0.0,
         | 
| 251 | 
            +
                    context_dim=None,
         | 
| 252 | 
            +
                    gated_ff=True,
         | 
| 253 | 
            +
                    checkpoint=True,
         | 
| 254 | 
            +
                    disable_self_attn=False,
         | 
| 255 | 
            +
                    attention_cls=None,
         | 
| 256 | 
            +
                    img_cross_attention=False,
         | 
| 257 | 
            +
                    record_attn_probs=False,
         | 
| 258 | 
            +
                ):
         | 
| 259 | 
            +
                    super().__init__()
         | 
| 260 | 
            +
                    attn_cls = CrossAttention if attention_cls is None else attention_cls
         | 
| 261 | 
            +
                    self.disable_self_attn = disable_self_attn
         | 
| 262 | 
            +
                    self.attn1 = attn_cls(
         | 
| 263 | 
            +
                        query_dim=dim,
         | 
| 264 | 
            +
                        heads=n_heads,
         | 
| 265 | 
            +
                        dim_head=d_head,
         | 
| 266 | 
            +
                        dropout=dropout,
         | 
| 267 | 
            +
                        context_dim=context_dim if self.disable_self_attn else None,
         | 
| 268 | 
            +
                        record_attn_probs=record_attn_probs,
         | 
| 269 | 
            +
                    )
         | 
| 270 | 
            +
                    self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
         | 
| 271 | 
            +
                    self.attn2 = attn_cls(
         | 
| 272 | 
            +
                        query_dim=dim,
         | 
| 273 | 
            +
                        context_dim=context_dim,
         | 
| 274 | 
            +
                        heads=n_heads,
         | 
| 275 | 
            +
                        dim_head=d_head,
         | 
| 276 | 
            +
                        dropout=dropout,
         | 
| 277 | 
            +
                        img_cross_attention=img_cross_attention,
         | 
| 278 | 
            +
                    )
         | 
| 279 | 
            +
                    self.norm1 = nn.LayerNorm(dim)
         | 
| 280 | 
            +
                    self.norm2 = nn.LayerNorm(dim)
         | 
| 281 | 
            +
                    self.norm3 = nn.LayerNorm(dim)
         | 
| 282 | 
            +
                    self.checkpoint = checkpoint
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                def forward(self, x, context=None, mask=None):
         | 
| 285 | 
            +
                    ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
         | 
| 286 | 
            +
                    input_tuple = (
         | 
| 287 | 
            +
                        x,
         | 
| 288 | 
            +
                    )  ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
         | 
| 289 | 
            +
                    if context is not None:
         | 
| 290 | 
            +
                        input_tuple = (x, context)
         | 
| 291 | 
            +
                    if mask is not None:
         | 
| 292 | 
            +
                        forward_mask = partial(self._forward, mask=mask)
         | 
| 293 | 
            +
                        return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint)
         | 
| 294 | 
            +
                    if context is not None and mask is not None:
         | 
| 295 | 
            +
                        input_tuple = (x, context, mask)
         | 
| 296 | 
            +
                    return checkpoint(
         | 
| 297 | 
            +
                        self._forward, input_tuple, self.parameters(), self.checkpoint
         | 
| 298 | 
            +
                    )
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                def _forward(self, x, context=None, mask=None):
         | 
| 301 | 
            +
                    x = (
         | 
| 302 | 
            +
                        self.attn1(
         | 
| 303 | 
            +
                            self.norm1(x),
         | 
| 304 | 
            +
                            context=context if self.disable_self_attn else None,
         | 
| 305 | 
            +
                            mask=mask,
         | 
| 306 | 
            +
                        )
         | 
| 307 | 
            +
                        + x
         | 
| 308 | 
            +
                    )
         | 
| 309 | 
            +
                    x = self.attn2(self.norm2(x), context=context, mask=mask) + x
         | 
| 310 | 
            +
                    x = self.ff(self.norm3(x)) + x
         | 
| 311 | 
            +
                    return x
         | 
| 312 | 
            +
             | 
| 313 | 
            +
             | 
| 314 | 
            +
            class SpatialTransformer(nn.Module):
         | 
| 315 | 
            +
                """
         | 
| 316 | 
            +
                Transformer block for image-like data in spatial axis.
         | 
| 317 | 
            +
                First, project the input (aka embedding)
         | 
| 318 | 
            +
                and reshape to b, t, d.
         | 
| 319 | 
            +
                Then apply standard transformer action.
         | 
| 320 | 
            +
                Finally, reshape to image
         | 
| 321 | 
            +
                NEW: use_linear for more efficiency instead of the 1x1 convs
         | 
| 322 | 
            +
                """
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                def __init__(
         | 
| 325 | 
            +
                    self,
         | 
| 326 | 
            +
                    in_channels,
         | 
| 327 | 
            +
                    n_heads,
         | 
| 328 | 
            +
                    d_head,
         | 
| 329 | 
            +
                    depth=1,
         | 
| 330 | 
            +
                    dropout=0.0,
         | 
| 331 | 
            +
                    context_dim=None,
         | 
| 332 | 
            +
                    use_checkpoint=True,
         | 
| 333 | 
            +
                    disable_self_attn=False,
         | 
| 334 | 
            +
                    use_linear=False,
         | 
| 335 | 
            +
                    img_cross_attention=False,
         | 
| 336 | 
            +
                ):
         | 
| 337 | 
            +
                    super().__init__()
         | 
| 338 | 
            +
                    self.in_channels = in_channels
         | 
| 339 | 
            +
                    inner_dim = n_heads * d_head
         | 
| 340 | 
            +
                    self.norm = torch.nn.GroupNorm(
         | 
| 341 | 
            +
                        num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
         | 
| 342 | 
            +
                    )
         | 
| 343 | 
            +
                    if not use_linear:
         | 
| 344 | 
            +
                        self.proj_in = nn.Conv2d(
         | 
| 345 | 
            +
                            in_channels, inner_dim, kernel_size=1, stride=1, padding=0
         | 
| 346 | 
            +
                        )
         | 
| 347 | 
            +
                    else:
         | 
| 348 | 
            +
                        self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 351 | 
            +
                        [
         | 
| 352 | 
            +
                            BasicTransformerBlock(
         | 
| 353 | 
            +
                                inner_dim,
         | 
| 354 | 
            +
                                n_heads,
         | 
| 355 | 
            +
                                d_head,
         | 
| 356 | 
            +
                                dropout=dropout,
         | 
| 357 | 
            +
                                context_dim=context_dim,
         | 
| 358 | 
            +
                                img_cross_attention=img_cross_attention,
         | 
| 359 | 
            +
                                disable_self_attn=disable_self_attn,
         | 
| 360 | 
            +
                                checkpoint=use_checkpoint,
         | 
| 361 | 
            +
                            )
         | 
| 362 | 
            +
                            for d in range(depth)
         | 
| 363 | 
            +
                        ]
         | 
| 364 | 
            +
                    )
         | 
| 365 | 
            +
                    if not use_linear:
         | 
| 366 | 
            +
                        self.proj_out = zero_module(
         | 
| 367 | 
            +
                            nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 368 | 
            +
                        )
         | 
| 369 | 
            +
                    else:
         | 
| 370 | 
            +
                        self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
         | 
| 371 | 
            +
                    self.use_linear = use_linear
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                def forward(self, x, context=None):
         | 
| 374 | 
            +
                    b, c, h, w = x.shape
         | 
| 375 | 
            +
                    x_in = x
         | 
| 376 | 
            +
                    x = self.norm(x)
         | 
| 377 | 
            +
                    if not self.use_linear:
         | 
| 378 | 
            +
                        x = self.proj_in(x)
         | 
| 379 | 
            +
                    x = rearrange(x, "b c h w -> b (h w) c").contiguous()
         | 
| 380 | 
            +
                    if self.use_linear:
         | 
| 381 | 
            +
                        x = self.proj_in(x)
         | 
| 382 | 
            +
                    for i, block in enumerate(self.transformer_blocks):
         | 
| 383 | 
            +
                        x = block(x, context=context)
         | 
| 384 | 
            +
                    if self.use_linear:
         | 
| 385 | 
            +
                        x = self.proj_out(x)
         | 
| 386 | 
            +
                    x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
         | 
| 387 | 
            +
                    if not self.use_linear:
         | 
| 388 | 
            +
                        x = self.proj_out(x)
         | 
| 389 | 
            +
                    return x + x_in
         | 
| 390 | 
            +
             | 
| 391 | 
            +
             | 
| 392 | 
            +
            class TemporalTransformer(nn.Module):
         | 
| 393 | 
            +
                """
         | 
| 394 | 
            +
                Transformer block for image-like data in temporal axis.
         | 
| 395 | 
            +
                First, reshape to b, t, d.
         | 
| 396 | 
            +
                Then apply standard transformer action.
         | 
| 397 | 
            +
                Finally, reshape to image
         | 
| 398 | 
            +
                """
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                def __init__(
         | 
| 401 | 
            +
                    self,
         | 
| 402 | 
            +
                    in_channels,
         | 
| 403 | 
            +
                    n_heads,
         | 
| 404 | 
            +
                    d_head,
         | 
| 405 | 
            +
                    depth=1,
         | 
| 406 | 
            +
                    dropout=0.0,
         | 
| 407 | 
            +
                    context_dim=None,
         | 
| 408 | 
            +
                    use_checkpoint=True,
         | 
| 409 | 
            +
                    use_linear=False,
         | 
| 410 | 
            +
                    only_self_att=True,
         | 
| 411 | 
            +
                    causal_attention=False,
         | 
| 412 | 
            +
                    relative_position=False,
         | 
| 413 | 
            +
                    temporal_length=None,
         | 
| 414 | 
            +
                    record_attn_probs=False,
         | 
| 415 | 
            +
                ):
         | 
| 416 | 
            +
                    super().__init__()
         | 
| 417 | 
            +
                    self.only_self_att = only_self_att
         | 
| 418 | 
            +
                    self.relative_position = relative_position
         | 
| 419 | 
            +
                    self.causal_attention = causal_attention
         | 
| 420 | 
            +
                    self.in_channels = in_channels
         | 
| 421 | 
            +
                    inner_dim = n_heads * d_head
         | 
| 422 | 
            +
                    self.norm = torch.nn.GroupNorm(
         | 
| 423 | 
            +
                        num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
         | 
| 424 | 
            +
                    )
         | 
| 425 | 
            +
                    self.proj_in = nn.Conv1d(
         | 
| 426 | 
            +
                        in_channels, inner_dim, kernel_size=1, stride=1, padding=0
         | 
| 427 | 
            +
                    )
         | 
| 428 | 
            +
                    if not use_linear:
         | 
| 429 | 
            +
                        self.proj_in = nn.Conv1d(
         | 
| 430 | 
            +
                            in_channels, inner_dim, kernel_size=1, stride=1, padding=0
         | 
| 431 | 
            +
                        )
         | 
| 432 | 
            +
                    else:
         | 
| 433 | 
            +
                        self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                    if relative_position:
         | 
| 436 | 
            +
                        assert temporal_length is not None
         | 
| 437 | 
            +
                        attention_cls = partial(
         | 
| 438 | 
            +
                            CrossAttention, relative_position=True, temporal_length=temporal_length
         | 
| 439 | 
            +
                        )
         | 
| 440 | 
            +
                    else:
         | 
| 441 | 
            +
                        attention_cls = None
         | 
| 442 | 
            +
                    if self.causal_attention:
         | 
| 443 | 
            +
                        assert temporal_length is not None
         | 
| 444 | 
            +
                        self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    if self.only_self_att:
         | 
| 447 | 
            +
                        context_dim = None
         | 
| 448 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 449 | 
            +
                        [
         | 
| 450 | 
            +
                            BasicTransformerBlock(
         | 
| 451 | 
            +
                                inner_dim,
         | 
| 452 | 
            +
                                n_heads,
         | 
| 453 | 
            +
                                d_head,
         | 
| 454 | 
            +
                                dropout=dropout,
         | 
| 455 | 
            +
                                context_dim=context_dim,
         | 
| 456 | 
            +
                                attention_cls=attention_cls,
         | 
| 457 | 
            +
                                checkpoint=use_checkpoint,
         | 
| 458 | 
            +
                                record_attn_probs=record_attn_probs,
         | 
| 459 | 
            +
                            )
         | 
| 460 | 
            +
                            for d in range(depth)
         | 
| 461 | 
            +
                        ]
         | 
| 462 | 
            +
                    )
         | 
| 463 | 
            +
                    if not use_linear:
         | 
| 464 | 
            +
                        self.proj_out = zero_module(
         | 
| 465 | 
            +
                            nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 466 | 
            +
                        )
         | 
| 467 | 
            +
                    else:
         | 
| 468 | 
            +
                        self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
         | 
| 469 | 
            +
                    self.use_linear = use_linear
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                def forward(self, x, context=None):
         | 
| 472 | 
            +
                    b, c, t, h, w = x.shape
         | 
| 473 | 
            +
                    x_in = x
         | 
| 474 | 
            +
                    x = self.norm(x)
         | 
| 475 | 
            +
                    x = rearrange(x, "b c t h w -> (b h w) c t").contiguous()
         | 
| 476 | 
            +
                    if not self.use_linear:
         | 
| 477 | 
            +
                        x = self.proj_in(x)
         | 
| 478 | 
            +
                    x = rearrange(x, "bhw c t -> bhw t c").contiguous()
         | 
| 479 | 
            +
                    if self.use_linear:
         | 
| 480 | 
            +
                        x = self.proj_in(x)
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    if self.causal_attention:
         | 
| 483 | 
            +
                        mask = self.mask.to(x.device)
         | 
| 484 | 
            +
                        mask = repeat(mask, "l i j -> (l bhw) i j", bhw=b * h * w)
         | 
| 485 | 
            +
                    else:
         | 
| 486 | 
            +
                        mask = None
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    if self.only_self_att:
         | 
| 489 | 
            +
                        ## note: if no context is given, cross-attention defaults to self-attention
         | 
| 490 | 
            +
                        for i, block in enumerate(self.transformer_blocks):
         | 
| 491 | 
            +
                            x = block(x, mask=mask)
         | 
| 492 | 
            +
                        x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous()
         | 
| 493 | 
            +
                    else:
         | 
| 494 | 
            +
                        x = rearrange(x, "(b hw) t c -> b hw t c", b=b).contiguous()
         | 
| 495 | 
            +
                        context = rearrange(context, "(b t) l con -> b t l con", t=t).contiguous()
         | 
| 496 | 
            +
                        for i, block in enumerate(self.transformer_blocks):
         | 
| 497 | 
            +
                            # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
         | 
| 498 | 
            +
                            for j in range(b):
         | 
| 499 | 
            +
                                context_j = repeat(
         | 
| 500 | 
            +
                                    context[j], "t l con -> (t r) l con", r=(h * w) // t, t=t
         | 
| 501 | 
            +
                                ).contiguous()
         | 
| 502 | 
            +
                                ## note: causal mask will not applied in cross-attention case
         | 
| 503 | 
            +
                                x[j] = block(x[j], context=context_j)
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    if self.use_linear:
         | 
| 506 | 
            +
                        x = self.proj_out(x)
         | 
| 507 | 
            +
                        x = rearrange(x, "b (h w) t c -> b c t h w", h=h, w=w).contiguous()
         | 
| 508 | 
            +
                    if not self.use_linear:
         | 
| 509 | 
            +
                        x = rearrange(x, "b hw t c -> (b hw) c t").contiguous()
         | 
| 510 | 
            +
                        x = self.proj_out(x)
         | 
| 511 | 
            +
                        x = rearrange(x, "(b h w) c t -> b c t h w", b=b, h=h, w=w).contiguous()
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    return x + x_in
         | 
| 514 | 
            +
             | 
| 515 | 
            +
             | 
| 516 | 
            +
            class GEGLU(nn.Module):
         | 
| 517 | 
            +
                def __init__(self, dim_in, dim_out):
         | 
| 518 | 
            +
                    super().__init__()
         | 
| 519 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out * 2)
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                def forward(self, x):
         | 
| 522 | 
            +
                    x, gate = self.proj(x).chunk(2, dim=-1)
         | 
| 523 | 
            +
                    return x * F.gelu(gate)
         | 
| 524 | 
            +
             | 
| 525 | 
            +
             | 
| 526 | 
            +
            class FeedForward(nn.Module):
         | 
| 527 | 
            +
                def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
         | 
| 528 | 
            +
                    super().__init__()
         | 
| 529 | 
            +
                    inner_dim = int(dim * mult)
         | 
| 530 | 
            +
                    dim_out = default(dim_out, dim)
         | 
| 531 | 
            +
                    project_in = (
         | 
| 532 | 
            +
                        nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
         | 
| 533 | 
            +
                        if not glu
         | 
| 534 | 
            +
                        else GEGLU(dim, inner_dim)
         | 
| 535 | 
            +
                    )
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    self.net = nn.Sequential(
         | 
| 538 | 
            +
                        project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
         | 
| 539 | 
            +
                    )
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                def forward(self, x):
         | 
| 542 | 
            +
                    return self.net(x)
         | 
| 543 | 
            +
             | 
| 544 | 
            +
             | 
| 545 | 
            +
            class LinearAttention(nn.Module):
         | 
| 546 | 
            +
                def __init__(self, dim, heads=4, dim_head=32):
         | 
| 547 | 
            +
                    super().__init__()
         | 
| 548 | 
            +
                    self.heads = heads
         | 
| 549 | 
            +
                    hidden_dim = dim_head * heads
         | 
| 550 | 
            +
                    self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
         | 
| 551 | 
            +
                    self.to_out = nn.Conv2d(hidden_dim, dim, 1)
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                def forward(self, x):
         | 
| 554 | 
            +
                    b, c, h, w = x.shape
         | 
| 555 | 
            +
                    qkv = self.to_qkv(x)
         | 
| 556 | 
            +
                    q, k, v = rearrange(
         | 
| 557 | 
            +
                        qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
         | 
| 558 | 
            +
                    )
         | 
| 559 | 
            +
                    k = k.softmax(dim=-1)
         | 
| 560 | 
            +
                    context = torch.einsum("bhdn,bhen->bhde", k, v)
         | 
| 561 | 
            +
                    out = torch.einsum("bhde,bhdn->bhen", context, q)
         | 
| 562 | 
            +
                    out = rearrange(
         | 
| 563 | 
            +
                        out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
         | 
| 564 | 
            +
                    )
         | 
| 565 | 
            +
                    return self.to_out(out)
         | 
| 566 | 
            +
             | 
| 567 | 
            +
             | 
| 568 | 
            +
            class SpatialSelfAttention(nn.Module):
         | 
| 569 | 
            +
                def __init__(self, in_channels):
         | 
| 570 | 
            +
                    super().__init__()
         | 
| 571 | 
            +
                    self.in_channels = in_channels
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                    self.norm = torch.nn.GroupNorm(
         | 
| 574 | 
            +
                        num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
         | 
| 575 | 
            +
                    )
         | 
| 576 | 
            +
                    self.q = torch.nn.Conv2d(
         | 
| 577 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 578 | 
            +
                    )
         | 
| 579 | 
            +
                    self.k = torch.nn.Conv2d(
         | 
| 580 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 581 | 
            +
                    )
         | 
| 582 | 
            +
                    self.v = torch.nn.Conv2d(
         | 
| 583 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 584 | 
            +
                    )
         | 
| 585 | 
            +
                    self.proj_out = torch.nn.Conv2d(
         | 
| 586 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 587 | 
            +
                    )
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                def forward(self, x):
         | 
| 590 | 
            +
                    h_ = x
         | 
| 591 | 
            +
                    h_ = self.norm(h_)
         | 
| 592 | 
            +
                    q = self.q(h_)
         | 
| 593 | 
            +
                    k = self.k(h_)
         | 
| 594 | 
            +
                    v = self.v(h_)
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                    # compute attention
         | 
| 597 | 
            +
                    b, c, h, w = q.shape
         | 
| 598 | 
            +
                    q = rearrange(q, "b c h w -> b (h w) c")
         | 
| 599 | 
            +
                    k = rearrange(k, "b c h w -> b c (h w)")
         | 
| 600 | 
            +
                    w_ = torch.einsum("bij,bjk->bik", q, k)
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    w_ = w_ * (int(c) ** (-0.5))
         | 
| 603 | 
            +
                    w_ = torch.nn.functional.softmax(w_, dim=2)
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    # attend to values
         | 
| 606 | 
            +
                    v = rearrange(v, "b c h w -> b c (h w)")
         | 
| 607 | 
            +
                    w_ = rearrange(w_, "b i j -> b j i")
         | 
| 608 | 
            +
                    h_ = torch.einsum("bij,bjk->bik", v, w_)
         | 
| 609 | 
            +
                    h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
         | 
| 610 | 
            +
                    h_ = self.proj_out(h_)
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                    return x + h_
         | 
    	
        lvdm/modules/encoders/__pycache__/condition.cpython-312.pyc
    ADDED
    
    | Binary file (23.3 kB). View file | 
|  | 
    	
        lvdm/modules/encoders/__pycache__/ip_resampler.cpython-312.pyc
    ADDED
    
    | Binary file (7.16 kB). View file | 
|  | 
    	
        lvdm/modules/encoders/condition.py
    ADDED
    
    | @@ -0,0 +1,512 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            from torch.utils.checkpoint import checkpoint
         | 
| 4 | 
            +
            import kornia
         | 
| 5 | 
            +
            import open_clip
         | 
| 6 | 
            +
            from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
         | 
| 7 | 
            +
            from lvdm.common import autocast
         | 
| 8 | 
            +
            from utils.utils import count_params
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class AbstractEncoder(nn.Module):
         | 
| 12 | 
            +
                def __init__(self):
         | 
| 13 | 
            +
                    super().__init__()
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def encode(self, *args, **kwargs):
         | 
| 16 | 
            +
                    raise NotImplementedError
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class IdentityEncoder(AbstractEncoder):
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def encode(self, x):
         | 
| 22 | 
            +
                    return x
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class ClassEmbedder(nn.Module):
         | 
| 26 | 
            +
                def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
         | 
| 27 | 
            +
                    super().__init__()
         | 
| 28 | 
            +
                    self.key = key
         | 
| 29 | 
            +
                    self.embedding = nn.Embedding(n_classes, embed_dim)
         | 
| 30 | 
            +
                    self.n_classes = n_classes
         | 
| 31 | 
            +
                    self.ucg_rate = ucg_rate
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def forward(self, batch, key=None, disable_dropout=False):
         | 
| 34 | 
            +
                    if key is None:
         | 
| 35 | 
            +
                        key = self.key
         | 
| 36 | 
            +
                    # this is for use in crossattn
         | 
| 37 | 
            +
                    c = batch[key][:, None]
         | 
| 38 | 
            +
                    if self.ucg_rate > 0.0 and not disable_dropout:
         | 
| 39 | 
            +
                        mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
         | 
| 40 | 
            +
                        c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
         | 
| 41 | 
            +
                        c = c.long()
         | 
| 42 | 
            +
                    c = self.embedding(c)
         | 
| 43 | 
            +
                    return c
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def get_unconditional_conditioning(self, bs, device="cuda"):
         | 
| 46 | 
            +
                    uc_class = (
         | 
| 47 | 
            +
                        self.n_classes - 1
         | 
| 48 | 
            +
                    )  # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
         | 
| 49 | 
            +
                    uc = torch.ones((bs,), device=device) * uc_class
         | 
| 50 | 
            +
                    uc = {self.key: uc}
         | 
| 51 | 
            +
                    return uc
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def disabled_train(self, mode=True):
         | 
| 55 | 
            +
                """Overwrite model.train with this function to make sure train/eval mode
         | 
| 56 | 
            +
                does not change anymore."""
         | 
| 57 | 
            +
                return self
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            class FrozenT5Embedder(AbstractEncoder):
         | 
| 61 | 
            +
                """Uses the T5 transformer encoder for text"""
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def __init__(
         | 
| 64 | 
            +
                    self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
         | 
| 65 | 
            +
                ):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
         | 
| 66 | 
            +
                    super().__init__()
         | 
| 67 | 
            +
                    self.tokenizer = T5Tokenizer.from_pretrained(version)
         | 
| 68 | 
            +
                    self.transformer = T5EncoderModel.from_pretrained(version)
         | 
| 69 | 
            +
                    self.device = device
         | 
| 70 | 
            +
                    self.max_length = max_length  # TODO: typical value?
         | 
| 71 | 
            +
                    if freeze:
         | 
| 72 | 
            +
                        self.freeze()
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def freeze(self):
         | 
| 75 | 
            +
                    self.transformer = self.transformer.eval()
         | 
| 76 | 
            +
                    # self.train = disabled_train
         | 
| 77 | 
            +
                    for param in self.parameters():
         | 
| 78 | 
            +
                        param.requires_grad = False
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def forward(self, text):
         | 
| 81 | 
            +
                    batch_encoding = self.tokenizer(
         | 
| 82 | 
            +
                        text,
         | 
| 83 | 
            +
                        truncation=True,
         | 
| 84 | 
            +
                        max_length=self.max_length,
         | 
| 85 | 
            +
                        return_length=True,
         | 
| 86 | 
            +
                        return_overflowing_tokens=False,
         | 
| 87 | 
            +
                        padding="max_length",
         | 
| 88 | 
            +
                        return_tensors="pt",
         | 
| 89 | 
            +
                    )
         | 
| 90 | 
            +
                    tokens = batch_encoding["input_ids"].to(self.device)
         | 
| 91 | 
            +
                    outputs = self.transformer(input_ids=tokens)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    z = outputs.last_hidden_state
         | 
| 94 | 
            +
                    return z
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def encode(self, text):
         | 
| 97 | 
            +
                    return self(text)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            class FrozenCLIPEmbedder(AbstractEncoder):
         | 
| 101 | 
            +
                """Uses the CLIP transformer encoder for text (from huggingface)"""
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                LAYERS = ["last", "pooled", "hidden"]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def __init__(
         | 
| 106 | 
            +
                    self,
         | 
| 107 | 
            +
                    version="openai/clip-vit-large-patch14",
         | 
| 108 | 
            +
                    device="cuda",
         | 
| 109 | 
            +
                    max_length=77,
         | 
| 110 | 
            +
                    freeze=True,
         | 
| 111 | 
            +
                    layer="last",
         | 
| 112 | 
            +
                    layer_idx=None,
         | 
| 113 | 
            +
                ):  # clip-vit-base-patch32
         | 
| 114 | 
            +
                    super().__init__()
         | 
| 115 | 
            +
                    assert layer in self.LAYERS
         | 
| 116 | 
            +
                    self.tokenizer = CLIPTokenizer.from_pretrained(version)
         | 
| 117 | 
            +
                    self.transformer = CLIPTextModel.from_pretrained(version)
         | 
| 118 | 
            +
                    self.device = device
         | 
| 119 | 
            +
                    self.max_length = max_length
         | 
| 120 | 
            +
                    if freeze:
         | 
| 121 | 
            +
                        self.freeze()
         | 
| 122 | 
            +
                    self.layer = layer
         | 
| 123 | 
            +
                    self.layer_idx = layer_idx
         | 
| 124 | 
            +
                    if layer == "hidden":
         | 
| 125 | 
            +
                        assert layer_idx is not None
         | 
| 126 | 
            +
                        assert 0 <= abs(layer_idx) <= 12
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def freeze(self):
         | 
| 129 | 
            +
                    self.transformer = self.transformer.eval()
         | 
| 130 | 
            +
                    # self.train = disabled_train
         | 
| 131 | 
            +
                    for param in self.parameters():
         | 
| 132 | 
            +
                        param.requires_grad = False
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def forward(self, text):
         | 
| 135 | 
            +
                    batch_encoding = self.tokenizer(
         | 
| 136 | 
            +
                        text,
         | 
| 137 | 
            +
                        truncation=True,
         | 
| 138 | 
            +
                        max_length=self.max_length,
         | 
| 139 | 
            +
                        return_length=True,
         | 
| 140 | 
            +
                        return_overflowing_tokens=False,
         | 
| 141 | 
            +
                        padding="max_length",
         | 
| 142 | 
            +
                        return_tensors="pt",
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
                    tokens = batch_encoding["input_ids"].to(self.device)
         | 
| 145 | 
            +
                    outputs = self.transformer(
         | 
| 146 | 
            +
                        input_ids=tokens, output_hidden_states=self.layer == "hidden"
         | 
| 147 | 
            +
                    )
         | 
| 148 | 
            +
                    if self.layer == "last":
         | 
| 149 | 
            +
                        z = outputs.last_hidden_state
         | 
| 150 | 
            +
                    elif self.layer == "pooled":
         | 
| 151 | 
            +
                        z = outputs.pooler_output[:, None, :]
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        z = outputs.hidden_states[self.layer_idx]
         | 
| 154 | 
            +
                    return z
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                def encode(self, text):
         | 
| 157 | 
            +
                    return self(text)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            class ClipImageEmbedder(nn.Module):
         | 
| 161 | 
            +
                def __init__(
         | 
| 162 | 
            +
                    self,
         | 
| 163 | 
            +
                    model,
         | 
| 164 | 
            +
                    jit=False,
         | 
| 165 | 
            +
                    device="cuda" if torch.cuda.is_available() else "cpu",
         | 
| 166 | 
            +
                    antialias=True,
         | 
| 167 | 
            +
                    ucg_rate=0.0,
         | 
| 168 | 
            +
                ):
         | 
| 169 | 
            +
                    super().__init__()
         | 
| 170 | 
            +
                    from clip import load as load_clip
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    self.model, _ = load_clip(name=model, device=device, jit=jit)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    self.antialias = antialias
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    self.register_buffer(
         | 
| 177 | 
            +
                        "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
         | 
| 178 | 
            +
                    )
         | 
| 179 | 
            +
                    self.register_buffer(
         | 
| 180 | 
            +
                        "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
         | 
| 181 | 
            +
                    )
         | 
| 182 | 
            +
                    self.ucg_rate = ucg_rate
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def preprocess(self, x):
         | 
| 185 | 
            +
                    # normalize to [0,1]
         | 
| 186 | 
            +
                    x = kornia.geometry.resize(
         | 
| 187 | 
            +
                        x,
         | 
| 188 | 
            +
                        (224, 224),
         | 
| 189 | 
            +
                        interpolation="bicubic",
         | 
| 190 | 
            +
                        align_corners=True,
         | 
| 191 | 
            +
                        antialias=self.antialias,
         | 
| 192 | 
            +
                    )
         | 
| 193 | 
            +
                    x = (x + 1.0) / 2.0
         | 
| 194 | 
            +
                    # re-normalize according to clip
         | 
| 195 | 
            +
                    x = kornia.enhance.normalize(x, self.mean, self.std)
         | 
| 196 | 
            +
                    return x
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def forward(self, x, no_dropout=False):
         | 
| 199 | 
            +
                    # x is assumed to be in range [-1,1]
         | 
| 200 | 
            +
                    out = self.model.encode_image(self.preprocess(x))
         | 
| 201 | 
            +
                    out = out.to(x.dtype)
         | 
| 202 | 
            +
                    if self.ucg_rate > 0.0 and not no_dropout:
         | 
| 203 | 
            +
                        out = (
         | 
| 204 | 
            +
                            torch.bernoulli(
         | 
| 205 | 
            +
                                (1.0 - self.ucg_rate) * torch.ones(out.shape[0], device=out.device)
         | 
| 206 | 
            +
                            )[:, None]
         | 
| 207 | 
            +
                            * out
         | 
| 208 | 
            +
                        )
         | 
| 209 | 
            +
                    return out
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            class FrozenOpenCLIPEmbedder(AbstractEncoder):
         | 
| 213 | 
            +
                """
         | 
| 214 | 
            +
                Uses the OpenCLIP transformer encoder for text
         | 
| 215 | 
            +
                """
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                LAYERS = [
         | 
| 218 | 
            +
                    # "pooled",
         | 
| 219 | 
            +
                    "last",
         | 
| 220 | 
            +
                    "penultimate",
         | 
| 221 | 
            +
                ]
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                def __init__(
         | 
| 224 | 
            +
                    self,
         | 
| 225 | 
            +
                    arch="ViT-H-14",
         | 
| 226 | 
            +
                    version="laion2b_s32b_b79k",
         | 
| 227 | 
            +
                    device="cuda",
         | 
| 228 | 
            +
                    max_length=77,
         | 
| 229 | 
            +
                    freeze=True,
         | 
| 230 | 
            +
                    layer="last",
         | 
| 231 | 
            +
                ):
         | 
| 232 | 
            +
                    super().__init__()
         | 
| 233 | 
            +
                    assert layer in self.LAYERS
         | 
| 234 | 
            +
                    model, _, _ = open_clip.create_model_and_transforms(
         | 
| 235 | 
            +
                        arch, device=torch.device("cpu")
         | 
| 236 | 
            +
                    )
         | 
| 237 | 
            +
                    del model.visual
         | 
| 238 | 
            +
                    self.model = model
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    self.device = device
         | 
| 241 | 
            +
                    self.max_length = max_length
         | 
| 242 | 
            +
                    if freeze:
         | 
| 243 | 
            +
                        self.freeze()
         | 
| 244 | 
            +
                    self.layer = layer
         | 
| 245 | 
            +
                    if self.layer == "last":
         | 
| 246 | 
            +
                        self.layer_idx = 0
         | 
| 247 | 
            +
                    elif self.layer == "penultimate":
         | 
| 248 | 
            +
                        self.layer_idx = 1
         | 
| 249 | 
            +
                    else:
         | 
| 250 | 
            +
                        raise NotImplementedError()
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def freeze(self):
         | 
| 253 | 
            +
                    self.model = self.model.eval()
         | 
| 254 | 
            +
                    for param in self.parameters():
         | 
| 255 | 
            +
                        param.requires_grad = False
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def forward(self, text):
         | 
| 258 | 
            +
                    self.device = self.model.positional_embedding.device
         | 
| 259 | 
            +
                    tokens = open_clip.tokenize(text)
         | 
| 260 | 
            +
                    z = self.encode_with_transformer(tokens.to(self.device))
         | 
| 261 | 
            +
                    return z
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                def encode_with_transformer(self, text):
         | 
| 264 | 
            +
                    x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]
         | 
| 265 | 
            +
                    x = x + self.model.positional_embedding
         | 
| 266 | 
            +
                    x = x.permute(1, 0, 2)  # NLD -> LND
         | 
| 267 | 
            +
                    x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
         | 
| 268 | 
            +
                    x = x.permute(1, 0, 2)  # LND -> NLD
         | 
| 269 | 
            +
                    x = self.model.ln_final(x)
         | 
| 270 | 
            +
                    return x
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
         | 
| 273 | 
            +
                    for i, r in enumerate(self.model.transformer.resblocks):
         | 
| 274 | 
            +
                        if i == len(self.model.transformer.resblocks) - self.layer_idx:
         | 
| 275 | 
            +
                            break
         | 
| 276 | 
            +
                        if (
         | 
| 277 | 
            +
                            self.model.transformer.grad_checkpointing
         | 
| 278 | 
            +
                            and not torch.jit.is_scripting()
         | 
| 279 | 
            +
                        ):
         | 
| 280 | 
            +
                            x = checkpoint(r, x, attn_mask)
         | 
| 281 | 
            +
                        else:
         | 
| 282 | 
            +
                            x = r(x, attn_mask=attn_mask)
         | 
| 283 | 
            +
                    return x
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                def encode(self, text):
         | 
| 286 | 
            +
                    return self(text)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
             | 
| 289 | 
            +
            class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
         | 
| 290 | 
            +
                """
         | 
| 291 | 
            +
                Uses the OpenCLIP vision transformer encoder for images
         | 
| 292 | 
            +
                """
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                def __init__(
         | 
| 295 | 
            +
                    self,
         | 
| 296 | 
            +
                    arch="ViT-H-14",
         | 
| 297 | 
            +
                    version="laion2b_s32b_b79k",
         | 
| 298 | 
            +
                    device="cuda",
         | 
| 299 | 
            +
                    max_length=77,
         | 
| 300 | 
            +
                    freeze=True,
         | 
| 301 | 
            +
                    layer="pooled",
         | 
| 302 | 
            +
                    antialias=True,
         | 
| 303 | 
            +
                    ucg_rate=0.0,
         | 
| 304 | 
            +
                ):
         | 
| 305 | 
            +
                    super().__init__()
         | 
| 306 | 
            +
                    model, _, _ = open_clip.create_model_and_transforms(
         | 
| 307 | 
            +
                        arch,
         | 
| 308 | 
            +
                        device=torch.device("cpu"),
         | 
| 309 | 
            +
                        pretrained=version,
         | 
| 310 | 
            +
                    )
         | 
| 311 | 
            +
                    del model.transformer
         | 
| 312 | 
            +
                    self.model = model
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    self.device = device
         | 
| 315 | 
            +
                    self.max_length = max_length
         | 
| 316 | 
            +
                    if freeze:
         | 
| 317 | 
            +
                        self.freeze()
         | 
| 318 | 
            +
                    self.layer = layer
         | 
| 319 | 
            +
                    if self.layer == "penultimate":
         | 
| 320 | 
            +
                        raise NotImplementedError()
         | 
| 321 | 
            +
                        self.layer_idx = 1
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    self.antialias = antialias
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    self.register_buffer(
         | 
| 326 | 
            +
                        "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
         | 
| 327 | 
            +
                    )
         | 
| 328 | 
            +
                    self.register_buffer(
         | 
| 329 | 
            +
                        "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
         | 
| 330 | 
            +
                    )
         | 
| 331 | 
            +
                    self.ucg_rate = ucg_rate
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def preprocess(self, x):
         | 
| 334 | 
            +
                    # normalize to [0,1]
         | 
| 335 | 
            +
                    x = kornia.geometry.resize(
         | 
| 336 | 
            +
                        x,
         | 
| 337 | 
            +
                        (224, 224),
         | 
| 338 | 
            +
                        interpolation="bicubic",
         | 
| 339 | 
            +
                        align_corners=True,
         | 
| 340 | 
            +
                        antialias=self.antialias,
         | 
| 341 | 
            +
                    )
         | 
| 342 | 
            +
                    x = (x + 1.0) / 2.0
         | 
| 343 | 
            +
                    # renormalize according to clip
         | 
| 344 | 
            +
                    x = kornia.enhance.normalize(x, self.mean, self.std)
         | 
| 345 | 
            +
                    return x
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                def freeze(self):
         | 
| 348 | 
            +
                    self.model = self.model.eval()
         | 
| 349 | 
            +
                    for param in self.parameters():
         | 
| 350 | 
            +
                        param.requires_grad = False
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                @autocast
         | 
| 353 | 
            +
                def forward(self, image, no_dropout=False):
         | 
| 354 | 
            +
                    z = self.encode_with_vision_transformer(image)
         | 
| 355 | 
            +
                    if self.ucg_rate > 0.0 and not no_dropout:
         | 
| 356 | 
            +
                        z = (
         | 
| 357 | 
            +
                            torch.bernoulli(
         | 
| 358 | 
            +
                                (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
         | 
| 359 | 
            +
                            )[:, None]
         | 
| 360 | 
            +
                            * z
         | 
| 361 | 
            +
                        )
         | 
| 362 | 
            +
                    return z
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                def encode_with_vision_transformer(self, img):
         | 
| 365 | 
            +
                    img = self.preprocess(img)
         | 
| 366 | 
            +
                    x = self.model.visual(img)
         | 
| 367 | 
            +
                    return x
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                def encode(self, text):
         | 
| 370 | 
            +
                    return self(text)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
             | 
| 373 | 
            +
            class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
         | 
| 374 | 
            +
                """
         | 
| 375 | 
            +
                Uses the OpenCLIP vision transformer encoder for images
         | 
| 376 | 
            +
                """
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                def __init__(
         | 
| 379 | 
            +
                    self,
         | 
| 380 | 
            +
                    arch="ViT-H-14",
         | 
| 381 | 
            +
                    version="laion2b_s32b_b79k",
         | 
| 382 | 
            +
                    device="cuda",
         | 
| 383 | 
            +
                    freeze=True,
         | 
| 384 | 
            +
                    layer="pooled",
         | 
| 385 | 
            +
                    antialias=True,
         | 
| 386 | 
            +
                ):
         | 
| 387 | 
            +
                    super().__init__()
         | 
| 388 | 
            +
                    model, _, _ = open_clip.create_model_and_transforms(
         | 
| 389 | 
            +
                        arch,
         | 
| 390 | 
            +
                        device=torch.device("cpu"),
         | 
| 391 | 
            +
                        pretrained=version,
         | 
| 392 | 
            +
                    )
         | 
| 393 | 
            +
                    del model.transformer
         | 
| 394 | 
            +
                    self.model = model
         | 
| 395 | 
            +
                    self.device = device
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    if freeze:
         | 
| 398 | 
            +
                        self.freeze()
         | 
| 399 | 
            +
                    self.layer = layer
         | 
| 400 | 
            +
                    if self.layer == "penultimate":
         | 
| 401 | 
            +
                        raise NotImplementedError()
         | 
| 402 | 
            +
                        self.layer_idx = 1
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    self.antialias = antialias
         | 
| 405 | 
            +
                    self.register_buffer(
         | 
| 406 | 
            +
                        "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
         | 
| 407 | 
            +
                    )
         | 
| 408 | 
            +
                    self.register_buffer(
         | 
| 409 | 
            +
                        "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
         | 
| 410 | 
            +
                    )
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                def preprocess(self, x):
         | 
| 413 | 
            +
                    # normalize to [0,1]
         | 
| 414 | 
            +
                    x = kornia.geometry.resize(
         | 
| 415 | 
            +
                        x,
         | 
| 416 | 
            +
                        (224, 224),
         | 
| 417 | 
            +
                        interpolation="bicubic",
         | 
| 418 | 
            +
                        align_corners=True,
         | 
| 419 | 
            +
                        antialias=self.antialias,
         | 
| 420 | 
            +
                    )
         | 
| 421 | 
            +
                    x = (x + 1.0) / 2.0
         | 
| 422 | 
            +
                    # renormalize according to clip
         | 
| 423 | 
            +
                    x = kornia.enhance.normalize(x, self.mean, self.std)
         | 
| 424 | 
            +
                    return x
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                def freeze(self):
         | 
| 427 | 
            +
                    self.model = self.model.eval()
         | 
| 428 | 
            +
                    for param in self.model.parameters():
         | 
| 429 | 
            +
                        param.requires_grad = False
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                def forward(self, image, no_dropout=False):
         | 
| 432 | 
            +
                    ## image: b c h w
         | 
| 433 | 
            +
                    z = self.encode_with_vision_transformer(image)
         | 
| 434 | 
            +
                    return z
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                def encode_with_vision_transformer(self, x):
         | 
| 437 | 
            +
                    x = self.preprocess(x)
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
         | 
| 440 | 
            +
                    if self.model.visual.input_patchnorm:
         | 
| 441 | 
            +
                        # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
         | 
| 442 | 
            +
                        x = x.reshape(
         | 
| 443 | 
            +
                            x.shape[0],
         | 
| 444 | 
            +
                            x.shape[1],
         | 
| 445 | 
            +
                            self.model.visual.grid_size[0],
         | 
| 446 | 
            +
                            self.model.visual.patch_size[0],
         | 
| 447 | 
            +
                            self.model.visual.grid_size[1],
         | 
| 448 | 
            +
                            self.model.visual.patch_size[1],
         | 
| 449 | 
            +
                        )
         | 
| 450 | 
            +
                        x = x.permute(0, 2, 4, 1, 3, 5)
         | 
| 451 | 
            +
                        x = x.reshape(
         | 
| 452 | 
            +
                            x.shape[0],
         | 
| 453 | 
            +
                            self.model.visual.grid_size[0] * self.model.visual.grid_size[1],
         | 
| 454 | 
            +
                            -1,
         | 
| 455 | 
            +
                        )
         | 
| 456 | 
            +
                        x = self.model.visual.patchnorm_pre_ln(x)
         | 
| 457 | 
            +
                        x = self.model.visual.conv1(x)
         | 
| 458 | 
            +
                    else:
         | 
| 459 | 
            +
                        x = self.model.visual.conv1(x)  # shape = [*, width, grid, grid]
         | 
| 460 | 
            +
                        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
         | 
| 461 | 
            +
                        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    # class embeddings and positional embeddings
         | 
| 464 | 
            +
                    x = torch.cat(
         | 
| 465 | 
            +
                        [
         | 
| 466 | 
            +
                            self.model.visual.class_embedding.to(x.dtype)
         | 
| 467 | 
            +
                            + torch.zeros(
         | 
| 468 | 
            +
                                x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
         | 
| 469 | 
            +
                            ),
         | 
| 470 | 
            +
                            x,
         | 
| 471 | 
            +
                        ],
         | 
| 472 | 
            +
                        dim=1,
         | 
| 473 | 
            +
                    )  # shape = [*, grid ** 2 + 1, width]
         | 
| 474 | 
            +
                    x = x + self.model.visual.positional_embedding.to(x.dtype)
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
         | 
| 477 | 
            +
                    x = self.model.visual.patch_dropout(x)
         | 
| 478 | 
            +
                    x = self.model.visual.ln_pre(x)
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    x = x.permute(1, 0, 2)  # NLD -> LND
         | 
| 481 | 
            +
                    x = self.model.visual.transformer(x)
         | 
| 482 | 
            +
                    x = x.permute(1, 0, 2)  # LND -> NLD
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    return x
         | 
| 485 | 
            +
             | 
| 486 | 
            +
             | 
| 487 | 
            +
            class FrozenCLIPT5Encoder(AbstractEncoder):
         | 
| 488 | 
            +
                def __init__(
         | 
| 489 | 
            +
                    self,
         | 
| 490 | 
            +
                    clip_version="openai/clip-vit-large-patch14",
         | 
| 491 | 
            +
                    t5_version="google/t5-v1_1-xl",
         | 
| 492 | 
            +
                    device="cuda",
         | 
| 493 | 
            +
                    clip_max_length=77,
         | 
| 494 | 
            +
                    t5_max_length=77,
         | 
| 495 | 
            +
                ):
         | 
| 496 | 
            +
                    super().__init__()
         | 
| 497 | 
            +
                    self.clip_encoder = FrozenCLIPEmbedder(
         | 
| 498 | 
            +
                        clip_version, device, max_length=clip_max_length
         | 
| 499 | 
            +
                    )
         | 
| 500 | 
            +
                    self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
         | 
| 501 | 
            +
                    print(
         | 
| 502 | 
            +
                        f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
         | 
| 503 | 
            +
                        f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
         | 
| 504 | 
            +
                    )
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                def encode(self, text):
         | 
| 507 | 
            +
                    return self(text)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                def forward(self, text):
         | 
| 510 | 
            +
                    clip_z = self.clip_encoder.encode(text)
         | 
| 511 | 
            +
                    t5_z = self.t5_encoder.encode(text)
         | 
| 512 | 
            +
                    return [clip_z, t5_z]
         | 
    	
        lvdm/modules/encoders/ip_resampler.py
    ADDED
    
    | @@ -0,0 +1,148 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class ImageProjModel(nn.Module):
         | 
| 8 | 
            +
                """Projection Model"""
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                def __init__(
         | 
| 11 | 
            +
                    self,
         | 
| 12 | 
            +
                    cross_attention_dim=1024,
         | 
| 13 | 
            +
                    clip_embeddings_dim=1024,
         | 
| 14 | 
            +
                    clip_extra_context_tokens=4,
         | 
| 15 | 
            +
                ):
         | 
| 16 | 
            +
                    super().__init__()
         | 
| 17 | 
            +
                    self.cross_attention_dim = cross_attention_dim
         | 
| 18 | 
            +
                    self.clip_extra_context_tokens = clip_extra_context_tokens
         | 
| 19 | 
            +
                    self.proj = nn.Linear(
         | 
| 20 | 
            +
                        clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
         | 
| 21 | 
            +
                    )
         | 
| 22 | 
            +
                    self.norm = nn.LayerNorm(cross_attention_dim)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def forward(self, image_embeds):
         | 
| 25 | 
            +
                    # embeds = image_embeds
         | 
| 26 | 
            +
                    embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
         | 
| 27 | 
            +
                    clip_extra_context_tokens = self.proj(embeds).reshape(
         | 
| 28 | 
            +
                        -1, self.clip_extra_context_tokens, self.cross_attention_dim
         | 
| 29 | 
            +
                    )
         | 
| 30 | 
            +
                    clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
         | 
| 31 | 
            +
                    return clip_extra_context_tokens
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            # FFN
         | 
| 35 | 
            +
            def FeedForward(dim, mult=4):
         | 
| 36 | 
            +
                inner_dim = int(dim * mult)
         | 
| 37 | 
            +
                return nn.Sequential(
         | 
| 38 | 
            +
                    nn.LayerNorm(dim),
         | 
| 39 | 
            +
                    nn.Linear(dim, inner_dim, bias=False),
         | 
| 40 | 
            +
                    nn.GELU(),
         | 
| 41 | 
            +
                    nn.Linear(inner_dim, dim, bias=False),
         | 
| 42 | 
            +
                )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def reshape_tensor(x, heads):
         | 
| 46 | 
            +
                bs, length, width = x.shape
         | 
| 47 | 
            +
                # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
         | 
| 48 | 
            +
                x = x.view(bs, length, heads, -1)
         | 
| 49 | 
            +
                # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
         | 
| 50 | 
            +
                x = x.transpose(1, 2)
         | 
| 51 | 
            +
                # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
         | 
| 52 | 
            +
                x = x.reshape(bs, heads, length, -1)
         | 
| 53 | 
            +
                return x
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            class PerceiverAttention(nn.Module):
         | 
| 57 | 
            +
                def __init__(self, *, dim, dim_head=64, heads=8):
         | 
| 58 | 
            +
                    super().__init__()
         | 
| 59 | 
            +
                    self.scale = dim_head**-0.5
         | 
| 60 | 
            +
                    self.dim_head = dim_head
         | 
| 61 | 
            +
                    self.heads = heads
         | 
| 62 | 
            +
                    inner_dim = dim_head * heads
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    self.norm1 = nn.LayerNorm(dim)
         | 
| 65 | 
            +
                    self.norm2 = nn.LayerNorm(dim)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    self.to_q = nn.Linear(dim, inner_dim, bias=False)
         | 
| 68 | 
            +
                    self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
         | 
| 69 | 
            +
                    self.to_out = nn.Linear(inner_dim, dim, bias=False)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def forward(self, x, latents):
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
                    Args:
         | 
| 74 | 
            +
                        x (torch.Tensor): image features
         | 
| 75 | 
            +
                            shape (b, n1, D)
         | 
| 76 | 
            +
                        latent (torch.Tensor): latent features
         | 
| 77 | 
            +
                            shape (b, n2, D)
         | 
| 78 | 
            +
                    """
         | 
| 79 | 
            +
                    x = self.norm1(x)
         | 
| 80 | 
            +
                    latents = self.norm2(latents)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    b, l, _ = latents.shape
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    q = self.to_q(latents)
         | 
| 85 | 
            +
                    kv_input = torch.cat((x, latents), dim=-2)
         | 
| 86 | 
            +
                    k, v = self.to_kv(kv_input).chunk(2, dim=-1)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    q = reshape_tensor(q, self.heads)
         | 
| 89 | 
            +
                    k = reshape_tensor(k, self.heads)
         | 
| 90 | 
            +
                    v = reshape_tensor(v, self.heads)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    # attention
         | 
| 93 | 
            +
                    scale = 1 / math.sqrt(math.sqrt(self.dim_head))
         | 
| 94 | 
            +
                    weight = (q * scale) @ (k * scale).transpose(
         | 
| 95 | 
            +
                        -2, -1
         | 
| 96 | 
            +
                    )  # More stable with f16 than dividing afterwards
         | 
| 97 | 
            +
                    weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
         | 
| 98 | 
            +
                    out = weight @ v
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    return self.to_out(out)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            class Resampler(nn.Module):
         | 
| 106 | 
            +
                def __init__(
         | 
| 107 | 
            +
                    self,
         | 
| 108 | 
            +
                    dim=1024,
         | 
| 109 | 
            +
                    depth=8,
         | 
| 110 | 
            +
                    dim_head=64,
         | 
| 111 | 
            +
                    heads=16,
         | 
| 112 | 
            +
                    num_queries=8,
         | 
| 113 | 
            +
                    embedding_dim=768,
         | 
| 114 | 
            +
                    output_dim=1024,
         | 
| 115 | 
            +
                    ff_mult=4,
         | 
| 116 | 
            +
                ):
         | 
| 117 | 
            +
                    super().__init__()
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    self.proj_in = nn.Linear(embedding_dim, dim)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    self.proj_out = nn.Linear(dim, output_dim)
         | 
| 124 | 
            +
                    self.norm_out = nn.LayerNorm(output_dim)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    self.layers = nn.ModuleList([])
         | 
| 127 | 
            +
                    for _ in range(depth):
         | 
| 128 | 
            +
                        self.layers.append(
         | 
| 129 | 
            +
                            nn.ModuleList(
         | 
| 130 | 
            +
                                [
         | 
| 131 | 
            +
                                    PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
         | 
| 132 | 
            +
                                    FeedForward(dim=dim, mult=ff_mult),
         | 
| 133 | 
            +
                                ]
         | 
| 134 | 
            +
                            )
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def forward(self, x):
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    latents = self.latents.repeat(x.size(0), 1, 1)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    x = self.proj_in(x)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    for attn, ff in self.layers:
         | 
| 144 | 
            +
                        latents = attn(x, latents) + latents
         | 
| 145 | 
            +
                        latents = ff(latents) + latents
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    latents = self.proj_out(latents)
         | 
| 148 | 
            +
                    return self.norm_out(latents)
         | 
    	
        lvdm/modules/networks/__pycache__/ae_modules.cpython-312.pyc
    ADDED
    
    | Binary file (39.4 kB). View file | 
|  | 
    	
        lvdm/modules/networks/__pycache__/openaimodel3d.cpython-312.pyc
    ADDED
    
    | Binary file (24.9 kB). View file | 
|  | 
    	
        lvdm/modules/networks/ae_modules.py
    ADDED
    
    | @@ -0,0 +1,1025 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # pytorch_diffusion + derived encoder decoder
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from utils.utils import instantiate_from_config
         | 
| 8 | 
            +
            from lvdm.modules.attention import LinearAttention
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def nonlinearity(x):
         | 
| 12 | 
            +
                # swish
         | 
| 13 | 
            +
                return x * torch.sigmoid(x)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def Normalize(in_channels, num_groups=32):
         | 
| 17 | 
            +
                return torch.nn.GroupNorm(
         | 
| 18 | 
            +
                    num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
         | 
| 19 | 
            +
                )
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class LinAttnBlock(LinearAttention):
         | 
| 23 | 
            +
                """to match AttnBlock usage"""
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def __init__(self, in_channels):
         | 
| 26 | 
            +
                    super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class AttnBlock(nn.Module):
         | 
| 30 | 
            +
                def __init__(self, in_channels):
         | 
| 31 | 
            +
                    super().__init__()
         | 
| 32 | 
            +
                    self.in_channels = in_channels
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 35 | 
            +
                    self.q = torch.nn.Conv2d(
         | 
| 36 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
                    self.k = torch.nn.Conv2d(
         | 
| 39 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
                    self.v = torch.nn.Conv2d(
         | 
| 42 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    self.proj_out = torch.nn.Conv2d(
         | 
| 45 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 46 | 
            +
                    )
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def forward(self, x):
         | 
| 49 | 
            +
                    h_ = x
         | 
| 50 | 
            +
                    h_ = self.norm(h_)
         | 
| 51 | 
            +
                    q = self.q(h_)
         | 
| 52 | 
            +
                    k = self.k(h_)
         | 
| 53 | 
            +
                    v = self.v(h_)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # compute attention
         | 
| 56 | 
            +
                    b, c, h, w = q.shape
         | 
| 57 | 
            +
                    q = q.reshape(b, c, h * w)  # bcl
         | 
| 58 | 
            +
                    q = q.permute(0, 2, 1)  # bcl -> blc l=hw
         | 
| 59 | 
            +
                    k = k.reshape(b, c, h * w)  # bcl
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
         | 
| 62 | 
            +
                    w_ = w_ * (int(c) ** (-0.5))
         | 
| 63 | 
            +
                    w_ = torch.nn.functional.softmax(w_, dim=2)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # attend to values
         | 
| 66 | 
            +
                    v = v.reshape(b, c, h * w)
         | 
| 67 | 
            +
                    w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)
         | 
| 68 | 
            +
                    h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
         | 
| 69 | 
            +
                    h_ = h_.reshape(b, c, h, w)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    h_ = self.proj_out(h_)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    return x + h_
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def make_attn(in_channels, attn_type="vanilla"):
         | 
| 77 | 
            +
                assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
         | 
| 78 | 
            +
                # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
         | 
| 79 | 
            +
                if attn_type == "vanilla":
         | 
| 80 | 
            +
                    return AttnBlock(in_channels)
         | 
| 81 | 
            +
                elif attn_type == "none":
         | 
| 82 | 
            +
                    return nn.Identity(in_channels)
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    return LinAttnBlock(in_channels)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            class Downsample(nn.Module):
         | 
| 88 | 
            +
                def __init__(self, in_channels, with_conv):
         | 
| 89 | 
            +
                    super().__init__()
         | 
| 90 | 
            +
                    self.with_conv = with_conv
         | 
| 91 | 
            +
                    self.in_channels = in_channels
         | 
| 92 | 
            +
                    if self.with_conv:
         | 
| 93 | 
            +
                        # no asymmetric padding in torch conv, must do it ourselves
         | 
| 94 | 
            +
                        self.conv = torch.nn.Conv2d(
         | 
| 95 | 
            +
                            in_channels, in_channels, kernel_size=3, stride=2, padding=0
         | 
| 96 | 
            +
                        )
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def forward(self, x):
         | 
| 99 | 
            +
                    if self.with_conv:
         | 
| 100 | 
            +
                        pad = (0, 1, 0, 1)
         | 
| 101 | 
            +
                        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
         | 
| 102 | 
            +
                        x = self.conv(x)
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
         | 
| 105 | 
            +
                    return x
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            class Upsample(nn.Module):
         | 
| 109 | 
            +
                def __init__(self, in_channels, with_conv):
         | 
| 110 | 
            +
                    super().__init__()
         | 
| 111 | 
            +
                    self.with_conv = with_conv
         | 
| 112 | 
            +
                    self.in_channels = in_channels
         | 
| 113 | 
            +
                    if self.with_conv:
         | 
| 114 | 
            +
                        self.conv = torch.nn.Conv2d(
         | 
| 115 | 
            +
                            in_channels, in_channels, kernel_size=3, stride=1, padding=1
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def forward(self, x):
         | 
| 119 | 
            +
                    x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
         | 
| 120 | 
            +
                    if self.with_conv:
         | 
| 121 | 
            +
                        x = self.conv(x)
         | 
| 122 | 
            +
                    return x
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            def get_timestep_embedding(timesteps, embedding_dim):
         | 
| 126 | 
            +
                """
         | 
| 127 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models:
         | 
| 128 | 
            +
                From Fairseq.
         | 
| 129 | 
            +
                Build sinusoidal embeddings.
         | 
| 130 | 
            +
                This matches the implementation in tensor2tensor, but differs slightly
         | 
| 131 | 
            +
                from the description in Section 3.5 of "Attention Is All You Need".
         | 
| 132 | 
            +
                """
         | 
| 133 | 
            +
                assert len(timesteps.shape) == 1
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                half_dim = embedding_dim // 2
         | 
| 136 | 
            +
                emb = math.log(10000) / (half_dim - 1)
         | 
| 137 | 
            +
                emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
         | 
| 138 | 
            +
                emb = emb.to(device=timesteps.device)
         | 
| 139 | 
            +
                emb = timesteps.float()[:, None] * emb[None, :]
         | 
| 140 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 141 | 
            +
                if embedding_dim % 2 == 1:  # zero pad
         | 
| 142 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         | 
| 143 | 
            +
                return emb
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            class ResnetBlock(nn.Module):
         | 
| 147 | 
            +
                def __init__(
         | 
| 148 | 
            +
                    self,
         | 
| 149 | 
            +
                    *,
         | 
| 150 | 
            +
                    in_channels,
         | 
| 151 | 
            +
                    out_channels=None,
         | 
| 152 | 
            +
                    conv_shortcut=False,
         | 
| 153 | 
            +
                    dropout,
         | 
| 154 | 
            +
                    temb_channels=512,
         | 
| 155 | 
            +
                ):
         | 
| 156 | 
            +
                    super().__init__()
         | 
| 157 | 
            +
                    self.in_channels = in_channels
         | 
| 158 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 159 | 
            +
                    self.out_channels = out_channels
         | 
| 160 | 
            +
                    self.use_conv_shortcut = conv_shortcut
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    self.norm1 = Normalize(in_channels)
         | 
| 163 | 
            +
                    self.conv1 = torch.nn.Conv2d(
         | 
| 164 | 
            +
                        in_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 165 | 
            +
                    )
         | 
| 166 | 
            +
                    if temb_channels > 0:
         | 
| 167 | 
            +
                        self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
         | 
| 168 | 
            +
                    self.norm2 = Normalize(out_channels)
         | 
| 169 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 170 | 
            +
                    self.conv2 = torch.nn.Conv2d(
         | 
| 171 | 
            +
                        out_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 172 | 
            +
                    )
         | 
| 173 | 
            +
                    if self.in_channels != self.out_channels:
         | 
| 174 | 
            +
                        if self.use_conv_shortcut:
         | 
| 175 | 
            +
                            self.conv_shortcut = torch.nn.Conv2d(
         | 
| 176 | 
            +
                                in_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 177 | 
            +
                            )
         | 
| 178 | 
            +
                        else:
         | 
| 179 | 
            +
                            self.nin_shortcut = torch.nn.Conv2d(
         | 
| 180 | 
            +
                                in_channels, out_channels, kernel_size=1, stride=1, padding=0
         | 
| 181 | 
            +
                            )
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def forward(self, x, temb):
         | 
| 184 | 
            +
                    h = x
         | 
| 185 | 
            +
                    h = self.norm1(h)
         | 
| 186 | 
            +
                    h = nonlinearity(h)
         | 
| 187 | 
            +
                    h = self.conv1(h)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    if temb is not None:
         | 
| 190 | 
            +
                        h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    h = self.norm2(h)
         | 
| 193 | 
            +
                    h = nonlinearity(h)
         | 
| 194 | 
            +
                    h = self.dropout(h)
         | 
| 195 | 
            +
                    h = self.conv2(h)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    if self.in_channels != self.out_channels:
         | 
| 198 | 
            +
                        if self.use_conv_shortcut:
         | 
| 199 | 
            +
                            x = self.conv_shortcut(x)
         | 
| 200 | 
            +
                        else:
         | 
| 201 | 
            +
                            x = self.nin_shortcut(x)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    return x + h
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            class Model(nn.Module):
         | 
| 207 | 
            +
                def __init__(
         | 
| 208 | 
            +
                    self,
         | 
| 209 | 
            +
                    *,
         | 
| 210 | 
            +
                    ch,
         | 
| 211 | 
            +
                    out_ch,
         | 
| 212 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 213 | 
            +
                    num_res_blocks,
         | 
| 214 | 
            +
                    attn_resolutions,
         | 
| 215 | 
            +
                    dropout=0.0,
         | 
| 216 | 
            +
                    resamp_with_conv=True,
         | 
| 217 | 
            +
                    in_channels,
         | 
| 218 | 
            +
                    resolution,
         | 
| 219 | 
            +
                    use_timestep=True,
         | 
| 220 | 
            +
                    use_linear_attn=False,
         | 
| 221 | 
            +
                    attn_type="vanilla",
         | 
| 222 | 
            +
                ):
         | 
| 223 | 
            +
                    super().__init__()
         | 
| 224 | 
            +
                    if use_linear_attn:
         | 
| 225 | 
            +
                        attn_type = "linear"
         | 
| 226 | 
            +
                    self.ch = ch
         | 
| 227 | 
            +
                    self.temb_ch = self.ch * 4
         | 
| 228 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 229 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 230 | 
            +
                    self.resolution = resolution
         | 
| 231 | 
            +
                    self.in_channels = in_channels
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    self.use_timestep = use_timestep
         | 
| 234 | 
            +
                    if self.use_timestep:
         | 
| 235 | 
            +
                        # timestep embedding
         | 
| 236 | 
            +
                        self.temb = nn.Module()
         | 
| 237 | 
            +
                        self.temb.dense = nn.ModuleList(
         | 
| 238 | 
            +
                            [
         | 
| 239 | 
            +
                                torch.nn.Linear(self.ch, self.temb_ch),
         | 
| 240 | 
            +
                                torch.nn.Linear(self.temb_ch, self.temb_ch),
         | 
| 241 | 
            +
                            ]
         | 
| 242 | 
            +
                        )
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    # downsampling
         | 
| 245 | 
            +
                    self.conv_in = torch.nn.Conv2d(
         | 
| 246 | 
            +
                        in_channels, self.ch, kernel_size=3, stride=1, padding=1
         | 
| 247 | 
            +
                    )
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    curr_res = resolution
         | 
| 250 | 
            +
                    in_ch_mult = (1,) + tuple(ch_mult)
         | 
| 251 | 
            +
                    self.down = nn.ModuleList()
         | 
| 252 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 253 | 
            +
                        block = nn.ModuleList()
         | 
| 254 | 
            +
                        attn = nn.ModuleList()
         | 
| 255 | 
            +
                        block_in = ch * in_ch_mult[i_level]
         | 
| 256 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 257 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 258 | 
            +
                            block.append(
         | 
| 259 | 
            +
                                ResnetBlock(
         | 
| 260 | 
            +
                                    in_channels=block_in,
         | 
| 261 | 
            +
                                    out_channels=block_out,
         | 
| 262 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 263 | 
            +
                                    dropout=dropout,
         | 
| 264 | 
            +
                                )
         | 
| 265 | 
            +
                            )
         | 
| 266 | 
            +
                            block_in = block_out
         | 
| 267 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 268 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 269 | 
            +
                        down = nn.Module()
         | 
| 270 | 
            +
                        down.block = block
         | 
| 271 | 
            +
                        down.attn = attn
         | 
| 272 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 273 | 
            +
                            down.downsample = Downsample(block_in, resamp_with_conv)
         | 
| 274 | 
            +
                            curr_res = curr_res // 2
         | 
| 275 | 
            +
                        self.down.append(down)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    # middle
         | 
| 278 | 
            +
                    self.mid = nn.Module()
         | 
| 279 | 
            +
                    self.mid.block_1 = ResnetBlock(
         | 
| 280 | 
            +
                        in_channels=block_in,
         | 
| 281 | 
            +
                        out_channels=block_in,
         | 
| 282 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 283 | 
            +
                        dropout=dropout,
         | 
| 284 | 
            +
                    )
         | 
| 285 | 
            +
                    self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
         | 
| 286 | 
            +
                    self.mid.block_2 = ResnetBlock(
         | 
| 287 | 
            +
                        in_channels=block_in,
         | 
| 288 | 
            +
                        out_channels=block_in,
         | 
| 289 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 290 | 
            +
                        dropout=dropout,
         | 
| 291 | 
            +
                    )
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    # upsampling
         | 
| 294 | 
            +
                    self.up = nn.ModuleList()
         | 
| 295 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 296 | 
            +
                        block = nn.ModuleList()
         | 
| 297 | 
            +
                        attn = nn.ModuleList()
         | 
| 298 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 299 | 
            +
                        skip_in = ch * ch_mult[i_level]
         | 
| 300 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 301 | 
            +
                            if i_block == self.num_res_blocks:
         | 
| 302 | 
            +
                                skip_in = ch * in_ch_mult[i_level]
         | 
| 303 | 
            +
                            block.append(
         | 
| 304 | 
            +
                                ResnetBlock(
         | 
| 305 | 
            +
                                    in_channels=block_in + skip_in,
         | 
| 306 | 
            +
                                    out_channels=block_out,
         | 
| 307 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 308 | 
            +
                                    dropout=dropout,
         | 
| 309 | 
            +
                                )
         | 
| 310 | 
            +
                            )
         | 
| 311 | 
            +
                            block_in = block_out
         | 
| 312 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 313 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 314 | 
            +
                        up = nn.Module()
         | 
| 315 | 
            +
                        up.block = block
         | 
| 316 | 
            +
                        up.attn = attn
         | 
| 317 | 
            +
                        if i_level != 0:
         | 
| 318 | 
            +
                            up.upsample = Upsample(block_in, resamp_with_conv)
         | 
| 319 | 
            +
                            curr_res = curr_res * 2
         | 
| 320 | 
            +
                        self.up.insert(0, up)  # prepend to get consistent order
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    # end
         | 
| 323 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 324 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 325 | 
            +
                        block_in, out_ch, kernel_size=3, stride=1, padding=1
         | 
| 326 | 
            +
                    )
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                def forward(self, x, t=None, context=None):
         | 
| 329 | 
            +
                    # assert x.shape[2] == x.shape[3] == self.resolution
         | 
| 330 | 
            +
                    if context is not None:
         | 
| 331 | 
            +
                        # assume aligned context, cat along channel axis
         | 
| 332 | 
            +
                        x = torch.cat((x, context), dim=1)
         | 
| 333 | 
            +
                    if self.use_timestep:
         | 
| 334 | 
            +
                        # timestep embedding
         | 
| 335 | 
            +
                        assert t is not None
         | 
| 336 | 
            +
                        temb = get_timestep_embedding(t, self.ch)
         | 
| 337 | 
            +
                        temb = self.temb.dense[0](temb)
         | 
| 338 | 
            +
                        temb = nonlinearity(temb)
         | 
| 339 | 
            +
                        temb = self.temb.dense[1](temb)
         | 
| 340 | 
            +
                    else:
         | 
| 341 | 
            +
                        temb = None
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    # downsampling
         | 
| 344 | 
            +
                    hs = [self.conv_in(x)]
         | 
| 345 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 346 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 347 | 
            +
                            h = self.down[i_level].block[i_block](hs[-1], temb)
         | 
| 348 | 
            +
                            if len(self.down[i_level].attn) > 0:
         | 
| 349 | 
            +
                                h = self.down[i_level].attn[i_block](h)
         | 
| 350 | 
            +
                            hs.append(h)
         | 
| 351 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 352 | 
            +
                            hs.append(self.down[i_level].downsample(hs[-1]))
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    # middle
         | 
| 355 | 
            +
                    h = hs[-1]
         | 
| 356 | 
            +
                    h = self.mid.block_1(h, temb)
         | 
| 357 | 
            +
                    h = self.mid.attn_1(h)
         | 
| 358 | 
            +
                    h = self.mid.block_2(h, temb)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    # upsampling
         | 
| 361 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 362 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 363 | 
            +
                            h = self.up[i_level].block[i_block](
         | 
| 364 | 
            +
                                torch.cat([h, hs.pop()], dim=1), temb
         | 
| 365 | 
            +
                            )
         | 
| 366 | 
            +
                            if len(self.up[i_level].attn) > 0:
         | 
| 367 | 
            +
                                h = self.up[i_level].attn[i_block](h)
         | 
| 368 | 
            +
                        if i_level != 0:
         | 
| 369 | 
            +
                            h = self.up[i_level].upsample(h)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    # end
         | 
| 372 | 
            +
                    h = self.norm_out(h)
         | 
| 373 | 
            +
                    h = nonlinearity(h)
         | 
| 374 | 
            +
                    h = self.conv_out(h)
         | 
| 375 | 
            +
                    return h
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                def get_last_layer(self):
         | 
| 378 | 
            +
                    return self.conv_out.weight
         | 
| 379 | 
            +
             | 
| 380 | 
            +
             | 
| 381 | 
            +
            class Encoder(nn.Module):
         | 
| 382 | 
            +
                def __init__(
         | 
| 383 | 
            +
                    self,
         | 
| 384 | 
            +
                    *,
         | 
| 385 | 
            +
                    ch,
         | 
| 386 | 
            +
                    out_ch,
         | 
| 387 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 388 | 
            +
                    num_res_blocks,
         | 
| 389 | 
            +
                    attn_resolutions,
         | 
| 390 | 
            +
                    dropout=0.0,
         | 
| 391 | 
            +
                    resamp_with_conv=True,
         | 
| 392 | 
            +
                    in_channels,
         | 
| 393 | 
            +
                    resolution,
         | 
| 394 | 
            +
                    z_channels,
         | 
| 395 | 
            +
                    double_z=True,
         | 
| 396 | 
            +
                    use_linear_attn=False,
         | 
| 397 | 
            +
                    attn_type="vanilla",
         | 
| 398 | 
            +
                    **ignore_kwargs,
         | 
| 399 | 
            +
                ):
         | 
| 400 | 
            +
                    super().__init__()
         | 
| 401 | 
            +
                    if use_linear_attn:
         | 
| 402 | 
            +
                        attn_type = "linear"
         | 
| 403 | 
            +
                    self.ch = ch
         | 
| 404 | 
            +
                    self.temb_ch = 0
         | 
| 405 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 406 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 407 | 
            +
                    self.resolution = resolution
         | 
| 408 | 
            +
                    self.in_channels = in_channels
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    # downsampling
         | 
| 411 | 
            +
                    self.conv_in = torch.nn.Conv2d(
         | 
| 412 | 
            +
                        in_channels, self.ch, kernel_size=3, stride=1, padding=1
         | 
| 413 | 
            +
                    )
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    curr_res = resolution
         | 
| 416 | 
            +
                    in_ch_mult = (1,) + tuple(ch_mult)
         | 
| 417 | 
            +
                    self.in_ch_mult = in_ch_mult
         | 
| 418 | 
            +
                    self.down = nn.ModuleList()
         | 
| 419 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 420 | 
            +
                        block = nn.ModuleList()
         | 
| 421 | 
            +
                        attn = nn.ModuleList()
         | 
| 422 | 
            +
                        block_in = ch * in_ch_mult[i_level]
         | 
| 423 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 424 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 425 | 
            +
                            block.append(
         | 
| 426 | 
            +
                                ResnetBlock(
         | 
| 427 | 
            +
                                    in_channels=block_in,
         | 
| 428 | 
            +
                                    out_channels=block_out,
         | 
| 429 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 430 | 
            +
                                    dropout=dropout,
         | 
| 431 | 
            +
                                )
         | 
| 432 | 
            +
                            )
         | 
| 433 | 
            +
                            block_in = block_out
         | 
| 434 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 435 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 436 | 
            +
                        down = nn.Module()
         | 
| 437 | 
            +
                        down.block = block
         | 
| 438 | 
            +
                        down.attn = attn
         | 
| 439 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 440 | 
            +
                            down.downsample = Downsample(block_in, resamp_with_conv)
         | 
| 441 | 
            +
                            curr_res = curr_res // 2
         | 
| 442 | 
            +
                        self.down.append(down)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    # middle
         | 
| 445 | 
            +
                    self.mid = nn.Module()
         | 
| 446 | 
            +
                    self.mid.block_1 = ResnetBlock(
         | 
| 447 | 
            +
                        in_channels=block_in,
         | 
| 448 | 
            +
                        out_channels=block_in,
         | 
| 449 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 450 | 
            +
                        dropout=dropout,
         | 
| 451 | 
            +
                    )
         | 
| 452 | 
            +
                    self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
         | 
| 453 | 
            +
                    self.mid.block_2 = ResnetBlock(
         | 
| 454 | 
            +
                        in_channels=block_in,
         | 
| 455 | 
            +
                        out_channels=block_in,
         | 
| 456 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 457 | 
            +
                        dropout=dropout,
         | 
| 458 | 
            +
                    )
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    # end
         | 
| 461 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 462 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 463 | 
            +
                        block_in,
         | 
| 464 | 
            +
                        2 * z_channels if double_z else z_channels,
         | 
| 465 | 
            +
                        kernel_size=3,
         | 
| 466 | 
            +
                        stride=1,
         | 
| 467 | 
            +
                        padding=1,
         | 
| 468 | 
            +
                    )
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                def forward(self, x):
         | 
| 471 | 
            +
                    # timestep embedding
         | 
| 472 | 
            +
                    temb = None
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    # print(f'encoder-input={x.shape}')
         | 
| 475 | 
            +
                    # downsampling
         | 
| 476 | 
            +
                    hs = [self.conv_in(x)]
         | 
| 477 | 
            +
                    # print(f'encoder-conv in feat={hs[0].shape}')
         | 
| 478 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 479 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 480 | 
            +
                            h = self.down[i_level].block[i_block](hs[-1], temb)
         | 
| 481 | 
            +
                            # print(f'encoder-down feat={h.shape}')
         | 
| 482 | 
            +
                            if len(self.down[i_level].attn) > 0:
         | 
| 483 | 
            +
                                h = self.down[i_level].attn[i_block](h)
         | 
| 484 | 
            +
                            hs.append(h)
         | 
| 485 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 486 | 
            +
                            # print(f'encoder-downsample (input)={hs[-1].shape}')
         | 
| 487 | 
            +
                            hs.append(self.down[i_level].downsample(hs[-1]))
         | 
| 488 | 
            +
                            # print(f'encoder-downsample (output)={hs[-1].shape}')
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    # middle
         | 
| 491 | 
            +
                    h = hs[-1]
         | 
| 492 | 
            +
                    h = self.mid.block_1(h, temb)
         | 
| 493 | 
            +
                    # print(f'encoder-mid1 feat={h.shape}')
         | 
| 494 | 
            +
                    h = self.mid.attn_1(h)
         | 
| 495 | 
            +
                    h = self.mid.block_2(h, temb)
         | 
| 496 | 
            +
                    # print(f'encoder-mid2 feat={h.shape}')
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    # end
         | 
| 499 | 
            +
                    h = self.norm_out(h)
         | 
| 500 | 
            +
                    h = nonlinearity(h)
         | 
| 501 | 
            +
                    h = self.conv_out(h)
         | 
| 502 | 
            +
                    # print(f'end feat={h.shape}')
         | 
| 503 | 
            +
                    return h
         | 
| 504 | 
            +
             | 
| 505 | 
            +
             | 
| 506 | 
            +
            class Decoder(nn.Module):
         | 
| 507 | 
            +
                def __init__(
         | 
| 508 | 
            +
                    self,
         | 
| 509 | 
            +
                    *,
         | 
| 510 | 
            +
                    ch,
         | 
| 511 | 
            +
                    out_ch,
         | 
| 512 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 513 | 
            +
                    num_res_blocks,
         | 
| 514 | 
            +
                    attn_resolutions,
         | 
| 515 | 
            +
                    dropout=0.0,
         | 
| 516 | 
            +
                    resamp_with_conv=True,
         | 
| 517 | 
            +
                    in_channels,
         | 
| 518 | 
            +
                    resolution,
         | 
| 519 | 
            +
                    z_channels,
         | 
| 520 | 
            +
                    give_pre_end=False,
         | 
| 521 | 
            +
                    tanh_out=False,
         | 
| 522 | 
            +
                    use_linear_attn=False,
         | 
| 523 | 
            +
                    attn_type="vanilla",
         | 
| 524 | 
            +
                    **ignorekwargs,
         | 
| 525 | 
            +
                ):
         | 
| 526 | 
            +
                    super().__init__()
         | 
| 527 | 
            +
                    if use_linear_attn:
         | 
| 528 | 
            +
                        attn_type = "linear"
         | 
| 529 | 
            +
                    self.ch = ch
         | 
| 530 | 
            +
                    self.temb_ch = 0
         | 
| 531 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 532 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 533 | 
            +
                    self.resolution = resolution
         | 
| 534 | 
            +
                    self.in_channels = in_channels
         | 
| 535 | 
            +
                    self.give_pre_end = give_pre_end
         | 
| 536 | 
            +
                    self.tanh_out = tanh_out
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                    # compute in_ch_mult, block_in and curr_res at lowest res
         | 
| 539 | 
            +
                    in_ch_mult = (1,) + tuple(ch_mult)
         | 
| 540 | 
            +
                    block_in = ch * ch_mult[self.num_resolutions - 1]
         | 
| 541 | 
            +
                    curr_res = resolution // 2 ** (self.num_resolutions - 1)
         | 
| 542 | 
            +
                    self.z_shape = (1, z_channels, curr_res, curr_res)
         | 
| 543 | 
            +
                    print(
         | 
| 544 | 
            +
                        "AE working on z of shape {} = {} dimensions.".format(
         | 
| 545 | 
            +
                            self.z_shape, np.prod(self.z_shape)
         | 
| 546 | 
            +
                        )
         | 
| 547 | 
            +
                    )
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    # z to block_in
         | 
| 550 | 
            +
                    self.conv_in = torch.nn.Conv2d(
         | 
| 551 | 
            +
                        z_channels, block_in, kernel_size=3, stride=1, padding=1
         | 
| 552 | 
            +
                    )
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                    # middle
         | 
| 555 | 
            +
                    self.mid = nn.Module()
         | 
| 556 | 
            +
                    self.mid.block_1 = ResnetBlock(
         | 
| 557 | 
            +
                        in_channels=block_in,
         | 
| 558 | 
            +
                        out_channels=block_in,
         | 
| 559 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 560 | 
            +
                        dropout=dropout,
         | 
| 561 | 
            +
                    )
         | 
| 562 | 
            +
                    self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
         | 
| 563 | 
            +
                    self.mid.block_2 = ResnetBlock(
         | 
| 564 | 
            +
                        in_channels=block_in,
         | 
| 565 | 
            +
                        out_channels=block_in,
         | 
| 566 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 567 | 
            +
                        dropout=dropout,
         | 
| 568 | 
            +
                    )
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    # upsampling
         | 
| 571 | 
            +
                    self.up = nn.ModuleList()
         | 
| 572 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 573 | 
            +
                        block = nn.ModuleList()
         | 
| 574 | 
            +
                        attn = nn.ModuleList()
         | 
| 575 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 576 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 577 | 
            +
                            block.append(
         | 
| 578 | 
            +
                                ResnetBlock(
         | 
| 579 | 
            +
                                    in_channels=block_in,
         | 
| 580 | 
            +
                                    out_channels=block_out,
         | 
| 581 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 582 | 
            +
                                    dropout=dropout,
         | 
| 583 | 
            +
                                )
         | 
| 584 | 
            +
                            )
         | 
| 585 | 
            +
                            block_in = block_out
         | 
| 586 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 587 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 588 | 
            +
                        up = nn.Module()
         | 
| 589 | 
            +
                        up.block = block
         | 
| 590 | 
            +
                        up.attn = attn
         | 
| 591 | 
            +
                        if i_level != 0:
         | 
| 592 | 
            +
                            up.upsample = Upsample(block_in, resamp_with_conv)
         | 
| 593 | 
            +
                            curr_res = curr_res * 2
         | 
| 594 | 
            +
                        self.up.insert(0, up)  # prepend to get consistent order
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                    # end
         | 
| 597 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 598 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 599 | 
            +
                        block_in, out_ch, kernel_size=3, stride=1, padding=1
         | 
| 600 | 
            +
                    )
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                def forward(self, z):
         | 
| 603 | 
            +
                    # assert z.shape[1:] == self.z_shape[1:]
         | 
| 604 | 
            +
                    self.last_z_shape = z.shape
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                    # print(f'decoder-input={z.shape}')
         | 
| 607 | 
            +
                    # timestep embedding
         | 
| 608 | 
            +
                    temb = None
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    # z to block_in
         | 
| 611 | 
            +
                    h = self.conv_in(z)
         | 
| 612 | 
            +
                    # print(f'decoder-conv in feat={h.shape}')
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                    # middle
         | 
| 615 | 
            +
                    h = self.mid.block_1(h, temb)
         | 
| 616 | 
            +
                    h = self.mid.attn_1(h)
         | 
| 617 | 
            +
                    h = self.mid.block_2(h, temb)
         | 
| 618 | 
            +
                    # print(f'decoder-mid feat={h.shape}')
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                    # upsampling
         | 
| 621 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 622 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 623 | 
            +
                            h = self.up[i_level].block[i_block](h, temb)
         | 
| 624 | 
            +
                            if len(self.up[i_level].attn) > 0:
         | 
| 625 | 
            +
                                h = self.up[i_level].attn[i_block](h)
         | 
| 626 | 
            +
                            # print(f'decoder-up feat={h.shape}')
         | 
| 627 | 
            +
                        if i_level != 0:
         | 
| 628 | 
            +
                            h = self.up[i_level].upsample(h)
         | 
| 629 | 
            +
                            # print(f'decoder-upsample feat={h.shape}')
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                    # end
         | 
| 632 | 
            +
                    if self.give_pre_end:
         | 
| 633 | 
            +
                        return h
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                    h = self.norm_out(h)
         | 
| 636 | 
            +
                    h = nonlinearity(h)
         | 
| 637 | 
            +
                    h = self.conv_out(h)
         | 
| 638 | 
            +
                    # print(f'decoder-conv_out feat={h.shape}')
         | 
| 639 | 
            +
                    if self.tanh_out:
         | 
| 640 | 
            +
                        h = torch.tanh(h)
         | 
| 641 | 
            +
                    return h
         | 
| 642 | 
            +
             | 
| 643 | 
            +
             | 
| 644 | 
            +
            class SimpleDecoder(nn.Module):
         | 
| 645 | 
            +
                def __init__(self, in_channels, out_channels, *args, **kwargs):
         | 
| 646 | 
            +
                    super().__init__()
         | 
| 647 | 
            +
                    self.model = nn.ModuleList(
         | 
| 648 | 
            +
                        [
         | 
| 649 | 
            +
                            nn.Conv2d(in_channels, in_channels, 1),
         | 
| 650 | 
            +
                            ResnetBlock(
         | 
| 651 | 
            +
                                in_channels=in_channels,
         | 
| 652 | 
            +
                                out_channels=2 * in_channels,
         | 
| 653 | 
            +
                                temb_channels=0,
         | 
| 654 | 
            +
                                dropout=0.0,
         | 
| 655 | 
            +
                            ),
         | 
| 656 | 
            +
                            ResnetBlock(
         | 
| 657 | 
            +
                                in_channels=2 * in_channels,
         | 
| 658 | 
            +
                                out_channels=4 * in_channels,
         | 
| 659 | 
            +
                                temb_channels=0,
         | 
| 660 | 
            +
                                dropout=0.0,
         | 
| 661 | 
            +
                            ),
         | 
| 662 | 
            +
                            ResnetBlock(
         | 
| 663 | 
            +
                                in_channels=4 * in_channels,
         | 
| 664 | 
            +
                                out_channels=2 * in_channels,
         | 
| 665 | 
            +
                                temb_channels=0,
         | 
| 666 | 
            +
                                dropout=0.0,
         | 
| 667 | 
            +
                            ),
         | 
| 668 | 
            +
                            nn.Conv2d(2 * in_channels, in_channels, 1),
         | 
| 669 | 
            +
                            Upsample(in_channels, with_conv=True),
         | 
| 670 | 
            +
                        ]
         | 
| 671 | 
            +
                    )
         | 
| 672 | 
            +
                    # end
         | 
| 673 | 
            +
                    self.norm_out = Normalize(in_channels)
         | 
| 674 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 675 | 
            +
                        in_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 676 | 
            +
                    )
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                def forward(self, x):
         | 
| 679 | 
            +
                    for i, layer in enumerate(self.model):
         | 
| 680 | 
            +
                        if i in [1, 2, 3]:
         | 
| 681 | 
            +
                            x = layer(x, None)
         | 
| 682 | 
            +
                        else:
         | 
| 683 | 
            +
                            x = layer(x)
         | 
| 684 | 
            +
             | 
| 685 | 
            +
                    h = self.norm_out(x)
         | 
| 686 | 
            +
                    h = nonlinearity(h)
         | 
| 687 | 
            +
                    x = self.conv_out(h)
         | 
| 688 | 
            +
                    return x
         | 
| 689 | 
            +
             | 
| 690 | 
            +
             | 
| 691 | 
            +
            class UpsampleDecoder(nn.Module):
         | 
| 692 | 
            +
                def __init__(
         | 
| 693 | 
            +
                    self,
         | 
| 694 | 
            +
                    in_channels,
         | 
| 695 | 
            +
                    out_channels,
         | 
| 696 | 
            +
                    ch,
         | 
| 697 | 
            +
                    num_res_blocks,
         | 
| 698 | 
            +
                    resolution,
         | 
| 699 | 
            +
                    ch_mult=(2, 2),
         | 
| 700 | 
            +
                    dropout=0.0,
         | 
| 701 | 
            +
                ):
         | 
| 702 | 
            +
                    super().__init__()
         | 
| 703 | 
            +
                    # upsampling
         | 
| 704 | 
            +
                    self.temb_ch = 0
         | 
| 705 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 706 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 707 | 
            +
                    block_in = in_channels
         | 
| 708 | 
            +
                    curr_res = resolution // 2 ** (self.num_resolutions - 1)
         | 
| 709 | 
            +
                    self.res_blocks = nn.ModuleList()
         | 
| 710 | 
            +
                    self.upsample_blocks = nn.ModuleList()
         | 
| 711 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 712 | 
            +
                        res_block = []
         | 
| 713 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 714 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 715 | 
            +
                            res_block.append(
         | 
| 716 | 
            +
                                ResnetBlock(
         | 
| 717 | 
            +
                                    in_channels=block_in,
         | 
| 718 | 
            +
                                    out_channels=block_out,
         | 
| 719 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 720 | 
            +
                                    dropout=dropout,
         | 
| 721 | 
            +
                                )
         | 
| 722 | 
            +
                            )
         | 
| 723 | 
            +
                            block_in = block_out
         | 
| 724 | 
            +
                        self.res_blocks.append(nn.ModuleList(res_block))
         | 
| 725 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 726 | 
            +
                            self.upsample_blocks.append(Upsample(block_in, True))
         | 
| 727 | 
            +
                            curr_res = curr_res * 2
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                    # end
         | 
| 730 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 731 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 732 | 
            +
                        block_in, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 733 | 
            +
                    )
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                def forward(self, x):
         | 
| 736 | 
            +
                    # upsampling
         | 
| 737 | 
            +
                    h = x
         | 
| 738 | 
            +
                    for k, i_level in enumerate(range(self.num_resolutions)):
         | 
| 739 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 740 | 
            +
                            h = self.res_blocks[i_level][i_block](h, None)
         | 
| 741 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 742 | 
            +
                            h = self.upsample_blocks[k](h)
         | 
| 743 | 
            +
                    h = self.norm_out(h)
         | 
| 744 | 
            +
                    h = nonlinearity(h)
         | 
| 745 | 
            +
                    h = self.conv_out(h)
         | 
| 746 | 
            +
                    return h
         | 
| 747 | 
            +
             | 
| 748 | 
            +
             | 
| 749 | 
            +
            class LatentRescaler(nn.Module):
         | 
| 750 | 
            +
                def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
         | 
| 751 | 
            +
                    super().__init__()
         | 
| 752 | 
            +
                    # residual block, interpolate, residual block
         | 
| 753 | 
            +
                    self.factor = factor
         | 
| 754 | 
            +
                    self.conv_in = nn.Conv2d(
         | 
| 755 | 
            +
                        in_channels, mid_channels, kernel_size=3, stride=1, padding=1
         | 
| 756 | 
            +
                    )
         | 
| 757 | 
            +
                    self.res_block1 = nn.ModuleList(
         | 
| 758 | 
            +
                        [
         | 
| 759 | 
            +
                            ResnetBlock(
         | 
| 760 | 
            +
                                in_channels=mid_channels,
         | 
| 761 | 
            +
                                out_channels=mid_channels,
         | 
| 762 | 
            +
                                temb_channels=0,
         | 
| 763 | 
            +
                                dropout=0.0,
         | 
| 764 | 
            +
                            )
         | 
| 765 | 
            +
                            for _ in range(depth)
         | 
| 766 | 
            +
                        ]
         | 
| 767 | 
            +
                    )
         | 
| 768 | 
            +
                    self.attn = AttnBlock(mid_channels)
         | 
| 769 | 
            +
                    self.res_block2 = nn.ModuleList(
         | 
| 770 | 
            +
                        [
         | 
| 771 | 
            +
                            ResnetBlock(
         | 
| 772 | 
            +
                                in_channels=mid_channels,
         | 
| 773 | 
            +
                                out_channels=mid_channels,
         | 
| 774 | 
            +
                                temb_channels=0,
         | 
| 775 | 
            +
                                dropout=0.0,
         | 
| 776 | 
            +
                            )
         | 
| 777 | 
            +
                            for _ in range(depth)
         | 
| 778 | 
            +
                        ]
         | 
| 779 | 
            +
                    )
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                    self.conv_out = nn.Conv2d(
         | 
| 782 | 
            +
                        mid_channels,
         | 
| 783 | 
            +
                        out_channels,
         | 
| 784 | 
            +
                        kernel_size=1,
         | 
| 785 | 
            +
                    )
         | 
| 786 | 
            +
             | 
| 787 | 
            +
                def forward(self, x):
         | 
| 788 | 
            +
                    x = self.conv_in(x)
         | 
| 789 | 
            +
                    for block in self.res_block1:
         | 
| 790 | 
            +
                        x = block(x, None)
         | 
| 791 | 
            +
                    x = torch.nn.functional.interpolate(
         | 
| 792 | 
            +
                        x,
         | 
| 793 | 
            +
                        size=(
         | 
| 794 | 
            +
                            int(round(x.shape[2] * self.factor)),
         | 
| 795 | 
            +
                            int(round(x.shape[3] * self.factor)),
         | 
| 796 | 
            +
                        ),
         | 
| 797 | 
            +
                    )
         | 
| 798 | 
            +
                    x = self.attn(x)
         | 
| 799 | 
            +
                    for block in self.res_block2:
         | 
| 800 | 
            +
                        x = block(x, None)
         | 
| 801 | 
            +
                    x = self.conv_out(x)
         | 
| 802 | 
            +
                    return x
         | 
| 803 | 
            +
             | 
| 804 | 
            +
             | 
| 805 | 
            +
            class MergedRescaleEncoder(nn.Module):
         | 
| 806 | 
            +
                def __init__(
         | 
| 807 | 
            +
                    self,
         | 
| 808 | 
            +
                    in_channels,
         | 
| 809 | 
            +
                    ch,
         | 
| 810 | 
            +
                    resolution,
         | 
| 811 | 
            +
                    out_ch,
         | 
| 812 | 
            +
                    num_res_blocks,
         | 
| 813 | 
            +
                    attn_resolutions,
         | 
| 814 | 
            +
                    dropout=0.0,
         | 
| 815 | 
            +
                    resamp_with_conv=True,
         | 
| 816 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 817 | 
            +
                    rescale_factor=1.0,
         | 
| 818 | 
            +
                    rescale_module_depth=1,
         | 
| 819 | 
            +
                ):
         | 
| 820 | 
            +
                    super().__init__()
         | 
| 821 | 
            +
                    intermediate_chn = ch * ch_mult[-1]
         | 
| 822 | 
            +
                    self.encoder = Encoder(
         | 
| 823 | 
            +
                        in_channels=in_channels,
         | 
| 824 | 
            +
                        num_res_blocks=num_res_blocks,
         | 
| 825 | 
            +
                        ch=ch,
         | 
| 826 | 
            +
                        ch_mult=ch_mult,
         | 
| 827 | 
            +
                        z_channels=intermediate_chn,
         | 
| 828 | 
            +
                        double_z=False,
         | 
| 829 | 
            +
                        resolution=resolution,
         | 
| 830 | 
            +
                        attn_resolutions=attn_resolutions,
         | 
| 831 | 
            +
                        dropout=dropout,
         | 
| 832 | 
            +
                        resamp_with_conv=resamp_with_conv,
         | 
| 833 | 
            +
                        out_ch=None,
         | 
| 834 | 
            +
                    )
         | 
| 835 | 
            +
                    self.rescaler = LatentRescaler(
         | 
| 836 | 
            +
                        factor=rescale_factor,
         | 
| 837 | 
            +
                        in_channels=intermediate_chn,
         | 
| 838 | 
            +
                        mid_channels=intermediate_chn,
         | 
| 839 | 
            +
                        out_channels=out_ch,
         | 
| 840 | 
            +
                        depth=rescale_module_depth,
         | 
| 841 | 
            +
                    )
         | 
| 842 | 
            +
             | 
| 843 | 
            +
                def forward(self, x):
         | 
| 844 | 
            +
                    x = self.encoder(x)
         | 
| 845 | 
            +
                    x = self.rescaler(x)
         | 
| 846 | 
            +
                    return x
         | 
| 847 | 
            +
             | 
| 848 | 
            +
             | 
| 849 | 
            +
            class MergedRescaleDecoder(nn.Module):
         | 
| 850 | 
            +
                def __init__(
         | 
| 851 | 
            +
                    self,
         | 
| 852 | 
            +
                    z_channels,
         | 
| 853 | 
            +
                    out_ch,
         | 
| 854 | 
            +
                    resolution,
         | 
| 855 | 
            +
                    num_res_blocks,
         | 
| 856 | 
            +
                    attn_resolutions,
         | 
| 857 | 
            +
                    ch,
         | 
| 858 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 859 | 
            +
                    dropout=0.0,
         | 
| 860 | 
            +
                    resamp_with_conv=True,
         | 
| 861 | 
            +
                    rescale_factor=1.0,
         | 
| 862 | 
            +
                    rescale_module_depth=1,
         | 
| 863 | 
            +
                ):
         | 
| 864 | 
            +
                    super().__init__()
         | 
| 865 | 
            +
                    tmp_chn = z_channels * ch_mult[-1]
         | 
| 866 | 
            +
                    self.decoder = Decoder(
         | 
| 867 | 
            +
                        out_ch=out_ch,
         | 
| 868 | 
            +
                        z_channels=tmp_chn,
         | 
| 869 | 
            +
                        attn_resolutions=attn_resolutions,
         | 
| 870 | 
            +
                        dropout=dropout,
         | 
| 871 | 
            +
                        resamp_with_conv=resamp_with_conv,
         | 
| 872 | 
            +
                        in_channels=None,
         | 
| 873 | 
            +
                        num_res_blocks=num_res_blocks,
         | 
| 874 | 
            +
                        ch_mult=ch_mult,
         | 
| 875 | 
            +
                        resolution=resolution,
         | 
| 876 | 
            +
                        ch=ch,
         | 
| 877 | 
            +
                    )
         | 
| 878 | 
            +
                    self.rescaler = LatentRescaler(
         | 
| 879 | 
            +
                        factor=rescale_factor,
         | 
| 880 | 
            +
                        in_channels=z_channels,
         | 
| 881 | 
            +
                        mid_channels=tmp_chn,
         | 
| 882 | 
            +
                        out_channels=tmp_chn,
         | 
| 883 | 
            +
                        depth=rescale_module_depth,
         | 
| 884 | 
            +
                    )
         | 
| 885 | 
            +
             | 
| 886 | 
            +
                def forward(self, x):
         | 
| 887 | 
            +
                    x = self.rescaler(x)
         | 
| 888 | 
            +
                    x = self.decoder(x)
         | 
| 889 | 
            +
                    return x
         | 
| 890 | 
            +
             | 
| 891 | 
            +
             | 
| 892 | 
            +
            class Upsampler(nn.Module):
         | 
| 893 | 
            +
                def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
         | 
| 894 | 
            +
                    super().__init__()
         | 
| 895 | 
            +
                    assert out_size >= in_size
         | 
| 896 | 
            +
                    num_blocks = int(np.log2(out_size // in_size)) + 1
         | 
| 897 | 
            +
                    factor_up = 1.0 + (out_size % in_size)
         | 
| 898 | 
            +
                    print(
         | 
| 899 | 
            +
                        f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
         | 
| 900 | 
            +
                    )
         | 
| 901 | 
            +
                    self.rescaler = LatentRescaler(
         | 
| 902 | 
            +
                        factor=factor_up,
         | 
| 903 | 
            +
                        in_channels=in_channels,
         | 
| 904 | 
            +
                        mid_channels=2 * in_channels,
         | 
| 905 | 
            +
                        out_channels=in_channels,
         | 
| 906 | 
            +
                    )
         | 
| 907 | 
            +
                    self.decoder = Decoder(
         | 
| 908 | 
            +
                        out_ch=out_channels,
         | 
| 909 | 
            +
                        resolution=out_size,
         | 
| 910 | 
            +
                        z_channels=in_channels,
         | 
| 911 | 
            +
                        num_res_blocks=2,
         | 
| 912 | 
            +
                        attn_resolutions=[],
         | 
| 913 | 
            +
                        in_channels=None,
         | 
| 914 | 
            +
                        ch=in_channels,
         | 
| 915 | 
            +
                        ch_mult=[ch_mult for _ in range(num_blocks)],
         | 
| 916 | 
            +
                    )
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                def forward(self, x):
         | 
| 919 | 
            +
                    x = self.rescaler(x)
         | 
| 920 | 
            +
                    x = self.decoder(x)
         | 
| 921 | 
            +
                    return x
         | 
| 922 | 
            +
             | 
| 923 | 
            +
             | 
| 924 | 
            +
            class Resize(nn.Module):
         | 
| 925 | 
            +
                def __init__(self, in_channels=None, learned=False, mode="bilinear"):
         | 
| 926 | 
            +
                    super().__init__()
         | 
| 927 | 
            +
                    self.with_conv = learned
         | 
| 928 | 
            +
                    self.mode = mode
         | 
| 929 | 
            +
                    if self.with_conv:
         | 
| 930 | 
            +
                        print(
         | 
| 931 | 
            +
                            f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
         | 
| 932 | 
            +
                        )
         | 
| 933 | 
            +
                        raise NotImplementedError()
         | 
| 934 | 
            +
                        assert in_channels is not None
         | 
| 935 | 
            +
                        # no asymmetric padding in torch conv, must do it ourselves
         | 
| 936 | 
            +
                        self.conv = torch.nn.Conv2d(
         | 
| 937 | 
            +
                            in_channels, in_channels, kernel_size=4, stride=2, padding=1
         | 
| 938 | 
            +
                        )
         | 
| 939 | 
            +
             | 
| 940 | 
            +
                def forward(self, x, scale_factor=1.0):
         | 
| 941 | 
            +
                    if scale_factor == 1.0:
         | 
| 942 | 
            +
                        return x
         | 
| 943 | 
            +
                    else:
         | 
| 944 | 
            +
                        x = torch.nn.functional.interpolate(
         | 
| 945 | 
            +
                            x, mode=self.mode, align_corners=False, scale_factor=scale_factor
         | 
| 946 | 
            +
                        )
         | 
| 947 | 
            +
                    return x
         | 
| 948 | 
            +
             | 
| 949 | 
            +
             | 
| 950 | 
            +
            class FirstStagePostProcessor(nn.Module):
         | 
| 951 | 
            +
             | 
| 952 | 
            +
                def __init__(
         | 
| 953 | 
            +
                    self,
         | 
| 954 | 
            +
                    ch_mult: list,
         | 
| 955 | 
            +
                    in_channels,
         | 
| 956 | 
            +
                    pretrained_model: nn.Module = None,
         | 
| 957 | 
            +
                    reshape=False,
         | 
| 958 | 
            +
                    n_channels=None,
         | 
| 959 | 
            +
                    dropout=0.0,
         | 
| 960 | 
            +
                    pretrained_config=None,
         | 
| 961 | 
            +
                ):
         | 
| 962 | 
            +
                    super().__init__()
         | 
| 963 | 
            +
                    if pretrained_config is None:
         | 
| 964 | 
            +
                        assert (
         | 
| 965 | 
            +
                            pretrained_model is not None
         | 
| 966 | 
            +
                        ), 'Either "pretrained_model" or "pretrained_config" must not be None'
         | 
| 967 | 
            +
                        self.pretrained_model = pretrained_model
         | 
| 968 | 
            +
                    else:
         | 
| 969 | 
            +
                        assert (
         | 
| 970 | 
            +
                            pretrained_config is not None
         | 
| 971 | 
            +
                        ), 'Either "pretrained_model" or "pretrained_config" must not be None'
         | 
| 972 | 
            +
                        self.instantiate_pretrained(pretrained_config)
         | 
| 973 | 
            +
             | 
| 974 | 
            +
                    self.do_reshape = reshape
         | 
| 975 | 
            +
             | 
| 976 | 
            +
                    if n_channels is None:
         | 
| 977 | 
            +
                        n_channels = self.pretrained_model.encoder.ch
         | 
| 978 | 
            +
             | 
| 979 | 
            +
                    self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
         | 
| 980 | 
            +
                    self.proj = nn.Conv2d(
         | 
| 981 | 
            +
                        in_channels, n_channels, kernel_size=3, stride=1, padding=1
         | 
| 982 | 
            +
                    )
         | 
| 983 | 
            +
             | 
| 984 | 
            +
                    blocks = []
         | 
| 985 | 
            +
                    downs = []
         | 
| 986 | 
            +
                    ch_in = n_channels
         | 
| 987 | 
            +
                    for m in ch_mult:
         | 
| 988 | 
            +
                        blocks.append(
         | 
| 989 | 
            +
                            ResnetBlock(
         | 
| 990 | 
            +
                                in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
         | 
| 991 | 
            +
                            )
         | 
| 992 | 
            +
                        )
         | 
| 993 | 
            +
                        ch_in = m * n_channels
         | 
| 994 | 
            +
                        downs.append(Downsample(ch_in, with_conv=False))
         | 
| 995 | 
            +
             | 
| 996 | 
            +
                    self.model = nn.ModuleList(blocks)
         | 
| 997 | 
            +
                    self.downsampler = nn.ModuleList(downs)
         | 
| 998 | 
            +
             | 
| 999 | 
            +
                def instantiate_pretrained(self, config):
         | 
| 1000 | 
            +
                    model = instantiate_from_config(config)
         | 
| 1001 | 
            +
                    self.pretrained_model = model.eval()
         | 
| 1002 | 
            +
                    # self.pretrained_model.train = False
         | 
| 1003 | 
            +
                    for param in self.pretrained_model.parameters():
         | 
| 1004 | 
            +
                        param.requires_grad = False
         | 
| 1005 | 
            +
             | 
| 1006 | 
            +
                @torch.no_grad()
         | 
| 1007 | 
            +
                def encode_with_pretrained(self, x):
         | 
| 1008 | 
            +
                    c = self.pretrained_model.encode(x)
         | 
| 1009 | 
            +
                    if isinstance(c, DiagonalGaussianDistribution):
         | 
| 1010 | 
            +
                        c = c.mode()
         | 
| 1011 | 
            +
                    return c
         | 
| 1012 | 
            +
             | 
| 1013 | 
            +
                def forward(self, x):
         | 
| 1014 | 
            +
                    z_fs = self.encode_with_pretrained(x)
         | 
| 1015 | 
            +
                    z = self.proj_norm(z_fs)
         | 
| 1016 | 
            +
                    z = self.proj(z)
         | 
| 1017 | 
            +
                    z = nonlinearity(z)
         | 
| 1018 | 
            +
             | 
| 1019 | 
            +
                    for submodel, downmodel in zip(self.model, self.downsampler):
         | 
| 1020 | 
            +
                        z = submodel(z, temb=None)
         | 
| 1021 | 
            +
                        z = downmodel(z)
         | 
| 1022 | 
            +
             | 
| 1023 | 
            +
                    if self.do_reshape:
         | 
| 1024 | 
            +
                        z = rearrange(z, "b c h w -> b (h w) c")
         | 
| 1025 | 
            +
                    return z
         | 
    	
        lvdm/modules/networks/openaimodel3d.py
    ADDED
    
    | @@ -0,0 +1,740 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from functools import partial
         | 
| 2 | 
            +
            from abc import abstractmethod
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            from einops import rearrange
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from lvdm.models.utils_diffusion import timestep_embedding
         | 
| 8 | 
            +
            from lvdm.common import checkpoint
         | 
| 9 | 
            +
            from lvdm.basics import zero_module, conv_nd, linear, avg_pool_nd, normalization
         | 
| 10 | 
            +
            from lvdm.modules.attention import SpatialTransformer, TemporalTransformer
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class TimestepBlock(nn.Module):
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                Any module where forward() takes timestep embeddings as a second argument.
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                @abstractmethod
         | 
| 19 | 
            +
                def forward(self, x, emb):
         | 
| 20 | 
            +
                    """
         | 
| 21 | 
            +
                    Apply the module to `x` given `emb` timestep embeddings.
         | 
| 22 | 
            +
                    """
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                A sequential module that passes timestep embeddings to the children that
         | 
| 28 | 
            +
                support it as an extra input.
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def forward(self, x, emb, context=None, batch_size=None):
         | 
| 32 | 
            +
                    for layer in self:
         | 
| 33 | 
            +
                        if isinstance(layer, TimestepBlock):
         | 
| 34 | 
            +
                            x = layer(x, emb, batch_size)
         | 
| 35 | 
            +
                        elif isinstance(layer, SpatialTransformer):
         | 
| 36 | 
            +
                            x = layer(x, context)
         | 
| 37 | 
            +
                        elif isinstance(layer, TemporalTransformer):
         | 
| 38 | 
            +
                            x = rearrange(x, "(b f) c h w -> b c f h w", b=batch_size)
         | 
| 39 | 
            +
                            x = layer(x, context)
         | 
| 40 | 
            +
                            x = rearrange(x, "b c f h w -> (b f) c h w")
         | 
| 41 | 
            +
                        else:
         | 
| 42 | 
            +
                            x = layer(
         | 
| 43 | 
            +
                                x,
         | 
| 44 | 
            +
                            )
         | 
| 45 | 
            +
                    return x
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            class Downsample(nn.Module):
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                A downsampling layer with an optional convolution.
         | 
| 51 | 
            +
                :param channels: channels in the inputs and outputs.
         | 
| 52 | 
            +
                :param use_conv: a bool determining if a convolution is applied.
         | 
| 53 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
         | 
| 54 | 
            +
                             downsampling occurs in the inner-two dimensions.
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
         | 
| 58 | 
            +
                    super().__init__()
         | 
| 59 | 
            +
                    self.channels = channels
         | 
| 60 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 61 | 
            +
                    self.use_conv = use_conv
         | 
| 62 | 
            +
                    self.dims = dims
         | 
| 63 | 
            +
                    stride = 2 if dims != 3 else (1, 2, 2)
         | 
| 64 | 
            +
                    if use_conv:
         | 
| 65 | 
            +
                        self.op = conv_nd(
         | 
| 66 | 
            +
                            dims,
         | 
| 67 | 
            +
                            self.channels,
         | 
| 68 | 
            +
                            self.out_channels,
         | 
| 69 | 
            +
                            3,
         | 
| 70 | 
            +
                            stride=stride,
         | 
| 71 | 
            +
                            padding=padding,
         | 
| 72 | 
            +
                        )
         | 
| 73 | 
            +
                    else:
         | 
| 74 | 
            +
                        assert self.channels == self.out_channels
         | 
| 75 | 
            +
                        self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def forward(self, x):
         | 
| 78 | 
            +
                    assert x.shape[1] == self.channels
         | 
| 79 | 
            +
                    return self.op(x)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            class Upsample(nn.Module):
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
                An upsampling layer with an optional convolution.
         | 
| 85 | 
            +
                :param channels: channels in the inputs and outputs.
         | 
| 86 | 
            +
                :param use_conv: a bool determining if a convolution is applied.
         | 
| 87 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
         | 
| 88 | 
            +
                             upsampling occurs in the inner-two dimensions.
         | 
| 89 | 
            +
                """
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
         | 
| 92 | 
            +
                    super().__init__()
         | 
| 93 | 
            +
                    self.channels = channels
         | 
| 94 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 95 | 
            +
                    self.use_conv = use_conv
         | 
| 96 | 
            +
                    self.dims = dims
         | 
| 97 | 
            +
                    if use_conv:
         | 
| 98 | 
            +
                        self.conv = conv_nd(
         | 
| 99 | 
            +
                            dims, self.channels, self.out_channels, 3, padding=padding
         | 
| 100 | 
            +
                        )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def forward(self, x):
         | 
| 103 | 
            +
                    assert x.shape[1] == self.channels
         | 
| 104 | 
            +
                    if self.dims == 3:
         | 
| 105 | 
            +
                        x = F.interpolate(
         | 
| 106 | 
            +
                            x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
         | 
| 107 | 
            +
                        )
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        x = F.interpolate(x, scale_factor=2, mode="nearest")
         | 
| 110 | 
            +
                    if self.use_conv:
         | 
| 111 | 
            +
                        x = self.conv(x)
         | 
| 112 | 
            +
                    return x
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            class ResBlock(TimestepBlock):
         | 
| 116 | 
            +
                """
         | 
| 117 | 
            +
                A residual block that can optionally change the number of channels.
         | 
| 118 | 
            +
                :param channels: the number of input channels.
         | 
| 119 | 
            +
                :param emb_channels: the number of timestep embedding channels.
         | 
| 120 | 
            +
                :param dropout: the rate of dropout.
         | 
| 121 | 
            +
                :param out_channels: if specified, the number of out channels.
         | 
| 122 | 
            +
                :param use_conv: if True and out_channels is specified, use a spatial
         | 
| 123 | 
            +
                    convolution instead of a smaller 1x1 convolution to change the
         | 
| 124 | 
            +
                    channels in the skip connection.
         | 
| 125 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D.
         | 
| 126 | 
            +
                :param up: if True, use this block for upsampling.
         | 
| 127 | 
            +
                :param down: if True, use this block for downsampling.
         | 
| 128 | 
            +
                """
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def __init__(
         | 
| 131 | 
            +
                    self,
         | 
| 132 | 
            +
                    channels,
         | 
| 133 | 
            +
                    emb_channels,
         | 
| 134 | 
            +
                    dropout,
         | 
| 135 | 
            +
                    out_channels=None,
         | 
| 136 | 
            +
                    use_scale_shift_norm=False,
         | 
| 137 | 
            +
                    dims=2,
         | 
| 138 | 
            +
                    use_checkpoint=False,
         | 
| 139 | 
            +
                    use_conv=False,
         | 
| 140 | 
            +
                    up=False,
         | 
| 141 | 
            +
                    down=False,
         | 
| 142 | 
            +
                    use_temporal_conv=False,
         | 
| 143 | 
            +
                    tempspatial_aware=False,
         | 
| 144 | 
            +
                ):
         | 
| 145 | 
            +
                    super().__init__()
         | 
| 146 | 
            +
                    self.channels = channels
         | 
| 147 | 
            +
                    self.emb_channels = emb_channels
         | 
| 148 | 
            +
                    self.dropout = dropout
         | 
| 149 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 150 | 
            +
                    self.use_conv = use_conv
         | 
| 151 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 152 | 
            +
                    self.use_scale_shift_norm = use_scale_shift_norm
         | 
| 153 | 
            +
                    self.use_temporal_conv = use_temporal_conv
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    self.in_layers = nn.Sequential(
         | 
| 156 | 
            +
                        normalization(channels),
         | 
| 157 | 
            +
                        nn.SiLU(),
         | 
| 158 | 
            +
                        conv_nd(dims, channels, self.out_channels, 3, padding=1),
         | 
| 159 | 
            +
                    )
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    self.updown = up or down
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    if up:
         | 
| 164 | 
            +
                        self.h_upd = Upsample(channels, False, dims)
         | 
| 165 | 
            +
                        self.x_upd = Upsample(channels, False, dims)
         | 
| 166 | 
            +
                    elif down:
         | 
| 167 | 
            +
                        self.h_upd = Downsample(channels, False, dims)
         | 
| 168 | 
            +
                        self.x_upd = Downsample(channels, False, dims)
         | 
| 169 | 
            +
                    else:
         | 
| 170 | 
            +
                        self.h_upd = self.x_upd = nn.Identity()
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    self.emb_layers = nn.Sequential(
         | 
| 173 | 
            +
                        nn.SiLU(),
         | 
| 174 | 
            +
                        nn.Linear(
         | 
| 175 | 
            +
                            emb_channels,
         | 
| 176 | 
            +
                            2 * self.out_channels if use_scale_shift_norm else self.out_channels,
         | 
| 177 | 
            +
                        ),
         | 
| 178 | 
            +
                    )
         | 
| 179 | 
            +
                    self.out_layers = nn.Sequential(
         | 
| 180 | 
            +
                        normalization(self.out_channels),
         | 
| 181 | 
            +
                        nn.SiLU(),
         | 
| 182 | 
            +
                        nn.Dropout(p=dropout),
         | 
| 183 | 
            +
                        zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
         | 
| 184 | 
            +
                    )
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    if self.out_channels == channels:
         | 
| 187 | 
            +
                        self.skip_connection = nn.Identity()
         | 
| 188 | 
            +
                    elif use_conv:
         | 
| 189 | 
            +
                        self.skip_connection = conv_nd(
         | 
| 190 | 
            +
                            dims, channels, self.out_channels, 3, padding=1
         | 
| 191 | 
            +
                        )
         | 
| 192 | 
            +
                    else:
         | 
| 193 | 
            +
                        self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if self.use_temporal_conv:
         | 
| 196 | 
            +
                        self.temopral_conv = TemporalConvBlock(
         | 
| 197 | 
            +
                            self.out_channels,
         | 
| 198 | 
            +
                            self.out_channels,
         | 
| 199 | 
            +
                            dropout=0.1,
         | 
| 200 | 
            +
                            spatial_aware=tempspatial_aware,
         | 
| 201 | 
            +
                        )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                def forward(self, x, emb, batch_size=None):
         | 
| 204 | 
            +
                    """
         | 
| 205 | 
            +
                    Apply the block to a Tensor, conditioned on a timestep embedding.
         | 
| 206 | 
            +
                    :param x: an [N x C x ...] Tensor of features.
         | 
| 207 | 
            +
                    :param emb: an [N x emb_channels] Tensor of timestep embeddings.
         | 
| 208 | 
            +
                    :return: an [N x C x ...] Tensor of outputs.
         | 
| 209 | 
            +
                    """
         | 
| 210 | 
            +
                    input_tuple = (
         | 
| 211 | 
            +
                        x,
         | 
| 212 | 
            +
                        emb,
         | 
| 213 | 
            +
                    )
         | 
| 214 | 
            +
                    if batch_size:
         | 
| 215 | 
            +
                        forward_batchsize = partial(self._forward, batch_size=batch_size)
         | 
| 216 | 
            +
                        return checkpoint(
         | 
| 217 | 
            +
                            forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint
         | 
| 218 | 
            +
                        )
         | 
| 219 | 
            +
                    return checkpoint(
         | 
| 220 | 
            +
                        self._forward, input_tuple, self.parameters(), self.use_checkpoint
         | 
| 221 | 
            +
                    )
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                def _forward(
         | 
| 224 | 
            +
                    self,
         | 
| 225 | 
            +
                    x,
         | 
| 226 | 
            +
                    emb,
         | 
| 227 | 
            +
                    batch_size=None,
         | 
| 228 | 
            +
                ):
         | 
| 229 | 
            +
                    if self.updown:
         | 
| 230 | 
            +
                        in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
         | 
| 231 | 
            +
                        h = in_rest(x)
         | 
| 232 | 
            +
                        h = self.h_upd(h)
         | 
| 233 | 
            +
                        x = self.x_upd(x)
         | 
| 234 | 
            +
                        h = in_conv(h)
         | 
| 235 | 
            +
                    else:
         | 
| 236 | 
            +
                        h = self.in_layers(x)
         | 
| 237 | 
            +
                    emb_out = self.emb_layers(emb).type(h.dtype)
         | 
| 238 | 
            +
                    while len(emb_out.shape) < len(h.shape):
         | 
| 239 | 
            +
                        emb_out = emb_out[..., None]
         | 
| 240 | 
            +
                    if self.use_scale_shift_norm:
         | 
| 241 | 
            +
                        out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
         | 
| 242 | 
            +
                        scale, shift = torch.chunk(emb_out, 2, dim=1)
         | 
| 243 | 
            +
                        h = out_norm(h) * (1 + scale) + shift
         | 
| 244 | 
            +
                        h = out_rest(h)
         | 
| 245 | 
            +
                    else:
         | 
| 246 | 
            +
                        h = h + emb_out
         | 
| 247 | 
            +
                        h = self.out_layers(h)
         | 
| 248 | 
            +
                    h = self.skip_connection(x) + h
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    if self.use_temporal_conv and batch_size:
         | 
| 251 | 
            +
                        h = rearrange(h, "(b t) c h w -> b c t h w", b=batch_size)
         | 
| 252 | 
            +
                        h = self.temopral_conv(h)
         | 
| 253 | 
            +
                        h = rearrange(h, "b c t h w -> (b t) c h w")
         | 
| 254 | 
            +
                    return h
         | 
| 255 | 
            +
             | 
| 256 | 
            +
             | 
| 257 | 
            +
            class TemporalConvBlock(nn.Module):
         | 
| 258 | 
            +
                """
         | 
| 259 | 
            +
                Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
         | 
| 260 | 
            +
                """
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                def __init__(
         | 
| 263 | 
            +
                    self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False
         | 
| 264 | 
            +
                ):
         | 
| 265 | 
            +
                    super(TemporalConvBlock, self).__init__()
         | 
| 266 | 
            +
                    if out_channels is None:
         | 
| 267 | 
            +
                        out_channels = in_channels
         | 
| 268 | 
            +
                    self.in_channels = in_channels
         | 
| 269 | 
            +
                    self.out_channels = out_channels
         | 
| 270 | 
            +
                    kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 3)
         | 
| 271 | 
            +
                    padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 1)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    # conv layers
         | 
| 274 | 
            +
                    self.conv1 = nn.Sequential(
         | 
| 275 | 
            +
                        nn.GroupNorm(32, in_channels),
         | 
| 276 | 
            +
                        nn.SiLU(),
         | 
| 277 | 
            +
                        nn.Conv3d(in_channels, out_channels, kernel_shape, padding=padding_shape),
         | 
| 278 | 
            +
                    )
         | 
| 279 | 
            +
                    self.conv2 = nn.Sequential(
         | 
| 280 | 
            +
                        nn.GroupNorm(32, out_channels),
         | 
| 281 | 
            +
                        nn.SiLU(),
         | 
| 282 | 
            +
                        nn.Dropout(dropout),
         | 
| 283 | 
            +
                        nn.Conv3d(out_channels, in_channels, kernel_shape, padding=padding_shape),
         | 
| 284 | 
            +
                    )
         | 
| 285 | 
            +
                    self.conv3 = nn.Sequential(
         | 
| 286 | 
            +
                        nn.GroupNorm(32, out_channels),
         | 
| 287 | 
            +
                        nn.SiLU(),
         | 
| 288 | 
            +
                        nn.Dropout(dropout),
         | 
| 289 | 
            +
                        nn.Conv3d(out_channels, in_channels, (3, 1, 1), padding=(1, 0, 0)),
         | 
| 290 | 
            +
                    )
         | 
| 291 | 
            +
                    self.conv4 = nn.Sequential(
         | 
| 292 | 
            +
                        nn.GroupNorm(32, out_channels),
         | 
| 293 | 
            +
                        nn.SiLU(),
         | 
| 294 | 
            +
                        nn.Dropout(dropout),
         | 
| 295 | 
            +
                        nn.Conv3d(out_channels, in_channels, (3, 1, 1), padding=(1, 0, 0)),
         | 
| 296 | 
            +
                    )
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    # zero out the last layer params,so the conv block is identity
         | 
| 299 | 
            +
                    nn.init.zeros_(self.conv4[-1].weight)
         | 
| 300 | 
            +
                    nn.init.zeros_(self.conv4[-1].bias)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                def forward(self, x):
         | 
| 303 | 
            +
                    identity = x
         | 
| 304 | 
            +
                    x = self.conv1(x)
         | 
| 305 | 
            +
                    x = self.conv2(x)
         | 
| 306 | 
            +
                    x = self.conv3(x)
         | 
| 307 | 
            +
                    x = self.conv4(x)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    return x + identity
         | 
| 310 | 
            +
             | 
| 311 | 
            +
             | 
| 312 | 
            +
            class UNetModel(nn.Module):
         | 
| 313 | 
            +
                """
         | 
| 314 | 
            +
                The full UNet model with attention and timestep embedding.
         | 
| 315 | 
            +
                :param in_channels: in_channels in the input Tensor.
         | 
| 316 | 
            +
                :param model_channels: base channel count for the model.
         | 
| 317 | 
            +
                :param out_channels: channels in the output Tensor.
         | 
| 318 | 
            +
                :param num_res_blocks: number of residual blocks per downsample.
         | 
| 319 | 
            +
                :param attention_resolutions: a collection of downsample rates at which
         | 
| 320 | 
            +
                    attention will take place. May be a set, list, or tuple.
         | 
| 321 | 
            +
                    For example, if this contains 4, then at 4x downsampling, attention
         | 
| 322 | 
            +
                    will be used.
         | 
| 323 | 
            +
                :param dropout: the dropout probability.
         | 
| 324 | 
            +
                :param channel_mult: channel multiplier for each level of the UNet.
         | 
| 325 | 
            +
                :param conv_resample: if True, use learned convolutions for upsampling and
         | 
| 326 | 
            +
                    downsampling.
         | 
| 327 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D.
         | 
| 328 | 
            +
                :param num_classes: if specified (as an int), then this model will be
         | 
| 329 | 
            +
                    class-conditional with `num_classes` classes.
         | 
| 330 | 
            +
                :param use_checkpoint: use gradient checkpointing to reduce memory usage.
         | 
| 331 | 
            +
                :param num_heads: the number of attention heads in each attention layer.
         | 
| 332 | 
            +
                :param num_heads_channels: if specified, ignore num_heads and instead use
         | 
| 333 | 
            +
                                           a fixed channel width per attention head.
         | 
| 334 | 
            +
                :param num_heads_upsample: works with num_heads to set a different number
         | 
| 335 | 
            +
                                           of heads for upsampling. Deprecated.
         | 
| 336 | 
            +
                :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
         | 
| 337 | 
            +
                :param resblock_updown: use residual blocks for up/downsampling.
         | 
| 338 | 
            +
                """
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                def __init__(
         | 
| 341 | 
            +
                    self,
         | 
| 342 | 
            +
                    in_channels,
         | 
| 343 | 
            +
                    model_channels,
         | 
| 344 | 
            +
                    out_channels,
         | 
| 345 | 
            +
                    num_res_blocks,
         | 
| 346 | 
            +
                    attention_resolutions,
         | 
| 347 | 
            +
                    dropout=0.0,
         | 
| 348 | 
            +
                    channel_mult=(1, 2, 4, 8),
         | 
| 349 | 
            +
                    conv_resample=True,
         | 
| 350 | 
            +
                    dims=2,
         | 
| 351 | 
            +
                    context_dim=None,
         | 
| 352 | 
            +
                    use_scale_shift_norm=False,
         | 
| 353 | 
            +
                    resblock_updown=False,
         | 
| 354 | 
            +
                    num_heads=-1,
         | 
| 355 | 
            +
                    num_head_channels=-1,
         | 
| 356 | 
            +
                    transformer_depth=1,
         | 
| 357 | 
            +
                    use_linear=False,
         | 
| 358 | 
            +
                    use_checkpoint=False,
         | 
| 359 | 
            +
                    temporal_conv=False,
         | 
| 360 | 
            +
                    tempspatial_aware=False,
         | 
| 361 | 
            +
                    temporal_attention=True,
         | 
| 362 | 
            +
                    temporal_selfatt_only=True,
         | 
| 363 | 
            +
                    use_relative_position=True,
         | 
| 364 | 
            +
                    use_causal_attention=False,
         | 
| 365 | 
            +
                    temporal_length=None,
         | 
| 366 | 
            +
                    use_fp16=False,
         | 
| 367 | 
            +
                    addition_attention=False,
         | 
| 368 | 
            +
                    use_image_attention=False,
         | 
| 369 | 
            +
                    temporal_transformer_depth=1,
         | 
| 370 | 
            +
                    fps_cond=False,
         | 
| 371 | 
            +
                    time_cond_proj_dim=None,
         | 
| 372 | 
            +
                    motion_cond_proj_dim=None,
         | 
| 373 | 
            +
                    record_attn_probs=False,
         | 
| 374 | 
            +
                ):
         | 
| 375 | 
            +
                    super(UNetModel, self).__init__()
         | 
| 376 | 
            +
                    if num_heads == -1:
         | 
| 377 | 
            +
                        assert (
         | 
| 378 | 
            +
                            num_head_channels != -1
         | 
| 379 | 
            +
                        ), "Either num_heads or num_head_channels has to be set"
         | 
| 380 | 
            +
                    if num_head_channels == -1:
         | 
| 381 | 
            +
                        assert (
         | 
| 382 | 
            +
                            num_heads != -1
         | 
| 383 | 
            +
                        ), "Either num_heads or num_head_channels has to be set"
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    self.in_channels = in_channels
         | 
| 386 | 
            +
                    self.model_channels = model_channels
         | 
| 387 | 
            +
                    self.out_channels = out_channels
         | 
| 388 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 389 | 
            +
                    self.attention_resolutions = attention_resolutions
         | 
| 390 | 
            +
                    self.dropout = dropout
         | 
| 391 | 
            +
                    self.channel_mult = channel_mult
         | 
| 392 | 
            +
                    self.conv_resample = conv_resample
         | 
| 393 | 
            +
                    self.temporal_attention = temporal_attention
         | 
| 394 | 
            +
                    time_embed_dim = model_channels * 4
         | 
| 395 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 396 | 
            +
                    self.dtype = torch.float16 if use_fp16 else torch.float32
         | 
| 397 | 
            +
                    self.addition_attention = addition_attention
         | 
| 398 | 
            +
                    self.use_image_attention = use_image_attention
         | 
| 399 | 
            +
                    self.fps_cond = fps_cond
         | 
| 400 | 
            +
                    self.time_cond_proj_dim = time_cond_proj_dim
         | 
| 401 | 
            +
                    self.motion_cond_proj_dim = motion_cond_proj_dim
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                    self.time_embed = nn.Sequential(
         | 
| 404 | 
            +
                        linear(model_channels, time_embed_dim),
         | 
| 405 | 
            +
                        nn.SiLU(),
         | 
| 406 | 
            +
                        linear(time_embed_dim, time_embed_dim),
         | 
| 407 | 
            +
                    )
         | 
| 408 | 
            +
                    if self.fps_cond:
         | 
| 409 | 
            +
                        self.fps_embedding = nn.Sequential(
         | 
| 410 | 
            +
                            linear(model_channels, time_embed_dim),
         | 
| 411 | 
            +
                            nn.SiLU(),
         | 
| 412 | 
            +
                            linear(time_embed_dim, time_embed_dim),
         | 
| 413 | 
            +
                        )
         | 
| 414 | 
            +
                    if time_cond_proj_dim is not None:
         | 
| 415 | 
            +
                        self.time_cond_proj = nn.Linear(
         | 
| 416 | 
            +
                            time_cond_proj_dim, model_channels, bias=False
         | 
| 417 | 
            +
                        )
         | 
| 418 | 
            +
                    else:
         | 
| 419 | 
            +
                        self.time_cond_proj = None
         | 
| 420 | 
            +
                    
         | 
| 421 | 
            +
                    if motion_cond_proj_dim is not None:
         | 
| 422 | 
            +
                        self.motion_cond_proj = nn.Linear(
         | 
| 423 | 
            +
                            motion_cond_proj_dim, model_channels, bias=False
         | 
| 424 | 
            +
                        )
         | 
| 425 | 
            +
                        self.combine_proj = nn.Linear(
         | 
| 426 | 
            +
                            model_channels * 2, model_channels, bias=False
         | 
| 427 | 
            +
                        )
         | 
| 428 | 
            +
                    else:
         | 
| 429 | 
            +
                        self.motion_cond_proj = None
         | 
| 430 | 
            +
                        self.combine_proj = None
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    self.input_blocks = nn.ModuleList(
         | 
| 433 | 
            +
                        [
         | 
| 434 | 
            +
                            TimestepEmbedSequential(
         | 
| 435 | 
            +
                                conv_nd(dims, in_channels, model_channels, 3, padding=1)
         | 
| 436 | 
            +
                            )
         | 
| 437 | 
            +
                        ]
         | 
| 438 | 
            +
                    )
         | 
| 439 | 
            +
                    if self.addition_attention:
         | 
| 440 | 
            +
                        self.init_attn = TimestepEmbedSequential(
         | 
| 441 | 
            +
                            TemporalTransformer(
         | 
| 442 | 
            +
                                model_channels,
         | 
| 443 | 
            +
                                n_heads=8,
         | 
| 444 | 
            +
                                d_head=num_head_channels,
         | 
| 445 | 
            +
                                depth=transformer_depth,
         | 
| 446 | 
            +
                                context_dim=context_dim,
         | 
| 447 | 
            +
                                use_checkpoint=use_checkpoint,
         | 
| 448 | 
            +
                                only_self_att=temporal_selfatt_only,
         | 
| 449 | 
            +
                                causal_attention=use_causal_attention,
         | 
| 450 | 
            +
                                relative_position=use_relative_position,
         | 
| 451 | 
            +
                                temporal_length=temporal_length,
         | 
| 452 | 
            +
                            )
         | 
| 453 | 
            +
                        )
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    input_block_chans = [model_channels]
         | 
| 456 | 
            +
                    ch = model_channels
         | 
| 457 | 
            +
                    ds = 1
         | 
| 458 | 
            +
                    for level, mult in enumerate(channel_mult):
         | 
| 459 | 
            +
                        for _ in range(num_res_blocks):
         | 
| 460 | 
            +
                            layers = [
         | 
| 461 | 
            +
                                ResBlock(
         | 
| 462 | 
            +
                                    ch,
         | 
| 463 | 
            +
                                    time_embed_dim,
         | 
| 464 | 
            +
                                    dropout,
         | 
| 465 | 
            +
                                    out_channels=mult * model_channels,
         | 
| 466 | 
            +
                                    dims=dims,
         | 
| 467 | 
            +
                                    use_checkpoint=use_checkpoint,
         | 
| 468 | 
            +
                                    use_scale_shift_norm=use_scale_shift_norm,
         | 
| 469 | 
            +
                                    tempspatial_aware=tempspatial_aware,
         | 
| 470 | 
            +
                                    use_temporal_conv=temporal_conv,
         | 
| 471 | 
            +
                                )
         | 
| 472 | 
            +
                            ]
         | 
| 473 | 
            +
                            ch = mult * model_channels
         | 
| 474 | 
            +
                            if ds in attention_resolutions:
         | 
| 475 | 
            +
                                if num_head_channels == -1:
         | 
| 476 | 
            +
                                    dim_head = ch // num_heads
         | 
| 477 | 
            +
                                else:
         | 
| 478 | 
            +
                                    num_heads = ch // num_head_channels
         | 
| 479 | 
            +
                                    dim_head = num_head_channels
         | 
| 480 | 
            +
                                layers.append(
         | 
| 481 | 
            +
                                    SpatialTransformer(
         | 
| 482 | 
            +
                                        ch,
         | 
| 483 | 
            +
                                        num_heads,
         | 
| 484 | 
            +
                                        dim_head,
         | 
| 485 | 
            +
                                        depth=transformer_depth,
         | 
| 486 | 
            +
                                        context_dim=context_dim,
         | 
| 487 | 
            +
                                        use_linear=use_linear,
         | 
| 488 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 489 | 
            +
                                        disable_self_attn=False,
         | 
| 490 | 
            +
                                        img_cross_attention=self.use_image_attention,
         | 
| 491 | 
            +
                                    )
         | 
| 492 | 
            +
                                )
         | 
| 493 | 
            +
                                if self.temporal_attention:
         | 
| 494 | 
            +
                                    layers.append(
         | 
| 495 | 
            +
                                        TemporalTransformer(
         | 
| 496 | 
            +
                                            ch,
         | 
| 497 | 
            +
                                            num_heads,
         | 
| 498 | 
            +
                                            dim_head,
         | 
| 499 | 
            +
                                            depth=temporal_transformer_depth,
         | 
| 500 | 
            +
                                            context_dim=context_dim,
         | 
| 501 | 
            +
                                            use_linear=use_linear,
         | 
| 502 | 
            +
                                            use_checkpoint=use_checkpoint,
         | 
| 503 | 
            +
                                            only_self_att=temporal_selfatt_only,
         | 
| 504 | 
            +
                                            causal_attention=use_causal_attention,
         | 
| 505 | 
            +
                                            relative_position=use_relative_position,
         | 
| 506 | 
            +
                                            temporal_length=temporal_length,
         | 
| 507 | 
            +
                                        )
         | 
| 508 | 
            +
                                    )
         | 
| 509 | 
            +
                            self.input_blocks.append(TimestepEmbedSequential(*layers))
         | 
| 510 | 
            +
                            input_block_chans.append(ch)
         | 
| 511 | 
            +
                        if level != len(channel_mult) - 1:
         | 
| 512 | 
            +
                            out_ch = ch
         | 
| 513 | 
            +
                            self.input_blocks.append(
         | 
| 514 | 
            +
                                TimestepEmbedSequential(
         | 
| 515 | 
            +
                                    ResBlock(
         | 
| 516 | 
            +
                                        ch,
         | 
| 517 | 
            +
                                        time_embed_dim,
         | 
| 518 | 
            +
                                        dropout,
         | 
| 519 | 
            +
                                        out_channels=out_ch,
         | 
| 520 | 
            +
                                        dims=dims,
         | 
| 521 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 522 | 
            +
                                        use_scale_shift_norm=use_scale_shift_norm,
         | 
| 523 | 
            +
                                        down=True,
         | 
| 524 | 
            +
                                    )
         | 
| 525 | 
            +
                                    if resblock_updown
         | 
| 526 | 
            +
                                    else Downsample(
         | 
| 527 | 
            +
                                        ch, conv_resample, dims=dims, out_channels=out_ch
         | 
| 528 | 
            +
                                    )
         | 
| 529 | 
            +
                                )
         | 
| 530 | 
            +
                            )
         | 
| 531 | 
            +
                            ch = out_ch
         | 
| 532 | 
            +
                            input_block_chans.append(ch)
         | 
| 533 | 
            +
                            ds *= 2
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    if num_head_channels == -1:
         | 
| 536 | 
            +
                        dim_head = ch // num_heads
         | 
| 537 | 
            +
                    else:
         | 
| 538 | 
            +
                        num_heads = ch // num_head_channels
         | 
| 539 | 
            +
                        dim_head = num_head_channels
         | 
| 540 | 
            +
                    layers = [
         | 
| 541 | 
            +
                        ResBlock(
         | 
| 542 | 
            +
                            ch,
         | 
| 543 | 
            +
                            time_embed_dim,
         | 
| 544 | 
            +
                            dropout,
         | 
| 545 | 
            +
                            dims=dims,
         | 
| 546 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 547 | 
            +
                            use_scale_shift_norm=use_scale_shift_norm,
         | 
| 548 | 
            +
                            tempspatial_aware=tempspatial_aware,
         | 
| 549 | 
            +
                            use_temporal_conv=temporal_conv,
         | 
| 550 | 
            +
                        ),
         | 
| 551 | 
            +
                        SpatialTransformer(
         | 
| 552 | 
            +
                            ch,
         | 
| 553 | 
            +
                            num_heads,
         | 
| 554 | 
            +
                            dim_head,
         | 
| 555 | 
            +
                            depth=transformer_depth,
         | 
| 556 | 
            +
                            context_dim=context_dim,
         | 
| 557 | 
            +
                            use_linear=use_linear,
         | 
| 558 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 559 | 
            +
                            disable_self_attn=False,
         | 
| 560 | 
            +
                            img_cross_attention=self.use_image_attention,
         | 
| 561 | 
            +
                        ),
         | 
| 562 | 
            +
                    ]
         | 
| 563 | 
            +
                    if self.temporal_attention:
         | 
| 564 | 
            +
                        layers.append(
         | 
| 565 | 
            +
                            TemporalTransformer(
         | 
| 566 | 
            +
                                ch,
         | 
| 567 | 
            +
                                num_heads,
         | 
| 568 | 
            +
                                dim_head,
         | 
| 569 | 
            +
                                depth=temporal_transformer_depth,
         | 
| 570 | 
            +
                                context_dim=context_dim,
         | 
| 571 | 
            +
                                use_linear=use_linear,
         | 
| 572 | 
            +
                                use_checkpoint=use_checkpoint,
         | 
| 573 | 
            +
                                only_self_att=temporal_selfatt_only,
         | 
| 574 | 
            +
                                causal_attention=use_causal_attention,
         | 
| 575 | 
            +
                                relative_position=use_relative_position,
         | 
| 576 | 
            +
                                temporal_length=temporal_length,
         | 
| 577 | 
            +
                            )
         | 
| 578 | 
            +
                        )
         | 
| 579 | 
            +
                    layers.append(
         | 
| 580 | 
            +
                        ResBlock(
         | 
| 581 | 
            +
                            ch,
         | 
| 582 | 
            +
                            time_embed_dim,
         | 
| 583 | 
            +
                            dropout,
         | 
| 584 | 
            +
                            dims=dims,
         | 
| 585 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 586 | 
            +
                            use_scale_shift_norm=use_scale_shift_norm,
         | 
| 587 | 
            +
                            tempspatial_aware=tempspatial_aware,
         | 
| 588 | 
            +
                            use_temporal_conv=temporal_conv,
         | 
| 589 | 
            +
                        )
         | 
| 590 | 
            +
                    )
         | 
| 591 | 
            +
                    self.middle_block = TimestepEmbedSequential(*layers)
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                    self.output_blocks = nn.ModuleList([])
         | 
| 594 | 
            +
                    for level, mult in list(enumerate(channel_mult))[::-1]:
         | 
| 595 | 
            +
                        for i in range(num_res_blocks + 1):
         | 
| 596 | 
            +
                            ich = input_block_chans.pop()
         | 
| 597 | 
            +
                            layers = [
         | 
| 598 | 
            +
                                ResBlock(
         | 
| 599 | 
            +
                                    ch + ich,
         | 
| 600 | 
            +
                                    time_embed_dim,
         | 
| 601 | 
            +
                                    dropout,
         | 
| 602 | 
            +
                                    out_channels=mult * model_channels,
         | 
| 603 | 
            +
                                    dims=dims,
         | 
| 604 | 
            +
                                    use_checkpoint=use_checkpoint,
         | 
| 605 | 
            +
                                    use_scale_shift_norm=use_scale_shift_norm,
         | 
| 606 | 
            +
                                    tempspatial_aware=tempspatial_aware,
         | 
| 607 | 
            +
                                    use_temporal_conv=temporal_conv,
         | 
| 608 | 
            +
                                )
         | 
| 609 | 
            +
                            ]
         | 
| 610 | 
            +
                            ch = model_channels * mult
         | 
| 611 | 
            +
                            if ds in attention_resolutions:
         | 
| 612 | 
            +
                                if num_head_channels == -1:
         | 
| 613 | 
            +
                                    dim_head = ch // num_heads
         | 
| 614 | 
            +
                                else:
         | 
| 615 | 
            +
                                    num_heads = ch // num_head_channels
         | 
| 616 | 
            +
                                    dim_head = num_head_channels
         | 
| 617 | 
            +
                                layers.append(
         | 
| 618 | 
            +
                                    SpatialTransformer(
         | 
| 619 | 
            +
                                        ch,
         | 
| 620 | 
            +
                                        num_heads,
         | 
| 621 | 
            +
                                        dim_head,
         | 
| 622 | 
            +
                                        depth=transformer_depth,
         | 
| 623 | 
            +
                                        context_dim=context_dim,
         | 
| 624 | 
            +
                                        use_linear=use_linear,
         | 
| 625 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 626 | 
            +
                                        disable_self_attn=False,
         | 
| 627 | 
            +
                                        img_cross_attention=self.use_image_attention,
         | 
| 628 | 
            +
                                    )
         | 
| 629 | 
            +
                                )
         | 
| 630 | 
            +
                                if self.temporal_attention:
         | 
| 631 | 
            +
                                    layers.append(
         | 
| 632 | 
            +
                                        TemporalTransformer(
         | 
| 633 | 
            +
                                            ch,
         | 
| 634 | 
            +
                                            num_heads,
         | 
| 635 | 
            +
                                            dim_head,
         | 
| 636 | 
            +
                                            depth=temporal_transformer_depth,
         | 
| 637 | 
            +
                                            context_dim=context_dim,
         | 
| 638 | 
            +
                                            use_linear=use_linear,
         | 
| 639 | 
            +
                                            use_checkpoint=use_checkpoint,
         | 
| 640 | 
            +
                                            only_self_att=temporal_selfatt_only,
         | 
| 641 | 
            +
                                            causal_attention=use_causal_attention,
         | 
| 642 | 
            +
                                            relative_position=use_relative_position,
         | 
| 643 | 
            +
                                            temporal_length=temporal_length,
         | 
| 644 | 
            +
                                            record_attn_probs=record_attn_probs,
         | 
| 645 | 
            +
                                        )
         | 
| 646 | 
            +
                                    )
         | 
| 647 | 
            +
                            if level and i == num_res_blocks:
         | 
| 648 | 
            +
                                out_ch = ch
         | 
| 649 | 
            +
                                layers.append(
         | 
| 650 | 
            +
                                    ResBlock(
         | 
| 651 | 
            +
                                        ch,
         | 
| 652 | 
            +
                                        time_embed_dim,
         | 
| 653 | 
            +
                                        dropout,
         | 
| 654 | 
            +
                                        out_channels=out_ch,
         | 
| 655 | 
            +
                                        dims=dims,
         | 
| 656 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 657 | 
            +
                                        use_scale_shift_norm=use_scale_shift_norm,
         | 
| 658 | 
            +
                                        up=True,
         | 
| 659 | 
            +
                                    )
         | 
| 660 | 
            +
                                    if resblock_updown
         | 
| 661 | 
            +
                                    else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
         | 
| 662 | 
            +
                                )
         | 
| 663 | 
            +
                                ds //= 2
         | 
| 664 | 
            +
                            self.output_blocks.append(TimestepEmbedSequential(*layers))
         | 
| 665 | 
            +
             | 
| 666 | 
            +
                    self.out = nn.Sequential(
         | 
| 667 | 
            +
                        normalization(ch),
         | 
| 668 | 
            +
                        nn.SiLU(),
         | 
| 669 | 
            +
                        zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
         | 
| 670 | 
            +
                    )
         | 
| 671 | 
            +
             | 
| 672 | 
            +
                def forward(
         | 
| 673 | 
            +
                    self,
         | 
| 674 | 
            +
                    x,
         | 
| 675 | 
            +
                    timesteps,
         | 
| 676 | 
            +
                    context=None,
         | 
| 677 | 
            +
                    features_adapter=None,
         | 
| 678 | 
            +
                    fps=16,
         | 
| 679 | 
            +
                    timestep_cond=None,
         | 
| 680 | 
            +
                    motion_cond=None,
         | 
| 681 | 
            +
                    **kwargs
         | 
| 682 | 
            +
                ):
         | 
| 683 | 
            +
                    t_emb = timestep_embedding(
         | 
| 684 | 
            +
                        timesteps, self.model_channels, repeat_only=False
         | 
| 685 | 
            +
                    ).to(self.dtype)
         | 
| 686 | 
            +
                    if timestep_cond is not None:
         | 
| 687 | 
            +
                        timestep_cond_embed = self.time_cond_proj(timestep_cond)
         | 
| 688 | 
            +
                    else:
         | 
| 689 | 
            +
                        timestep_cond_embed = 0.
         | 
| 690 | 
            +
                    if motion_cond is not None:
         | 
| 691 | 
            +
                        assert timestep_cond is not None
         | 
| 692 | 
            +
                        motion_cond_emb = self.motion_cond_proj(motion_cond)
         | 
| 693 | 
            +
                        combined_cond_emb = self.combine_proj(
         | 
| 694 | 
            +
                            torch.cat([timestep_cond_embed, motion_cond_emb], dim=1)
         | 
| 695 | 
            +
                        )
         | 
| 696 | 
            +
                    else:
         | 
| 697 | 
            +
                        combined_cond_emb = timestep_cond_embed
         | 
| 698 | 
            +
                    emb = self.time_embed(t_emb + combined_cond_emb)
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                    if self.fps_cond:
         | 
| 701 | 
            +
                        if type(fps) == int:
         | 
| 702 | 
            +
                            fps = torch.full_like(timesteps, fps)
         | 
| 703 | 
            +
                        fps_emb = timestep_embedding(
         | 
| 704 | 
            +
                            fps, self.model_channels, repeat_only=False
         | 
| 705 | 
            +
                        ).to(self.dtype)
         | 
| 706 | 
            +
                        emb += self.fps_embedding(fps_emb)
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                    b, _, t, _, _ = x.shape
         | 
| 709 | 
            +
                    ## repeat t times for context [(b t) 77 768] & time embedding
         | 
| 710 | 
            +
                    context = context.repeat_interleave(repeats=t, dim=0)
         | 
| 711 | 
            +
                    emb = emb.repeat_interleave(repeats=t, dim=0)
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                    ## always in shape (b t) c h w, except for temporal layer
         | 
| 714 | 
            +
                    x = rearrange(x, "b c t h w -> (b t) c h w")
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                    h = x.type(self.dtype)
         | 
| 717 | 
            +
                    adapter_idx = 0
         | 
| 718 | 
            +
                    hs = []
         | 
| 719 | 
            +
                    for id, module in enumerate(self.input_blocks):
         | 
| 720 | 
            +
                        h = module(h, emb, context=context, batch_size=b)
         | 
| 721 | 
            +
                        if id == 0 and self.addition_attention:
         | 
| 722 | 
            +
                            h = self.init_attn(h, emb, context=context, batch_size=b)
         | 
| 723 | 
            +
                        ## plug-in adapter features
         | 
| 724 | 
            +
                        if ((id + 1) % 3 == 0) and features_adapter is not None:
         | 
| 725 | 
            +
                            h = h + features_adapter[adapter_idx]
         | 
| 726 | 
            +
                            adapter_idx += 1
         | 
| 727 | 
            +
                        hs.append(h)
         | 
| 728 | 
            +
                    if features_adapter is not None:
         | 
| 729 | 
            +
                        assert len(features_adapter) == adapter_idx, "Wrong features_adapter"
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                    h = self.middle_block(h, emb, context=context, batch_size=b)
         | 
| 732 | 
            +
                    for module in self.output_blocks:
         | 
| 733 | 
            +
                        h = torch.cat([h, hs.pop()], dim=1)
         | 
| 734 | 
            +
                        h = module(h, emb, context=context, batch_size=b)
         | 
| 735 | 
            +
                    h = h.type(x.dtype)
         | 
| 736 | 
            +
                    y = self.out(h)
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                    # reshape back to (b c t h w)
         | 
| 739 | 
            +
                    y = rearrange(y, "(b t) c h w -> b c t h w", b=b)
         | 
| 740 | 
            +
                    return y
         | 
    	
        lvdm/modules/x_transformer.py
    ADDED
    
    | @@ -0,0 +1,704 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from functools import partial
         | 
| 4 | 
            +
            from inspect import isfunction
         | 
| 5 | 
            +
            from collections import namedtuple
         | 
| 6 | 
            +
            from einops import rearrange, repeat
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from torch import nn, einsum
         | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # constants
         | 
| 12 | 
            +
            DEFAULT_DIM_HEAD = 64
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"])
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class AbsolutePositionalEmbedding(nn.Module):
         | 
| 20 | 
            +
                def __init__(self, dim, max_seq_len):
         | 
| 21 | 
            +
                    super().__init__()
         | 
| 22 | 
            +
                    self.emb = nn.Embedding(max_seq_len, dim)
         | 
| 23 | 
            +
                    self.init_()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def init_(self):
         | 
| 26 | 
            +
                    nn.init.normal_(self.emb.weight, std=0.02)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def forward(self, x):
         | 
| 29 | 
            +
                    n = torch.arange(x.shape[1], device=x.device)
         | 
| 30 | 
            +
                    return self.emb(n)[None, :, :]
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class FixedPositionalEmbedding(nn.Module):
         | 
| 34 | 
            +
                def __init__(self, dim):
         | 
| 35 | 
            +
                    super().__init__()
         | 
| 36 | 
            +
                    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
         | 
| 37 | 
            +
                    self.register_buffer("inv_freq", inv_freq)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def forward(self, x, seq_dim=1, offset=0):
         | 
| 40 | 
            +
                    t = (
         | 
| 41 | 
            +
                        torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
         | 
| 42 | 
            +
                        + offset
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
         | 
| 45 | 
            +
                    emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
         | 
| 46 | 
            +
                    return emb[None, :, :]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            # helpers
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def exists(val):
         | 
| 53 | 
            +
                return val is not None
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def default(val, d):
         | 
| 57 | 
            +
                if exists(val):
         | 
| 58 | 
            +
                    return val
         | 
| 59 | 
            +
                return d() if isfunction(d) else d
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def always(val):
         | 
| 63 | 
            +
                def inner(*args, **kwargs):
         | 
| 64 | 
            +
                    return val
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return inner
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def not_equals(val):
         | 
| 70 | 
            +
                def inner(x):
         | 
| 71 | 
            +
                    return x != val
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                return inner
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def equals(val):
         | 
| 77 | 
            +
                def inner(x):
         | 
| 78 | 
            +
                    return x == val
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                return inner
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def max_neg_value(tensor):
         | 
| 84 | 
            +
                return -torch.finfo(tensor.dtype).max
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            # keyword argument helpers
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def pick_and_pop(keys, d):
         | 
| 91 | 
            +
                values = list(map(lambda key: d.pop(key), keys))
         | 
| 92 | 
            +
                return dict(zip(keys, values))
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def group_dict_by_key(cond, d):
         | 
| 96 | 
            +
                return_val = [dict(), dict()]
         | 
| 97 | 
            +
                for key in d.keys():
         | 
| 98 | 
            +
                    match = bool(cond(key))
         | 
| 99 | 
            +
                    ind = int(not match)
         | 
| 100 | 
            +
                    return_val[ind][key] = d[key]
         | 
| 101 | 
            +
                return (*return_val,)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def string_begins_with(prefix, str):
         | 
| 105 | 
            +
                return str.startswith(prefix)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def group_by_key_prefix(prefix, d):
         | 
| 109 | 
            +
                return group_dict_by_key(partial(string_begins_with, prefix), d)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            def groupby_prefix_and_trim(prefix, d):
         | 
| 113 | 
            +
                kwargs_with_prefix, kwargs = group_dict_by_key(
         | 
| 114 | 
            +
                    partial(string_begins_with, prefix), d
         | 
| 115 | 
            +
                )
         | 
| 116 | 
            +
                kwargs_without_prefix = dict(
         | 
| 117 | 
            +
                    map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
                return kwargs_without_prefix, kwargs
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            # classes
         | 
| 123 | 
            +
            class Scale(nn.Module):
         | 
| 124 | 
            +
                def __init__(self, value, fn):
         | 
| 125 | 
            +
                    super().__init__()
         | 
| 126 | 
            +
                    self.value = value
         | 
| 127 | 
            +
                    self.fn = fn
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def forward(self, x, **kwargs):
         | 
| 130 | 
            +
                    x, *rest = self.fn(x, **kwargs)
         | 
| 131 | 
            +
                    return (x * self.value, *rest)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            class Rezero(nn.Module):
         | 
| 135 | 
            +
                def __init__(self, fn):
         | 
| 136 | 
            +
                    super().__init__()
         | 
| 137 | 
            +
                    self.fn = fn
         | 
| 138 | 
            +
                    self.g = nn.Parameter(torch.zeros(1))
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def forward(self, x, **kwargs):
         | 
| 141 | 
            +
                    x, *rest = self.fn(x, **kwargs)
         | 
| 142 | 
            +
                    return (x * self.g, *rest)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            class ScaleNorm(nn.Module):
         | 
| 146 | 
            +
                def __init__(self, dim, eps=1e-5):
         | 
| 147 | 
            +
                    super().__init__()
         | 
| 148 | 
            +
                    self.scale = dim**-0.5
         | 
| 149 | 
            +
                    self.eps = eps
         | 
| 150 | 
            +
                    self.g = nn.Parameter(torch.ones(1))
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def forward(self, x):
         | 
| 153 | 
            +
                    norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
         | 
| 154 | 
            +
                    return x / norm.clamp(min=self.eps) * self.g
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            class RMSNorm(nn.Module):
         | 
| 158 | 
            +
                def __init__(self, dim, eps=1e-8):
         | 
| 159 | 
            +
                    super().__init__()
         | 
| 160 | 
            +
                    self.scale = dim**-0.5
         | 
| 161 | 
            +
                    self.eps = eps
         | 
| 162 | 
            +
                    self.g = nn.Parameter(torch.ones(dim))
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def forward(self, x):
         | 
| 165 | 
            +
                    norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
         | 
| 166 | 
            +
                    return x / norm.clamp(min=self.eps) * self.g
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            class Residual(nn.Module):
         | 
| 170 | 
            +
                def forward(self, x, residual):
         | 
| 171 | 
            +
                    return x + residual
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            class GRUGating(nn.Module):
         | 
| 175 | 
            +
                def __init__(self, dim):
         | 
| 176 | 
            +
                    super().__init__()
         | 
| 177 | 
            +
                    self.gru = nn.GRUCell(dim, dim)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def forward(self, x, residual):
         | 
| 180 | 
            +
                    gated_output = self.gru(
         | 
| 181 | 
            +
                        rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")
         | 
| 182 | 
            +
                    )
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    return gated_output.reshape_as(x)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            # feedforward
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            class GEGLU(nn.Module):
         | 
| 191 | 
            +
                def __init__(self, dim_in, dim_out):
         | 
| 192 | 
            +
                    super().__init__()
         | 
| 193 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out * 2)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                def forward(self, x):
         | 
| 196 | 
            +
                    x, gate = self.proj(x).chunk(2, dim=-1)
         | 
| 197 | 
            +
                    return x * F.gelu(gate)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            class FeedForward(nn.Module):
         | 
| 201 | 
            +
                def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
         | 
| 202 | 
            +
                    super().__init__()
         | 
| 203 | 
            +
                    inner_dim = int(dim * mult)
         | 
| 204 | 
            +
                    dim_out = default(dim_out, dim)
         | 
| 205 | 
            +
                    project_in = (
         | 
| 206 | 
            +
                        nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
         | 
| 207 | 
            +
                        if not glu
         | 
| 208 | 
            +
                        else GEGLU(dim, inner_dim)
         | 
| 209 | 
            +
                    )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    self.net = nn.Sequential(
         | 
| 212 | 
            +
                        project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
         | 
| 213 | 
            +
                    )
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def forward(self, x):
         | 
| 216 | 
            +
                    return self.net(x)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
            # attention.
         | 
| 220 | 
            +
            class Attention(nn.Module):
         | 
| 221 | 
            +
                def __init__(
         | 
| 222 | 
            +
                    self,
         | 
| 223 | 
            +
                    dim,
         | 
| 224 | 
            +
                    dim_head=DEFAULT_DIM_HEAD,
         | 
| 225 | 
            +
                    heads=8,
         | 
| 226 | 
            +
                    causal=False,
         | 
| 227 | 
            +
                    mask=None,
         | 
| 228 | 
            +
                    talking_heads=False,
         | 
| 229 | 
            +
                    sparse_topk=None,
         | 
| 230 | 
            +
                    use_entmax15=False,
         | 
| 231 | 
            +
                    num_mem_kv=0,
         | 
| 232 | 
            +
                    dropout=0.0,
         | 
| 233 | 
            +
                    on_attn=False,
         | 
| 234 | 
            +
                ):
         | 
| 235 | 
            +
                    super().__init__()
         | 
| 236 | 
            +
                    if use_entmax15:
         | 
| 237 | 
            +
                        raise NotImplementedError(
         | 
| 238 | 
            +
                            "Check out entmax activation instead of softmax activation!"
         | 
| 239 | 
            +
                        )
         | 
| 240 | 
            +
                    self.scale = dim_head**-0.5
         | 
| 241 | 
            +
                    self.heads = heads
         | 
| 242 | 
            +
                    self.causal = causal
         | 
| 243 | 
            +
                    self.mask = mask
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    inner_dim = dim_head * heads
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    self.to_q = nn.Linear(dim, inner_dim, bias=False)
         | 
| 248 | 
            +
                    self.to_k = nn.Linear(dim, inner_dim, bias=False)
         | 
| 249 | 
            +
                    self.to_v = nn.Linear(dim, inner_dim, bias=False)
         | 
| 250 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    # talking heads
         | 
| 253 | 
            +
                    self.talking_heads = talking_heads
         | 
| 254 | 
            +
                    if talking_heads:
         | 
| 255 | 
            +
                        self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
         | 
| 256 | 
            +
                        self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    # explicit topk sparse attention
         | 
| 259 | 
            +
                    self.sparse_topk = sparse_topk
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    # entmax
         | 
| 262 | 
            +
                    # self.attn_fn = entmax15 if use_entmax15 else F.softmax
         | 
| 263 | 
            +
                    self.attn_fn = F.softmax
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    # add memory key / values
         | 
| 266 | 
            +
                    self.num_mem_kv = num_mem_kv
         | 
| 267 | 
            +
                    if num_mem_kv > 0:
         | 
| 268 | 
            +
                        self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
         | 
| 269 | 
            +
                        self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    # attention on attention
         | 
| 272 | 
            +
                    self.attn_on_attn = on_attn
         | 
| 273 | 
            +
                    self.to_out = (
         | 
| 274 | 
            +
                        nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
         | 
| 275 | 
            +
                        if on_attn
         | 
| 276 | 
            +
                        else nn.Linear(inner_dim, dim)
         | 
| 277 | 
            +
                    )
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def forward(
         | 
| 280 | 
            +
                    self,
         | 
| 281 | 
            +
                    x,
         | 
| 282 | 
            +
                    context=None,
         | 
| 283 | 
            +
                    mask=None,
         | 
| 284 | 
            +
                    context_mask=None,
         | 
| 285 | 
            +
                    rel_pos=None,
         | 
| 286 | 
            +
                    sinusoidal_emb=None,
         | 
| 287 | 
            +
                    prev_attn=None,
         | 
| 288 | 
            +
                    mem=None,
         | 
| 289 | 
            +
                ):
         | 
| 290 | 
            +
                    b, n, _, h, talking_heads, device = (
         | 
| 291 | 
            +
                        *x.shape,
         | 
| 292 | 
            +
                        self.heads,
         | 
| 293 | 
            +
                        self.talking_heads,
         | 
| 294 | 
            +
                        x.device,
         | 
| 295 | 
            +
                    )
         | 
| 296 | 
            +
                    kv_input = default(context, x)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    q_input = x
         | 
| 299 | 
            +
                    k_input = kv_input
         | 
| 300 | 
            +
                    v_input = kv_input
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    if exists(mem):
         | 
| 303 | 
            +
                        k_input = torch.cat((mem, k_input), dim=-2)
         | 
| 304 | 
            +
                        v_input = torch.cat((mem, v_input), dim=-2)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    if exists(sinusoidal_emb):
         | 
| 307 | 
            +
                        # in shortformer, the query would start at a position offset depending on the past cached memory
         | 
| 308 | 
            +
                        offset = k_input.shape[-2] - q_input.shape[-2]
         | 
| 309 | 
            +
                        q_input = q_input + sinusoidal_emb(q_input, offset=offset)
         | 
| 310 | 
            +
                        k_input = k_input + sinusoidal_emb(k_input)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    q = self.to_q(q_input)
         | 
| 313 | 
            +
                    k = self.to_k(k_input)
         | 
| 314 | 
            +
                    v = self.to_v(v_input)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    input_mask = None
         | 
| 319 | 
            +
                    if any(map(exists, (mask, context_mask))):
         | 
| 320 | 
            +
                        q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
         | 
| 321 | 
            +
                        k_mask = q_mask if not exists(context) else context_mask
         | 
| 322 | 
            +
                        k_mask = default(
         | 
| 323 | 
            +
                            k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()
         | 
| 324 | 
            +
                        )
         | 
| 325 | 
            +
                        q_mask = rearrange(q_mask, "b i -> b () i ()")
         | 
| 326 | 
            +
                        k_mask = rearrange(k_mask, "b j -> b () () j")
         | 
| 327 | 
            +
                        input_mask = q_mask * k_mask
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    if self.num_mem_kv > 0:
         | 
| 330 | 
            +
                        mem_k, mem_v = map(
         | 
| 331 | 
            +
                            lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)
         | 
| 332 | 
            +
                        )
         | 
| 333 | 
            +
                        k = torch.cat((mem_k, k), dim=-2)
         | 
| 334 | 
            +
                        v = torch.cat((mem_v, v), dim=-2)
         | 
| 335 | 
            +
                        if exists(input_mask):
         | 
| 336 | 
            +
                            input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
         | 
| 339 | 
            +
                    mask_value = max_neg_value(dots)
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    if exists(prev_attn):
         | 
| 342 | 
            +
                        dots = dots + prev_attn
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    pre_softmax_attn = dots
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    if talking_heads:
         | 
| 347 | 
            +
                        dots = einsum(
         | 
| 348 | 
            +
                            "b h i j, h k -> b k i j", dots, self.pre_softmax_proj
         | 
| 349 | 
            +
                        ).contiguous()
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    if exists(rel_pos):
         | 
| 352 | 
            +
                        dots = rel_pos(dots)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    if exists(input_mask):
         | 
| 355 | 
            +
                        dots.masked_fill_(~input_mask, mask_value)
         | 
| 356 | 
            +
                        del input_mask
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    if self.causal:
         | 
| 359 | 
            +
                        i, j = dots.shape[-2:]
         | 
| 360 | 
            +
                        r = torch.arange(i, device=device)
         | 
| 361 | 
            +
                        mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
         | 
| 362 | 
            +
                        mask = F.pad(mask, (j - i, 0), value=False)
         | 
| 363 | 
            +
                        dots.masked_fill_(mask, mask_value)
         | 
| 364 | 
            +
                        del mask
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
         | 
| 367 | 
            +
                        top, _ = dots.topk(self.sparse_topk, dim=-1)
         | 
| 368 | 
            +
                        vk = top[..., -1].unsqueeze(-1).expand_as(dots)
         | 
| 369 | 
            +
                        mask = dots < vk
         | 
| 370 | 
            +
                        dots.masked_fill_(mask, mask_value)
         | 
| 371 | 
            +
                        del mask
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    attn = self.attn_fn(dots, dim=-1)
         | 
| 374 | 
            +
                    post_softmax_attn = attn
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    attn = self.dropout(attn)
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    if talking_heads:
         | 
| 379 | 
            +
                        attn = einsum(
         | 
| 380 | 
            +
                            "b h i j, h k -> b k i j", attn, self.post_softmax_proj
         | 
| 381 | 
            +
                        ).contiguous()
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    out = einsum("b h i j, b h j d -> b h i d", attn, v)
         | 
| 384 | 
            +
                    out = rearrange(out, "b h n d -> b n (h d)")
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    intermediates = Intermediates(
         | 
| 387 | 
            +
                        pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
         | 
| 388 | 
            +
                    )
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    return self.to_out(out), intermediates
         | 
| 391 | 
            +
             | 
| 392 | 
            +
             | 
| 393 | 
            +
            class AttentionLayers(nn.Module):
         | 
| 394 | 
            +
                def __init__(
         | 
| 395 | 
            +
                    self,
         | 
| 396 | 
            +
                    dim,
         | 
| 397 | 
            +
                    depth,
         | 
| 398 | 
            +
                    heads=8,
         | 
| 399 | 
            +
                    causal=False,
         | 
| 400 | 
            +
                    cross_attend=False,
         | 
| 401 | 
            +
                    only_cross=False,
         | 
| 402 | 
            +
                    use_scalenorm=False,
         | 
| 403 | 
            +
                    use_rmsnorm=False,
         | 
| 404 | 
            +
                    use_rezero=False,
         | 
| 405 | 
            +
                    rel_pos_num_buckets=32,
         | 
| 406 | 
            +
                    rel_pos_max_distance=128,
         | 
| 407 | 
            +
                    position_infused_attn=False,
         | 
| 408 | 
            +
                    custom_layers=None,
         | 
| 409 | 
            +
                    sandwich_coef=None,
         | 
| 410 | 
            +
                    par_ratio=None,
         | 
| 411 | 
            +
                    residual_attn=False,
         | 
| 412 | 
            +
                    cross_residual_attn=False,
         | 
| 413 | 
            +
                    macaron=False,
         | 
| 414 | 
            +
                    pre_norm=True,
         | 
| 415 | 
            +
                    gate_residual=False,
         | 
| 416 | 
            +
                    **kwargs,
         | 
| 417 | 
            +
                ):
         | 
| 418 | 
            +
                    super().__init__()
         | 
| 419 | 
            +
                    ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
         | 
| 420 | 
            +
                    attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    self.dim = dim
         | 
| 425 | 
            +
                    self.depth = depth
         | 
| 426 | 
            +
                    self.layers = nn.ModuleList([])
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    self.has_pos_emb = position_infused_attn
         | 
| 429 | 
            +
                    self.pia_pos_emb = (
         | 
| 430 | 
            +
                        FixedPositionalEmbedding(dim) if position_infused_attn else None
         | 
| 431 | 
            +
                    )
         | 
| 432 | 
            +
                    self.rotary_pos_emb = always(None)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    assert (
         | 
| 435 | 
            +
                        rel_pos_num_buckets <= rel_pos_max_distance
         | 
| 436 | 
            +
                    ), "number of relative position buckets must be less than the relative position max distance"
         | 
| 437 | 
            +
                    self.rel_pos = None
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    self.pre_norm = pre_norm
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    self.residual_attn = residual_attn
         | 
| 442 | 
            +
                    self.cross_residual_attn = cross_residual_attn
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
         | 
| 445 | 
            +
                    norm_class = RMSNorm if use_rmsnorm else norm_class
         | 
| 446 | 
            +
                    norm_fn = partial(norm_class, dim)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    norm_fn = nn.Identity if use_rezero else norm_fn
         | 
| 449 | 
            +
                    branch_fn = Rezero if use_rezero else None
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    if cross_attend and not only_cross:
         | 
| 452 | 
            +
                        default_block = ("a", "c", "f")
         | 
| 453 | 
            +
                    elif cross_attend and only_cross:
         | 
| 454 | 
            +
                        default_block = ("c", "f")
         | 
| 455 | 
            +
                    else:
         | 
| 456 | 
            +
                        default_block = ("a", "f")
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                    if macaron:
         | 
| 459 | 
            +
                        default_block = ("f",) + default_block
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    if exists(custom_layers):
         | 
| 462 | 
            +
                        layer_types = custom_layers
         | 
| 463 | 
            +
                    elif exists(par_ratio):
         | 
| 464 | 
            +
                        par_depth = depth * len(default_block)
         | 
| 465 | 
            +
                        assert 1 < par_ratio <= par_depth, "par ratio out of range"
         | 
| 466 | 
            +
                        default_block = tuple(filter(not_equals("f"), default_block))
         | 
| 467 | 
            +
                        par_attn = par_depth // par_ratio
         | 
| 468 | 
            +
                        depth_cut = (
         | 
| 469 | 
            +
                            par_depth * 2 // 3
         | 
| 470 | 
            +
                        )  # 2 / 3 attention layer cutoff suggested by PAR paper
         | 
| 471 | 
            +
                        par_width = (depth_cut + depth_cut // par_attn) // par_attn
         | 
| 472 | 
            +
                        assert (
         | 
| 473 | 
            +
                            len(default_block) <= par_width
         | 
| 474 | 
            +
                        ), "default block is too large for par_ratio"
         | 
| 475 | 
            +
                        par_block = default_block + ("f",) * (par_width - len(default_block))
         | 
| 476 | 
            +
                        par_head = par_block * par_attn
         | 
| 477 | 
            +
                        layer_types = par_head + ("f",) * (par_depth - len(par_head))
         | 
| 478 | 
            +
                    elif exists(sandwich_coef):
         | 
| 479 | 
            +
                        assert (
         | 
| 480 | 
            +
                            sandwich_coef > 0 and sandwich_coef <= depth
         | 
| 481 | 
            +
                        ), "sandwich coefficient should be less than the depth"
         | 
| 482 | 
            +
                        layer_types = (
         | 
| 483 | 
            +
                            ("a",) * sandwich_coef
         | 
| 484 | 
            +
                            + default_block * (depth - sandwich_coef)
         | 
| 485 | 
            +
                            + ("f",) * sandwich_coef
         | 
| 486 | 
            +
                        )
         | 
| 487 | 
            +
                    else:
         | 
| 488 | 
            +
                        layer_types = default_block * depth
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    self.layer_types = layer_types
         | 
| 491 | 
            +
                    self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    for layer_type in self.layer_types:
         | 
| 494 | 
            +
                        if layer_type == "a":
         | 
| 495 | 
            +
                            layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
         | 
| 496 | 
            +
                        elif layer_type == "c":
         | 
| 497 | 
            +
                            layer = Attention(dim, heads=heads, **attn_kwargs)
         | 
| 498 | 
            +
                        elif layer_type == "f":
         | 
| 499 | 
            +
                            layer = FeedForward(dim, **ff_kwargs)
         | 
| 500 | 
            +
                            layer = layer if not macaron else Scale(0.5, layer)
         | 
| 501 | 
            +
                        else:
         | 
| 502 | 
            +
                            raise Exception(f"invalid layer type {layer_type}")
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                        if isinstance(layer, Attention) and exists(branch_fn):
         | 
| 505 | 
            +
                            layer = branch_fn(layer)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                        if gate_residual:
         | 
| 508 | 
            +
                            residual_fn = GRUGating(dim)
         | 
| 509 | 
            +
                        else:
         | 
| 510 | 
            +
                            residual_fn = Residual()
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                        self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                def forward(
         | 
| 515 | 
            +
                    self,
         | 
| 516 | 
            +
                    x,
         | 
| 517 | 
            +
                    context=None,
         | 
| 518 | 
            +
                    mask=None,
         | 
| 519 | 
            +
                    context_mask=None,
         | 
| 520 | 
            +
                    mems=None,
         | 
| 521 | 
            +
                    return_hiddens=False,
         | 
| 522 | 
            +
                ):
         | 
| 523 | 
            +
                    hiddens = []
         | 
| 524 | 
            +
                    intermediates = []
         | 
| 525 | 
            +
                    prev_attn = None
         | 
| 526 | 
            +
                    prev_cross_attn = None
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
         | 
| 531 | 
            +
                        zip(self.layer_types, self.layers)
         | 
| 532 | 
            +
                    ):
         | 
| 533 | 
            +
                        is_last = ind == (len(self.layers) - 1)
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                        if layer_type == "a":
         | 
| 536 | 
            +
                            hiddens.append(x)
         | 
| 537 | 
            +
                            layer_mem = mems.pop(0)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                        residual = x
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                        if self.pre_norm:
         | 
| 542 | 
            +
                            x = norm(x)
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                        if layer_type == "a":
         | 
| 545 | 
            +
                            out, inter = block(
         | 
| 546 | 
            +
                                x,
         | 
| 547 | 
            +
                                mask=mask,
         | 
| 548 | 
            +
                                sinusoidal_emb=self.pia_pos_emb,
         | 
| 549 | 
            +
                                rel_pos=self.rel_pos,
         | 
| 550 | 
            +
                                prev_attn=prev_attn,
         | 
| 551 | 
            +
                                mem=layer_mem,
         | 
| 552 | 
            +
                            )
         | 
| 553 | 
            +
                        elif layer_type == "c":
         | 
| 554 | 
            +
                            out, inter = block(
         | 
| 555 | 
            +
                                x,
         | 
| 556 | 
            +
                                context=context,
         | 
| 557 | 
            +
                                mask=mask,
         | 
| 558 | 
            +
                                context_mask=context_mask,
         | 
| 559 | 
            +
                                prev_attn=prev_cross_attn,
         | 
| 560 | 
            +
                            )
         | 
| 561 | 
            +
                        elif layer_type == "f":
         | 
| 562 | 
            +
                            out = block(x)
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                        x = residual_fn(out, residual)
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                        if layer_type in ("a", "c"):
         | 
| 567 | 
            +
                            intermediates.append(inter)
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                        if layer_type == "a" and self.residual_attn:
         | 
| 570 | 
            +
                            prev_attn = inter.pre_softmax_attn
         | 
| 571 | 
            +
                        elif layer_type == "c" and self.cross_residual_attn:
         | 
| 572 | 
            +
                            prev_cross_attn = inter.pre_softmax_attn
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                        if not self.pre_norm and not is_last:
         | 
| 575 | 
            +
                            x = norm(x)
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                    if return_hiddens:
         | 
| 578 | 
            +
                        intermediates = LayerIntermediates(
         | 
| 579 | 
            +
                            hiddens=hiddens, attn_intermediates=intermediates
         | 
| 580 | 
            +
                        )
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                        return x, intermediates
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                    return x
         | 
| 585 | 
            +
             | 
| 586 | 
            +
             | 
| 587 | 
            +
            class Encoder(AttentionLayers):
         | 
| 588 | 
            +
                def __init__(self, **kwargs):
         | 
| 589 | 
            +
                    assert "causal" not in kwargs, "cannot set causality on encoder"
         | 
| 590 | 
            +
                    super().__init__(causal=False, **kwargs)
         | 
| 591 | 
            +
             | 
| 592 | 
            +
             | 
| 593 | 
            +
            class TransformerWrapper(nn.Module):
         | 
| 594 | 
            +
                def __init__(
         | 
| 595 | 
            +
                    self,
         | 
| 596 | 
            +
                    *,
         | 
| 597 | 
            +
                    num_tokens,
         | 
| 598 | 
            +
                    max_seq_len,
         | 
| 599 | 
            +
                    attn_layers,
         | 
| 600 | 
            +
                    emb_dim=None,
         | 
| 601 | 
            +
                    max_mem_len=0.0,
         | 
| 602 | 
            +
                    emb_dropout=0.0,
         | 
| 603 | 
            +
                    num_memory_tokens=None,
         | 
| 604 | 
            +
                    tie_embedding=False,
         | 
| 605 | 
            +
                    use_pos_emb=True,
         | 
| 606 | 
            +
                ):
         | 
| 607 | 
            +
                    super().__init__()
         | 
| 608 | 
            +
                    assert isinstance(
         | 
| 609 | 
            +
                        attn_layers, AttentionLayers
         | 
| 610 | 
            +
                    ), "attention layers must be one of Encoder or Decoder"
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                    dim = attn_layers.dim
         | 
| 613 | 
            +
                    emb_dim = default(emb_dim, dim)
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                    self.max_seq_len = max_seq_len
         | 
| 616 | 
            +
                    self.max_mem_len = max_mem_len
         | 
| 617 | 
            +
                    self.num_tokens = num_tokens
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                    self.token_emb = nn.Embedding(num_tokens, emb_dim)
         | 
| 620 | 
            +
                    self.pos_emb = (
         | 
| 621 | 
            +
                        AbsolutePositionalEmbedding(emb_dim, max_seq_len)
         | 
| 622 | 
            +
                        if (use_pos_emb and not attn_layers.has_pos_emb)
         | 
| 623 | 
            +
                        else always(0)
         | 
| 624 | 
            +
                    )
         | 
| 625 | 
            +
                    self.emb_dropout = nn.Dropout(emb_dropout)
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
         | 
| 628 | 
            +
                    self.attn_layers = attn_layers
         | 
| 629 | 
            +
                    self.norm = nn.LayerNorm(dim)
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                    self.init_()
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                    self.to_logits = (
         | 
| 634 | 
            +
                        nn.Linear(dim, num_tokens)
         | 
| 635 | 
            +
                        if not tie_embedding
         | 
| 636 | 
            +
                        else lambda t: t @ self.token_emb.weight.t()
         | 
| 637 | 
            +
                    )
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                    # memory tokens (like [cls]) from Memory Transformers paper
         | 
| 640 | 
            +
                    num_memory_tokens = default(num_memory_tokens, 0)
         | 
| 641 | 
            +
                    self.num_memory_tokens = num_memory_tokens
         | 
| 642 | 
            +
                    if num_memory_tokens > 0:
         | 
| 643 | 
            +
                        self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                        # let funnel encoder know number of memory tokens, if specified
         | 
| 646 | 
            +
                        if hasattr(attn_layers, "num_memory_tokens"):
         | 
| 647 | 
            +
                            attn_layers.num_memory_tokens = num_memory_tokens
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                def init_(self):
         | 
| 650 | 
            +
                    nn.init.normal_(self.token_emb.weight, std=0.02)
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                def forward(
         | 
| 653 | 
            +
                    self,
         | 
| 654 | 
            +
                    x,
         | 
| 655 | 
            +
                    return_embeddings=False,
         | 
| 656 | 
            +
                    mask=None,
         | 
| 657 | 
            +
                    return_mems=False,
         | 
| 658 | 
            +
                    return_attn=False,
         | 
| 659 | 
            +
                    mems=None,
         | 
| 660 | 
            +
                    **kwargs,
         | 
| 661 | 
            +
                ):
         | 
| 662 | 
            +
                    b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
         | 
| 663 | 
            +
                    x = self.token_emb(x)
         | 
| 664 | 
            +
                    x += self.pos_emb(x)
         | 
| 665 | 
            +
                    x = self.emb_dropout(x)
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                    x = self.project_emb(x)
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    if num_mem > 0:
         | 
| 670 | 
            +
                        mem = repeat(self.memory_tokens, "n d -> b n d", b=b)
         | 
| 671 | 
            +
                        x = torch.cat((mem, x), dim=1)
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                        # auto-handle masking after appending memory tokens
         | 
| 674 | 
            +
                        if exists(mask):
         | 
| 675 | 
            +
                            mask = F.pad(mask, (num_mem, 0), value=True)
         | 
| 676 | 
            +
             | 
| 677 | 
            +
                    x, intermediates = self.attn_layers(
         | 
| 678 | 
            +
                        x, mask=mask, mems=mems, return_hiddens=True, **kwargs
         | 
| 679 | 
            +
                    )
         | 
| 680 | 
            +
                    x = self.norm(x)
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                    mem, x = x[:, :num_mem], x[:, num_mem:]
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                    out = self.to_logits(x) if not return_embeddings else x
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                    if return_mems:
         | 
| 687 | 
            +
                        hiddens = intermediates.hiddens
         | 
| 688 | 
            +
                        new_mems = (
         | 
| 689 | 
            +
                            list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens)))
         | 
| 690 | 
            +
                            if exists(mems)
         | 
| 691 | 
            +
                            else hiddens
         | 
| 692 | 
            +
                        )
         | 
| 693 | 
            +
                        new_mems = list(
         | 
| 694 | 
            +
                            map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)
         | 
| 695 | 
            +
                        )
         | 
| 696 | 
            +
                        return out, new_mems
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    if return_attn:
         | 
| 699 | 
            +
                        attn_maps = list(
         | 
| 700 | 
            +
                            map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
         | 
| 701 | 
            +
                        )
         | 
| 702 | 
            +
                        return out, attn_maps
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                    return out
         | 
    	
        pipeline/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        pipeline/__pycache__/__init__.cpython-312.pyc
    ADDED
    
    | Binary file (148 Bytes). View file | 
|  | 
    	
        pipeline/__pycache__/t2v_turbo_vc2_pipeline.cpython-312.pyc
    ADDED
    
    | Binary file (9.19 kB). View file | 
|  | 
    	
        pipeline/t2v_turbo_vc2_pipeline.py
    ADDED
    
    | @@ -0,0 +1,221 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from diffusers import DiffusionPipeline
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import List, Optional, Union, Dict, Any
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from diffusers import logging
         | 
| 7 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 8 | 
            +
            from lvdm.models.ddpm3d import LatentDiffusion
         | 
| 9 | 
            +
            from scheduler.t2v_turbo_scheduler import T2VTurboScheduler
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class T2VTurboVC2Pipeline(DiffusionPipeline):
         | 
| 15 | 
            +
                def __init__(
         | 
| 16 | 
            +
                    self,
         | 
| 17 | 
            +
                    pretrained_t2v: LatentDiffusion,
         | 
| 18 | 
            +
                    scheduler: T2VTurboScheduler,
         | 
| 19 | 
            +
                    model_config: Dict[str, Any] = None,
         | 
| 20 | 
            +
                ):
         | 
| 21 | 
            +
                    super().__init__()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.register_modules(
         | 
| 24 | 
            +
                        pretrained_t2v=pretrained_t2v,
         | 
| 25 | 
            +
                        scheduler=scheduler,
         | 
| 26 | 
            +
                    )
         | 
| 27 | 
            +
                    self.vae = pretrained_t2v.first_stage_model
         | 
| 28 | 
            +
                    self.unet = pretrained_t2v.model.diffusion_model
         | 
| 29 | 
            +
                    self.text_encoder = pretrained_t2v.cond_stage_model
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    self.model_config = model_config
         | 
| 32 | 
            +
                    self.vae_scale_factor = 8
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def _encode_prompt(
         | 
| 35 | 
            +
                    self,
         | 
| 36 | 
            +
                    prompt,
         | 
| 37 | 
            +
                    device,
         | 
| 38 | 
            +
                    num_videos_per_prompt,
         | 
| 39 | 
            +
                    prompt_embeds: None,
         | 
| 40 | 
            +
                ):
         | 
| 41 | 
            +
                    r"""
         | 
| 42 | 
            +
                    Encodes the prompt into text encoder hidden states.
         | 
| 43 | 
            +
                    Args:
         | 
| 44 | 
            +
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 45 | 
            +
                            prompt to be encoded
         | 
| 46 | 
            +
                        device: (`torch.device`):
         | 
| 47 | 
            +
                            torch device
         | 
| 48 | 
            +
                        num_videos_per_prompt (`int`):
         | 
| 49 | 
            +
                            number of images that should be generated per prompt
         | 
| 50 | 
            +
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 51 | 
            +
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 52 | 
            +
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 53 | 
            +
                    """
         | 
| 54 | 
            +
                    if prompt_embeds is None:
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                        prompt_embeds = self.text_encoder(prompt)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    prompt_embeds = prompt_embeds.to(device=device)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 61 | 
            +
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         | 
| 62 | 
            +
                    prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
         | 
| 63 | 
            +
                    prompt_embeds = prompt_embeds.view(
         | 
| 64 | 
            +
                        bs_embed * num_videos_per_prompt, seq_len, -1
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    # Don't need to get uncond prompt embedding because of LCM Guided Distillation
         | 
| 68 | 
            +
                    return prompt_embeds
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def prepare_latents(
         | 
| 71 | 
            +
                    self,
         | 
| 72 | 
            +
                    batch_size,
         | 
| 73 | 
            +
                    num_channels_latents,
         | 
| 74 | 
            +
                    frames,
         | 
| 75 | 
            +
                    height,
         | 
| 76 | 
            +
                    width,
         | 
| 77 | 
            +
                    dtype,
         | 
| 78 | 
            +
                    device,
         | 
| 79 | 
            +
                    generator,
         | 
| 80 | 
            +
                    latents=None,
         | 
| 81 | 
            +
                ):
         | 
| 82 | 
            +
                    shape = (
         | 
| 83 | 
            +
                        batch_size,
         | 
| 84 | 
            +
                        num_channels_latents,
         | 
| 85 | 
            +
                        frames,
         | 
| 86 | 
            +
                        height // self.vae_scale_factor,
         | 
| 87 | 
            +
                        width // self.vae_scale_factor,
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
                    if latents is None:
         | 
| 90 | 
            +
                        latents = randn_tensor(
         | 
| 91 | 
            +
                            shape, generator=generator, device=device, dtype=dtype
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
            +
                    else:
         | 
| 94 | 
            +
                        latents = latents.to(device)
         | 
| 95 | 
            +
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 96 | 
            +
                    latents = latents * self.scheduler.init_noise_sigma
         | 
| 97 | 
            +
                    return latents
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
         | 
| 100 | 
            +
                    """
         | 
| 101 | 
            +
                    see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
         | 
| 102 | 
            +
                    Args:
         | 
| 103 | 
            +
                    timesteps: torch.Tensor: generate embedding vectors at these timesteps
         | 
| 104 | 
            +
                    embedding_dim: int: dimension of the embeddings to generate
         | 
| 105 | 
            +
                    dtype: data type of the generated embeddings
         | 
| 106 | 
            +
                    Returns:
         | 
| 107 | 
            +
                    embedding vectors with shape `(len(timesteps), embedding_dim)`
         | 
| 108 | 
            +
                    """
         | 
| 109 | 
            +
                    assert len(w.shape) == 1
         | 
| 110 | 
            +
                    w = w * 1000.0
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    half_dim = embedding_dim // 2
         | 
| 113 | 
            +
                    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
         | 
| 114 | 
            +
                    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
         | 
| 115 | 
            +
                    emb = w.to(dtype)[:, None] * emb[None, :]
         | 
| 116 | 
            +
                    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 117 | 
            +
                    if embedding_dim % 2 == 1:  # zero pad
         | 
| 118 | 
            +
                        emb = torch.nn.functional.pad(emb, (0, 1))
         | 
| 119 | 
            +
                    assert emb.shape == (w.shape[0], embedding_dim)
         | 
| 120 | 
            +
                    return emb
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                @torch.no_grad()
         | 
| 123 | 
            +
                def __call__(
         | 
| 124 | 
            +
                    self,
         | 
| 125 | 
            +
                    prompt: Union[str, List[str]] = None,
         | 
| 126 | 
            +
                    height: Optional[int] = 320,
         | 
| 127 | 
            +
                    width: Optional[int] = 512,
         | 
| 128 | 
            +
                    frames: int = 16,
         | 
| 129 | 
            +
                    fps: int = 16,
         | 
| 130 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 131 | 
            +
                    motion_gs: float = 0.1,
         | 
| 132 | 
            +
                    use_motion_cond: bool = False,
         | 
| 133 | 
            +
                    percentage: float = 0.3,
         | 
| 134 | 
            +
                    num_videos_per_prompt: Optional[int] = 1,
         | 
| 135 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 136 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 137 | 
            +
                    num_inference_steps: int = 4,
         | 
| 138 | 
            +
                    lcm_origin_steps: int = 50,
         | 
| 139 | 
            +
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 140 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 141 | 
            +
                ):
         | 
| 142 | 
            +
                    unet_config = self.model_config["params"]["unet_config"]
         | 
| 143 | 
            +
                    # 0. Default height and width to unet
         | 
| 144 | 
            +
                    frames = self.pretrained_t2v.temporal_length if frames < 0 else frames
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # 2. Define call parameters
         | 
| 147 | 
            +
                    if prompt is not None and isinstance(prompt, str):
         | 
| 148 | 
            +
                        batch_size = 1
         | 
| 149 | 
            +
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 150 | 
            +
                        batch_size = len(prompt)
         | 
| 151 | 
            +
                    else:
         | 
| 152 | 
            +
                        batch_size = prompt_embeds.shape[0]
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    device = self._execution_device
         | 
| 155 | 
            +
                    # do_classifier_free_guidance = guidance_scale > 0.0  # In LCM Implementation:  cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # 3. Encode input prompt
         | 
| 158 | 
            +
                    prompt_embeds = self._encode_prompt(
         | 
| 159 | 
            +
                        prompt,
         | 
| 160 | 
            +
                        device,
         | 
| 161 | 
            +
                        num_videos_per_prompt,
         | 
| 162 | 
            +
                        prompt_embeds=prompt_embeds,
         | 
| 163 | 
            +
                    )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    # 4. Prepare timesteps
         | 
| 166 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps, lcm_origin_steps)
         | 
| 167 | 
            +
                    timesteps = self.scheduler.timesteps
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    # 5. Prepare latent variable
         | 
| 170 | 
            +
                    num_channels_latents = unet_config["params"]["in_channels"]
         | 
| 171 | 
            +
                    latents = self.prepare_latents(
         | 
| 172 | 
            +
                        batch_size * num_videos_per_prompt,
         | 
| 173 | 
            +
                        num_channels_latents,
         | 
| 174 | 
            +
                        frames,
         | 
| 175 | 
            +
                        height,
         | 
| 176 | 
            +
                        width,
         | 
| 177 | 
            +
                        prompt_embeds.dtype,
         | 
| 178 | 
            +
                        device,
         | 
| 179 | 
            +
                        generator,
         | 
| 180 | 
            +
                        latents,
         | 
| 181 | 
            +
                    )
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    bs = batch_size * num_videos_per_prompt
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    context = {"context": torch.cat([prompt_embeds.to(self.dtype)], 1), "fps": fps}
         | 
| 186 | 
            +
                    # 6. Get Guidance Scale Embedding
         | 
| 187 | 
            +
                    w = torch.tensor(guidance_scale).repeat(bs)
         | 
| 188 | 
            +
                    w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device)
         | 
| 189 | 
            +
                    context["timestep_cond"] = w_embedding.to(self.dtype)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    ms_t_threshold = self.scheduler.config.num_train_timesteps * (1 - percentage)
         | 
| 192 | 
            +
                    # 7. LCM MultiStep Sampling Loop:
         | 
| 193 | 
            +
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 194 | 
            +
                        for i, t in enumerate(timesteps):
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                            ts = torch.full((bs,), t, device=device, dtype=torch.long)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                            if use_motion_cond:
         | 
| 199 | 
            +
                                motion_gs_pt = torch.tensor(motion_gs).repeat(bs)
         | 
| 200 | 
            +
                                if t < ms_t_threshold:
         | 
| 201 | 
            +
                                    motion_gs_pt = torch.zeros_like(motion_gs_pt)
         | 
| 202 | 
            +
                                motion_gs_embedding = self.get_w_embedding(
         | 
| 203 | 
            +
                                    motion_gs_pt, embedding_dim=256, dtype=self.dtype
         | 
| 204 | 
            +
                                ).to(device)
         | 
| 205 | 
            +
                                context["motion_cond"] = motion_gs_embedding
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                            # model prediction (v-prediction, eps, x)
         | 
| 208 | 
            +
                            model_pred = self.unet(latents, ts, **context)
         | 
| 209 | 
            +
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 210 | 
            +
                            latents, denoised = self.scheduler.step(
         | 
| 211 | 
            +
                                model_pred, i, t, latents, generator=generator, return_dict=False
         | 
| 212 | 
            +
                            )
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                            progress_bar.update()
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    if not output_type == "latent":
         | 
| 217 | 
            +
                        videos = self.pretrained_t2v.decode_first_stage_2DAE(denoised)
         | 
| 218 | 
            +
                    else:
         | 
| 219 | 
            +
                        videos = denoised
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    return videos
         | 
    	
        scheduler/__pycache__/t2v_turbo_scheduler.cpython-312.pyc
    ADDED
    
    | Binary file (24 kB). View file | 
|  | 
    	
        scheduler/t2v_turbo_scheduler.py
    ADDED
    
    | @@ -0,0 +1,524 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
         | 
| 16 | 
            +
            # and https://github.com/hojonathanho/diffusion
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import math
         | 
| 19 | 
            +
            from dataclasses import dataclass
         | 
| 20 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import numpy as np
         | 
| 23 | 
            +
            import torch
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from diffusers import ConfigMixin, SchedulerMixin
         | 
| 26 | 
            +
            from diffusers.configuration_utils import register_to_config
         | 
| 27 | 
            +
            from diffusers.utils import BaseOutput
         | 
| 28 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def extract_into_tensor(a, t, x_shape):
         | 
| 32 | 
            +
                b, *_ = t.shape
         | 
| 33 | 
            +
                out = a.gather(-1, t)
         | 
| 34 | 
            +
                return out.reshape(b, *((1,) * (len(x_shape) - 1)))
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            @dataclass
         | 
| 38 | 
            +
            # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
         | 
| 39 | 
            +
            class T2VTurboSchedulerOutput(BaseOutput):
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                Output class for the scheduler's `step` function output.
         | 
| 42 | 
            +
                Args:
         | 
| 43 | 
            +
                    prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 44 | 
            +
                        Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
         | 
| 45 | 
            +
                        denoising loop.
         | 
| 46 | 
            +
                    pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 47 | 
            +
                        The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
         | 
| 48 | 
            +
                        `pred_original_sample` can be used to preview progress or for guidance.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                prev_sample: torch.FloatTensor
         | 
| 52 | 
            +
                denoised: Optional[torch.FloatTensor] = None
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
         | 
| 56 | 
            +
            def betas_for_alpha_bar(
         | 
| 57 | 
            +
                num_diffusion_timesteps,
         | 
| 58 | 
            +
                max_beta=0.999,
         | 
| 59 | 
            +
                alpha_transform_type="cosine",
         | 
| 60 | 
            +
            ):
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
         | 
| 63 | 
            +
                (1-beta) over time from t = [0,1].
         | 
| 64 | 
            +
                Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
         | 
| 65 | 
            +
                to that part of the diffusion process.
         | 
| 66 | 
            +
                Args:
         | 
| 67 | 
            +
                    num_diffusion_timesteps (`int`): the number of betas to produce.
         | 
| 68 | 
            +
                    max_beta (`float`): the maximum beta to use; use values lower than 1 to
         | 
| 69 | 
            +
                                 prevent singularities.
         | 
| 70 | 
            +
                    alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
         | 
| 71 | 
            +
                                 Choose from `cosine` or `exp`
         | 
| 72 | 
            +
                Returns:
         | 
| 73 | 
            +
                    betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                if alpha_transform_type == "cosine":
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    def alpha_bar_fn(t):
         | 
| 78 | 
            +
                        return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                elif alpha_transform_type == "exp":
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    def alpha_bar_fn(t):
         | 
| 83 | 
            +
                        return math.exp(t * -12.0)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                else:
         | 
| 86 | 
            +
                    raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                betas = []
         | 
| 89 | 
            +
                for i in range(num_diffusion_timesteps):
         | 
| 90 | 
            +
                    t1 = i / num_diffusion_timesteps
         | 
| 91 | 
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         | 
| 92 | 
            +
                    betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
         | 
| 93 | 
            +
                return torch.tensor(betas, dtype=torch.float32)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def rescale_zero_terminal_snr(betas):
         | 
| 97 | 
            +
                """
         | 
| 98 | 
            +
                Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
         | 
| 99 | 
            +
                Args:
         | 
| 100 | 
            +
                    betas (`torch.FloatTensor`):
         | 
| 101 | 
            +
                        the betas that the scheduler is being initialized with.
         | 
| 102 | 
            +
                Returns:
         | 
| 103 | 
            +
                    `torch.FloatTensor`: rescaled betas with zero terminal SNR
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                # Convert betas to alphas_bar_sqrt
         | 
| 106 | 
            +
                alphas = 1.0 - betas
         | 
| 107 | 
            +
                alphas_cumprod = torch.cumprod(alphas, dim=0)
         | 
| 108 | 
            +
                alphas_bar_sqrt = alphas_cumprod.sqrt()
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                # Store old values.
         | 
| 111 | 
            +
                alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
         | 
| 112 | 
            +
                alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                # Shift so the last timestep is zero.
         | 
| 115 | 
            +
                alphas_bar_sqrt -= alphas_bar_sqrt_T
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                # Scale so the first timestep is back to the old value.
         | 
| 118 | 
            +
                alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                # Convert alphas_bar_sqrt to betas
         | 
| 121 | 
            +
                alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
         | 
| 122 | 
            +
                alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
         | 
| 123 | 
            +
                alphas = torch.cat([alphas_bar[0:1], alphas])
         | 
| 124 | 
            +
                betas = 1 - alphas
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                return betas
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            class T2VTurboScheduler(SchedulerMixin, ConfigMixin):
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
                `T2VTurboScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
         | 
| 132 | 
            +
                non-Markovian guidance.
         | 
| 133 | 
            +
                This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
         | 
| 134 | 
            +
                methods the library implements for all schedulers such as loading and saving.
         | 
| 135 | 
            +
                Args:
         | 
| 136 | 
            +
                    num_train_timesteps (`int`, defaults to 1000):
         | 
| 137 | 
            +
                        The number of diffusion steps to train the model.
         | 
| 138 | 
            +
                    beta_start (`float`, defaults to 0.0001):
         | 
| 139 | 
            +
                        The starting `beta` value of inference.
         | 
| 140 | 
            +
                    beta_end (`float`, defaults to 0.02):
         | 
| 141 | 
            +
                        The final `beta` value.
         | 
| 142 | 
            +
                    beta_schedule (`str`, defaults to `"linear"`):
         | 
| 143 | 
            +
                        The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
         | 
| 144 | 
            +
                        `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
         | 
| 145 | 
            +
                    trained_betas (`np.ndarray`, *optional*):
         | 
| 146 | 
            +
                        Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
         | 
| 147 | 
            +
                    clip_sample (`bool`, defaults to `True`):
         | 
| 148 | 
            +
                        Clip the predicted sample for numerical stability.
         | 
| 149 | 
            +
                    clip_sample_range (`float`, defaults to 1.0):
         | 
| 150 | 
            +
                        The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
         | 
| 151 | 
            +
                    set_alpha_to_one (`bool`, defaults to `True`):
         | 
| 152 | 
            +
                        Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
         | 
| 153 | 
            +
                        there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
         | 
| 154 | 
            +
                        otherwise it uses the alpha value at step 0.
         | 
| 155 | 
            +
                    steps_offset (`int`, defaults to 0):
         | 
| 156 | 
            +
                        An offset added to the inference steps. You can use a combination of `offset=1` and
         | 
| 157 | 
            +
                        `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
         | 
| 158 | 
            +
                        Diffusion.
         | 
| 159 | 
            +
                    prediction_type (`str`, defaults to `epsilon`, *optional*):
         | 
| 160 | 
            +
                        Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
         | 
| 161 | 
            +
                        `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
         | 
| 162 | 
            +
                        Video](https://imagen.research.google/video/paper.pdf) paper).
         | 
| 163 | 
            +
                    thresholding (`bool`, defaults to `False`):
         | 
| 164 | 
            +
                        Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
         | 
| 165 | 
            +
                        as Stable Diffusion.
         | 
| 166 | 
            +
                    dynamic_thresholding_ratio (`float`, defaults to 0.995):
         | 
| 167 | 
            +
                        The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
         | 
| 168 | 
            +
                    sample_max_value (`float`, defaults to 1.0):
         | 
| 169 | 
            +
                        The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
         | 
| 170 | 
            +
                    timestep_spacing (`str`, defaults to `"leading"`):
         | 
| 171 | 
            +
                        The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
         | 
| 172 | 
            +
                        Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
         | 
| 173 | 
            +
                    rescale_betas_zero_snr (`bool`, defaults to `False`):
         | 
| 174 | 
            +
                        Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
         | 
| 175 | 
            +
                        dark samples instead of limiting it to samples with medium brightness. Loosely related to
         | 
| 176 | 
            +
                        [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
         | 
| 177 | 
            +
                """
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
         | 
| 180 | 
            +
                order = 1
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                @register_to_config
         | 
| 183 | 
            +
                def __init__(
         | 
| 184 | 
            +
                    self,
         | 
| 185 | 
            +
                    num_train_timesteps: int = 1000,
         | 
| 186 | 
            +
                    linear_start: float = 0.00085,
         | 
| 187 | 
            +
                    linear_end: float = 0.012,
         | 
| 188 | 
            +
                    beta_schedule: str = "scaled_linear",
         | 
| 189 | 
            +
                    trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
         | 
| 190 | 
            +
                    clip_sample: bool = True,
         | 
| 191 | 
            +
                    set_alpha_to_one: bool = True,
         | 
| 192 | 
            +
                    steps_offset: int = 0,
         | 
| 193 | 
            +
                    prediction_type: str = "epsilon",
         | 
| 194 | 
            +
                    thresholding: bool = False,
         | 
| 195 | 
            +
                    dynamic_thresholding_ratio: float = 0.995,
         | 
| 196 | 
            +
                    clip_sample_range: float = 1.0,
         | 
| 197 | 
            +
                    sample_max_value: float = 1.0,
         | 
| 198 | 
            +
                    timestep_spacing: str = "leading",
         | 
| 199 | 
            +
                    rescale_betas_zero_snr: bool = False,
         | 
| 200 | 
            +
                ):
         | 
| 201 | 
            +
                    assert beta_schedule == "scaled_linear"
         | 
| 202 | 
            +
                    assert trained_betas is None
         | 
| 203 | 
            +
                    if trained_betas is not None:
         | 
| 204 | 
            +
                        self.betas = torch.tensor(trained_betas, dtype=torch.float32)
         | 
| 205 | 
            +
                    elif beta_schedule == "linear":
         | 
| 206 | 
            +
                        self.betas = torch.linspace(
         | 
| 207 | 
            +
                            linear_start, linear_end, num_train_timesteps, dtype=torch.float32
         | 
| 208 | 
            +
                        )
         | 
| 209 | 
            +
                    elif beta_schedule == "scaled_linear":
         | 
| 210 | 
            +
                        # this schedule is very specific to the latent diffusion model.
         | 
| 211 | 
            +
                        self.betas = (
         | 
| 212 | 
            +
                            torch.linspace(
         | 
| 213 | 
            +
                                linear_start**0.5,
         | 
| 214 | 
            +
                                linear_end**0.5,
         | 
| 215 | 
            +
                                num_train_timesteps,
         | 
| 216 | 
            +
                                dtype=torch.float32,
         | 
| 217 | 
            +
                            )
         | 
| 218 | 
            +
                            ** 2
         | 
| 219 | 
            +
                        )
         | 
| 220 | 
            +
                    elif beta_schedule == "squaredcos_cap_v2":
         | 
| 221 | 
            +
                        # Glide cosine schedule
         | 
| 222 | 
            +
                        self.betas = betas_for_alpha_bar(num_train_timesteps)
         | 
| 223 | 
            +
                    else:
         | 
| 224 | 
            +
                        raise NotImplementedError(
         | 
| 225 | 
            +
                            f"{beta_schedule} does is not implemented for {self.__class__}"
         | 
| 226 | 
            +
                        )
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # Rescale for zero SNR
         | 
| 229 | 
            +
                    if rescale_betas_zero_snr:
         | 
| 230 | 
            +
                        self.betas = rescale_zero_terminal_snr(self.betas)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    self.alphas = 1.0 - self.betas
         | 
| 233 | 
            +
                    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    # At every step in ddim, we are looking into the previous alphas_cumprod
         | 
| 236 | 
            +
                    # For the final step, there is no previous alphas_cumprod because we are already at 0
         | 
| 237 | 
            +
                    # `set_alpha_to_one` decides whether we set this parameter simply to one or
         | 
| 238 | 
            +
                    # whether we use the final alpha of the "non-previous" one.
         | 
| 239 | 
            +
                    self.final_alpha_cumprod = (
         | 
| 240 | 
            +
                        torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
         | 
| 241 | 
            +
                    )
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    # standard deviation of the initial noise distribution
         | 
| 244 | 
            +
                    self.init_noise_sigma = 1.0
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    # setable values
         | 
| 247 | 
            +
                    self.num_inference_steps = None
         | 
| 248 | 
            +
                    self.timesteps = torch.from_numpy(
         | 
| 249 | 
            +
                        np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)
         | 
| 250 | 
            +
                    )
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def scale_model_input(
         | 
| 253 | 
            +
                    self, sample: torch.FloatTensor, timestep: Optional[int] = None
         | 
| 254 | 
            +
                ) -> torch.FloatTensor:
         | 
| 255 | 
            +
                    """
         | 
| 256 | 
            +
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         | 
| 257 | 
            +
                    current timestep.
         | 
| 258 | 
            +
                    Args:
         | 
| 259 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 260 | 
            +
                            The input sample.
         | 
| 261 | 
            +
                        timestep (`int`, *optional*):
         | 
| 262 | 
            +
                            The current timestep in the diffusion chain.
         | 
| 263 | 
            +
                    Returns:
         | 
| 264 | 
            +
                        `torch.FloatTensor`:
         | 
| 265 | 
            +
                            A scaled input sample.
         | 
| 266 | 
            +
                    """
         | 
| 267 | 
            +
                    return sample
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def _get_variance(self, timestep, prev_timestep):
         | 
| 270 | 
            +
                    alpha_prod_t = self.alphas_cumprod[timestep]
         | 
| 271 | 
            +
                    alpha_prod_t_prev = (
         | 
| 272 | 
            +
                        self.alphas_cumprod[prev_timestep]
         | 
| 273 | 
            +
                        if prev_timestep >= 0
         | 
| 274 | 
            +
                        else self.final_alpha_cumprod
         | 
| 275 | 
            +
                    )
         | 
| 276 | 
            +
                    beta_prod_t = 1 - alpha_prod_t
         | 
| 277 | 
            +
                    beta_prod_t_prev = 1 - alpha_prod_t_prev
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    variance = (beta_prod_t_prev / beta_prod_t) * (
         | 
| 280 | 
            +
                        1 - alpha_prod_t / alpha_prod_t_prev
         | 
| 281 | 
            +
                    )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    return variance
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
         | 
| 286 | 
            +
                def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 287 | 
            +
                    """
         | 
| 288 | 
            +
                    "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
         | 
| 289 | 
            +
                    prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
         | 
| 290 | 
            +
                    s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
         | 
| 291 | 
            +
                    pixels from saturation at each step. We find that dynamic thresholding results in significantly better
         | 
| 292 | 
            +
                    photorealism as well as better image-text alignment, especially when using very large guidance weights."
         | 
| 293 | 
            +
                    https://arxiv.org/abs/2205.11487
         | 
| 294 | 
            +
                    """
         | 
| 295 | 
            +
                    dtype = sample.dtype
         | 
| 296 | 
            +
                    batch_size, channels, height, width = sample.shape
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    if dtype not in (torch.float32, torch.float64):
         | 
| 299 | 
            +
                        sample = (
         | 
| 300 | 
            +
                            sample.float()
         | 
| 301 | 
            +
                        )  # upcast for quantile calculation, and clamp not implemented for cpu half
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    # Flatten sample for doing quantile calculation along each image
         | 
| 304 | 
            +
                    sample = sample.reshape(batch_size, channels * height * width)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    abs_sample = sample.abs()  # "a certain percentile absolute pixel value"
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
         | 
| 309 | 
            +
                    s = torch.clamp(
         | 
| 310 | 
            +
                        s, min=1, max=self.config.sample_max_value
         | 
| 311 | 
            +
                    )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
         | 
| 314 | 
            +
                    sample = (
         | 
| 315 | 
            +
                        torch.clamp(sample, -s, s) / s
         | 
| 316 | 
            +
                    )  # "we threshold xt0 to the range [-s, s] and then divide by s"
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    sample = sample.reshape(batch_size, channels, height, width)
         | 
| 319 | 
            +
                    sample = sample.to(dtype)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    return sample
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                def set_timesteps(
         | 
| 324 | 
            +
                    self,
         | 
| 325 | 
            +
                    num_inference_steps: int,
         | 
| 326 | 
            +
                    lcm_origin_steps: int,
         | 
| 327 | 
            +
                    device: Union[str, torch.device] = None,
         | 
| 328 | 
            +
                ):
         | 
| 329 | 
            +
                    """
         | 
| 330 | 
            +
                    Sets the discrete timesteps used for the diffusion chain (to be run before inference).
         | 
| 331 | 
            +
                    Args:
         | 
| 332 | 
            +
                        num_inference_steps (`int`):
         | 
| 333 | 
            +
                            The number of diffusion steps used when generating samples with a pre-trained model.
         | 
| 334 | 
            +
                    """
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    if num_inference_steps > self.config.num_train_timesteps:
         | 
| 337 | 
            +
                        raise ValueError(
         | 
| 338 | 
            +
                            f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
         | 
| 339 | 
            +
                            f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
         | 
| 340 | 
            +
                            f" maximal {self.config.num_train_timesteps} timesteps."
         | 
| 341 | 
            +
                        )
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    self.num_inference_steps = num_inference_steps
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    # LCM Timesteps Setting:  # Linear Spacing
         | 
| 346 | 
            +
                    c = self.config.num_train_timesteps // lcm_origin_steps
         | 
| 347 | 
            +
                    lcm_origin_timesteps = (
         | 
| 348 | 
            +
                        np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1
         | 
| 349 | 
            +
                    )  # LCM Training  Steps Schedule
         | 
| 350 | 
            +
                    skipping_step = len(lcm_origin_timesteps) // num_inference_steps
         | 
| 351 | 
            +
                    timesteps = lcm_origin_timesteps[::-skipping_step][
         | 
| 352 | 
            +
                        :num_inference_steps
         | 
| 353 | 
            +
                    ]  # LCM Inference Steps Schedule
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                    ## From VideoCrafter 2
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                def get_scalings_for_boundary_condition_discrete(self, t):
         | 
| 360 | 
            +
                    self.sigma_data = 0.5  # Default: 0.5
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    # By dividing 0.1: This is almost a delta function at t=0.
         | 
| 363 | 
            +
                    c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
         | 
| 364 | 
            +
                    c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
         | 
| 365 | 
            +
                    return c_skip, c_out
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                def step(
         | 
| 368 | 
            +
                    self,
         | 
| 369 | 
            +
                    model_output: torch.FloatTensor,
         | 
| 370 | 
            +
                    timeindex: int,
         | 
| 371 | 
            +
                    timestep: int,
         | 
| 372 | 
            +
                    sample: torch.FloatTensor,
         | 
| 373 | 
            +
                    eta: float = 0.0,
         | 
| 374 | 
            +
                    use_clipped_model_output: bool = False,
         | 
| 375 | 
            +
                    generator=None,
         | 
| 376 | 
            +
                    variance_noise: Optional[torch.FloatTensor] = None,
         | 
| 377 | 
            +
                    return_dict: bool = True,
         | 
| 378 | 
            +
                ) -> Union[T2VTurboSchedulerOutput, Tuple]:
         | 
| 379 | 
            +
                    """
         | 
| 380 | 
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
         | 
| 381 | 
            +
                    process from the learned model outputs (most often the predicted noise).
         | 
| 382 | 
            +
                    Args:
         | 
| 383 | 
            +
                        model_output (`torch.FloatTensor`):
         | 
| 384 | 
            +
                            The direct output from learned diffusion model.
         | 
| 385 | 
            +
                        timestep (`float`):
         | 
| 386 | 
            +
                            The current discrete timestep in the diffusion chain.
         | 
| 387 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 388 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 389 | 
            +
                        eta (`float`):
         | 
| 390 | 
            +
                            The weight of noise for added noise in diffusion step.
         | 
| 391 | 
            +
                        use_clipped_model_output (`bool`, defaults to `False`):
         | 
| 392 | 
            +
                            If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
         | 
| 393 | 
            +
                            because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
         | 
| 394 | 
            +
                            clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
         | 
| 395 | 
            +
                            `use_clipped_model_output` has no effect.
         | 
| 396 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 397 | 
            +
                            A random number generator.
         | 
| 398 | 
            +
                        variance_noise (`torch.FloatTensor`):
         | 
| 399 | 
            +
                            Alternative to generating noise with `generator` by directly providing the noise for the variance
         | 
| 400 | 
            +
                            itself. Useful for methods such as [`CycleDiffusion`].
         | 
| 401 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 402 | 
            +
                            Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
         | 
| 403 | 
            +
                    Returns:
         | 
| 404 | 
            +
                        [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
         | 
| 405 | 
            +
                            If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
         | 
| 406 | 
            +
                            tuple is returned where the first element is the sample tensor.
         | 
| 407 | 
            +
                    """
         | 
| 408 | 
            +
                    if self.num_inference_steps is None:
         | 
| 409 | 
            +
                        raise ValueError(
         | 
| 410 | 
            +
                            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
         | 
| 411 | 
            +
                        )
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    # 1. get previous step value
         | 
| 414 | 
            +
                    prev_timeindex = timeindex + 1
         | 
| 415 | 
            +
                    if prev_timeindex < len(self.timesteps):
         | 
| 416 | 
            +
                        prev_timestep = self.timesteps[prev_timeindex]
         | 
| 417 | 
            +
                    else:
         | 
| 418 | 
            +
                        prev_timestep = timestep
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                    # 2. compute alphas, betas
         | 
| 421 | 
            +
                    alpha_prod_t = self.alphas_cumprod[timestep]
         | 
| 422 | 
            +
                    alpha_prod_t_prev = (
         | 
| 423 | 
            +
                        self.alphas_cumprod[prev_timestep]
         | 
| 424 | 
            +
                        if prev_timestep >= 0
         | 
| 425 | 
            +
                        else self.final_alpha_cumprod
         | 
| 426 | 
            +
                    )
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    beta_prod_t = 1 - alpha_prod_t
         | 
| 429 | 
            +
                    beta_prod_t_prev = 1 - alpha_prod_t_prev
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                    # 3. Get scalings for boundary conditions
         | 
| 432 | 
            +
                    c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    # 4. Different Parameterization:
         | 
| 435 | 
            +
                    parameterization = self.config.prediction_type
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    if parameterization == "epsilon":  # noise-prediction
         | 
| 438 | 
            +
                        pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    elif parameterization == "sample":  # x-prediction
         | 
| 441 | 
            +
                        pred_x0 = model_output
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                    elif parameterization == "v_prediction":  # v-prediction
         | 
| 444 | 
            +
                        pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    # 4. Denoise model output using boundary conditions
         | 
| 447 | 
            +
                    denoised = c_out * pred_x0 + c_skip * sample
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    # 5. Sample z ~ N(0, I), For MultiStep Inference
         | 
| 450 | 
            +
                    # Noise is not used for one-step sampling.
         | 
| 451 | 
            +
                    if len(self.timesteps) > 1:
         | 
| 452 | 
            +
                        noise = randn_tensor(
         | 
| 453 | 
            +
                            denoised.shape,
         | 
| 454 | 
            +
                            generator=generator,
         | 
| 455 | 
            +
                            device=denoised.device,
         | 
| 456 | 
            +
                            dtype=denoised.dtype,
         | 
| 457 | 
            +
                        )
         | 
| 458 | 
            +
                        prev_sample = (
         | 
| 459 | 
            +
                            alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
         | 
| 460 | 
            +
                        )
         | 
| 461 | 
            +
                    else:
         | 
| 462 | 
            +
                        prev_sample = denoised
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    if not return_dict:
         | 
| 465 | 
            +
                        return (prev_sample, denoised)
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    return T2VTurboSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
         | 
| 470 | 
            +
                def add_noise(
         | 
| 471 | 
            +
                    self,
         | 
| 472 | 
            +
                    original_samples: torch.FloatTensor,
         | 
| 473 | 
            +
                    noise: torch.FloatTensor,
         | 
| 474 | 
            +
                    timesteps: torch.IntTensor,
         | 
| 475 | 
            +
                ) -> torch.FloatTensor:
         | 
| 476 | 
            +
                    # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
         | 
| 477 | 
            +
                    alphas_cumprod = self.alphas_cumprod.to(
         | 
| 478 | 
            +
                        device=original_samples.device, dtype=original_samples.dtype
         | 
| 479 | 
            +
                    )
         | 
| 480 | 
            +
                    timesteps = timesteps.to(original_samples.device)
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         | 
| 483 | 
            +
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         | 
| 484 | 
            +
                    while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
         | 
| 485 | 
            +
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         | 
| 488 | 
            +
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         | 
| 489 | 
            +
                    while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
         | 
| 490 | 
            +
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                    noisy_samples = (
         | 
| 493 | 
            +
                        sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
         | 
| 494 | 
            +
                    )
         | 
| 495 | 
            +
                    return noisy_samples
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
         | 
| 498 | 
            +
                def get_velocity(
         | 
| 499 | 
            +
                    self,
         | 
| 500 | 
            +
                    sample: torch.FloatTensor,
         | 
| 501 | 
            +
                    noise: torch.FloatTensor,
         | 
| 502 | 
            +
                    timesteps: torch.IntTensor,
         | 
| 503 | 
            +
                ) -> torch.FloatTensor:
         | 
| 504 | 
            +
                    # Make sure alphas_cumprod and timestep have same device and dtype as sample
         | 
| 505 | 
            +
                    alphas_cumprod = self.alphas_cumprod.to(
         | 
| 506 | 
            +
                        device=sample.device, dtype=sample.dtype
         | 
| 507 | 
            +
                    )
         | 
| 508 | 
            +
                    timesteps = timesteps.to(sample.device)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         | 
| 511 | 
            +
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         | 
| 512 | 
            +
                    while len(sqrt_alpha_prod.shape) < len(sample.shape):
         | 
| 513 | 
            +
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         | 
| 516 | 
            +
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         | 
| 517 | 
            +
                    while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
         | 
| 518 | 
            +
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
         | 
| 521 | 
            +
                    return velocity
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                def __len__(self):
         | 
| 524 | 
            +
                    return self.config.num_train_timesteps
         | 
    	
        utils/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        utils/__pycache__/__init__.cpython-312.pyc
    ADDED
    
    | Binary file (145 Bytes). View file | 
|  | 
    	
        utils/__pycache__/common_utils.cpython-312.pyc
    ADDED
    
    | Binary file (19.8 kB). View file | 
|  | 
    	
        utils/__pycache__/lora.cpython-312.pyc
    ADDED
    
    | Binary file (52.5 kB). View file | 
|  | 
    	
        utils/__pycache__/lora_handler.cpython-312.pyc
    ADDED
    
    | Binary file (5.54 kB). View file | 
|  | 
    	
        utils/__pycache__/utils.cpython-312.pyc
    ADDED
    
    | Binary file (5.81 kB). View file | 
|  | 
    	
        utils/common_utils.py
    ADDED
    
    | @@ -0,0 +1,511 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import ast
         | 
| 2 | 
            +
            import gc
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from diffusers.models.attention_processor import AttnProcessor2_0
         | 
| 11 | 
            +
            from diffusers.models.attention import BasicTransformerBlock
         | 
| 12 | 
            +
            from decord import VideoReader
         | 
| 13 | 
            +
            import wandb
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def extract_into_tensor(a, t, x_shape):
         | 
| 17 | 
            +
                b, *_ = t.shape
         | 
| 18 | 
            +
                out = a.gather(-1, t)
         | 
| 19 | 
            +
                return out.reshape(b, *((1,) * (len(x_shape) - 1)))
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def is_attn(name):
         | 
| 23 | 
            +
                return "attn1" or "attn2" == name.split(".")[-1]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def set_processors(attentions):
         | 
| 27 | 
            +
                for attn in attentions:
         | 
| 28 | 
            +
                    attn.set_processor(AttnProcessor2_0())
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def set_torch_2_attn(unet):
         | 
| 32 | 
            +
                optim_count = 0
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                for name, module in unet.named_modules():
         | 
| 35 | 
            +
                    if is_attn(name):
         | 
| 36 | 
            +
                        if isinstance(module, torch.nn.ModuleList):
         | 
| 37 | 
            +
                            for m in module:
         | 
| 38 | 
            +
                                if isinstance(m, BasicTransformerBlock):
         | 
| 39 | 
            +
                                    set_processors([m.attn1, m.attn2])
         | 
| 40 | 
            +
                                    optim_count += 1
         | 
| 41 | 
            +
                if optim_count > 0:
         | 
| 42 | 
            +
                    print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            # From LatentConsistencyModel.get_guidance_scale_embedding
         | 
| 46 | 
            +
            def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                Args:
         | 
| 51 | 
            +
                    timesteps (`torch.Tensor`):
         | 
| 52 | 
            +
                        generate embedding vectors at these timesteps
         | 
| 53 | 
            +
                    embedding_dim (`int`, *optional*, defaults to 512):
         | 
| 54 | 
            +
                        dimension of the embeddings to generate
         | 
| 55 | 
            +
                    dtype:
         | 
| 56 | 
            +
                        data type of the generated embeddings
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                Returns:
         | 
| 59 | 
            +
                    `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                assert len(w.shape) == 1
         | 
| 62 | 
            +
                w = w * 1000.0
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                half_dim = embedding_dim // 2
         | 
| 65 | 
            +
                emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
         | 
| 66 | 
            +
                emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
         | 
| 67 | 
            +
                emb = w.to(dtype)[:, None] * emb[None, :]
         | 
| 68 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 69 | 
            +
                if embedding_dim % 2 == 1:  # zero pad
         | 
| 70 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1))
         | 
| 71 | 
            +
                assert emb.shape == (w.shape[0], embedding_dim)
         | 
| 72 | 
            +
                return emb
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def append_dims(x, target_dims):
         | 
| 76 | 
            +
                """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
         | 
| 77 | 
            +
                dims_to_append = target_dims - x.ndim
         | 
| 78 | 
            +
                if dims_to_append < 0:
         | 
| 79 | 
            +
                    raise ValueError(
         | 
| 80 | 
            +
                        f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
                return x[(...,) + (None,) * dims_to_append]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            # From LCMScheduler.get_scalings_for_boundary_condition_discrete
         | 
| 86 | 
            +
            def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
         | 
| 87 | 
            +
                scaled_timestep = timestep_scaling * timestep
         | 
| 88 | 
            +
                c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
         | 
| 89 | 
            +
                c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
         | 
| 90 | 
            +
                return c_skip, c_out
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            # Compare LCMScheduler.step, Step 4
         | 
| 94 | 
            +
            def get_predicted_original_sample(
         | 
| 95 | 
            +
                model_output, timesteps, sample, prediction_type, alphas, sigmas
         | 
| 96 | 
            +
            ):
         | 
| 97 | 
            +
                alphas = extract_into_tensor(alphas, timesteps, sample.shape)
         | 
| 98 | 
            +
                sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
         | 
| 99 | 
            +
                if prediction_type == "epsilon":
         | 
| 100 | 
            +
                    pred_x_0 = (sample - sigmas * model_output) / alphas
         | 
| 101 | 
            +
                elif prediction_type == "sample":
         | 
| 102 | 
            +
                    pred_x_0 = model_output
         | 
| 103 | 
            +
                elif prediction_type == "v_prediction":
         | 
| 104 | 
            +
                    pred_x_0 = alphas * sample - sigmas * model_output
         | 
| 105 | 
            +
                else:
         | 
| 106 | 
            +
                    raise ValueError(
         | 
| 107 | 
            +
                        f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
         | 
| 108 | 
            +
                        f" are supported."
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                return pred_x_0
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            # Based on step 4 in DDIMScheduler.step
         | 
| 115 | 
            +
            def get_predicted_noise(
         | 
| 116 | 
            +
                model_output, timesteps, sample, prediction_type, alphas, sigmas
         | 
| 117 | 
            +
            ):
         | 
| 118 | 
            +
                alphas = extract_into_tensor(alphas, timesteps, sample.shape)
         | 
| 119 | 
            +
                sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
         | 
| 120 | 
            +
                if prediction_type == "epsilon":
         | 
| 121 | 
            +
                    pred_epsilon = model_output
         | 
| 122 | 
            +
                elif prediction_type == "sample":
         | 
| 123 | 
            +
                    pred_epsilon = (sample - alphas * model_output) / sigmas
         | 
| 124 | 
            +
                elif prediction_type == "v_prediction":
         | 
| 125 | 
            +
                    pred_epsilon = alphas * model_output + sigmas * sample
         | 
| 126 | 
            +
                else:
         | 
| 127 | 
            +
                    raise ValueError(
         | 
| 128 | 
            +
                        f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
         | 
| 129 | 
            +
                        f" are supported."
         | 
| 130 | 
            +
                    )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                return pred_epsilon
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            # From LatentConsistencyModel.get_guidance_scale_embedding
         | 
| 136 | 
            +
            def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                Args:
         | 
| 141 | 
            +
                    timesteps (`torch.Tensor`):
         | 
| 142 | 
            +
                        generate embedding vectors at these timesteps
         | 
| 143 | 
            +
                    embedding_dim (`int`, *optional*, defaults to 512):
         | 
| 144 | 
            +
                        dimension of the embeddings to generate
         | 
| 145 | 
            +
                    dtype:
         | 
| 146 | 
            +
                        data type of the generated embeddings
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                Returns:
         | 
| 149 | 
            +
                    `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                assert len(w.shape) == 1
         | 
| 152 | 
            +
                w = w * 1000.0
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                half_dim = embedding_dim // 2
         | 
| 155 | 
            +
                emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
         | 
| 156 | 
            +
                emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
         | 
| 157 | 
            +
                emb = w.to(dtype)[:, None] * emb[None, :]
         | 
| 158 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 159 | 
            +
                if embedding_dim % 2 == 1:  # zero pad
         | 
| 160 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1))
         | 
| 161 | 
            +
                assert emb.shape == (w.shape[0], embedding_dim)
         | 
| 162 | 
            +
                return emb
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            def append_dims(x, target_dims):
         | 
| 166 | 
            +
                """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
         | 
| 167 | 
            +
                dims_to_append = target_dims - x.ndim
         | 
| 168 | 
            +
                if dims_to_append < 0:
         | 
| 169 | 
            +
                    raise ValueError(
         | 
| 170 | 
            +
                        f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
         | 
| 171 | 
            +
                    )
         | 
| 172 | 
            +
                return x[(...,) + (None,) * dims_to_append]
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            # From LCMScheduler.get_scalings_for_boundary_condition_discrete
         | 
| 176 | 
            +
            def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
         | 
| 177 | 
            +
                scaled_timestep = timestep_scaling * timestep
         | 
| 178 | 
            +
                c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
         | 
| 179 | 
            +
                c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
         | 
| 180 | 
            +
                return c_skip, c_out
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            # Compare LCMScheduler.step, Step 4
         | 
| 184 | 
            +
            def get_predicted_original_sample(
         | 
| 185 | 
            +
                model_output, timesteps, sample, prediction_type, alphas, sigmas
         | 
| 186 | 
            +
            ):
         | 
| 187 | 
            +
                alphas = extract_into_tensor(alphas, timesteps, sample.shape)
         | 
| 188 | 
            +
                sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
         | 
| 189 | 
            +
                if prediction_type == "epsilon":
         | 
| 190 | 
            +
                    pred_x_0 = (sample - sigmas * model_output) / alphas
         | 
| 191 | 
            +
                elif prediction_type == "sample":
         | 
| 192 | 
            +
                    pred_x_0 = model_output
         | 
| 193 | 
            +
                elif prediction_type == "v_prediction":
         | 
| 194 | 
            +
                    pred_x_0 = alphas * sample - sigmas * model_output
         | 
| 195 | 
            +
                else:
         | 
| 196 | 
            +
                    raise ValueError(
         | 
| 197 | 
            +
                        f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
         | 
| 198 | 
            +
                        f" are supported."
         | 
| 199 | 
            +
                    )
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                return pred_x_0
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            # Based on step 4 in DDIMScheduler.step
         | 
| 205 | 
            +
            def get_predicted_noise(
         | 
| 206 | 
            +
                model_output, timesteps, sample, prediction_type, alphas, sigmas
         | 
| 207 | 
            +
            ):
         | 
| 208 | 
            +
                alphas = extract_into_tensor(alphas, timesteps, sample.shape)
         | 
| 209 | 
            +
                sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
         | 
| 210 | 
            +
                if prediction_type == "epsilon":
         | 
| 211 | 
            +
                    pred_epsilon = model_output
         | 
| 212 | 
            +
                elif prediction_type == "sample":
         | 
| 213 | 
            +
                    pred_epsilon = (sample - alphas * model_output) / sigmas
         | 
| 214 | 
            +
                elif prediction_type == "v_prediction":
         | 
| 215 | 
            +
                    pred_epsilon = alphas * model_output + sigmas * sample
         | 
| 216 | 
            +
                else:
         | 
| 217 | 
            +
                    raise ValueError(
         | 
| 218 | 
            +
                        f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
         | 
| 219 | 
            +
                        f" are supported."
         | 
| 220 | 
            +
                    )
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                return pred_epsilon
         | 
| 223 | 
            +
             | 
| 224 | 
            +
             | 
| 225 | 
            +
            def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
         | 
| 226 | 
            +
                extra_params = extra_params if len(extra_params.keys()) > 0 else None
         | 
| 227 | 
            +
                return {
         | 
| 228 | 
            +
                    "model": model,
         | 
| 229 | 
            +
                    "condition": condition,
         | 
| 230 | 
            +
                    "extra_params": extra_params,
         | 
| 231 | 
            +
                    "is_lora": is_lora,
         | 
| 232 | 
            +
                    "negation": negation,
         | 
| 233 | 
            +
                }
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            def create_optim_params(name="param", params=None, lr=5e-6, extra_params=None):
         | 
| 237 | 
            +
                params = {"name": name, "params": params, "lr": lr}
         | 
| 238 | 
            +
                if extra_params is not None:
         | 
| 239 | 
            +
                    for k, v in extra_params.items():
         | 
| 240 | 
            +
                        params[k] = v
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                return params
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
            def create_optimizer_params(model_list, lr):
         | 
| 246 | 
            +
                import itertools
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                optimizer_params = []
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                for optim in model_list:
         | 
| 251 | 
            +
                    model, condition, extra_params, is_lora, negation = optim.values()
         | 
| 252 | 
            +
                    # Check if we are doing LoRA training.
         | 
| 253 | 
            +
                    if is_lora and condition and isinstance(model, list):
         | 
| 254 | 
            +
                        params = create_optim_params(
         | 
| 255 | 
            +
                            params=itertools.chain(*model), extra_params=extra_params
         | 
| 256 | 
            +
                        )
         | 
| 257 | 
            +
                        optimizer_params.append(params)
         | 
| 258 | 
            +
                        continue
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    if is_lora and condition and not isinstance(model, list):
         | 
| 261 | 
            +
                        for n, p in model.named_parameters():
         | 
| 262 | 
            +
                            if "lora" in n:
         | 
| 263 | 
            +
                                params = create_optim_params(n, p, lr, extra_params)
         | 
| 264 | 
            +
                                optimizer_params.append(params)
         | 
| 265 | 
            +
                        continue
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    # If this is true, we can train it.
         | 
| 268 | 
            +
                    if condition:
         | 
| 269 | 
            +
                        for n, p in model.named_parameters():
         | 
| 270 | 
            +
                            should_negate = "lora" in n and not is_lora
         | 
| 271 | 
            +
                            if should_negate:
         | 
| 272 | 
            +
                                continue
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                            params = create_optim_params(n, p, lr, extra_params)
         | 
| 275 | 
            +
                            optimizer_params.append(params)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                return optimizer_params
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            def handle_trainable_modules(
         | 
| 281 | 
            +
                model, trainable_modules=None, is_enabled=True, negation=None
         | 
| 282 | 
            +
            ):
         | 
| 283 | 
            +
                acc = []
         | 
| 284 | 
            +
                unfrozen_params = 0
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                if trainable_modules is not None:
         | 
| 287 | 
            +
                    unlock_all = any([name == "all" for name in trainable_modules])
         | 
| 288 | 
            +
                    if unlock_all:
         | 
| 289 | 
            +
                        model.requires_grad_(True)
         | 
| 290 | 
            +
                        unfrozen_params = len(list(model.parameters()))
         | 
| 291 | 
            +
                    else:
         | 
| 292 | 
            +
                        model.requires_grad_(False)
         | 
| 293 | 
            +
                        for name, param in model.named_parameters():
         | 
| 294 | 
            +
                            for tm in trainable_modules:
         | 
| 295 | 
            +
                                if all([tm in name, name not in acc, "lora" not in name]):
         | 
| 296 | 
            +
                                    param.requires_grad_(is_enabled)
         | 
| 297 | 
            +
                                    acc.append(name)
         | 
| 298 | 
            +
                                    unfrozen_params += 1
         | 
| 299 | 
            +
             | 
| 300 | 
            +
             | 
| 301 | 
            +
            def huber_loss(pred, target, huber_c=0.001):
         | 
| 302 | 
            +
                loss = torch.sqrt((pred.float() - target.float()) ** 2 + huber_c**2) - huber_c
         | 
| 303 | 
            +
                return loss.mean()
         | 
| 304 | 
            +
             | 
| 305 | 
            +
             | 
| 306 | 
            +
            @torch.no_grad()
         | 
| 307 | 
            +
            def update_ema(target_params, source_params, rate=0.99):
         | 
| 308 | 
            +
                """
         | 
| 309 | 
            +
                Update target parameters to be closer to those of source parameters using
         | 
| 310 | 
            +
                an exponential moving average.
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                :param target_params: the target parameter sequence.
         | 
| 313 | 
            +
                :param source_params: the source parameter sequence.
         | 
| 314 | 
            +
                :param rate: the EMA rate (closer to 1 means slower).
         | 
| 315 | 
            +
                """
         | 
| 316 | 
            +
                for targ, src in zip(target_params, source_params):
         | 
| 317 | 
            +
                    src_to_dtype = src.to(targ.dtype)
         | 
| 318 | 
            +
                    targ.detach().mul_(rate).add_(src_to_dtype, alpha=1 - rate)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
             | 
| 321 | 
            +
            def log_validation_video(pipeline, args, accelerator, save_fps):
         | 
| 322 | 
            +
                if args.seed is None:
         | 
| 323 | 
            +
                    generator = None
         | 
| 324 | 
            +
                else:
         | 
| 325 | 
            +
                    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                validation_prompts = [
         | 
| 328 | 
            +
                    "An astronaut riding a horse.",
         | 
| 329 | 
            +
                    "Darth vader surfing in waves.",
         | 
| 330 | 
            +
                    "Robot dancing in times square.",
         | 
| 331 | 
            +
                    "Clown fish swimming through the coral reef.",
         | 
| 332 | 
            +
                    "A child excitedly swings on a rusty swing set, laughter filling the air.",
         | 
| 333 | 
            +
                    "With the style of van gogh, A young couple dances under the moonlight by the lake.",
         | 
| 334 | 
            +
                    "A young woman with glasses is jogging in the park wearing a pink headband.",
         | 
| 335 | 
            +
                    "Impressionist style, a yellow rubber duck floating on the wave on the sunset",
         | 
| 336 | 
            +
                    "Wolf, turns its head, in the wild",
         | 
| 337 | 
            +
                    "Iron man, walks, on the moon, 8k, high detailed, best quality",
         | 
| 338 | 
            +
                    "With the style of low-poly game art, A majestic, white horse gallops gracefully",
         | 
| 339 | 
            +
                    "a rabbit, low-poly game art style",
         | 
| 340 | 
            +
                ]
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                video_logs = []
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                if getattr(args, "use_motion_cond", False):
         | 
| 345 | 
            +
                    use_motion_cond = True
         | 
| 346 | 
            +
                else:
         | 
| 347 | 
            +
                    use_motion_cond = False
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                for _, prompt in enumerate(validation_prompts):
         | 
| 350 | 
            +
                    if use_motion_cond:
         | 
| 351 | 
            +
                        motin_gs_unit = (args.motion_gs_max - args.motion_gs_min) / 2
         | 
| 352 | 
            +
                        for i in range(3):
         | 
| 353 | 
            +
                            with torch.autocast("cuda"):
         | 
| 354 | 
            +
                                videos = pipeline(
         | 
| 355 | 
            +
                                    prompt=prompt,
         | 
| 356 | 
            +
                                    frames=args.n_frames,
         | 
| 357 | 
            +
                                    num_inference_steps=8,
         | 
| 358 | 
            +
                                    num_videos_per_prompt=1,
         | 
| 359 | 
            +
                                    fps=args.fps,
         | 
| 360 | 
            +
                                    use_motion_cond=True,
         | 
| 361 | 
            +
                                    motion_gs=motin_gs_unit * i,
         | 
| 362 | 
            +
                                    lcm_origin_steps=args.num_ddim_timesteps,
         | 
| 363 | 
            +
                                    generator=generator,
         | 
| 364 | 
            +
                                )
         | 
| 365 | 
            +
                                videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0
         | 
| 366 | 
            +
                                videos = (
         | 
| 367 | 
            +
                                    (videos * 255)
         | 
| 368 | 
            +
                                    .to(torch.uint8)
         | 
| 369 | 
            +
                                    .permute(0, 2, 1, 3, 4)
         | 
| 370 | 
            +
                                    .cpu()
         | 
| 371 | 
            +
                                    .numpy()
         | 
| 372 | 
            +
                                )
         | 
| 373 | 
            +
                            video_logs.append(
         | 
| 374 | 
            +
                                {
         | 
| 375 | 
            +
                                    "validation_prompt": f"GS={i * motin_gs_unit}, {prompt}",
         | 
| 376 | 
            +
                                    "videos": videos,
         | 
| 377 | 
            +
                                }
         | 
| 378 | 
            +
                            )
         | 
| 379 | 
            +
                    else:
         | 
| 380 | 
            +
                        for i in range(2):
         | 
| 381 | 
            +
                            with torch.autocast("cuda"):
         | 
| 382 | 
            +
                                videos = pipeline(
         | 
| 383 | 
            +
                                    prompt=prompt,
         | 
| 384 | 
            +
                                    frames=args.n_frames,
         | 
| 385 | 
            +
                                    num_inference_steps=4 * (i + 1),
         | 
| 386 | 
            +
                                    num_videos_per_prompt=1,
         | 
| 387 | 
            +
                                    fps=args.fps,
         | 
| 388 | 
            +
                                    use_motion_cond=False,
         | 
| 389 | 
            +
                                    lcm_origin_steps=args.num_ddim_timesteps,
         | 
| 390 | 
            +
                                    generator=generator,
         | 
| 391 | 
            +
                                )
         | 
| 392 | 
            +
                                videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0
         | 
| 393 | 
            +
                                videos = (
         | 
| 394 | 
            +
                                    (videos * 255)
         | 
| 395 | 
            +
                                    .to(torch.uint8)
         | 
| 396 | 
            +
                                    .permute(0, 2, 1, 3, 4)
         | 
| 397 | 
            +
                                    .cpu()
         | 
| 398 | 
            +
                                    .numpy()
         | 
| 399 | 
            +
                                )
         | 
| 400 | 
            +
                            video_logs.append(
         | 
| 401 | 
            +
                                {
         | 
| 402 | 
            +
                                    "validation_prompt": f"Steps={4 * (i + 1)}, {prompt}",
         | 
| 403 | 
            +
                                    "videos": videos,
         | 
| 404 | 
            +
                                }
         | 
| 405 | 
            +
                            )
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                for tracker in accelerator.trackers:
         | 
| 408 | 
            +
                    if tracker.name == "wandb":
         | 
| 409 | 
            +
                        formatted_videos = []
         | 
| 410 | 
            +
                        for log in video_logs:
         | 
| 411 | 
            +
                            videos = log["videos"]
         | 
| 412 | 
            +
                            validation_prompt = log["validation_prompt"]
         | 
| 413 | 
            +
                            for video in videos:
         | 
| 414 | 
            +
                                video = wandb.Video(video, caption=validation_prompt, fps=save_fps)
         | 
| 415 | 
            +
                                formatted_videos.append(video)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                        tracker.log({f"validation": formatted_videos})
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    del pipeline
         | 
| 420 | 
            +
                    gc.collect()
         | 
| 421 | 
            +
             | 
| 422 | 
            +
             | 
| 423 | 
            +
            def tuple_type(s):
         | 
| 424 | 
            +
                if isinstance(s, tuple):
         | 
| 425 | 
            +
                    return s
         | 
| 426 | 
            +
                value = ast.literal_eval(s)
         | 
| 427 | 
            +
                if isinstance(value, tuple):
         | 
| 428 | 
            +
                    return value
         | 
| 429 | 
            +
                raise TypeError("Argument must be a tuple")
         | 
| 430 | 
            +
             | 
| 431 | 
            +
             | 
| 432 | 
            +
            def load_model_checkpoint(model, ckpt):
         | 
| 433 | 
            +
                def load_checkpoint(model, ckpt, full_strict):
         | 
| 434 | 
            +
                    state_dict = torch.load(ckpt, map_location="cpu", weights_only=True)
         | 
| 435 | 
            +
                    if "state_dict" in list(state_dict.keys()):
         | 
| 436 | 
            +
                        state_dict = state_dict["state_dict"]
         | 
| 437 | 
            +
                    model.load_state_dict(state_dict, strict=full_strict)
         | 
| 438 | 
            +
                    del state_dict
         | 
| 439 | 
            +
                    gc.collect()
         | 
| 440 | 
            +
                    return model
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                load_checkpoint(model, ckpt, full_strict=True)
         | 
| 443 | 
            +
                print(">>> model checkpoint loaded.")
         | 
| 444 | 
            +
                return model
         | 
| 445 | 
            +
             | 
| 446 | 
            +
             | 
| 447 | 
            +
            def read_video_to_tensor(
         | 
| 448 | 
            +
                path_to_video, sample_fps, sample_frames, uniform_sampling=False
         | 
| 449 | 
            +
            ):
         | 
| 450 | 
            +
                video_reader = VideoReader(path_to_video)
         | 
| 451 | 
            +
                video_fps = video_reader.get_avg_fps()
         | 
| 452 | 
            +
                video_frames = video_reader._num_frame
         | 
| 453 | 
            +
                video_duration = video_frames / video_fps
         | 
| 454 | 
            +
                sample_duration = sample_frames / sample_fps
         | 
| 455 | 
            +
                stride = video_fps / sample_fps
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                if uniform_sampling or video_duration <= sample_duration:
         | 
| 458 | 
            +
                    index_range = np.linspace(0, video_frames - 1, sample_frames).astype(np.int32)
         | 
| 459 | 
            +
                else:
         | 
| 460 | 
            +
                    max_start_frame = video_frames - np.ceil(sample_frames * stride).astype(
         | 
| 461 | 
            +
                        np.int32
         | 
| 462 | 
            +
                    )
         | 
| 463 | 
            +
                    if max_start_frame > 0:
         | 
| 464 | 
            +
                        start_frame = random.randint(0, max_start_frame)
         | 
| 465 | 
            +
                    else:
         | 
| 466 | 
            +
                        start_frame = 0
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                    index_range = start_frame + np.arange(sample_frames) * stride
         | 
| 469 | 
            +
                    index_range = np.round(index_range).astype(np.int32)
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                sampled_frames = video_reader.get_batch(index_range).asnumpy()
         | 
| 472 | 
            +
                pixel_values = torch.from_numpy(sampled_frames).permute(0, 3, 1, 2).contiguous()
         | 
| 473 | 
            +
                pixel_values = pixel_values / 255.0
         | 
| 474 | 
            +
                del video_reader
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                return pixel_values
         | 
| 477 | 
            +
             | 
| 478 | 
            +
             | 
| 479 | 
            +
            def calculate_motion_rank_new(tensor_ref, tensor_gen, rank_k=1):
         | 
| 480 | 
            +
                if rank_k == 0:
         | 
| 481 | 
            +
                    loss = torch.tensor(0.0, device=tensor_ref.device)
         | 
| 482 | 
            +
                elif rank_k > tensor_ref.shape[-1]:
         | 
| 483 | 
            +
                    raise ValueError(
         | 
| 484 | 
            +
                        "The value of rank_k cannot be larger than the number of frames"
         | 
| 485 | 
            +
                    )
         | 
| 486 | 
            +
                else:
         | 
| 487 | 
            +
                    # Sort the reference tensor along the frames dimension
         | 
| 488 | 
            +
                    _, sorted_indices = torch.sort(tensor_ref, dim=-1)
         | 
| 489 | 
            +
                    # Create a mask to select the top rank_k frames
         | 
| 490 | 
            +
                    mask = torch.zeros_like(tensor_ref, dtype=torch.bool)
         | 
| 491 | 
            +
                    mask.scatter_(-1, sorted_indices[..., -rank_k:], True)
         | 
| 492 | 
            +
                    # Compute the mean squared error loss only on the masked elements
         | 
| 493 | 
            +
                    loss = F.mse_loss(tensor_ref[mask].detach(), tensor_gen[mask])
         | 
| 494 | 
            +
                return loss
         | 
| 495 | 
            +
             | 
| 496 | 
            +
             | 
| 497 | 
            +
            def compute_temp_loss(attention_prob, attention_prob_example):
         | 
| 498 | 
            +
                temp_attn_prob_loss = []
         | 
| 499 | 
            +
                # 1. Loop though all layers to get the query, key, and Compute the PCA loss
         | 
| 500 | 
            +
                for name in attention_prob.keys():
         | 
| 501 | 
            +
                    attn_prob_example = attention_prob_example[name]
         | 
| 502 | 
            +
                    attn_prob = attention_prob[name]
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    module_attn_loss = calculate_motion_rank_new(
         | 
| 505 | 
            +
                        attn_prob_example.detach(), attn_prob, rank_k=1
         | 
| 506 | 
            +
                    )
         | 
| 507 | 
            +
                    temp_attn_prob_loss.append(module_attn_loss)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                loss_temp = torch.stack(temp_attn_prob_loss) * 100
         | 
| 510 | 
            +
                loss = loss_temp.mean()
         | 
| 511 | 
            +
                return loss
         | 
    	
        utils/lora.py
    ADDED
    
    | @@ -0,0 +1,1349 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            from itertools import groupby
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import PIL
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from safetensors.torch import safe_open
         | 
| 14 | 
            +
            from safetensors.torch import save_file as safe_save
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            safetensors_available = True
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class LoraInjectedLinear(nn.Module):
         | 
| 20 | 
            +
                def __init__(
         | 
| 21 | 
            +
                    self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
         | 
| 22 | 
            +
                ):
         | 
| 23 | 
            +
                    super().__init__()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    if r > min(in_features, out_features):
         | 
| 26 | 
            +
                        # raise ValueError(
         | 
| 27 | 
            +
                        #    f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
         | 
| 28 | 
            +
                        # )
         | 
| 29 | 
            +
                        print(
         | 
| 30 | 
            +
                            f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}"
         | 
| 31 | 
            +
                        )
         | 
| 32 | 
            +
                        r = min(in_features, out_features)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    self.r = r
         | 
| 35 | 
            +
                    self.linear = nn.Linear(in_features, out_features, bias)
         | 
| 36 | 
            +
                    self.lora_down = nn.Linear(in_features, r, bias=False)
         | 
| 37 | 
            +
                    self.dropout = nn.Dropout(dropout_p)
         | 
| 38 | 
            +
                    self.lora_up = nn.Linear(r, out_features, bias=False)
         | 
| 39 | 
            +
                    self.scale = scale
         | 
| 40 | 
            +
                    self.selector = nn.Identity()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    nn.init.normal_(self.lora_down.weight, std=1 / r)
         | 
| 43 | 
            +
                    nn.init.zeros_(self.lora_up.weight)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def forward(self, input):
         | 
| 46 | 
            +
                    return (
         | 
| 47 | 
            +
                        self.linear(input)
         | 
| 48 | 
            +
                        + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
         | 
| 49 | 
            +
                        * self.scale
         | 
| 50 | 
            +
                    )
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def realize_as_lora(self):
         | 
| 53 | 
            +
                    return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def set_selector_from_diag(self, diag: torch.Tensor):
         | 
| 56 | 
            +
                    # diag is a 1D tensor of size (r,)
         | 
| 57 | 
            +
                    assert diag.shape == (self.r,)
         | 
| 58 | 
            +
                    self.selector = nn.Linear(self.r, self.r, bias=False)
         | 
| 59 | 
            +
                    self.selector.weight.data = torch.diag(diag)
         | 
| 60 | 
            +
                    self.selector.weight.data = self.selector.weight.data.to(
         | 
| 61 | 
            +
                        self.lora_up.weight.device
         | 
| 62 | 
            +
                    ).to(self.lora_up.weight.dtype)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            class LoraInjectedConv2d(nn.Module):
         | 
| 66 | 
            +
                def __init__(
         | 
| 67 | 
            +
                    self,
         | 
| 68 | 
            +
                    in_channels: int,
         | 
| 69 | 
            +
                    out_channels: int,
         | 
| 70 | 
            +
                    kernel_size,
         | 
| 71 | 
            +
                    stride=1,
         | 
| 72 | 
            +
                    padding=0,
         | 
| 73 | 
            +
                    dilation=1,
         | 
| 74 | 
            +
                    groups: int = 1,
         | 
| 75 | 
            +
                    bias: bool = True,
         | 
| 76 | 
            +
                    r: int = 4,
         | 
| 77 | 
            +
                    dropout_p: float = 0.1,
         | 
| 78 | 
            +
                    scale: float = 1.0,
         | 
| 79 | 
            +
                ):
         | 
| 80 | 
            +
                    super().__init__()
         | 
| 81 | 
            +
                    if r > min(in_channels, out_channels):
         | 
| 82 | 
            +
                        print(
         | 
| 83 | 
            +
                            f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}"
         | 
| 84 | 
            +
                        )
         | 
| 85 | 
            +
                        r = min(in_channels, out_channels)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    self.r = r
         | 
| 88 | 
            +
                    self.conv = nn.Conv2d(
         | 
| 89 | 
            +
                        in_channels=in_channels,
         | 
| 90 | 
            +
                        out_channels=out_channels,
         | 
| 91 | 
            +
                        kernel_size=kernel_size,
         | 
| 92 | 
            +
                        stride=stride,
         | 
| 93 | 
            +
                        padding=padding,
         | 
| 94 | 
            +
                        dilation=dilation,
         | 
| 95 | 
            +
                        groups=groups,
         | 
| 96 | 
            +
                        bias=bias,
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.lora_down = nn.Conv2d(
         | 
| 100 | 
            +
                        in_channels=in_channels,
         | 
| 101 | 
            +
                        out_channels=r,
         | 
| 102 | 
            +
                        kernel_size=kernel_size,
         | 
| 103 | 
            +
                        stride=stride,
         | 
| 104 | 
            +
                        padding=padding,
         | 
| 105 | 
            +
                        dilation=dilation,
         | 
| 106 | 
            +
                        groups=groups,
         | 
| 107 | 
            +
                        bias=False,
         | 
| 108 | 
            +
                    )
         | 
| 109 | 
            +
                    self.dropout = nn.Dropout(dropout_p)
         | 
| 110 | 
            +
                    self.lora_up = nn.Conv2d(
         | 
| 111 | 
            +
                        in_channels=r,
         | 
| 112 | 
            +
                        out_channels=out_channels,
         | 
| 113 | 
            +
                        kernel_size=1,
         | 
| 114 | 
            +
                        stride=1,
         | 
| 115 | 
            +
                        padding=0,
         | 
| 116 | 
            +
                        bias=False,
         | 
| 117 | 
            +
                    )
         | 
| 118 | 
            +
                    self.selector = nn.Identity()
         | 
| 119 | 
            +
                    self.scale = scale
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    nn.init.normal_(self.lora_down.weight, std=1 / r)
         | 
| 122 | 
            +
                    nn.init.zeros_(self.lora_up.weight)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def forward(self, input):
         | 
| 125 | 
            +
                    return (
         | 
| 126 | 
            +
                        self.conv(input)
         | 
| 127 | 
            +
                        + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
         | 
| 128 | 
            +
                        * self.scale
         | 
| 129 | 
            +
                    )
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def realize_as_lora(self):
         | 
| 132 | 
            +
                    return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def set_selector_from_diag(self, diag: torch.Tensor):
         | 
| 135 | 
            +
                    # diag is a 1D tensor of size (r,)
         | 
| 136 | 
            +
                    assert diag.shape == (self.r,)
         | 
| 137 | 
            +
                    self.selector = nn.Conv2d(
         | 
| 138 | 
            +
                        in_channels=self.r,
         | 
| 139 | 
            +
                        out_channels=self.r,
         | 
| 140 | 
            +
                        kernel_size=1,
         | 
| 141 | 
            +
                        stride=1,
         | 
| 142 | 
            +
                        padding=0,
         | 
| 143 | 
            +
                        bias=False,
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
                    self.selector.weight.data = torch.diag(diag)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    # same device + dtype as lora_up
         | 
| 148 | 
            +
                    self.selector.weight.data = self.selector.weight.data.to(
         | 
| 149 | 
            +
                        self.lora_up.weight.device
         | 
| 150 | 
            +
                    ).to(self.lora_up.weight.dtype)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            class LoraInjectedConv3d(nn.Module):
         | 
| 154 | 
            +
                def __init__(
         | 
| 155 | 
            +
                    self,
         | 
| 156 | 
            +
                    in_channels: int,
         | 
| 157 | 
            +
                    out_channels: int,
         | 
| 158 | 
            +
                    kernel_size: Tuple[int, int, int],  # (3, 1, 1)
         | 
| 159 | 
            +
                    padding: Tuple[int, int, int],  # (1, 0, 0)
         | 
| 160 | 
            +
                    bias: bool = False,
         | 
| 161 | 
            +
                    r: int = 4,
         | 
| 162 | 
            +
                    dropout_p: float = 0,
         | 
| 163 | 
            +
                    scale: float = 1.0,
         | 
| 164 | 
            +
                ):
         | 
| 165 | 
            +
                    super().__init__()
         | 
| 166 | 
            +
                    if r > min(in_channels, out_channels):
         | 
| 167 | 
            +
                        print(
         | 
| 168 | 
            +
                            f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}"
         | 
| 169 | 
            +
                        )
         | 
| 170 | 
            +
                        r = min(in_channels, out_channels)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    self.r = r
         | 
| 173 | 
            +
                    self.kernel_size = kernel_size
         | 
| 174 | 
            +
                    self.padding = padding
         | 
| 175 | 
            +
                    self.conv = nn.Conv3d(
         | 
| 176 | 
            +
                        in_channels=in_channels,
         | 
| 177 | 
            +
                        out_channels=out_channels,
         | 
| 178 | 
            +
                        kernel_size=kernel_size,
         | 
| 179 | 
            +
                        padding=padding,
         | 
| 180 | 
            +
                    )
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    self.lora_down = nn.Conv3d(
         | 
| 183 | 
            +
                        in_channels=in_channels,
         | 
| 184 | 
            +
                        out_channels=r,
         | 
| 185 | 
            +
                        kernel_size=kernel_size,
         | 
| 186 | 
            +
                        bias=False,
         | 
| 187 | 
            +
                        padding=padding,
         | 
| 188 | 
            +
                    )
         | 
| 189 | 
            +
                    self.dropout = nn.Dropout(dropout_p)
         | 
| 190 | 
            +
                    self.lora_up = nn.Conv3d(
         | 
| 191 | 
            +
                        in_channels=r,
         | 
| 192 | 
            +
                        out_channels=out_channels,
         | 
| 193 | 
            +
                        kernel_size=1,
         | 
| 194 | 
            +
                        stride=1,
         | 
| 195 | 
            +
                        padding=0,
         | 
| 196 | 
            +
                        bias=False,
         | 
| 197 | 
            +
                    )
         | 
| 198 | 
            +
                    self.selector = nn.Identity()
         | 
| 199 | 
            +
                    self.scale = scale
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    nn.init.normal_(self.lora_down.weight, std=1 / r)
         | 
| 202 | 
            +
                    nn.init.zeros_(self.lora_up.weight)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def forward(self, input):
         | 
| 205 | 
            +
                    return (
         | 
| 206 | 
            +
                        self.conv(input)
         | 
| 207 | 
            +
                        + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
         | 
| 208 | 
            +
                        * self.scale
         | 
| 209 | 
            +
                    )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def realize_as_lora(self):
         | 
| 212 | 
            +
                    return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                def set_selector_from_diag(self, diag: torch.Tensor):
         | 
| 215 | 
            +
                    # diag is a 1D tensor of size (r,)
         | 
| 216 | 
            +
                    assert diag.shape == (self.r,)
         | 
| 217 | 
            +
                    self.selector = nn.Conv3d(
         | 
| 218 | 
            +
                        in_channels=self.r,
         | 
| 219 | 
            +
                        out_channels=self.r,
         | 
| 220 | 
            +
                        kernel_size=1,
         | 
| 221 | 
            +
                        stride=1,
         | 
| 222 | 
            +
                        padding=0,
         | 
| 223 | 
            +
                        bias=False,
         | 
| 224 | 
            +
                    )
         | 
| 225 | 
            +
                    self.selector.weight.data = torch.diag(diag)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # same device + dtype as lora_up
         | 
| 228 | 
            +
                    self.selector.weight.data = self.selector.weight.data.to(
         | 
| 229 | 
            +
                        self.lora_up.weight.device
         | 
| 230 | 
            +
                    ).to(self.lora_up.weight.dtype)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
         | 
| 234 | 
            +
             | 
| 235 | 
            +
            UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
         | 
| 236 | 
            +
             | 
| 237 | 
            +
            TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
         | 
| 238 | 
            +
             | 
| 239 | 
            +
            TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
         | 
| 240 | 
            +
             | 
| 241 | 
            +
            DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
         | 
| 242 | 
            +
             | 
| 243 | 
            +
            EMBED_FLAG = "<embed>"
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
            def _find_children(
         | 
| 247 | 
            +
                model,
         | 
| 248 | 
            +
                search_class: List[Type[nn.Module]] = [nn.Linear],
         | 
| 249 | 
            +
            ):
         | 
| 250 | 
            +
                """
         | 
| 251 | 
            +
                Find all modules of a certain class (or union of classes).
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                Returns all matching modules, along with the parent of those moduless and the
         | 
| 254 | 
            +
                names they are referenced by.
         | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
                # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
         | 
| 257 | 
            +
                for parent in model.modules():
         | 
| 258 | 
            +
                    for name, module in parent.named_children():
         | 
| 259 | 
            +
                        if any([isinstance(module, _class) for _class in search_class]):
         | 
| 260 | 
            +
                            yield parent, name, module
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            def _find_modules_v2(
         | 
| 264 | 
            +
                model,
         | 
| 265 | 
            +
                ancestor_class: Optional[Set[str]] = None,
         | 
| 266 | 
            +
                search_class: List[Type[nn.Module]] = [nn.Linear],
         | 
| 267 | 
            +
                exclude_children_of: Optional[List[Type[nn.Module]]] = [
         | 
| 268 | 
            +
                    LoraInjectedLinear,
         | 
| 269 | 
            +
                    LoraInjectedConv2d,
         | 
| 270 | 
            +
                    LoraInjectedConv3d,
         | 
| 271 | 
            +
                ],
         | 
| 272 | 
            +
            ):
         | 
| 273 | 
            +
                """
         | 
| 274 | 
            +
                Find all modules of a certain class (or union of classes) that are direct or
         | 
| 275 | 
            +
                indirect descendants of other modules of a certain class (or union of classes).
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                Returns all matching modules, along with the parent of those moduless and the
         | 
| 278 | 
            +
                names they are referenced by.
         | 
| 279 | 
            +
                """
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                # Get the targets we should replace all linears under
         | 
| 282 | 
            +
                if ancestor_class is not None:
         | 
| 283 | 
            +
                    ancestors = (
         | 
| 284 | 
            +
                        module
         | 
| 285 | 
            +
                        for module in model.modules()
         | 
| 286 | 
            +
                        if module.__class__.__name__ in ancestor_class
         | 
| 287 | 
            +
                    )
         | 
| 288 | 
            +
                else:
         | 
| 289 | 
            +
                    # this, incase you want to naively iterate over all modules.
         | 
| 290 | 
            +
                    ancestors = [module for module in model.modules()]
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
         | 
| 293 | 
            +
                for ancestor in ancestors:
         | 
| 294 | 
            +
                    for fullname, module in ancestor.named_modules():
         | 
| 295 | 
            +
                        if any([isinstance(module, _class) for _class in search_class]):
         | 
| 296 | 
            +
                            # Find the direct parent if this is a descendant, not a child, of target
         | 
| 297 | 
            +
                            *path, name = fullname.split(".")
         | 
| 298 | 
            +
                            parent = ancestor
         | 
| 299 | 
            +
                            while path:
         | 
| 300 | 
            +
                                parent = parent.get_submodule(path.pop(0))
         | 
| 301 | 
            +
                            # Skip this linear if it's a child of a LoraInjectedLinear
         | 
| 302 | 
            +
                            if exclude_children_of and any(
         | 
| 303 | 
            +
                                [isinstance(parent, _class) for _class in exclude_children_of]
         | 
| 304 | 
            +
                            ):
         | 
| 305 | 
            +
                                continue
         | 
| 306 | 
            +
                            # Otherwise, yield it
         | 
| 307 | 
            +
                            yield parent, name, module
         | 
| 308 | 
            +
             | 
| 309 | 
            +
             | 
| 310 | 
            +
            def _find_modules_old(
         | 
| 311 | 
            +
                model,
         | 
| 312 | 
            +
                ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
         | 
| 313 | 
            +
                search_class: List[Type[nn.Module]] = [nn.Linear],
         | 
| 314 | 
            +
                exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
         | 
| 315 | 
            +
            ):
         | 
| 316 | 
            +
                ret = []
         | 
| 317 | 
            +
                for _module in model.modules():
         | 
| 318 | 
            +
                    if _module.__class__.__name__ in ancestor_class:
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                        for name, _child_module in _module.named_modules():
         | 
| 321 | 
            +
                            if _child_module.__class__ in search_class:
         | 
| 322 | 
            +
                                ret.append((_module, name, _child_module))
         | 
| 323 | 
            +
                print(ret)
         | 
| 324 | 
            +
                return ret
         | 
| 325 | 
            +
             | 
| 326 | 
            +
             | 
| 327 | 
            +
            _find_modules = _find_modules_v2
         | 
| 328 | 
            +
             | 
| 329 | 
            +
             | 
| 330 | 
            +
            def inject_trainable_lora(
         | 
| 331 | 
            +
                model: nn.Module,
         | 
| 332 | 
            +
                target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
         | 
| 333 | 
            +
                r: int = 4,
         | 
| 334 | 
            +
                loras=None,  # path to lora .pt
         | 
| 335 | 
            +
                verbose: bool = False,
         | 
| 336 | 
            +
                dropout_p: float = 0.0,
         | 
| 337 | 
            +
                scale: float = 1.0,
         | 
| 338 | 
            +
            ):
         | 
| 339 | 
            +
                """
         | 
| 340 | 
            +
                inject lora into model, and returns lora parameter groups.
         | 
| 341 | 
            +
                """
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                require_grad_params = []
         | 
| 344 | 
            +
                names = []
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                if loras != None:
         | 
| 347 | 
            +
                    loras = torch.load(loras)
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                for _module, name, _child_module in _find_modules(
         | 
| 350 | 
            +
                    model, target_replace_module, search_class=[nn.Linear]
         | 
| 351 | 
            +
                ):
         | 
| 352 | 
            +
                    weight = _child_module.weight
         | 
| 353 | 
            +
                    bias = _child_module.bias
         | 
| 354 | 
            +
                    if verbose:
         | 
| 355 | 
            +
                        print("LoRA Injection : injecting lora into ", name)
         | 
| 356 | 
            +
                        print("LoRA Injection : weight shape", weight.shape)
         | 
| 357 | 
            +
                    _tmp = LoraInjectedLinear(
         | 
| 358 | 
            +
                        _child_module.in_features,
         | 
| 359 | 
            +
                        _child_module.out_features,
         | 
| 360 | 
            +
                        _child_module.bias is not None,
         | 
| 361 | 
            +
                        r=r,
         | 
| 362 | 
            +
                        dropout_p=dropout_p,
         | 
| 363 | 
            +
                        scale=scale,
         | 
| 364 | 
            +
                    )
         | 
| 365 | 
            +
                    _tmp.linear.weight = weight
         | 
| 366 | 
            +
                    if bias is not None:
         | 
| 367 | 
            +
                        _tmp.linear.bias = bias
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    # switch the module
         | 
| 370 | 
            +
                    _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
         | 
| 371 | 
            +
                    _module._modules[name] = _tmp
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    require_grad_params.append(_module._modules[name].lora_up.parameters())
         | 
| 374 | 
            +
                    require_grad_params.append(_module._modules[name].lora_down.parameters())
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    if loras != None:
         | 
| 377 | 
            +
                        _module._modules[name].lora_up.weight = loras.pop(0)
         | 
| 378 | 
            +
                        _module._modules[name].lora_down.weight = loras.pop(0)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    _module._modules[name].lora_up.weight.requires_grad = True
         | 
| 381 | 
            +
                    _module._modules[name].lora_down.weight.requires_grad = True
         | 
| 382 | 
            +
                    names.append(name)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                return require_grad_params, names
         | 
| 385 | 
            +
             | 
| 386 | 
            +
             | 
| 387 | 
            +
            def inject_trainable_lora_extended(
         | 
| 388 | 
            +
                model: nn.Module,
         | 
| 389 | 
            +
                target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
         | 
| 390 | 
            +
                r: int = 4,
         | 
| 391 | 
            +
                loras=None,  # path to lora .pt
         | 
| 392 | 
            +
            ):
         | 
| 393 | 
            +
                """
         | 
| 394 | 
            +
                inject lora into model, and returns lora parameter groups.
         | 
| 395 | 
            +
                """
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                require_grad_params = []
         | 
| 398 | 
            +
                names = []
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                if loras != None:
         | 
| 401 | 
            +
                    loras = torch.load(loras)
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                for _module, name, _child_module in _find_modules(
         | 
| 404 | 
            +
                    model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
         | 
| 405 | 
            +
                ):
         | 
| 406 | 
            +
                    if _child_module.__class__ == nn.Linear:
         | 
| 407 | 
            +
                        weight = _child_module.weight
         | 
| 408 | 
            +
                        bias = _child_module.bias
         | 
| 409 | 
            +
                        _tmp = LoraInjectedLinear(
         | 
| 410 | 
            +
                            _child_module.in_features,
         | 
| 411 | 
            +
                            _child_module.out_features,
         | 
| 412 | 
            +
                            _child_module.bias is not None,
         | 
| 413 | 
            +
                            r=r,
         | 
| 414 | 
            +
                        )
         | 
| 415 | 
            +
                        _tmp.linear.weight = weight
         | 
| 416 | 
            +
                        if bias is not None:
         | 
| 417 | 
            +
                            _tmp.linear.bias = bias
         | 
| 418 | 
            +
                    elif _child_module.__class__ == nn.Conv2d:
         | 
| 419 | 
            +
                        weight = _child_module.weight
         | 
| 420 | 
            +
                        bias = _child_module.bias
         | 
| 421 | 
            +
                        _tmp = LoraInjectedConv2d(
         | 
| 422 | 
            +
                            _child_module.in_channels,
         | 
| 423 | 
            +
                            _child_module.out_channels,
         | 
| 424 | 
            +
                            _child_module.kernel_size,
         | 
| 425 | 
            +
                            _child_module.stride,
         | 
| 426 | 
            +
                            _child_module.padding,
         | 
| 427 | 
            +
                            _child_module.dilation,
         | 
| 428 | 
            +
                            _child_module.groups,
         | 
| 429 | 
            +
                            _child_module.bias is not None,
         | 
| 430 | 
            +
                            r=r,
         | 
| 431 | 
            +
                        )
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                        _tmp.conv.weight = weight
         | 
| 434 | 
            +
                        if bias is not None:
         | 
| 435 | 
            +
                            _tmp.conv.bias = bias
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    elif _child_module.__class__ == nn.Conv3d:
         | 
| 438 | 
            +
                        weight = _child_module.weight
         | 
| 439 | 
            +
                        bias = _child_module.bias
         | 
| 440 | 
            +
                        _tmp = LoraInjectedConv3d(
         | 
| 441 | 
            +
                            _child_module.in_channels,
         | 
| 442 | 
            +
                            _child_module.out_channels,
         | 
| 443 | 
            +
                            bias=_child_module.bias is not None,
         | 
| 444 | 
            +
                            kernel_size=_child_module.kernel_size,
         | 
| 445 | 
            +
                            padding=_child_module.padding,
         | 
| 446 | 
            +
                            r=r,
         | 
| 447 | 
            +
                        )
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                        _tmp.conv.weight = weight
         | 
| 450 | 
            +
                        if bias is not None:
         | 
| 451 | 
            +
                            _tmp.conv.bias = bias
         | 
| 452 | 
            +
                    else:
         | 
| 453 | 
            +
                        # ignore module which are not included in search_class
         | 
| 454 | 
            +
                        # For example:
         | 
| 455 | 
            +
                        # zeroscope_v2_576w model, which has <class 'diffusers.models.lora.LoRACompatibleLinear'> and <class 'diffusers.models.lora.LoRACompatibleConv'>
         | 
| 456 | 
            +
                        continue
         | 
| 457 | 
            +
                    # switch the module
         | 
| 458 | 
            +
                    _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
         | 
| 459 | 
            +
                    if bias is not None:
         | 
| 460 | 
            +
                        _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    _module._modules[name] = _tmp
         | 
| 463 | 
            +
                    require_grad_params.append(_module._modules[name].lora_up.parameters())
         | 
| 464 | 
            +
                    require_grad_params.append(_module._modules[name].lora_down.parameters())
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    if loras != None:
         | 
| 467 | 
            +
                        param = loras.pop(0)
         | 
| 468 | 
            +
                        if isinstance(param, torch.FloatTensor):
         | 
| 469 | 
            +
                            _module._modules[name].lora_up.weight = nn.Parameter(param)
         | 
| 470 | 
            +
                        else:
         | 
| 471 | 
            +
                            _module._modules[name].lora_up.weight = param
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                        param = loras.pop(0)
         | 
| 474 | 
            +
                        if isinstance(param, torch.FloatTensor):
         | 
| 475 | 
            +
                            _module._modules[name].lora_down.weight = nn.Parameter(param)
         | 
| 476 | 
            +
                        else:
         | 
| 477 | 
            +
                            _module._modules[name].lora_down.weight = param
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                        # _module._modules[name].lora_up.weight = loras.pop(0)
         | 
| 480 | 
            +
                        # _module._modules[name].lora_down.weight = loras.pop(0)
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    _module._modules[name].lora_up.weight.requires_grad = True
         | 
| 483 | 
            +
                    _module._modules[name].lora_down.weight.requires_grad = True
         | 
| 484 | 
            +
                    names.append(name)
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                return require_grad_params, names
         | 
| 487 | 
            +
             | 
| 488 | 
            +
             | 
| 489 | 
            +
            def inject_inferable_lora(
         | 
| 490 | 
            +
                model,
         | 
| 491 | 
            +
                lora_path="",
         | 
| 492 | 
            +
                unet_replace_modules=["UNet3DConditionModel"],
         | 
| 493 | 
            +
                text_encoder_replace_modules=["CLIPEncoderLayer"],
         | 
| 494 | 
            +
                is_extended=False,
         | 
| 495 | 
            +
                r=16,
         | 
| 496 | 
            +
            ):
         | 
| 497 | 
            +
                from transformers.models.clip import CLIPTextModel
         | 
| 498 | 
            +
                from diffusers import UNet3DConditionModel
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                def is_text_model(f):
         | 
| 501 | 
            +
                    return "text_encoder" in f and isinstance(model.text_encoder, CLIPTextModel)
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                def is_unet(f):
         | 
| 504 | 
            +
                    return "unet" in f and model.unet.__class__.__name__ == "UNet3DConditionModel"
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                if os.path.exists(lora_path):
         | 
| 507 | 
            +
                    try:
         | 
| 508 | 
            +
                        for f in os.listdir(lora_path):
         | 
| 509 | 
            +
                            if f.endswith(".pt"):
         | 
| 510 | 
            +
                                lora_file = os.path.join(lora_path, f)
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                                if is_text_model(f):
         | 
| 513 | 
            +
                                    monkeypatch_or_replace_lora(
         | 
| 514 | 
            +
                                        model.text_encoder,
         | 
| 515 | 
            +
                                        torch.load(lora_file),
         | 
| 516 | 
            +
                                        target_replace_module=text_encoder_replace_modules,
         | 
| 517 | 
            +
                                        r=r,
         | 
| 518 | 
            +
                                    )
         | 
| 519 | 
            +
                                    print("Successfully loaded Text Encoder LoRa.")
         | 
| 520 | 
            +
                                    continue
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                                if is_unet(f):
         | 
| 523 | 
            +
                                    monkeypatch_or_replace_lora_extended(
         | 
| 524 | 
            +
                                        model.unet,
         | 
| 525 | 
            +
                                        torch.load(lora_file),
         | 
| 526 | 
            +
                                        target_replace_module=unet_replace_modules,
         | 
| 527 | 
            +
                                        r=r,
         | 
| 528 | 
            +
                                    )
         | 
| 529 | 
            +
                                    print("Successfully loaded UNET LoRa.")
         | 
| 530 | 
            +
                                    continue
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                                print(
         | 
| 533 | 
            +
                                    "Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)"
         | 
| 534 | 
            +
                                )
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    except Exception as e:
         | 
| 537 | 
            +
                        print(e)
         | 
| 538 | 
            +
                        print("Couldn't inject LoRA's due to an error.")
         | 
| 539 | 
            +
             | 
| 540 | 
            +
             | 
| 541 | 
            +
            def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                loras = []
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                for _m, _n, _child_module in _find_modules(
         | 
| 546 | 
            +
                    model,
         | 
| 547 | 
            +
                    target_replace_module,
         | 
| 548 | 
            +
                    search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
         | 
| 549 | 
            +
                ):
         | 
| 550 | 
            +
                    loras.append((_child_module.lora_up, _child_module.lora_down))
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                if len(loras) == 0:
         | 
| 553 | 
            +
                    raise ValueError("No lora injected.")
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                return loras
         | 
| 556 | 
            +
             | 
| 557 | 
            +
             | 
| 558 | 
            +
            def extract_lora_as_tensor(
         | 
| 559 | 
            +
                model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
         | 
| 560 | 
            +
            ):
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                loras = []
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                for _m, _n, _child_module in _find_modules(
         | 
| 565 | 
            +
                    model,
         | 
| 566 | 
            +
                    target_replace_module,
         | 
| 567 | 
            +
                    search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
         | 
| 568 | 
            +
                ):
         | 
| 569 | 
            +
                    up, down = _child_module.realize_as_lora()
         | 
| 570 | 
            +
                    if as_fp16:
         | 
| 571 | 
            +
                        up = up.to(torch.float16)
         | 
| 572 | 
            +
                        down = down.to(torch.float16)
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                    loras.append((up, down))
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                if len(loras) == 0:
         | 
| 577 | 
            +
                    raise ValueError("No lora injected.")
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                return loras
         | 
| 580 | 
            +
             | 
| 581 | 
            +
             | 
| 582 | 
            +
            def save_lora_weight(
         | 
| 583 | 
            +
                model,
         | 
| 584 | 
            +
                path="./lora.pt",
         | 
| 585 | 
            +
                target_replace_module=DEFAULT_TARGET_REPLACE,
         | 
| 586 | 
            +
            ):
         | 
| 587 | 
            +
                weights = []
         | 
| 588 | 
            +
                for _up, _down in extract_lora_ups_down(
         | 
| 589 | 
            +
                    model, target_replace_module=target_replace_module
         | 
| 590 | 
            +
                ):
         | 
| 591 | 
            +
                    weights.append(_up.weight.to("cpu").to(torch.float32))
         | 
| 592 | 
            +
                    weights.append(_down.weight.to("cpu").to(torch.float32))
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                torch.save(weights, path)
         | 
| 595 | 
            +
             | 
| 596 | 
            +
             | 
| 597 | 
            +
            def save_lora_as_json(model, path="./lora.json"):
         | 
| 598 | 
            +
                weights = []
         | 
| 599 | 
            +
                for _up, _down in extract_lora_ups_down(model):
         | 
| 600 | 
            +
                    weights.append(_up.weight.detach().cpu().numpy().tolist())
         | 
| 601 | 
            +
                    weights.append(_down.weight.detach().cpu().numpy().tolist())
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                import json
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                with open(path, "w") as f:
         | 
| 606 | 
            +
                    json.dump(weights, f)
         | 
| 607 | 
            +
             | 
| 608 | 
            +
             | 
| 609 | 
            +
            def save_safeloras_with_embeds(
         | 
| 610 | 
            +
                modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
         | 
| 611 | 
            +
                embeds: Dict[str, torch.Tensor] = {},
         | 
| 612 | 
            +
                outpath="./lora.safetensors",
         | 
| 613 | 
            +
            ):
         | 
| 614 | 
            +
                """
         | 
| 615 | 
            +
                Saves the Lora from multiple modules in a single safetensor file.
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                modelmap is a dictionary of {
         | 
| 618 | 
            +
                    "module name": (module, target_replace_module)
         | 
| 619 | 
            +
                }
         | 
| 620 | 
            +
                """
         | 
| 621 | 
            +
                weights = {}
         | 
| 622 | 
            +
                metadata = {}
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                for name, (model, target_replace_module) in modelmap.items():
         | 
| 625 | 
            +
                    metadata[name] = json.dumps(list(target_replace_module))
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    for i, (_up, _down) in enumerate(
         | 
| 628 | 
            +
                        extract_lora_as_tensor(model, target_replace_module)
         | 
| 629 | 
            +
                    ):
         | 
| 630 | 
            +
                        rank = _down.shape[0]
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                        metadata[f"{name}:{i}:rank"] = str(rank)
         | 
| 633 | 
            +
                        weights[f"{name}:{i}:up"] = _up
         | 
| 634 | 
            +
                        weights[f"{name}:{i}:down"] = _down
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                for token, tensor in embeds.items():
         | 
| 637 | 
            +
                    metadata[token] = EMBED_FLAG
         | 
| 638 | 
            +
                    weights[token] = tensor
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                print(f"Saving weights to {outpath}")
         | 
| 641 | 
            +
                safe_save(weights, outpath, metadata)
         | 
| 642 | 
            +
             | 
| 643 | 
            +
             | 
| 644 | 
            +
            def save_safeloras(
         | 
| 645 | 
            +
                modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
         | 
| 646 | 
            +
                outpath="./lora.safetensors",
         | 
| 647 | 
            +
            ):
         | 
| 648 | 
            +
                return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
         | 
| 649 | 
            +
             | 
| 650 | 
            +
             | 
| 651 | 
            +
            def convert_loras_to_safeloras_with_embeds(
         | 
| 652 | 
            +
                modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
         | 
| 653 | 
            +
                embeds: Dict[str, torch.Tensor] = {},
         | 
| 654 | 
            +
                outpath="./lora.safetensors",
         | 
| 655 | 
            +
            ):
         | 
| 656 | 
            +
                """
         | 
| 657 | 
            +
                Converts the Lora from multiple pytorch .pt files into a single safetensor file.
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                modelmap is a dictionary of {
         | 
| 660 | 
            +
                    "module name": (pytorch_model_path, target_replace_module, rank)
         | 
| 661 | 
            +
                }
         | 
| 662 | 
            +
                """
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                weights = {}
         | 
| 665 | 
            +
                metadata = {}
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                for name, (path, target_replace_module, r) in modelmap.items():
         | 
| 668 | 
            +
                    metadata[name] = json.dumps(list(target_replace_module))
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                    lora = torch.load(path)
         | 
| 671 | 
            +
                    for i, weight in enumerate(lora):
         | 
| 672 | 
            +
                        is_up = i % 2 == 0
         | 
| 673 | 
            +
                        i = i // 2
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                        if is_up:
         | 
| 676 | 
            +
                            metadata[f"{name}:{i}:rank"] = str(r)
         | 
| 677 | 
            +
                            weights[f"{name}:{i}:up"] = weight
         | 
| 678 | 
            +
                        else:
         | 
| 679 | 
            +
                            weights[f"{name}:{i}:down"] = weight
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                for token, tensor in embeds.items():
         | 
| 682 | 
            +
                    metadata[token] = EMBED_FLAG
         | 
| 683 | 
            +
                    weights[token] = tensor
         | 
| 684 | 
            +
             | 
| 685 | 
            +
                print(f"Saving weights to {outpath}")
         | 
| 686 | 
            +
                safe_save(weights, outpath, metadata)
         | 
| 687 | 
            +
             | 
| 688 | 
            +
             | 
| 689 | 
            +
            def convert_loras_to_safeloras(
         | 
| 690 | 
            +
                modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
         | 
| 691 | 
            +
                outpath="./lora.safetensors",
         | 
| 692 | 
            +
            ):
         | 
| 693 | 
            +
                convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
         | 
| 694 | 
            +
             | 
| 695 | 
            +
             | 
| 696 | 
            +
            def parse_safeloras(
         | 
| 697 | 
            +
                safeloras,
         | 
| 698 | 
            +
            ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
         | 
| 699 | 
            +
                """
         | 
| 700 | 
            +
                Converts a loaded safetensor file that contains a set of module Loras
         | 
| 701 | 
            +
                into Parameters and other information
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                Output is a dictionary of {
         | 
| 704 | 
            +
                    "module name": (
         | 
| 705 | 
            +
                        [list of weights],
         | 
| 706 | 
            +
                        [list of ranks],
         | 
| 707 | 
            +
                        target_replacement_modules
         | 
| 708 | 
            +
                    )
         | 
| 709 | 
            +
                }
         | 
| 710 | 
            +
                """
         | 
| 711 | 
            +
                loras = {}
         | 
| 712 | 
            +
                metadata = safeloras.metadata()
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                get_name = lambda k: k.split(":")[0]
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                keys = list(safeloras.keys())
         | 
| 717 | 
            +
                keys.sort(key=get_name)
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                for name, module_keys in groupby(keys, get_name):
         | 
| 720 | 
            +
                    info = metadata.get(name)
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                    if not info:
         | 
| 723 | 
            +
                        raise ValueError(
         | 
| 724 | 
            +
                            f"Tensor {name} has no metadata - is this a Lora safetensor?"
         | 
| 725 | 
            +
                        )
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                    # Skip Textual Inversion embeds
         | 
| 728 | 
            +
                    if info == EMBED_FLAG:
         | 
| 729 | 
            +
                        continue
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                    # Handle Loras
         | 
| 732 | 
            +
                    # Extract the targets
         | 
| 733 | 
            +
                    target = json.loads(info)
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                    # Build the result lists - Python needs us to preallocate lists to insert into them
         | 
| 736 | 
            +
                    module_keys = list(module_keys)
         | 
| 737 | 
            +
                    ranks = [4] * (len(module_keys) // 2)
         | 
| 738 | 
            +
                    weights = [None] * len(module_keys)
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                    for key in module_keys:
         | 
| 741 | 
            +
                        # Split the model name and index out of the key
         | 
| 742 | 
            +
                        _, idx, direction = key.split(":")
         | 
| 743 | 
            +
                        idx = int(idx)
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                        # Add the rank
         | 
| 746 | 
            +
                        ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                        # Insert the weight into the list
         | 
| 749 | 
            +
                        idx = idx * 2 + (1 if direction == "down" else 0)
         | 
| 750 | 
            +
                        weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                    loras[name] = (weights, ranks, target)
         | 
| 753 | 
            +
             | 
| 754 | 
            +
                return loras
         | 
| 755 | 
            +
             | 
| 756 | 
            +
             | 
| 757 | 
            +
            def parse_safeloras_embeds(
         | 
| 758 | 
            +
                safeloras,
         | 
| 759 | 
            +
            ) -> Dict[str, torch.Tensor]:
         | 
| 760 | 
            +
                """
         | 
| 761 | 
            +
                Converts a loaded safetensor file that contains Textual Inversion embeds into
         | 
| 762 | 
            +
                a dictionary of embed_token: Tensor
         | 
| 763 | 
            +
                """
         | 
| 764 | 
            +
                embeds = {}
         | 
| 765 | 
            +
                metadata = safeloras.metadata()
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                for key in safeloras.keys():
         | 
| 768 | 
            +
                    # Only handle Textual Inversion embeds
         | 
| 769 | 
            +
                    meta = metadata.get(key)
         | 
| 770 | 
            +
                    if not meta or meta != EMBED_FLAG:
         | 
| 771 | 
            +
                        continue
         | 
| 772 | 
            +
             | 
| 773 | 
            +
                    embeds[key] = safeloras.get_tensor(key)
         | 
| 774 | 
            +
             | 
| 775 | 
            +
                return embeds
         | 
| 776 | 
            +
             | 
| 777 | 
            +
             | 
| 778 | 
            +
            def load_safeloras(path, device="cpu"):
         | 
| 779 | 
            +
                safeloras = safe_open(path, framework="pt", device=device)
         | 
| 780 | 
            +
                return parse_safeloras(safeloras)
         | 
| 781 | 
            +
             | 
| 782 | 
            +
             | 
| 783 | 
            +
            def load_safeloras_embeds(path, device="cpu"):
         | 
| 784 | 
            +
                safeloras = safe_open(path, framework="pt", device=device)
         | 
| 785 | 
            +
                return parse_safeloras_embeds(safeloras)
         | 
| 786 | 
            +
             | 
| 787 | 
            +
             | 
| 788 | 
            +
            def load_safeloras_both(path, device="cpu"):
         | 
| 789 | 
            +
                safeloras = safe_open(path, framework="pt", device=device)
         | 
| 790 | 
            +
                return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
         | 
| 791 | 
            +
             | 
| 792 | 
            +
             | 
| 793 | 
            +
            def collapse_lora(
         | 
| 794 | 
            +
                model,
         | 
| 795 | 
            +
                replace_modules=UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
         | 
| 796 | 
            +
                alpha=1.0,
         | 
| 797 | 
            +
            ):
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                search_class = [LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d]
         | 
| 800 | 
            +
                for _module, name, _child_module in _find_modules(
         | 
| 801 | 
            +
                    model, replace_modules, search_class=search_class
         | 
| 802 | 
            +
                ):
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                    if isinstance(_child_module, LoraInjectedLinear):
         | 
| 805 | 
            +
                        print("Collapsing Lin Lora in", name)
         | 
| 806 | 
            +
             | 
| 807 | 
            +
                        _child_module.linear.weight = nn.Parameter(
         | 
| 808 | 
            +
                            _child_module.linear.weight.data
         | 
| 809 | 
            +
                            + alpha
         | 
| 810 | 
            +
                            * (
         | 
| 811 | 
            +
                                _child_module.lora_up.weight.data
         | 
| 812 | 
            +
                                @ _child_module.lora_down.weight.data
         | 
| 813 | 
            +
                            )
         | 
| 814 | 
            +
                            .type(_child_module.linear.weight.dtype)
         | 
| 815 | 
            +
                            .to(_child_module.linear.weight.device)
         | 
| 816 | 
            +
                        )
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                    else:
         | 
| 819 | 
            +
                        print("Collapsing Conv Lora in", name)
         | 
| 820 | 
            +
                        _child_module.conv.weight = nn.Parameter(
         | 
| 821 | 
            +
                            _child_module.conv.weight.data
         | 
| 822 | 
            +
                            + alpha
         | 
| 823 | 
            +
                            * (
         | 
| 824 | 
            +
                                _child_module.lora_up.weight.data.flatten(start_dim=1)
         | 
| 825 | 
            +
                                @ _child_module.lora_down.weight.data.flatten(start_dim=1)
         | 
| 826 | 
            +
                            )
         | 
| 827 | 
            +
                            .reshape(_child_module.conv.weight.data.shape)
         | 
| 828 | 
            +
                            .type(_child_module.conv.weight.dtype)
         | 
| 829 | 
            +
                            .to(_child_module.conv.weight.device)
         | 
| 830 | 
            +
                        )
         | 
| 831 | 
            +
             | 
| 832 | 
            +
             | 
| 833 | 
            +
            def monkeypatch_or_replace_lora(
         | 
| 834 | 
            +
                model,
         | 
| 835 | 
            +
                loras,
         | 
| 836 | 
            +
                target_replace_module=DEFAULT_TARGET_REPLACE,
         | 
| 837 | 
            +
                r: Union[int, List[int]] = 4,
         | 
| 838 | 
            +
            ):
         | 
| 839 | 
            +
                for _module, name, _child_module in _find_modules(
         | 
| 840 | 
            +
                    model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
         | 
| 841 | 
            +
                ):
         | 
| 842 | 
            +
                    _source = (
         | 
| 843 | 
            +
                        _child_module.linear
         | 
| 844 | 
            +
                        if isinstance(_child_module, LoraInjectedLinear)
         | 
| 845 | 
            +
                        else _child_module
         | 
| 846 | 
            +
                    )
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                    weight = _source.weight
         | 
| 849 | 
            +
                    bias = _source.bias
         | 
| 850 | 
            +
                    _tmp = LoraInjectedLinear(
         | 
| 851 | 
            +
                        _source.in_features,
         | 
| 852 | 
            +
                        _source.out_features,
         | 
| 853 | 
            +
                        _source.bias is not None,
         | 
| 854 | 
            +
                        r=r.pop(0) if isinstance(r, list) else r,
         | 
| 855 | 
            +
                    )
         | 
| 856 | 
            +
                    _tmp.linear.weight = weight
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                    if bias is not None:
         | 
| 859 | 
            +
                        _tmp.linear.bias = bias
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                    # switch the module
         | 
| 862 | 
            +
                    _module._modules[name] = _tmp
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                    up_weight = loras.pop(0)
         | 
| 865 | 
            +
                    down_weight = loras.pop(0)
         | 
| 866 | 
            +
             | 
| 867 | 
            +
                    _module._modules[name].lora_up.weight = nn.Parameter(
         | 
| 868 | 
            +
                        up_weight.type(weight.dtype)
         | 
| 869 | 
            +
                    )
         | 
| 870 | 
            +
                    _module._modules[name].lora_down.weight = nn.Parameter(
         | 
| 871 | 
            +
                        down_weight.type(weight.dtype)
         | 
| 872 | 
            +
                    )
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                    _module._modules[name].to(weight.device)
         | 
| 875 | 
            +
             | 
| 876 | 
            +
             | 
| 877 | 
            +
            def monkeypatch_or_replace_lora_extended(
         | 
| 878 | 
            +
                model,
         | 
| 879 | 
            +
                loras,
         | 
| 880 | 
            +
                target_replace_module=DEFAULT_TARGET_REPLACE,
         | 
| 881 | 
            +
                r: Union[int, List[int]] = 4,
         | 
| 882 | 
            +
            ):
         | 
| 883 | 
            +
                for _module, name, _child_module in _find_modules(
         | 
| 884 | 
            +
                    model,
         | 
| 885 | 
            +
                    target_replace_module,
         | 
| 886 | 
            +
                    search_class=[
         | 
| 887 | 
            +
                        nn.Linear,
         | 
| 888 | 
            +
                        nn.Conv2d,
         | 
| 889 | 
            +
                        nn.Conv3d,
         | 
| 890 | 
            +
                        LoraInjectedLinear,
         | 
| 891 | 
            +
                        LoraInjectedConv2d,
         | 
| 892 | 
            +
                        LoraInjectedConv3d,
         | 
| 893 | 
            +
                    ],
         | 
| 894 | 
            +
                ):
         | 
| 895 | 
            +
             | 
| 896 | 
            +
                    if (_child_module.__class__ == nn.Linear) or (
         | 
| 897 | 
            +
                        _child_module.__class__ == LoraInjectedLinear
         | 
| 898 | 
            +
                    ):
         | 
| 899 | 
            +
                        if len(loras[0].shape) != 2:
         | 
| 900 | 
            +
                            continue
         | 
| 901 | 
            +
             | 
| 902 | 
            +
                        _source = (
         | 
| 903 | 
            +
                            _child_module.linear
         | 
| 904 | 
            +
                            if isinstance(_child_module, LoraInjectedLinear)
         | 
| 905 | 
            +
                            else _child_module
         | 
| 906 | 
            +
                        )
         | 
| 907 | 
            +
             | 
| 908 | 
            +
                        weight = _source.weight
         | 
| 909 | 
            +
                        bias = _source.bias
         | 
| 910 | 
            +
                        _tmp = LoraInjectedLinear(
         | 
| 911 | 
            +
                            _source.in_features,
         | 
| 912 | 
            +
                            _source.out_features,
         | 
| 913 | 
            +
                            _source.bias is not None,
         | 
| 914 | 
            +
                            r=r.pop(0) if isinstance(r, list) else r,
         | 
| 915 | 
            +
                        )
         | 
| 916 | 
            +
                        _tmp.linear.weight = weight
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                        if bias is not None:
         | 
| 919 | 
            +
                            _tmp.linear.bias = bias
         | 
| 920 | 
            +
             | 
| 921 | 
            +
                    elif (_child_module.__class__ == nn.Conv2d) or (
         | 
| 922 | 
            +
                        _child_module.__class__ == LoraInjectedConv2d
         | 
| 923 | 
            +
                    ):
         | 
| 924 | 
            +
                        if len(loras[0].shape) != 4:
         | 
| 925 | 
            +
                            continue
         | 
| 926 | 
            +
                        _source = (
         | 
| 927 | 
            +
                            _child_module.conv
         | 
| 928 | 
            +
                            if isinstance(_child_module, LoraInjectedConv2d)
         | 
| 929 | 
            +
                            else _child_module
         | 
| 930 | 
            +
                        )
         | 
| 931 | 
            +
             | 
| 932 | 
            +
                        weight = _source.weight
         | 
| 933 | 
            +
                        bias = _source.bias
         | 
| 934 | 
            +
                        _tmp = LoraInjectedConv2d(
         | 
| 935 | 
            +
                            _source.in_channels,
         | 
| 936 | 
            +
                            _source.out_channels,
         | 
| 937 | 
            +
                            _source.kernel_size,
         | 
| 938 | 
            +
                            _source.stride,
         | 
| 939 | 
            +
                            _source.padding,
         | 
| 940 | 
            +
                            _source.dilation,
         | 
| 941 | 
            +
                            _source.groups,
         | 
| 942 | 
            +
                            _source.bias is not None,
         | 
| 943 | 
            +
                            r=r.pop(0) if isinstance(r, list) else r,
         | 
| 944 | 
            +
                        )
         | 
| 945 | 
            +
             | 
| 946 | 
            +
                        _tmp.conv.weight = weight
         | 
| 947 | 
            +
             | 
| 948 | 
            +
                        if bias is not None:
         | 
| 949 | 
            +
                            _tmp.conv.bias = bias
         | 
| 950 | 
            +
             | 
| 951 | 
            +
                    elif _child_module.__class__ == nn.Conv3d or (
         | 
| 952 | 
            +
                        _child_module.__class__ == LoraInjectedConv3d
         | 
| 953 | 
            +
                    ):
         | 
| 954 | 
            +
             | 
| 955 | 
            +
                        if len(loras[0].shape) != 5:
         | 
| 956 | 
            +
                            continue
         | 
| 957 | 
            +
             | 
| 958 | 
            +
                        _source = (
         | 
| 959 | 
            +
                            _child_module.conv
         | 
| 960 | 
            +
                            if isinstance(_child_module, LoraInjectedConv3d)
         | 
| 961 | 
            +
                            else _child_module
         | 
| 962 | 
            +
                        )
         | 
| 963 | 
            +
             | 
| 964 | 
            +
                        weight = _source.weight
         | 
| 965 | 
            +
                        bias = _source.bias
         | 
| 966 | 
            +
                        _tmp = LoraInjectedConv3d(
         | 
| 967 | 
            +
                            _source.in_channels,
         | 
| 968 | 
            +
                            _source.out_channels,
         | 
| 969 | 
            +
                            bias=_source.bias is not None,
         | 
| 970 | 
            +
                            kernel_size=_source.kernel_size,
         | 
| 971 | 
            +
                            padding=_source.padding,
         | 
| 972 | 
            +
                            r=r.pop(0) if isinstance(r, list) else r,
         | 
| 973 | 
            +
                        )
         | 
| 974 | 
            +
             | 
| 975 | 
            +
                        _tmp.conv.weight = weight
         | 
| 976 | 
            +
             | 
| 977 | 
            +
                        if bias is not None:
         | 
| 978 | 
            +
                            _tmp.conv.bias = bias
         | 
| 979 | 
            +
                    else:
         | 
| 980 | 
            +
                        # ignore module which are not included in search_class
         | 
| 981 | 
            +
                        # For example:
         | 
| 982 | 
            +
                        # zeroscope_v2_576w model, which has <class 'diffusers.models.lora.LoRACompatibleLinear'> and <class 'diffusers.models.lora.LoRACompatibleConv'>
         | 
| 983 | 
            +
                        continue
         | 
| 984 | 
            +
                    # switch the module
         | 
| 985 | 
            +
                    _module._modules[name] = _tmp
         | 
| 986 | 
            +
             | 
| 987 | 
            +
                    up_weight = loras.pop(0)
         | 
| 988 | 
            +
                    down_weight = loras.pop(0)
         | 
| 989 | 
            +
             | 
| 990 | 
            +
                    _module._modules[name].lora_up.weight = nn.Parameter(
         | 
| 991 | 
            +
                        up_weight.type(weight.dtype)
         | 
| 992 | 
            +
                    )
         | 
| 993 | 
            +
                    _module._modules[name].lora_down.weight = nn.Parameter(
         | 
| 994 | 
            +
                        down_weight.type(weight.dtype)
         | 
| 995 | 
            +
                    )
         | 
| 996 | 
            +
             | 
| 997 | 
            +
                    _module._modules[name].to(weight.device)
         | 
| 998 | 
            +
             | 
| 999 | 
            +
             | 
| 1000 | 
            +
            def monkeypatch_or_replace_safeloras(models, safeloras):
         | 
| 1001 | 
            +
                loras = parse_safeloras(safeloras)
         | 
| 1002 | 
            +
             | 
| 1003 | 
            +
                for name, (lora, ranks, target) in loras.items():
         | 
| 1004 | 
            +
                    model = getattr(models, name, None)
         | 
| 1005 | 
            +
             | 
| 1006 | 
            +
                    if not model:
         | 
| 1007 | 
            +
                        print(f"No model provided for {name}, contained in Lora")
         | 
| 1008 | 
            +
                        continue
         | 
| 1009 | 
            +
             | 
| 1010 | 
            +
                    monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
         | 
| 1011 | 
            +
             | 
| 1012 | 
            +
             | 
| 1013 | 
            +
            def monkeypatch_remove_lora(model):
         | 
| 1014 | 
            +
                for _module, name, _child_module in _find_modules(
         | 
| 1015 | 
            +
                    model, search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d]
         | 
| 1016 | 
            +
                ):
         | 
| 1017 | 
            +
                    if isinstance(_child_module, LoraInjectedLinear):
         | 
| 1018 | 
            +
                        _source = _child_module.linear
         | 
| 1019 | 
            +
                        weight, bias = _source.weight, _source.bias
         | 
| 1020 | 
            +
             | 
| 1021 | 
            +
                        _tmp = nn.Linear(
         | 
| 1022 | 
            +
                            _source.in_features, _source.out_features, bias is not None
         | 
| 1023 | 
            +
                        )
         | 
| 1024 | 
            +
             | 
| 1025 | 
            +
                        _tmp.weight = weight
         | 
| 1026 | 
            +
                        if bias is not None:
         | 
| 1027 | 
            +
                            _tmp.bias = bias
         | 
| 1028 | 
            +
             | 
| 1029 | 
            +
                    else:
         | 
| 1030 | 
            +
                        _source = _child_module.conv
         | 
| 1031 | 
            +
                        weight, bias = _source.weight, _source.bias
         | 
| 1032 | 
            +
             | 
| 1033 | 
            +
                        if isinstance(_source, nn.Conv2d):
         | 
| 1034 | 
            +
                            _tmp = nn.Conv2d(
         | 
| 1035 | 
            +
                                in_channels=_source.in_channels,
         | 
| 1036 | 
            +
                                out_channels=_source.out_channels,
         | 
| 1037 | 
            +
                                kernel_size=_source.kernel_size,
         | 
| 1038 | 
            +
                                stride=_source.stride,
         | 
| 1039 | 
            +
                                padding=_source.padding,
         | 
| 1040 | 
            +
                                dilation=_source.dilation,
         | 
| 1041 | 
            +
                                groups=_source.groups,
         | 
| 1042 | 
            +
                                bias=bias is not None,
         | 
| 1043 | 
            +
                            )
         | 
| 1044 | 
            +
             | 
| 1045 | 
            +
                            _tmp.weight = weight
         | 
| 1046 | 
            +
                            if bias is not None:
         | 
| 1047 | 
            +
                                _tmp.bias = bias
         | 
| 1048 | 
            +
             | 
| 1049 | 
            +
                        if isinstance(_source, nn.Conv3d):
         | 
| 1050 | 
            +
                            _tmp = nn.Conv3d(
         | 
| 1051 | 
            +
                                _source.in_channels,
         | 
| 1052 | 
            +
                                _source.out_channels,
         | 
| 1053 | 
            +
                                bias=_source.bias is not None,
         | 
| 1054 | 
            +
                                kernel_size=_source.kernel_size,
         | 
| 1055 | 
            +
                                padding=_source.padding,
         | 
| 1056 | 
            +
                            )
         | 
| 1057 | 
            +
             | 
| 1058 | 
            +
                        _tmp.weight = weight
         | 
| 1059 | 
            +
                        if bias is not None:
         | 
| 1060 | 
            +
                            _tmp.bias = bias
         | 
| 1061 | 
            +
             | 
| 1062 | 
            +
                    _module._modules[name] = _tmp
         | 
| 1063 | 
            +
             | 
| 1064 | 
            +
             | 
| 1065 | 
            +
            def monkeypatch_add_lora(
         | 
| 1066 | 
            +
                model,
         | 
| 1067 | 
            +
                loras,
         | 
| 1068 | 
            +
                target_replace_module=DEFAULT_TARGET_REPLACE,
         | 
| 1069 | 
            +
                alpha: float = 1.0,
         | 
| 1070 | 
            +
                beta: float = 1.0,
         | 
| 1071 | 
            +
            ):
         | 
| 1072 | 
            +
                for _module, name, _child_module in _find_modules(
         | 
| 1073 | 
            +
                    model, target_replace_module, search_class=[LoraInjectedLinear]
         | 
| 1074 | 
            +
                ):
         | 
| 1075 | 
            +
                    weight = _child_module.linear.weight
         | 
| 1076 | 
            +
             | 
| 1077 | 
            +
                    up_weight = loras.pop(0)
         | 
| 1078 | 
            +
                    down_weight = loras.pop(0)
         | 
| 1079 | 
            +
             | 
| 1080 | 
            +
                    _module._modules[name].lora_up.weight = nn.Parameter(
         | 
| 1081 | 
            +
                        up_weight.type(weight.dtype).to(weight.device) * alpha
         | 
| 1082 | 
            +
                        + _module._modules[name].lora_up.weight.to(weight.device) * beta
         | 
| 1083 | 
            +
                    )
         | 
| 1084 | 
            +
                    _module._modules[name].lora_down.weight = nn.Parameter(
         | 
| 1085 | 
            +
                        down_weight.type(weight.dtype).to(weight.device) * alpha
         | 
| 1086 | 
            +
                        + _module._modules[name].lora_down.weight.to(weight.device) * beta
         | 
| 1087 | 
            +
                    )
         | 
| 1088 | 
            +
             | 
| 1089 | 
            +
                    _module._modules[name].to(weight.device)
         | 
| 1090 | 
            +
             | 
| 1091 | 
            +
             | 
| 1092 | 
            +
            def tune_lora_scale(model, alpha: float = 1.0):
         | 
| 1093 | 
            +
                for _module in model.modules():
         | 
| 1094 | 
            +
                    if _module.__class__.__name__ in [
         | 
| 1095 | 
            +
                        "LoraInjectedLinear",
         | 
| 1096 | 
            +
                        "LoraInjectedConv2d",
         | 
| 1097 | 
            +
                        "LoraInjectedConv3d",
         | 
| 1098 | 
            +
                    ]:
         | 
| 1099 | 
            +
                        _module.scale = alpha
         | 
| 1100 | 
            +
             | 
| 1101 | 
            +
             | 
| 1102 | 
            +
            def set_lora_diag(model, diag: torch.Tensor):
         | 
| 1103 | 
            +
                for _module in model.modules():
         | 
| 1104 | 
            +
                    if _module.__class__.__name__ in [
         | 
| 1105 | 
            +
                        "LoraInjectedLinear",
         | 
| 1106 | 
            +
                        "LoraInjectedConv2d",
         | 
| 1107 | 
            +
                        "LoraInjectedConv3d",
         | 
| 1108 | 
            +
                    ]:
         | 
| 1109 | 
            +
                        _module.set_selector_from_diag(diag)
         | 
| 1110 | 
            +
             | 
| 1111 | 
            +
             | 
| 1112 | 
            +
            def _text_lora_path(path: str) -> str:
         | 
| 1113 | 
            +
                assert path.endswith(".pt"), "Only .pt files are supported"
         | 
| 1114 | 
            +
                return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
         | 
| 1115 | 
            +
             | 
| 1116 | 
            +
             | 
| 1117 | 
            +
            def _ti_lora_path(path: str) -> str:
         | 
| 1118 | 
            +
                assert path.endswith(".pt"), "Only .pt files are supported"
         | 
| 1119 | 
            +
                return ".".join(path.split(".")[:-1] + ["ti", "pt"])
         | 
| 1120 | 
            +
             | 
| 1121 | 
            +
             | 
| 1122 | 
            +
            def apply_learned_embed_in_clip(
         | 
| 1123 | 
            +
                learned_embeds,
         | 
| 1124 | 
            +
                text_encoder,
         | 
| 1125 | 
            +
                tokenizer,
         | 
| 1126 | 
            +
                token: Optional[Union[str, List[str]]] = None,
         | 
| 1127 | 
            +
                idempotent=False,
         | 
| 1128 | 
            +
            ):
         | 
| 1129 | 
            +
                if isinstance(token, str):
         | 
| 1130 | 
            +
                    trained_tokens = [token]
         | 
| 1131 | 
            +
                elif isinstance(token, list):
         | 
| 1132 | 
            +
                    assert len(learned_embeds.keys()) == len(
         | 
| 1133 | 
            +
                        token
         | 
| 1134 | 
            +
                    ), "The number of tokens and the number of embeds should be the same"
         | 
| 1135 | 
            +
                    trained_tokens = token
         | 
| 1136 | 
            +
                else:
         | 
| 1137 | 
            +
                    trained_tokens = list(learned_embeds.keys())
         | 
| 1138 | 
            +
             | 
| 1139 | 
            +
                for token in trained_tokens:
         | 
| 1140 | 
            +
                    print(token)
         | 
| 1141 | 
            +
                    embeds = learned_embeds[token]
         | 
| 1142 | 
            +
             | 
| 1143 | 
            +
                    # cast to dtype of text_encoder
         | 
| 1144 | 
            +
                    dtype = text_encoder.get_input_embeddings().weight.dtype
         | 
| 1145 | 
            +
                    num_added_tokens = tokenizer.add_tokens(token)
         | 
| 1146 | 
            +
             | 
| 1147 | 
            +
                    i = 1
         | 
| 1148 | 
            +
                    if not idempotent:
         | 
| 1149 | 
            +
                        while num_added_tokens == 0:
         | 
| 1150 | 
            +
                            print(f"The tokenizer already contains the token {token}.")
         | 
| 1151 | 
            +
                            token = f"{token[:-1]}-{i}>"
         | 
| 1152 | 
            +
                            print(f"Attempting to add the token {token}.")
         | 
| 1153 | 
            +
                            num_added_tokens = tokenizer.add_tokens(token)
         | 
| 1154 | 
            +
                            i += 1
         | 
| 1155 | 
            +
                    elif num_added_tokens == 0 and idempotent:
         | 
| 1156 | 
            +
                        print(f"The tokenizer already contains the token {token}.")
         | 
| 1157 | 
            +
                        print(f"Replacing {token} embedding.")
         | 
| 1158 | 
            +
             | 
| 1159 | 
            +
                    # resize the token embeddings
         | 
| 1160 | 
            +
                    text_encoder.resize_token_embeddings(len(tokenizer))
         | 
| 1161 | 
            +
             | 
| 1162 | 
            +
                    # get the id for the token and assign the embeds
         | 
| 1163 | 
            +
                    token_id = tokenizer.convert_tokens_to_ids(token)
         | 
| 1164 | 
            +
                    text_encoder.get_input_embeddings().weight.data[token_id] = embeds
         | 
| 1165 | 
            +
                return token
         | 
| 1166 | 
            +
             | 
| 1167 | 
            +
             | 
| 1168 | 
            +
            def load_learned_embed_in_clip(
         | 
| 1169 | 
            +
                learned_embeds_path,
         | 
| 1170 | 
            +
                text_encoder,
         | 
| 1171 | 
            +
                tokenizer,
         | 
| 1172 | 
            +
                token: Optional[Union[str, List[str]]] = None,
         | 
| 1173 | 
            +
                idempotent=False,
         | 
| 1174 | 
            +
            ):
         | 
| 1175 | 
            +
                learned_embeds = torch.load(learned_embeds_path)
         | 
| 1176 | 
            +
                apply_learned_embed_in_clip(
         | 
| 1177 | 
            +
                    learned_embeds, text_encoder, tokenizer, token, idempotent
         | 
| 1178 | 
            +
                )
         | 
| 1179 | 
            +
             | 
| 1180 | 
            +
             | 
| 1181 | 
            +
            def patch_pipe(
         | 
| 1182 | 
            +
                pipe,
         | 
| 1183 | 
            +
                maybe_unet_path,
         | 
| 1184 | 
            +
                token: Optional[str] = None,
         | 
| 1185 | 
            +
                r: int = 4,
         | 
| 1186 | 
            +
                patch_unet=True,
         | 
| 1187 | 
            +
                patch_text=True,
         | 
| 1188 | 
            +
                patch_ti=True,
         | 
| 1189 | 
            +
                idempotent_token=True,
         | 
| 1190 | 
            +
                unet_target_replace_module=DEFAULT_TARGET_REPLACE,
         | 
| 1191 | 
            +
                text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
         | 
| 1192 | 
            +
            ):
         | 
| 1193 | 
            +
                if maybe_unet_path.endswith(".pt"):
         | 
| 1194 | 
            +
                    # torch format
         | 
| 1195 | 
            +
             | 
| 1196 | 
            +
                    if maybe_unet_path.endswith(".ti.pt"):
         | 
| 1197 | 
            +
                        unet_path = maybe_unet_path[:-6] + ".pt"
         | 
| 1198 | 
            +
                    elif maybe_unet_path.endswith(".text_encoder.pt"):
         | 
| 1199 | 
            +
                        unet_path = maybe_unet_path[:-16] + ".pt"
         | 
| 1200 | 
            +
                    else:
         | 
| 1201 | 
            +
                        unet_path = maybe_unet_path
         | 
| 1202 | 
            +
             | 
| 1203 | 
            +
                    ti_path = _ti_lora_path(unet_path)
         | 
| 1204 | 
            +
                    text_path = _text_lora_path(unet_path)
         | 
| 1205 | 
            +
             | 
| 1206 | 
            +
                    if patch_unet:
         | 
| 1207 | 
            +
                        print("LoRA : Patching Unet")
         | 
| 1208 | 
            +
                        monkeypatch_or_replace_lora(
         | 
| 1209 | 
            +
                            pipe.unet,
         | 
| 1210 | 
            +
                            torch.load(unet_path),
         | 
| 1211 | 
            +
                            r=r,
         | 
| 1212 | 
            +
                            target_replace_module=unet_target_replace_module,
         | 
| 1213 | 
            +
                        )
         | 
| 1214 | 
            +
             | 
| 1215 | 
            +
                    if patch_text:
         | 
| 1216 | 
            +
                        print("LoRA : Patching text encoder")
         | 
| 1217 | 
            +
                        monkeypatch_or_replace_lora(
         | 
| 1218 | 
            +
                            pipe.text_encoder,
         | 
| 1219 | 
            +
                            torch.load(text_path),
         | 
| 1220 | 
            +
                            target_replace_module=text_target_replace_module,
         | 
| 1221 | 
            +
                            r=r,
         | 
| 1222 | 
            +
                        )
         | 
| 1223 | 
            +
                    if patch_ti:
         | 
| 1224 | 
            +
                        print("LoRA : Patching token input")
         | 
| 1225 | 
            +
                        token = load_learned_embed_in_clip(
         | 
| 1226 | 
            +
                            ti_path,
         | 
| 1227 | 
            +
                            pipe.text_encoder,
         | 
| 1228 | 
            +
                            pipe.tokenizer,
         | 
| 1229 | 
            +
                            token=token,
         | 
| 1230 | 
            +
                            idempotent=idempotent_token,
         | 
| 1231 | 
            +
                        )
         | 
| 1232 | 
            +
             | 
| 1233 | 
            +
                elif maybe_unet_path.endswith(".safetensors"):
         | 
| 1234 | 
            +
                    safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
         | 
| 1235 | 
            +
                    monkeypatch_or_replace_safeloras(pipe, safeloras)
         | 
| 1236 | 
            +
                    tok_dict = parse_safeloras_embeds(safeloras)
         | 
| 1237 | 
            +
                    if patch_ti:
         | 
| 1238 | 
            +
                        apply_learned_embed_in_clip(
         | 
| 1239 | 
            +
                            tok_dict,
         | 
| 1240 | 
            +
                            pipe.text_encoder,
         | 
| 1241 | 
            +
                            pipe.tokenizer,
         | 
| 1242 | 
            +
                            token=token,
         | 
| 1243 | 
            +
                            idempotent=idempotent_token,
         | 
| 1244 | 
            +
                        )
         | 
| 1245 | 
            +
                    return tok_dict
         | 
| 1246 | 
            +
             | 
| 1247 | 
            +
             | 
| 1248 | 
            +
            def train_patch_pipe(pipe, patch_unet, patch_text):
         | 
| 1249 | 
            +
                if patch_unet:
         | 
| 1250 | 
            +
                    print("LoRA : Patching Unet")
         | 
| 1251 | 
            +
                    collapse_lora(pipe.unet)
         | 
| 1252 | 
            +
                    monkeypatch_remove_lora(pipe.unet)
         | 
| 1253 | 
            +
             | 
| 1254 | 
            +
                if patch_text:
         | 
| 1255 | 
            +
                    print("LoRA : Patching text encoder")
         | 
| 1256 | 
            +
             | 
| 1257 | 
            +
                    collapse_lora(pipe.text_encoder)
         | 
| 1258 | 
            +
                    monkeypatch_remove_lora(pipe.text_encoder)
         | 
| 1259 | 
            +
             | 
| 1260 | 
            +
             | 
| 1261 | 
            +
            @torch.no_grad()
         | 
| 1262 | 
            +
            def inspect_lora(model):
         | 
| 1263 | 
            +
                moved = {}
         | 
| 1264 | 
            +
             | 
| 1265 | 
            +
                for name, _module in model.named_modules():
         | 
| 1266 | 
            +
                    if _module.__class__.__name__ in [
         | 
| 1267 | 
            +
                        "LoraInjectedLinear",
         | 
| 1268 | 
            +
                        "LoraInjectedConv2d",
         | 
| 1269 | 
            +
                        "LoraInjectedConv3d",
         | 
| 1270 | 
            +
                    ]:
         | 
| 1271 | 
            +
                        ups = _module.lora_up.weight.data.clone()
         | 
| 1272 | 
            +
                        downs = _module.lora_down.weight.data.clone()
         | 
| 1273 | 
            +
             | 
| 1274 | 
            +
                        wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
         | 
| 1275 | 
            +
             | 
| 1276 | 
            +
                        dist = wght.flatten().abs().mean().item()
         | 
| 1277 | 
            +
                        if name in moved:
         | 
| 1278 | 
            +
                            moved[name].append(dist)
         | 
| 1279 | 
            +
                        else:
         | 
| 1280 | 
            +
                            moved[name] = [dist]
         | 
| 1281 | 
            +
             | 
| 1282 | 
            +
                return moved
         | 
| 1283 | 
            +
             | 
| 1284 | 
            +
             | 
| 1285 | 
            +
            def save_all(
         | 
| 1286 | 
            +
                unet,
         | 
| 1287 | 
            +
                text_encoder,
         | 
| 1288 | 
            +
                save_path,
         | 
| 1289 | 
            +
                placeholder_token_ids=None,
         | 
| 1290 | 
            +
                placeholder_tokens=None,
         | 
| 1291 | 
            +
                save_lora=True,
         | 
| 1292 | 
            +
                save_ti=True,
         | 
| 1293 | 
            +
                target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
         | 
| 1294 | 
            +
                target_replace_module_unet=DEFAULT_TARGET_REPLACE,
         | 
| 1295 | 
            +
                safe_form=True,
         | 
| 1296 | 
            +
            ):
         | 
| 1297 | 
            +
                if not safe_form:
         | 
| 1298 | 
            +
                    # save ti
         | 
| 1299 | 
            +
                    if save_ti:
         | 
| 1300 | 
            +
                        ti_path = _ti_lora_path(save_path)
         | 
| 1301 | 
            +
                        learned_embeds_dict = {}
         | 
| 1302 | 
            +
                        for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
         | 
| 1303 | 
            +
                            learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
         | 
| 1304 | 
            +
                            print(
         | 
| 1305 | 
            +
                                f"Current Learned Embeddings for {tok}:, id {tok_id} ",
         | 
| 1306 | 
            +
                                learned_embeds[:4],
         | 
| 1307 | 
            +
                            )
         | 
| 1308 | 
            +
                            learned_embeds_dict[tok] = learned_embeds.detach().cpu()
         | 
| 1309 | 
            +
             | 
| 1310 | 
            +
                        torch.save(learned_embeds_dict, ti_path)
         | 
| 1311 | 
            +
                        print("Ti saved to ", ti_path)
         | 
| 1312 | 
            +
             | 
| 1313 | 
            +
                    # save text encoder
         | 
| 1314 | 
            +
                    if save_lora:
         | 
| 1315 | 
            +
                        save_lora_weight(
         | 
| 1316 | 
            +
                            unet, save_path, target_replace_module=target_replace_module_unet
         | 
| 1317 | 
            +
                        )
         | 
| 1318 | 
            +
                        print("Unet saved to ", save_path)
         | 
| 1319 | 
            +
             | 
| 1320 | 
            +
                        save_lora_weight(
         | 
| 1321 | 
            +
                            text_encoder,
         | 
| 1322 | 
            +
                            _text_lora_path(save_path),
         | 
| 1323 | 
            +
                            target_replace_module=target_replace_module_text,
         | 
| 1324 | 
            +
                        )
         | 
| 1325 | 
            +
                        print("Text Encoder saved to ", _text_lora_path(save_path))
         | 
| 1326 | 
            +
             | 
| 1327 | 
            +
                else:
         | 
| 1328 | 
            +
                    assert save_path.endswith(
         | 
| 1329 | 
            +
                        ".safetensors"
         | 
| 1330 | 
            +
                    ), f"Save path : {save_path} should end with .safetensors"
         | 
| 1331 | 
            +
             | 
| 1332 | 
            +
                    loras = {}
         | 
| 1333 | 
            +
                    embeds = {}
         | 
| 1334 | 
            +
             | 
| 1335 | 
            +
                    if save_lora:
         | 
| 1336 | 
            +
             | 
| 1337 | 
            +
                        loras["unet"] = (unet, target_replace_module_unet)
         | 
| 1338 | 
            +
                        loras["text_encoder"] = (text_encoder, target_replace_module_text)
         | 
| 1339 | 
            +
             | 
| 1340 | 
            +
                    if save_ti:
         | 
| 1341 | 
            +
                        for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
         | 
| 1342 | 
            +
                            learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
         | 
| 1343 | 
            +
                            print(
         | 
| 1344 | 
            +
                                f"Current Learned Embeddings for {tok}:, id {tok_id} ",
         | 
| 1345 | 
            +
                                learned_embeds[:4],
         | 
| 1346 | 
            +
                            )
         | 
| 1347 | 
            +
                            embeds[tok] = learned_embeds.detach().cpu()
         | 
| 1348 | 
            +
             | 
| 1349 | 
            +
                    save_safeloras_with_embeds(loras, embeds, save_path)
         | 
    	
        utils/lora_handler.py
    ADDED
    
    | @@ -0,0 +1,153 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from types import SimpleNamespace
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .lora import (
         | 
| 5 | 
            +
                extract_lora_ups_down,
         | 
| 6 | 
            +
                inject_trainable_lora_extended,
         | 
| 7 | 
            +
                monkeypatch_or_replace_lora_extended,
         | 
| 8 | 
            +
            )
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            CLONE_OF_SIMO_KEYS = ["model", "loras", "target_replace_module", "r"]
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            lora_versions = dict(stable_lora="stable_lora", cloneofsimo="cloneofsimo")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            lora_func_types = dict(loader="loader", injector="injector")
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            lora_args = dict(
         | 
| 17 | 
            +
                model=None,
         | 
| 18 | 
            +
                loras=None,
         | 
| 19 | 
            +
                target_replace_module=[],
         | 
| 20 | 
            +
                target_module=[],
         | 
| 21 | 
            +
                r=4,
         | 
| 22 | 
            +
                search_class=[torch.nn.Linear],
         | 
| 23 | 
            +
                dropout=0,
         | 
| 24 | 
            +
                lora_bias="none",
         | 
| 25 | 
            +
            )
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            LoraVersions = SimpleNamespace(**lora_versions)
         | 
| 28 | 
            +
            LoraFuncTypes = SimpleNamespace(**lora_func_types)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
         | 
| 31 | 
            +
            LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def filter_dict(_dict, keys=[]):
         | 
| 35 | 
            +
                if len(keys) == 0:
         | 
| 36 | 
            +
                    assert "Keys cannot empty for filtering return dict."
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                for k in keys:
         | 
| 39 | 
            +
                    if k not in lora_args.keys():
         | 
| 40 | 
            +
                        assert f"{k} does not exist in available LoRA arguments"
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                return {k: v for k, v in _dict.items() if k in keys}
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            class LoraHandler(object):
         | 
| 46 | 
            +
                def __init__(
         | 
| 47 | 
            +
                    self,
         | 
| 48 | 
            +
                    version: str = LoraVersions.cloneofsimo,
         | 
| 49 | 
            +
                    use_unet_lora: bool = False,
         | 
| 50 | 
            +
                    use_text_lora: bool = False,
         | 
| 51 | 
            +
                    save_for_webui: bool = False,
         | 
| 52 | 
            +
                    only_for_webui: bool = False,
         | 
| 53 | 
            +
                    lora_bias: str = "none",
         | 
| 54 | 
            +
                    unet_replace_modules: list = ["UNet3DConditionModel"],
         | 
| 55 | 
            +
                ):
         | 
| 56 | 
            +
                    self.version = version
         | 
| 57 | 
            +
                    assert self.is_cloneofsimo_lora()
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
         | 
| 60 | 
            +
                    self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
         | 
| 61 | 
            +
                    self.lora_bias = lora_bias
         | 
| 62 | 
            +
                    self.use_unet_lora = use_unet_lora
         | 
| 63 | 
            +
                    self.use_text_lora = use_text_lora
         | 
| 64 | 
            +
                    self.save_for_webui = save_for_webui
         | 
| 65 | 
            +
                    self.only_for_webui = only_for_webui
         | 
| 66 | 
            +
                    self.unet_replace_modules = unet_replace_modules
         | 
| 67 | 
            +
                    self.use_lora = any([use_text_lora, use_unet_lora])
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    if self.use_lora:
         | 
| 70 | 
            +
                        print(f"Using LoRA Version: {self.version}")
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def is_cloneofsimo_lora(self):
         | 
| 73 | 
            +
                    return self.version == LoraVersions.cloneofsimo
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def get_lora_func(self, func_type: str = LoraFuncTypes.loader):
         | 
| 76 | 
            +
                    if func_type == LoraFuncTypes.loader:
         | 
| 77 | 
            +
                        return monkeypatch_or_replace_lora_extended
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if func_type == LoraFuncTypes.injector:
         | 
| 80 | 
            +
                        return inject_trainable_lora_extended
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    assert "LoRA Version does not exist."
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def get_lora_func_args(
         | 
| 85 | 
            +
                    self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias
         | 
| 86 | 
            +
                ):
         | 
| 87 | 
            +
                    return_dict = lora_args.copy()
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
         | 
| 90 | 
            +
                    return_dict.update(
         | 
| 91 | 
            +
                        {
         | 
| 92 | 
            +
                            "model": model,
         | 
| 93 | 
            +
                            "loras": lora_path,
         | 
| 94 | 
            +
                            "target_replace_module": replace_modules,
         | 
| 95 | 
            +
                            "r": r,
         | 
| 96 | 
            +
                        }
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    return return_dict
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def do_lora_injection(
         | 
| 102 | 
            +
                    self,
         | 
| 103 | 
            +
                    model,
         | 
| 104 | 
            +
                    replace_modules,
         | 
| 105 | 
            +
                    bias="none",
         | 
| 106 | 
            +
                    dropout=0,
         | 
| 107 | 
            +
                    r=4,
         | 
| 108 | 
            +
                    lora_loader_args=None,
         | 
| 109 | 
            +
                ):
         | 
| 110 | 
            +
                    REPLACE_MODULES = replace_modules
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    params = None
         | 
| 113 | 
            +
                    negation = None
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    injector_args = lora_loader_args
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    params, negation = self.lora_injector(**injector_args)
         | 
| 118 | 
            +
                    for _up, _down in extract_lora_ups_down(
         | 
| 119 | 
            +
                        model, target_replace_module=REPLACE_MODULES
         | 
| 120 | 
            +
                    ):
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                        if all(x is not None for x in [_up, _down]):
         | 
| 123 | 
            +
                            print(
         | 
| 124 | 
            +
                                f"Lora successfully injected into {model.__class__.__name__}."
         | 
| 125 | 
            +
                            )
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        break
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    return params, negation
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def add_lora_to_model(
         | 
| 132 | 
            +
                    self, use_lora, model, replace_modules, dropout=0.0, lora_path=None, r=16
         | 
| 133 | 
            +
                ):
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    params = None
         | 
| 136 | 
            +
                    negation = None
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    lora_loader_args = self.get_lora_func_args(
         | 
| 139 | 
            +
                        lora_path, use_lora, model, replace_modules, r, dropout, self.lora_bias
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    if use_lora:
         | 
| 143 | 
            +
                        params, negation = self.do_lora_injection(
         | 
| 144 | 
            +
                            model,
         | 
| 145 | 
            +
                            replace_modules,
         | 
| 146 | 
            +
                            bias=self.lora_bias,
         | 
| 147 | 
            +
                            lora_loader_args=lora_loader_args,
         | 
| 148 | 
            +
                            dropout=dropout,
         | 
| 149 | 
            +
                            r=r,
         | 
| 150 | 
            +
                        )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    params = model if params is None else params
         | 
| 153 | 
            +
                    return params, negation
         | 
    	
        utils/utils.py
    ADDED
    
    | @@ -0,0 +1,99 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import importlib
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import cv2
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.distributed as dist
         | 
| 7 | 
            +
            import torchvision
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def count_params(model, verbose=False):
         | 
| 11 | 
            +
                total_params = sum(p.numel() for p in model.parameters())
         | 
| 12 | 
            +
                if verbose:
         | 
| 13 | 
            +
                    print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
         | 
| 14 | 
            +
                return total_params
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def check_istarget(name, para_list):
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                name: full name of source para
         | 
| 20 | 
            +
                para_list: partial name of target para
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                istarget = False
         | 
| 23 | 
            +
                for para in para_list:
         | 
| 24 | 
            +
                    if para in name:
         | 
| 25 | 
            +
                        return True
         | 
| 26 | 
            +
                return istarget
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def instantiate_from_config(config):
         | 
| 30 | 
            +
                if not "target" in config:
         | 
| 31 | 
            +
                    if config == "__is_first_stage__":
         | 
| 32 | 
            +
                        return None
         | 
| 33 | 
            +
                    elif config == "__is_unconditional__":
         | 
| 34 | 
            +
                        return None
         | 
| 35 | 
            +
                    raise KeyError("Expected key `target` to instantiate.")
         | 
| 36 | 
            +
                return get_obj_from_str(config["target"])(**config.get("params", dict()))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def get_obj_from_str(string, reload=False):
         | 
| 40 | 
            +
                module, cls = string.rsplit(".", 1)
         | 
| 41 | 
            +
                if reload:
         | 
| 42 | 
            +
                    module_imp = importlib.import_module(module)
         | 
| 43 | 
            +
                    importlib.reload(module_imp)
         | 
| 44 | 
            +
                return getattr(importlib.import_module(module, package=None), cls)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def load_npz_from_dir(data_dir):
         | 
| 48 | 
            +
                data = [
         | 
| 49 | 
            +
                    np.load(os.path.join(data_dir, data_name))["arr_0"]
         | 
| 50 | 
            +
                    for data_name in os.listdir(data_dir)
         | 
| 51 | 
            +
                ]
         | 
| 52 | 
            +
                data = np.concatenate(data, axis=0)
         | 
| 53 | 
            +
                return data
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def load_npz_from_paths(data_paths):
         | 
| 57 | 
            +
                data = [np.load(data_path)["arr_0"] for data_path in data_paths]
         | 
| 58 | 
            +
                data = np.concatenate(data, axis=0)
         | 
| 59 | 
            +
                return data
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
         | 
| 63 | 
            +
                h, w = image.shape[:2]
         | 
| 64 | 
            +
                if resize_short_edge is not None:
         | 
| 65 | 
            +
                    k = resize_short_edge / min(h, w)
         | 
| 66 | 
            +
                else:
         | 
| 67 | 
            +
                    k = max_resolution / (h * w)
         | 
| 68 | 
            +
                    k = k**0.5
         | 
| 69 | 
            +
                h = int(np.round(h * k / 64)) * 64
         | 
| 70 | 
            +
                w = int(np.round(w * k / 64)) * 64
         | 
| 71 | 
            +
                image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
         | 
| 72 | 
            +
                return image
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def setup_dist(args):
         | 
| 76 | 
            +
                if dist.is_initialized():
         | 
| 77 | 
            +
                    return
         | 
| 78 | 
            +
                torch.cuda.set_device(args.local_rank)
         | 
| 79 | 
            +
                torch.distributed.init_process_group("nccl", init_method="env://")
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def save_videos(batch_tensors, savedir, filenames, fps=16):
         | 
| 83 | 
            +
                # b,samples,c,t,h,w
         | 
| 84 | 
            +
                n_samples = batch_tensors.shape[1]
         | 
| 85 | 
            +
                for idx, vid_tensor in enumerate(batch_tensors):
         | 
| 86 | 
            +
                    video = vid_tensor.detach().cpu()
         | 
| 87 | 
            +
                    video = torch.clamp(video.float(), -1.0, 1.0)
         | 
| 88 | 
            +
                    video = video.permute(2, 0, 1, 3, 4)  # t,n,c,h,w
         | 
| 89 | 
            +
                    frame_grids = [
         | 
| 90 | 
            +
                        torchvision.utils.make_grid(framesheet, nrow=int(n_samples))
         | 
| 91 | 
            +
                        for framesheet in video
         | 
| 92 | 
            +
                    ]  # [3, 1*h, n*w]
         | 
| 93 | 
            +
                    grid = torch.stack(frame_grids, dim=0)  # stack in temporal dim [t, 3, n*h, w]
         | 
| 94 | 
            +
                    grid = (grid + 1.0) / 2.0
         | 
| 95 | 
            +
                    grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
         | 
| 96 | 
            +
                    savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
         | 
| 97 | 
            +
                    torchvision.io.write_video(
         | 
| 98 | 
            +
                        savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"}
         | 
| 99 | 
            +
                    )
         | 
