Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload 9 files
Browse files- scripts/evaluation/__pycache__/funcs.cpython-39.pyc +0 -0
- scripts/evaluation/__pycache__/inference.cpython-39.pyc +0 -0
- scripts/evaluation/ddp_wrapper.py +46 -46
- scripts/evaluation/funcs.py +225 -204
- scripts/evaluation/inference.py +346 -328
- scripts/gradio/__pycache__/i2v_test.cpython-39.pyc +0 -0
- scripts/gradio/i2v_test.py +105 -101
- scripts/run.sh +45 -9
- scripts/run_mp.sh +102 -0
    	
        scripts/evaluation/__pycache__/funcs.cpython-39.pyc
    CHANGED
    
    | Binary files a/scripts/evaluation/__pycache__/funcs.cpython-39.pyc and b/scripts/evaluation/__pycache__/funcs.cpython-39.pyc differ | 
|  | 
    	
        scripts/evaluation/__pycache__/inference.cpython-39.pyc
    ADDED
    
    | Binary file (12.1 kB). View file | 
|  | 
    	
        scripts/evaluation/ddp_wrapper.py
    CHANGED
    
    | @@ -1,47 +1,47 @@ | |
| 1 | 
            -
            import datetime
         | 
| 2 | 
            -
            import argparse, importlib
         | 
| 3 | 
            -
            from pytorch_lightning import seed_everything
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            import torch
         | 
| 6 | 
            -
            import torch.distributed as dist
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            def setup_dist(local_rank):
         | 
| 9 | 
            -
                if dist.is_initialized():
         | 
| 10 | 
            -
                    return
         | 
| 11 | 
            -
                torch.cuda.set_device(local_rank)
         | 
| 12 | 
            -
                torch.distributed.init_process_group('nccl', init_method='env://')
         | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
            def get_dist_info():
         | 
| 16 | 
            -
                if dist.is_available():
         | 
| 17 | 
            -
                    initialized = dist.is_initialized()
         | 
| 18 | 
            -
                else:
         | 
| 19 | 
            -
                    initialized = False
         | 
| 20 | 
            -
                if initialized:
         | 
| 21 | 
            -
                    rank = dist.get_rank()
         | 
| 22 | 
            -
                    world_size = dist.get_world_size()
         | 
| 23 | 
            -
                else:
         | 
| 24 | 
            -
                    rank = 0
         | 
| 25 | 
            -
                    world_size = 1
         | 
| 26 | 
            -
                return rank, world_size
         | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
            if __name__ == '__main__':
         | 
| 30 | 
            -
                now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
         | 
| 31 | 
            -
                parser = argparse.ArgumentParser()
         | 
| 32 | 
            -
                parser.add_argument("--module", type=str, help="module name", default="inference")
         | 
| 33 | 
            -
                parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0)
         | 
| 34 | 
            -
                args, unknown = parser.parse_known_args()
         | 
| 35 | 
            -
                inference_api = importlib.import_module(args.module, package=None)
         | 
| 36 | 
            -
             | 
| 37 | 
            -
                inference_parser = inference_api.get_parser()
         | 
| 38 | 
            -
                inference_args, unknown = inference_parser.parse_known_args()
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                seed_everything(inference_args.seed)
         | 
| 41 | 
            -
                setup_dist(args.local_rank)
         | 
| 42 | 
            -
                torch.backends.cudnn.benchmark = True
         | 
| 43 | 
            -
                rank, gpu_num = get_dist_info()
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                inference_args.savedir = inference_args.savedir+str('_seed')+str(inference_args.seed)
         | 
| 46 | 
            -
                print("@ | 
| 47 | 
             
                inference_api.run_inference(inference_args, gpu_num, rank)
         | 
|  | |
| 1 | 
            +
            import datetime
         | 
| 2 | 
            +
            import argparse, importlib
         | 
| 3 | 
            +
            from pytorch_lightning import seed_everything
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.distributed as dist
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            def setup_dist(local_rank):
         | 
| 9 | 
            +
                if dist.is_initialized():
         | 
| 10 | 
            +
                    return
         | 
| 11 | 
            +
                torch.cuda.set_device(local_rank)
         | 
| 12 | 
            +
                torch.distributed.init_process_group('nccl', init_method='env://')
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def get_dist_info():
         | 
| 16 | 
            +
                if dist.is_available():
         | 
| 17 | 
            +
                    initialized = dist.is_initialized()
         | 
| 18 | 
            +
                else:
         | 
| 19 | 
            +
                    initialized = False
         | 
| 20 | 
            +
                if initialized:
         | 
| 21 | 
            +
                    rank = dist.get_rank()
         | 
| 22 | 
            +
                    world_size = dist.get_world_size()
         | 
| 23 | 
            +
                else:
         | 
| 24 | 
            +
                    rank = 0
         | 
| 25 | 
            +
                    world_size = 1
         | 
| 26 | 
            +
                return rank, world_size
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            if __name__ == '__main__':
         | 
| 30 | 
            +
                now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
         | 
| 31 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 32 | 
            +
                parser.add_argument("--module", type=str, help="module name", default="inference")
         | 
| 33 | 
            +
                parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0)
         | 
| 34 | 
            +
                args, unknown = parser.parse_known_args()
         | 
| 35 | 
            +
                inference_api = importlib.import_module(args.module, package=None)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                inference_parser = inference_api.get_parser()
         | 
| 38 | 
            +
                inference_args, unknown = inference_parser.parse_known_args()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                seed_everything(inference_args.seed)
         | 
| 41 | 
            +
                setup_dist(args.local_rank)
         | 
| 42 | 
            +
                torch.backends.cudnn.benchmark = True
         | 
| 43 | 
            +
                rank, gpu_num = get_dist_info()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                # inference_args.savedir = inference_args.savedir+str('_seed')+str(inference_args.seed)
         | 
| 46 | 
            +
                print("@DynamiCrafter Inference [rank%d]: %s"%(rank, now))
         | 
| 47 | 
             
                inference_api.run_inference(inference_args, gpu_num, rank)
         | 
    	
        scripts/evaluation/funcs.py
    CHANGED
    
    | @@ -1,205 +1,226 @@ | |
| 1 | 
            -
            import os, sys, glob
         | 
| 2 | 
            -
            import numpy as np
         | 
| 3 | 
            -
            from collections import OrderedDict
         | 
| 4 | 
            -
            from decord import VideoReader, cpu
         | 
| 5 | 
            -
            import cv2
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            import torch
         | 
| 8 | 
            -
            import torchvision
         | 
| 9 | 
            -
            sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
         | 
| 10 | 
            -
            from lvdm.models.samplers.ddim import DDIMSampler
         | 
| 11 | 
            -
            from einops import rearrange
         | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
            def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
         | 
| 15 | 
            -
                                    cfg_scale=1.0, temporal_cfg_scale=None, **kwargs):
         | 
| 16 | 
            -
                ddim_sampler = DDIMSampler(model)
         | 
| 17 | 
            -
                uncond_type = model.uncond_type
         | 
| 18 | 
            -
                batch_size = noise_shape[0]
         | 
| 19 | 
            -
                fs = cond["fs"]
         | 
| 20 | 
            -
                del cond["fs"]
         | 
| 21 | 
            -
                 | 
| 22 | 
            -
             | 
| 23 | 
            -
                     | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
                         | 
| 35 | 
            -
                         | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
                    
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                         | 
| 41 | 
            -
                         | 
| 42 | 
            -
             | 
| 43 | 
            -
                         | 
| 44 | 
            -
             | 
| 45 | 
            -
                     | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
                 | 
| 51 | 
            -
                     | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
                                                         | 
| 61 | 
            -
                                                         | 
| 62 | 
            -
                                                         | 
| 63 | 
            -
                                                         | 
| 64 | 
            -
                                                         | 
| 65 | 
            -
                                                         | 
| 66 | 
            -
                                                         | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
                 | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
                 | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
             | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
                         | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
             | 
| 122 | 
            -
             | 
| 123 | 
            -
             | 
| 124 | 
            -
             | 
| 125 | 
            -
             | 
| 126 | 
            -
             | 
| 127 | 
            -
                 | 
| 128 | 
            -
                 | 
| 129 | 
            -
                 | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 | 
            -
             | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 135 | 
            -
             | 
| 136 | 
            -
                     | 
| 137 | 
            -
                    if  | 
| 138 | 
            -
                         | 
| 139 | 
            -
             | 
| 140 | 
            -
             | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
                     | 
| 155 | 
            -
                     | 
| 156 | 
            -
                     | 
| 157 | 
            -
             | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
                     | 
| 165 | 
            -
                     | 
| 166 | 
            -
             | 
| 167 | 
            -
             | 
| 168 | 
            -
             | 
| 169 | 
            -
             | 
| 170 | 
            -
                     | 
| 171 | 
            -
             | 
| 172 | 
            -
                         | 
| 173 | 
            -
                         | 
| 174 | 
            -
                         | 
| 175 | 
            -
             | 
| 176 | 
            -
             | 
| 177 | 
            -
                     | 
| 178 | 
            -
             | 
| 179 | 
            -
             | 
| 180 | 
            -
             | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
| 191 | 
            -
                     | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
             | 
| 198 | 
            -
             | 
| 199 | 
            -
             | 
| 200 | 
            -
             | 
| 201 | 
            -
             | 
| 202 | 
            -
             | 
| 203 | 
            -
                 | 
| 204 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 205 | 
             
                return z
         | 
|  | |
| 1 | 
            +
            import os, sys, glob
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from collections import OrderedDict
         | 
| 4 | 
            +
            from decord import VideoReader, cpu
         | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torchvision
         | 
| 9 | 
            +
            sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
         | 
| 10 | 
            +
            from lvdm.models.samplers.ddim import DDIMSampler
         | 
| 11 | 
            +
            from einops import rearrange
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
         | 
| 15 | 
            +
                                    cfg_scale=1.0, temporal_cfg_scale=None, **kwargs):
         | 
| 16 | 
            +
                ddim_sampler = DDIMSampler(model)
         | 
| 17 | 
            +
                uncond_type = model.uncond_type
         | 
| 18 | 
            +
                batch_size = noise_shape[0]
         | 
| 19 | 
            +
                fs = cond["fs"]
         | 
| 20 | 
            +
                del cond["fs"]
         | 
| 21 | 
            +
                if noise_shape[-1] == 32:
         | 
| 22 | 
            +
                    timestep_spacing = "uniform"
         | 
| 23 | 
            +
                    guidance_rescale = 0.0
         | 
| 24 | 
            +
                else:
         | 
| 25 | 
            +
                    timestep_spacing = "uniform_trailing"
         | 
| 26 | 
            +
                    guidance_rescale = 0.7
         | 
| 27 | 
            +
                ## construct unconditional guidance
         | 
| 28 | 
            +
                if cfg_scale != 1.0:
         | 
| 29 | 
            +
                    if uncond_type == "empty_seq":
         | 
| 30 | 
            +
                        prompts = batch_size * [""]
         | 
| 31 | 
            +
                        #prompts = N * T * [""]  ## if is_imgbatch=True
         | 
| 32 | 
            +
                        uc_emb = model.get_learned_conditioning(prompts)
         | 
| 33 | 
            +
                    elif uncond_type == "zero_embed":
         | 
| 34 | 
            +
                        c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
         | 
| 35 | 
            +
                        uc_emb = torch.zeros_like(c_emb)
         | 
| 36 | 
            +
                            
         | 
| 37 | 
            +
                    ## process image embedding token
         | 
| 38 | 
            +
                    if hasattr(model, 'embedder'):
         | 
| 39 | 
            +
                        uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
         | 
| 40 | 
            +
                        ## img: b c h w >> b l c
         | 
| 41 | 
            +
                        uc_img = model.embedder(uc_img)
         | 
| 42 | 
            +
                        uc_img = model.image_proj_model(uc_img)
         | 
| 43 | 
            +
                        uc_emb = torch.cat([uc_emb, uc_img], dim=1)
         | 
| 44 | 
            +
                    
         | 
| 45 | 
            +
                    if isinstance(cond, dict):
         | 
| 46 | 
            +
                        uc = {key:cond[key] for key in cond.keys()}
         | 
| 47 | 
            +
                        uc.update({'c_crossattn': [uc_emb]})
         | 
| 48 | 
            +
                    else:
         | 
| 49 | 
            +
                        uc = uc_emb
         | 
| 50 | 
            +
                else:
         | 
| 51 | 
            +
                    uc = None
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                x_T = None
         | 
| 54 | 
            +
                batch_variants = []
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                for _ in range(n_samples):
         | 
| 57 | 
            +
                    if ddim_sampler is not None:
         | 
| 58 | 
            +
                        kwargs.update({"clean_cond": True})
         | 
| 59 | 
            +
                        samples, _ = ddim_sampler.sample(S=ddim_steps,
         | 
| 60 | 
            +
                                                        conditioning=cond,
         | 
| 61 | 
            +
                                                        batch_size=noise_shape[0],
         | 
| 62 | 
            +
                                                        shape=noise_shape[1:],
         | 
| 63 | 
            +
                                                        verbose=False,
         | 
| 64 | 
            +
                                                        unconditional_guidance_scale=cfg_scale,
         | 
| 65 | 
            +
                                                        unconditional_conditioning=uc,
         | 
| 66 | 
            +
                                                        eta=ddim_eta,
         | 
| 67 | 
            +
                                                        temporal_length=noise_shape[2],
         | 
| 68 | 
            +
                                                        conditional_guidance_scale_temporal=temporal_cfg_scale,
         | 
| 69 | 
            +
                                                        x_T=x_T,
         | 
| 70 | 
            +
                                                        fs=fs,
         | 
| 71 | 
            +
                                                        timestep_spacing=timestep_spacing,
         | 
| 72 | 
            +
                                                        guidance_rescale=guidance_rescale,
         | 
| 73 | 
            +
                                                        **kwargs
         | 
| 74 | 
            +
                                                        )
         | 
| 75 | 
            +
                    ## reconstruct from latent to pixel space
         | 
| 76 | 
            +
                    batch_images = model.decode_first_stage(samples)
         | 
| 77 | 
            +
                    batch_variants.append(batch_images)
         | 
| 78 | 
            +
                ## batch, <samples>, c, t, h, w
         | 
| 79 | 
            +
                batch_variants = torch.stack(batch_variants, dim=1)
         | 
| 80 | 
            +
                return batch_variants
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def get_filelist(data_dir, ext='*'):
         | 
| 84 | 
            +
                file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
         | 
| 85 | 
            +
                file_list.sort()
         | 
| 86 | 
            +
                return file_list
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            def get_dirlist(path):
         | 
| 89 | 
            +
                list = []
         | 
| 90 | 
            +
                if (os.path.exists(path)):
         | 
| 91 | 
            +
                    files = os.listdir(path)
         | 
| 92 | 
            +
                    for file in files:
         | 
| 93 | 
            +
                        m = os.path.join(path,file)
         | 
| 94 | 
            +
                        if (os.path.isdir(m)):
         | 
| 95 | 
            +
                            list.append(m)
         | 
| 96 | 
            +
                list.sort()
         | 
| 97 | 
            +
                return list
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def load_model_checkpoint(model, ckpt):
         | 
| 101 | 
            +
                def load_checkpoint(model, ckpt, full_strict):
         | 
| 102 | 
            +
                    state_dict = torch.load(ckpt, map_location="cpu")
         | 
| 103 | 
            +
                    if "state_dict" in list(state_dict.keys()):
         | 
| 104 | 
            +
                        state_dict = state_dict["state_dict"]
         | 
| 105 | 
            +
                        try:
         | 
| 106 | 
            +
                            model.load_state_dict(state_dict, strict=full_strict)
         | 
| 107 | 
            +
                        except:
         | 
| 108 | 
            +
                            ## rename the keys for 256x256 model
         | 
| 109 | 
            +
                            new_pl_sd = OrderedDict()
         | 
| 110 | 
            +
                            for k,v in state_dict.items():
         | 
| 111 | 
            +
                                new_pl_sd[k] = v
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                            for k in list(new_pl_sd.keys()):
         | 
| 114 | 
            +
                                if "framestride_embed" in k:
         | 
| 115 | 
            +
                                    new_key = k.replace("framestride_embed", "fps_embedding")
         | 
| 116 | 
            +
                                    new_pl_sd[new_key] = new_pl_sd[k]
         | 
| 117 | 
            +
                                    del new_pl_sd[k]
         | 
| 118 | 
            +
                            model.load_state_dict(new_pl_sd, strict=full_strict)
         | 
| 119 | 
            +
                    else:
         | 
| 120 | 
            +
                        ## deepspeed
         | 
| 121 | 
            +
                        new_pl_sd = OrderedDict()
         | 
| 122 | 
            +
                        for key in state_dict['module'].keys():
         | 
| 123 | 
            +
                            new_pl_sd[key[16:]]=state_dict['module'][key]
         | 
| 124 | 
            +
                        model.load_state_dict(new_pl_sd, strict=full_strict)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    return model
         | 
| 127 | 
            +
                load_checkpoint(model, ckpt, full_strict=True)
         | 
| 128 | 
            +
                print('>>> model checkpoint loaded.')
         | 
| 129 | 
            +
                return model
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            def load_prompts(prompt_file):
         | 
| 133 | 
            +
                f = open(prompt_file, 'r')
         | 
| 134 | 
            +
                prompt_list = []
         | 
| 135 | 
            +
                for idx, line in enumerate(f.readlines()):
         | 
| 136 | 
            +
                    l = line.strip()
         | 
| 137 | 
            +
                    if len(l) != 0:
         | 
| 138 | 
            +
                        prompt_list.append(l)
         | 
| 139 | 
            +
                    f.close()
         | 
| 140 | 
            +
                return prompt_list
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
         | 
| 144 | 
            +
                '''
         | 
| 145 | 
            +
                Notice about some special cases:
         | 
| 146 | 
            +
                1. video_frames=-1 means to take all the frames (with fs=1)
         | 
| 147 | 
            +
                2. when the total video frames is less than required, padding strategy will be used (repreated last frame)
         | 
| 148 | 
            +
                '''
         | 
| 149 | 
            +
                fps_list = []
         | 
| 150 | 
            +
                batch_tensor = []
         | 
| 151 | 
            +
                assert frame_stride > 0, "valid frame stride should be a positive interge!"
         | 
| 152 | 
            +
                for filepath in filepath_list:
         | 
| 153 | 
            +
                    padding_num = 0
         | 
| 154 | 
            +
                    vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
         | 
| 155 | 
            +
                    fps = vidreader.get_avg_fps()
         | 
| 156 | 
            +
                    total_frames = len(vidreader)
         | 
| 157 | 
            +
                    max_valid_frames = (total_frames-1) // frame_stride + 1
         | 
| 158 | 
            +
                    if video_frames < 0:
         | 
| 159 | 
            +
                        ## all frames are collected: fs=1 is a must
         | 
| 160 | 
            +
                        required_frames = total_frames
         | 
| 161 | 
            +
                        frame_stride = 1
         | 
| 162 | 
            +
                    else:
         | 
| 163 | 
            +
                        required_frames = video_frames
         | 
| 164 | 
            +
                    query_frames = min(required_frames, max_valid_frames)
         | 
| 165 | 
            +
                    frame_indices = [frame_stride*i for i in range(query_frames)]
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    ## [t,h,w,c] -> [c,t,h,w]
         | 
| 168 | 
            +
                    frames = vidreader.get_batch(frame_indices)
         | 
| 169 | 
            +
                    frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
         | 
| 170 | 
            +
                    frame_tensor = (frame_tensor / 255. - 0.5) * 2
         | 
| 171 | 
            +
                    if max_valid_frames < required_frames:
         | 
| 172 | 
            +
                        padding_num = required_frames - max_valid_frames
         | 
| 173 | 
            +
                        frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
         | 
| 174 | 
            +
                        print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
         | 
| 175 | 
            +
                    batch_tensor.append(frame_tensor)
         | 
| 176 | 
            +
                    sample_fps = int(fps/frame_stride)
         | 
| 177 | 
            +
                    fps_list.append(sample_fps)
         | 
| 178 | 
            +
                
         | 
| 179 | 
            +
                return torch.stack(batch_tensor, dim=0)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
            from PIL import Image
         | 
| 182 | 
            +
            def load_image_batch(filepath_list, image_size=(256,256)):
         | 
| 183 | 
            +
                batch_tensor = []
         | 
| 184 | 
            +
                for filepath in filepath_list:
         | 
| 185 | 
            +
                    _, filename = os.path.split(filepath)
         | 
| 186 | 
            +
                    _, ext = os.path.splitext(filename)
         | 
| 187 | 
            +
                    if ext == '.mp4':
         | 
| 188 | 
            +
                        vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
         | 
| 189 | 
            +
                        frame = vidreader.get_batch([0])
         | 
| 190 | 
            +
                        img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
         | 
| 191 | 
            +
                    elif ext == '.png' or ext == '.jpg':
         | 
| 192 | 
            +
                        img = Image.open(filepath).convert("RGB")
         | 
| 193 | 
            +
                        rgb_img = np.array(img, np.float32)
         | 
| 194 | 
            +
                        #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
         | 
| 195 | 
            +
                        #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
         | 
| 196 | 
            +
                        rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
         | 
| 197 | 
            +
                        img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
         | 
| 198 | 
            +
                    else:
         | 
| 199 | 
            +
                        print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
         | 
| 200 | 
            +
                        raise NotImplementedError
         | 
| 201 | 
            +
                    img_tensor = (img_tensor / 255. - 0.5) * 2
         | 
| 202 | 
            +
                    batch_tensor.append(img_tensor)
         | 
| 203 | 
            +
                return torch.stack(batch_tensor, dim=0)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            def save_videos(batch_tensors, savedir, filenames, fps=10):
         | 
| 207 | 
            +
                # b,samples,c,t,h,w
         | 
| 208 | 
            +
                n_samples = batch_tensors.shape[1]
         | 
| 209 | 
            +
                for idx, vid_tensor in enumerate(batch_tensors):
         | 
| 210 | 
            +
                    video = vid_tensor.detach().cpu()
         | 
| 211 | 
            +
                    video = torch.clamp(video.float(), -1., 1.)
         | 
| 212 | 
            +
                    video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
         | 
| 213 | 
            +
                    frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
         | 
| 214 | 
            +
                    grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
         | 
| 215 | 
            +
                    grid = (grid + 1.0) / 2.0
         | 
| 216 | 
            +
                    grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
         | 
| 217 | 
            +
                    savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
         | 
| 218 | 
            +
                    torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
         | 
| 219 | 
            +
             | 
| 220 | 
            +
             | 
| 221 | 
            +
            def get_latent_z(model, videos):
         | 
| 222 | 
            +
                b, c, t, h, w = videos.shape
         | 
| 223 | 
            +
                x = rearrange(videos, 'b c t h w -> (b t) c h w')
         | 
| 224 | 
            +
                z = model.encode_first_stage(x)
         | 
| 225 | 
            +
                z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
         | 
| 226 | 
             
                return z
         | 
    	
        scripts/evaluation/inference.py
    CHANGED
    
    | @@ -1,329 +1,347 @@ | |
| 1 | 
            -
            import argparse, os, sys, glob
         | 
| 2 | 
            -
            import datetime, time
         | 
| 3 | 
            -
            from omegaconf import OmegaConf
         | 
| 4 | 
            -
            from tqdm import tqdm
         | 
| 5 | 
            -
            from einops import rearrange, repeat
         | 
| 6 | 
            -
            from collections import OrderedDict
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            import torch
         | 
| 9 | 
            -
            import torchvision
         | 
| 10 | 
            -
            import torchvision.transforms as transforms
         | 
| 11 | 
            -
            from pytorch_lightning import seed_everything
         | 
| 12 | 
            -
            from PIL import Image
         | 
| 13 | 
            -
            sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
         | 
| 14 | 
            -
            from lvdm.models.samplers.ddim import DDIMSampler
         | 
| 15 | 
            -
            from lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond
         | 
| 16 | 
            -
            from utils.utils import instantiate_from_config
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
            def get_filelist(data_dir, postfixes):
         | 
| 20 | 
            -
                patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes]
         | 
| 21 | 
            -
                file_list = []
         | 
| 22 | 
            -
                for pattern in patterns:
         | 
| 23 | 
            -
                    file_list.extend(glob.glob(pattern))
         | 
| 24 | 
            -
                file_list.sort()
         | 
| 25 | 
            -
                return file_list
         | 
| 26 | 
            -
             | 
| 27 | 
            -
            def load_model_checkpoint(model, ckpt):
         | 
| 28 | 
            -
                state_dict = torch.load(ckpt, map_location="cpu")
         | 
| 29 | 
            -
                if "state_dict" in list(state_dict.keys()):
         | 
| 30 | 
            -
                    state_dict = state_dict["state_dict"]
         | 
| 31 | 
            -
                     | 
| 32 | 
            -
             | 
| 33 | 
            -
                     | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
                         | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
                     | 
| 47 | 
            -
             | 
| 48 | 
            -
                     | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
                 | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
                 | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
                 | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
                 | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
                 | 
| 71 | 
            -
                 | 
| 72 | 
            -
                 | 
| 73 | 
            -
                 | 
| 74 | 
            -
                 | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
                     | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
                 | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
| 87 | 
            -
                 | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
                     | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
                     | 
| 108 | 
            -
             | 
| 109 | 
            -
                     | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
| 119 | 
            -
                     | 
| 120 | 
            -
             | 
| 121 | 
            -
                     | 
| 122 | 
            -
                     | 
| 123 | 
            -
             | 
| 124 | 
            -
             | 
| 125 | 
            -
             | 
| 126 | 
            -
             | 
| 127 | 
            -
             | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 | 
            -
             | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 135 | 
            -
             | 
| 136 | 
            -
             | 
| 137 | 
            -
             | 
| 138 | 
            -
             | 
| 139 | 
            -
             | 
| 140 | 
            -
             | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
                 | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
                 | 
| 152 | 
            -
                 | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
             | 
| 156 | 
            -
             | 
| 157 | 
            -
                 | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
             | 
| 165 | 
            -
             | 
| 166 | 
            -
             | 
| 167 | 
            -
             | 
| 168 | 
            -
                 | 
| 169 | 
            -
             | 
| 170 | 
            -
             | 
| 171 | 
            -
             | 
| 172 | 
            -
                     | 
| 173 | 
            -
                         | 
| 174 | 
            -
             | 
| 175 | 
            -
             | 
| 176 | 
            -
                     | 
| 177 | 
            -
             | 
| 178 | 
            -
                         | 
| 179 | 
            -
             | 
| 180 | 
            -
             | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 184 | 
            -
             | 
| 185 | 
            -
                     | 
| 186 | 
            -
                         | 
| 187 | 
            -
                     | 
| 188 | 
            -
             | 
| 189 | 
            -
                     | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
                 | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
| 195 | 
            -
                 | 
| 196 | 
            -
             | 
| 197 | 
            -
                     | 
| 198 | 
            -
             | 
| 199 | 
            -
                         | 
| 200 | 
            -
                     | 
| 201 | 
            -
             | 
| 202 | 
            -
                     | 
| 203 | 
            -
             | 
| 204 | 
            -
             | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
| 207 | 
            -
             | 
| 208 | 
            -
             | 
| 209 | 
            -
             | 
| 210 | 
            -
             | 
| 211 | 
            -
             | 
| 212 | 
            -
             | 
| 213 | 
            -
             | 
| 214 | 
            -
             | 
| 215 | 
            -
             | 
| 216 | 
            -
             | 
| 217 | 
            -
             | 
| 218 | 
            -
             | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
             | 
| 222 | 
            -
             | 
| 223 | 
            -
             | 
| 224 | 
            -
             | 
| 225 | 
            -
             | 
| 226 | 
            -
             | 
| 227 | 
            -
             | 
| 228 | 
            -
             | 
| 229 | 
            -
             | 
| 230 | 
            -
             | 
| 231 | 
            -
             | 
| 232 | 
            -
             | 
| 233 | 
            -
             | 
| 234 | 
            -
             | 
| 235 | 
            -
             | 
| 236 | 
            -
             | 
| 237 | 
            -
                 | 
| 238 | 
            -
                 | 
| 239 | 
            -
                 | 
| 240 | 
            -
             | 
| 241 | 
            -
             | 
| 242 | 
            -
             | 
| 243 | 
            -
                 | 
| 244 | 
            -
                 | 
| 245 | 
            -
                 | 
| 246 | 
            -
                 | 
| 247 | 
            -
                 | 
| 248 | 
            -
                 | 
| 249 | 
            -
                 | 
| 250 | 
            -
             | 
| 251 | 
            -
                 | 
| 252 | 
            -
                 | 
| 253 | 
            -
             | 
| 254 | 
            -
                 | 
| 255 | 
            -
             | 
| 256 | 
            -
             | 
| 257 | 
            -
                 | 
| 258 | 
            -
                assert  | 
| 259 | 
            -
                 | 
| 260 | 
            -
                 | 
| 261 | 
            -
                 | 
| 262 | 
            -
                 | 
| 263 | 
            -
                 | 
| 264 | 
            -
                 | 
| 265 | 
            -
             | 
| 266 | 
            -
                 | 
| 267 | 
            -
                 | 
| 268 | 
            -
             | 
| 269 | 
            -
                 | 
| 270 | 
            -
                 | 
| 271 | 
            -
             | 
| 272 | 
            -
             | 
| 273 | 
            -
             | 
| 274 | 
            -
             | 
| 275 | 
            -
             | 
| 276 | 
            -
             | 
| 277 | 
            -
             | 
| 278 | 
            -
             | 
| 279 | 
            -
             | 
| 280 | 
            -
             | 
| 281 | 
            -
             | 
| 282 | 
            -
             | 
| 283 | 
            -
             | 
| 284 | 
            -
             | 
| 285 | 
            -
             | 
| 286 | 
            -
             | 
| 287 | 
            -
             | 
| 288 | 
            -
             | 
| 289 | 
            -
             | 
| 290 | 
            -
             | 
| 291 | 
            -
             | 
| 292 | 
            -
             | 
| 293 | 
            -
             | 
| 294 | 
            -
             | 
| 295 | 
            -
             | 
| 296 | 
            -
             | 
| 297 | 
            -
             | 
| 298 | 
            -
             | 
| 299 | 
            -
             | 
| 300 | 
            -
             | 
| 301 | 
            -
             | 
| 302 | 
            -
             | 
| 303 | 
            -
             | 
| 304 | 
            -
             | 
| 305 | 
            -
             | 
| 306 | 
            -
                 | 
| 307 | 
            -
             | 
| 308 | 
            -
             | 
| 309 | 
            -
             | 
| 310 | 
            -
                parser | 
| 311 | 
            -
                parser.add_argument("-- | 
| 312 | 
            -
                parser.add_argument("-- | 
| 313 | 
            -
                parser.add_argument("-- | 
| 314 | 
            -
             | 
| 315 | 
            -
                 | 
| 316 | 
            -
                parser.add_argument("-- | 
| 317 | 
            -
                parser.add_argument("-- | 
| 318 | 
            -
                 | 
| 319 | 
            -
             | 
| 320 | 
            -
             | 
| 321 | 
            -
             | 
| 322 | 
            -
                 | 
| 323 | 
            -
                 | 
| 324 | 
            -
                parser =  | 
| 325 | 
            -
                 | 
| 326 | 
            -
                
         | 
| 327 | 
            -
                 | 
| 328 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 329 | 
             
                run_inference(args, gpu_num, rank)
         | 
|  | |
| 1 | 
            +
            import argparse, os, sys, glob
         | 
| 2 | 
            +
            import datetime, time
         | 
| 3 | 
            +
            from omegaconf import OmegaConf
         | 
| 4 | 
            +
            from tqdm import tqdm
         | 
| 5 | 
            +
            from einops import rearrange, repeat
         | 
| 6 | 
            +
            from collections import OrderedDict
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torchvision
         | 
| 10 | 
            +
            import torchvision.transforms as transforms
         | 
| 11 | 
            +
            from pytorch_lightning import seed_everything
         | 
| 12 | 
            +
            from PIL import Image
         | 
| 13 | 
            +
            sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
         | 
| 14 | 
            +
            from lvdm.models.samplers.ddim import DDIMSampler
         | 
| 15 | 
            +
            from lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond
         | 
| 16 | 
            +
            from utils.utils import instantiate_from_config
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def get_filelist(data_dir, postfixes):
         | 
| 20 | 
            +
                patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes]
         | 
| 21 | 
            +
                file_list = []
         | 
| 22 | 
            +
                for pattern in patterns:
         | 
| 23 | 
            +
                    file_list.extend(glob.glob(pattern))
         | 
| 24 | 
            +
                file_list.sort()
         | 
| 25 | 
            +
                return file_list
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def load_model_checkpoint(model, ckpt):
         | 
| 28 | 
            +
                state_dict = torch.load(ckpt, map_location="cpu")
         | 
| 29 | 
            +
                if "state_dict" in list(state_dict.keys()):
         | 
| 30 | 
            +
                    state_dict = state_dict["state_dict"]
         | 
| 31 | 
            +
                    try:
         | 
| 32 | 
            +
                        model.load_state_dict(state_dict, strict=True)
         | 
| 33 | 
            +
                    except:
         | 
| 34 | 
            +
                        ## rename the keys for 256x256 model
         | 
| 35 | 
            +
                        new_pl_sd = OrderedDict()
         | 
| 36 | 
            +
                        for k,v in state_dict.items():
         | 
| 37 | 
            +
                            new_pl_sd[k] = v
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                        for k in list(new_pl_sd.keys()):
         | 
| 40 | 
            +
                            if "framestride_embed" in k:
         | 
| 41 | 
            +
                                new_key = k.replace("framestride_embed", "fps_embedding")
         | 
| 42 | 
            +
                                new_pl_sd[new_key] = new_pl_sd[k]
         | 
| 43 | 
            +
                                del new_pl_sd[k]
         | 
| 44 | 
            +
                        model.load_state_dict(new_pl_sd, strict=True)
         | 
| 45 | 
            +
                else:
         | 
| 46 | 
            +
                    # deepspeed
         | 
| 47 | 
            +
                    new_pl_sd = OrderedDict()
         | 
| 48 | 
            +
                    for key in state_dict['module'].keys():
         | 
| 49 | 
            +
                        new_pl_sd[key[16:]]=state_dict['module'][key]
         | 
| 50 | 
            +
                    model.load_state_dict(new_pl_sd)
         | 
| 51 | 
            +
                print('>>> model checkpoint loaded.')
         | 
| 52 | 
            +
                return model
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            def load_prompts(prompt_file):
         | 
| 55 | 
            +
                f = open(prompt_file, 'r')
         | 
| 56 | 
            +
                prompt_list = []
         | 
| 57 | 
            +
                for idx, line in enumerate(f.readlines()):
         | 
| 58 | 
            +
                    l = line.strip()
         | 
| 59 | 
            +
                    if len(l) != 0:
         | 
| 60 | 
            +
                        prompt_list.append(l)
         | 
| 61 | 
            +
                    f.close()
         | 
| 62 | 
            +
                return prompt_list
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, gfi=False):
         | 
| 65 | 
            +
                transform = transforms.Compose([
         | 
| 66 | 
            +
                    transforms.Resize(min(video_size)),
         | 
| 67 | 
            +
                    transforms.CenterCrop(video_size),
         | 
| 68 | 
            +
                    transforms.ToTensor(),
         | 
| 69 | 
            +
                    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
         | 
| 70 | 
            +
                ## load prompts
         | 
| 71 | 
            +
                prompt_file = get_filelist(data_dir, ['txt'])
         | 
| 72 | 
            +
                assert len(prompt_file) > 0, "Error: found NO prompt file!"
         | 
| 73 | 
            +
                ###### default prompt
         | 
| 74 | 
            +
                default_idx = 0
         | 
| 75 | 
            +
                default_idx = min(default_idx, len(prompt_file)-1)
         | 
| 76 | 
            +
                if len(prompt_file) > 1:
         | 
| 77 | 
            +
                    print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.")
         | 
| 78 | 
            +
                ## only use the first one (sorted by name) if multiple exist
         | 
| 79 | 
            +
                
         | 
| 80 | 
            +
                ## load video
         | 
| 81 | 
            +
                file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
         | 
| 82 | 
            +
                # assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!"
         | 
| 83 | 
            +
                data_list = []
         | 
| 84 | 
            +
                filename_list = []
         | 
| 85 | 
            +
                prompt_list = load_prompts(prompt_file[default_idx])
         | 
| 86 | 
            +
                n_samples = len(prompt_list)
         | 
| 87 | 
            +
                for idx in range(n_samples):
         | 
| 88 | 
            +
                    image = Image.open(file_list[idx]).convert('RGB')
         | 
| 89 | 
            +
                    image_tensor = transform(image).unsqueeze(1) # [c,1,h,w]
         | 
| 90 | 
            +
                    frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    data_list.append(frame_tensor)
         | 
| 93 | 
            +
                    _, filename = os.path.split(file_list[idx])
         | 
| 94 | 
            +
                    filename_list.append(filename)
         | 
| 95 | 
            +
                    
         | 
| 96 | 
            +
                return filename_list, data_list, prompt_list
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def save_results(prompt, samples, filename, fakedir, fps=8, loop=False):
         | 
| 100 | 
            +
                filename = filename.split('.')[0]+'.mp4'
         | 
| 101 | 
            +
                prompt = prompt[0] if isinstance(prompt, list) else prompt
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                ## save video
         | 
| 104 | 
            +
                videos = [samples]
         | 
| 105 | 
            +
                savedirs = [fakedir]
         | 
| 106 | 
            +
                for idx, video in enumerate(videos):
         | 
| 107 | 
            +
                    if video is None:
         | 
| 108 | 
            +
                        continue
         | 
| 109 | 
            +
                    # b,c,t,h,w
         | 
| 110 | 
            +
                    video = video.detach().cpu()
         | 
| 111 | 
            +
                    video = torch.clamp(video.float(), -1., 1.)
         | 
| 112 | 
            +
                    n = video.shape[0]
         | 
| 113 | 
            +
                    video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
         | 
| 114 | 
            +
                    if loop:
         | 
| 115 | 
            +
                        video = video[:-1,...]
         | 
| 116 | 
            +
                    
         | 
| 117 | 
            +
                    frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, 1*h, n*w]
         | 
| 118 | 
            +
                    grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w]
         | 
| 119 | 
            +
                    grid = (grid + 1.0) / 2.0
         | 
| 120 | 
            +
                    grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
         | 
| 121 | 
            +
                    path = os.path.join(savedirs[idx], filename)
         | 
| 122 | 
            +
                    torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) ## crf indicates the quality
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False):
         | 
| 126 | 
            +
                prompt = prompt[0] if isinstance(prompt, list) else prompt
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                ## save video
         | 
| 129 | 
            +
                videos = [samples]
         | 
| 130 | 
            +
                savedirs = [fakedir]
         | 
| 131 | 
            +
                for idx, video in enumerate(videos):
         | 
| 132 | 
            +
                    if video is None:
         | 
| 133 | 
            +
                        continue
         | 
| 134 | 
            +
                    # b,c,t,h,w
         | 
| 135 | 
            +
                    video = video.detach().cpu()
         | 
| 136 | 
            +
                    if loop: # remove the last frame
         | 
| 137 | 
            +
                        video = video[:,:,:-1,...]
         | 
| 138 | 
            +
                    video = torch.clamp(video.float(), -1., 1.)
         | 
| 139 | 
            +
                    n = video.shape[0]
         | 
| 140 | 
            +
                    for i in range(n):
         | 
| 141 | 
            +
                        grid = video[i,...]
         | 
| 142 | 
            +
                        grid = (grid + 1.0) / 2.0
         | 
| 143 | 
            +
                        grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc
         | 
| 144 | 
            +
                        path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), f'{filename.split(".")[0]}_sample{i}.mp4')
         | 
| 145 | 
            +
                        torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'})
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            def get_latent_z(model, videos):
         | 
| 148 | 
            +
                b, c, t, h, w = videos.shape
         | 
| 149 | 
            +
                x = rearrange(videos, 'b c t h w -> (b t) c h w')
         | 
| 150 | 
            +
                z = model.encode_first_stage(x)
         | 
| 151 | 
            +
                z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
         | 
| 152 | 
            +
                return z
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
         | 
| 156 | 
            +
                                    unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, gfi=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
         | 
| 157 | 
            +
                ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
         | 
| 158 | 
            +
                batch_size = noise_shape[0]
         | 
| 159 | 
            +
                fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                if not text_input:
         | 
| 162 | 
            +
                    prompts = [""]*batch_size
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                img = videos[:,:,0] #bchw
         | 
| 165 | 
            +
                img_emb = model.embedder(img) ## blc
         | 
| 166 | 
            +
                img_emb = model.image_proj_model(img_emb)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                cond_emb = model.get_learned_conditioning(prompts)
         | 
| 169 | 
            +
                cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
         | 
| 170 | 
            +
                if model.model.conditioning_key == 'hybrid':
         | 
| 171 | 
            +
                    z = get_latent_z(model, videos) # b c t h w
         | 
| 172 | 
            +
                    if loop or gfi:
         | 
| 173 | 
            +
                        img_cat_cond = torch.zeros_like(z)
         | 
| 174 | 
            +
                        img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
         | 
| 175 | 
            +
                        img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
         | 
| 176 | 
            +
                    else:
         | 
| 177 | 
            +
                        img_cat_cond = z[:,:,:1,:,:]
         | 
| 178 | 
            +
                        img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
         | 
| 179 | 
            +
                    cond["c_concat"] = [img_cat_cond] # b c 1 h w
         | 
| 180 | 
            +
                
         | 
| 181 | 
            +
                if unconditional_guidance_scale != 1.0:
         | 
| 182 | 
            +
                    if model.uncond_type == "empty_seq":
         | 
| 183 | 
            +
                        prompts = batch_size * [""]
         | 
| 184 | 
            +
                        uc_emb = model.get_learned_conditioning(prompts)
         | 
| 185 | 
            +
                    elif model.uncond_type == "zero_embed":
         | 
| 186 | 
            +
                        uc_emb = torch.zeros_like(cond_emb)
         | 
| 187 | 
            +
                    uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
         | 
| 188 | 
            +
                    uc_img_emb = model.image_proj_model(uc_img_emb)
         | 
| 189 | 
            +
                    uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
         | 
| 190 | 
            +
                    if model.model.conditioning_key == 'hybrid':
         | 
| 191 | 
            +
                        uc["c_concat"] = [img_cat_cond]
         | 
| 192 | 
            +
                else:
         | 
| 193 | 
            +
                    uc = None
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                ## we need one more unconditioning image=yes, text=""
         | 
| 196 | 
            +
                if multiple_cond_cfg and cfg_img != 1.0:
         | 
| 197 | 
            +
                    uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
         | 
| 198 | 
            +
                    if model.model.conditioning_key == 'hybrid':
         | 
| 199 | 
            +
                        uc_2["c_concat"] = [img_cat_cond]
         | 
| 200 | 
            +
                    kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
         | 
| 201 | 
            +
                else:
         | 
| 202 | 
            +
                    kwargs.update({"unconditional_conditioning_img_nonetext": None})
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                z0 = None
         | 
| 205 | 
            +
                cond_mask = None
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                batch_variants = []
         | 
| 208 | 
            +
                for _ in range(n_samples):
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    if z0 is not None:
         | 
| 211 | 
            +
                        cond_z0 = z0.clone()
         | 
| 212 | 
            +
                        kwargs.update({"clean_cond": True})
         | 
| 213 | 
            +
                    else:
         | 
| 214 | 
            +
                        cond_z0 = None
         | 
| 215 | 
            +
                    if ddim_sampler is not None:
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                        samples, _ = ddim_sampler.sample(S=ddim_steps,
         | 
| 218 | 
            +
                                                        conditioning=cond,
         | 
| 219 | 
            +
                                                        batch_size=batch_size,
         | 
| 220 | 
            +
                                                        shape=noise_shape[1:],
         | 
| 221 | 
            +
                                                        verbose=False,
         | 
| 222 | 
            +
                                                        unconditional_guidance_scale=unconditional_guidance_scale,
         | 
| 223 | 
            +
                                                        unconditional_conditioning=uc,
         | 
| 224 | 
            +
                                                        eta=ddim_eta,
         | 
| 225 | 
            +
                                                        cfg_img=cfg_img, 
         | 
| 226 | 
            +
                                                        mask=cond_mask,
         | 
| 227 | 
            +
                                                        x0=cond_z0,
         | 
| 228 | 
            +
                                                        fs=fs,
         | 
| 229 | 
            +
                                                        timestep_spacing=timestep_spacing,
         | 
| 230 | 
            +
                                                        guidance_rescale=guidance_rescale,
         | 
| 231 | 
            +
                                                        **kwargs
         | 
| 232 | 
            +
                                                        )
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    ## reconstruct from latent to pixel space
         | 
| 235 | 
            +
                    batch_images = model.decode_first_stage(samples)
         | 
| 236 | 
            +
                    batch_variants.append(batch_images)
         | 
| 237 | 
            +
                ## variants, batch, c, t, h, w
         | 
| 238 | 
            +
                batch_variants = torch.stack(batch_variants)
         | 
| 239 | 
            +
                return batch_variants.permute(1, 0, 2, 3, 4, 5)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            def run_inference(args, gpu_num, gpu_no):
         | 
| 243 | 
            +
                ## model config
         | 
| 244 | 
            +
                config = OmegaConf.load(args.config)
         | 
| 245 | 
            +
                model_config = config.pop("model", OmegaConf.create())
         | 
| 246 | 
            +
                
         | 
| 247 | 
            +
                ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
         | 
| 248 | 
            +
                model_config['params']['unet_config']['params']['use_checkpoint'] = False
         | 
| 249 | 
            +
                model = instantiate_from_config(model_config)
         | 
| 250 | 
            +
                model = model.cuda(gpu_no)
         | 
| 251 | 
            +
                model.perframe_ae = args.perframe_ae
         | 
| 252 | 
            +
                assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
         | 
| 253 | 
            +
                model = load_model_checkpoint(model, args.ckpt_path)
         | 
| 254 | 
            +
                model.eval()
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                ## run over data
         | 
| 257 | 
            +
                assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
         | 
| 258 | 
            +
                assert args.bs == 1, "Current implementation only support [batch size = 1]!"
         | 
| 259 | 
            +
                ## latent noise shape
         | 
| 260 | 
            +
                h, w = args.height // 8, args.width // 8
         | 
| 261 | 
            +
                channels = model.model.diffusion_model.out_channels
         | 
| 262 | 
            +
                n_frames = args.video_length
         | 
| 263 | 
            +
                print(f'Inference with {n_frames} frames')
         | 
| 264 | 
            +
                noise_shape = [args.bs, channels, n_frames, h, w]
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                fakedir = os.path.join(args.savedir, "samples")
         | 
| 267 | 
            +
                fakedir_separate = os.path.join(args.savedir, "samples_separate")
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                # os.makedirs(fakedir, exist_ok=True)
         | 
| 270 | 
            +
                os.makedirs(fakedir_separate, exist_ok=True)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                ## prompt file setting
         | 
| 273 | 
            +
                assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
         | 
| 274 | 
            +
                filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, gfi=args.gfi)
         | 
| 275 | 
            +
                num_samples = len(prompt_list)
         | 
| 276 | 
            +
                samples_split = num_samples // gpu_num
         | 
| 277 | 
            +
                print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples))
         | 
| 278 | 
            +
                #indices = random.choices(list(range(0, num_samples)), k=samples_per_device)
         | 
| 279 | 
            +
                indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1)))
         | 
| 280 | 
            +
                prompt_list_rank = [prompt_list[i] for i in indices]
         | 
| 281 | 
            +
                data_list_rank = [data_list[i] for i in indices]
         | 
| 282 | 
            +
                filename_list_rank = [filename_list[i] for i in indices]
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                start = time.time()
         | 
| 285 | 
            +
                with torch.no_grad(), torch.cuda.amp.autocast():
         | 
| 286 | 
            +
                    for idx, indice in tqdm(enumerate(range(0, len(prompt_list_rank), args.bs)), desc='Sample Batch'):
         | 
| 287 | 
            +
                        prompts = prompt_list_rank[indice:indice+args.bs]
         | 
| 288 | 
            +
                        videos = data_list_rank[indice:indice+args.bs]
         | 
| 289 | 
            +
                        filenames = filename_list_rank[indice:indice+args.bs]
         | 
| 290 | 
            +
                        if isinstance(videos, list):
         | 
| 291 | 
            +
                            videos = torch.stack(videos, dim=0).to("cuda")
         | 
| 292 | 
            +
                        else:
         | 
| 293 | 
            +
                            videos = videos.unsqueeze(0).to("cuda")
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                        batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \
         | 
| 296 | 
            +
                                            args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.gfi, args.timestep_spacing, args.guidance_rescale)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                        ## save each example individually
         | 
| 299 | 
            +
                        for nn, samples in enumerate(batch_samples):
         | 
| 300 | 
            +
                            ## samples : [n_samples,c,t,h,w]
         | 
| 301 | 
            +
                            prompt = prompts[nn]
         | 
| 302 | 
            +
                            filename = filenames[nn]
         | 
| 303 | 
            +
                            # save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
         | 
| 304 | 
            +
                            save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
         | 
| 307 | 
            +
             | 
| 308 | 
            +
             | 
| 309 | 
            +
            def get_parser():
         | 
| 310 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 311 | 
            +
                parser.add_argument("--savedir", type=str, default=None, help="results saving path")
         | 
| 312 | 
            +
                parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
         | 
| 313 | 
            +
                parser.add_argument("--config", type=str, help="config (yaml) path")
         | 
| 314 | 
            +
                parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts")
         | 
| 315 | 
            +
                parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
         | 
| 316 | 
            +
                parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
         | 
| 317 | 
            +
                parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
         | 
| 318 | 
            +
                parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one")
         | 
| 319 | 
            +
                parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
         | 
| 320 | 
            +
                parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
         | 
| 321 | 
            +
                parser.add_argument("--frame_stride", type=int, default=3, help="frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)")
         | 
| 322 | 
            +
                parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
         | 
| 323 | 
            +
                parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything")
         | 
| 324 | 
            +
                parser.add_argument("--video_length", type=int, default=16, help="inference video length")
         | 
| 325 | 
            +
                parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt")
         | 
| 326 | 
            +
                parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not")
         | 
| 327 | 
            +
                parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not")
         | 
| 328 | 
            +
                parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning")
         | 
| 329 | 
            +
                parser.add_argument("--timestep_spacing", type=str, default="uniform", help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.")
         | 
| 330 | 
            +
                parser.add_argument("--guidance_rescale", type=float, default=0.0, help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)")
         | 
| 331 | 
            +
                parser.add_argument("--perframe_ae", action='store_true', default=False, help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024")
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                ## currently not support looping video and generative frame interpolation
         | 
| 334 | 
            +
                parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not")
         | 
| 335 | 
            +
                parser.add_argument("--gfi", action='store_true', default=False, help="generate generative frame interpolation (gfi) or not")
         | 
| 336 | 
            +
                return parser
         | 
| 337 | 
            +
             | 
| 338 | 
            +
             | 
| 339 | 
            +
            if __name__ == '__main__':
         | 
| 340 | 
            +
                now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
         | 
| 341 | 
            +
                print("@DynamiCrafter cond-Inference: %s"%now)
         | 
| 342 | 
            +
                parser = get_parser()
         | 
| 343 | 
            +
                args = parser.parse_args()
         | 
| 344 | 
            +
                
         | 
| 345 | 
            +
                seed_everything(args.seed)
         | 
| 346 | 
            +
                rank, gpu_num = 0, 1
         | 
| 347 | 
             
                run_inference(args, gpu_num, rank)
         | 
    	
        scripts/gradio/__pycache__/i2v_test.cpython-39.pyc
    CHANGED
    
    | Binary files a/scripts/gradio/__pycache__/i2v_test.cpython-39.pyc and b/scripts/gradio/__pycache__/i2v_test.cpython-39.pyc differ | 
|  | 
    	
        scripts/gradio/i2v_test.py
    CHANGED
    
    | @@ -1,102 +1,106 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            import time
         | 
| 3 | 
            -
            from omegaconf import OmegaConf
         | 
| 4 | 
            -
            import torch
         | 
| 5 | 
            -
            from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
         | 
| 6 | 
            -
            from utils.utils import instantiate_from_config
         | 
| 7 | 
            -
            from huggingface_hub import hf_hub_download
         | 
| 8 | 
            -
            from einops import repeat
         | 
| 9 | 
            -
            import torchvision.transforms as transforms
         | 
| 10 | 
            -
            from pytorch_lightning import seed_everything
         | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
            class Image2Video():
         | 
| 14 | 
            -
                def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
         | 
| 15 | 
            -
                    self. | 
| 16 | 
            -
                    self. | 
| 17 | 
            -
                     | 
| 18 | 
            -
             | 
| 19 | 
            -
                     | 
| 20 | 
            -
             | 
| 21 | 
            -
                     | 
| 22 | 
            -
                     | 
| 23 | 
            -
                     | 
| 24 | 
            -
                     | 
| 25 | 
            -
                     | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
                         | 
| 29 | 
            -
                        model =  | 
| 30 | 
            -
                         | 
| 31 | 
            -
                         | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
                         | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
                     | 
| 44 | 
            -
                     | 
| 45 | 
            -
                     | 
| 46 | 
            -
             | 
| 47 | 
            -
                     | 
| 48 | 
            -
             | 
| 49 | 
            -
                     | 
| 50 | 
            -
                     | 
| 51 | 
            -
                     | 
| 52 | 
            -
                     | 
| 53 | 
            -
                     | 
| 54 | 
            -
             | 
| 55 | 
            -
                     | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
                     | 
| 59 | 
            -
             | 
| 60 | 
            -
                     | 
| 61 | 
            -
             | 
| 62 | 
            -
                     | 
| 63 | 
            -
             | 
| 64 | 
            -
                    
         | 
| 65 | 
            -
                     | 
| 66 | 
            -
                    
         | 
| 67 | 
            -
                     | 
| 68 | 
            -
             | 
| 69 | 
            -
                     | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
                     | 
| 73 | 
            -
             | 
| 74 | 
            -
                     | 
| 75 | 
            -
             | 
| 76 | 
            -
                    
         | 
| 77 | 
            -
                     | 
| 78 | 
            -
                     | 
| 79 | 
            -
                    ##  | 
| 80 | 
            -
                     | 
| 81 | 
            -
                     | 
| 82 | 
            -
                    prompt_str= | 
| 83 | 
            -
             | 
| 84 | 
            -
                     | 
| 85 | 
            -
                     | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
                     | 
| 91 | 
            -
                     | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
                     | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 102 | 
             
                print('done', video_path)
         | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            from omegaconf import OmegaConf
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
         | 
| 6 | 
            +
            from utils.utils import instantiate_from_config
         | 
| 7 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 8 | 
            +
            from einops import repeat
         | 
| 9 | 
            +
            import torchvision.transforms as transforms
         | 
| 10 | 
            +
            from pytorch_lightning import seed_everything
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class Image2Video():
         | 
| 14 | 
            +
                def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None:
         | 
| 15 | 
            +
                    self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
         | 
| 16 | 
            +
                    self.download_model()
         | 
| 17 | 
            +
                    
         | 
| 18 | 
            +
                    self.result_dir = result_dir
         | 
| 19 | 
            +
                    if not os.path.exists(self.result_dir):
         | 
| 20 | 
            +
                        os.mkdir(self.result_dir)
         | 
| 21 | 
            +
                    ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_v1/model.ckpt'
         | 
| 22 | 
            +
                    config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
         | 
| 23 | 
            +
                    config = OmegaConf.load(config_file)
         | 
| 24 | 
            +
                    model_config = config.pop("model", OmegaConf.create())
         | 
| 25 | 
            +
                    model_config['params']['unet_config']['params']['use_checkpoint']=False   
         | 
| 26 | 
            +
                    model_list = []
         | 
| 27 | 
            +
                    for gpu_id in range(gpu_num):
         | 
| 28 | 
            +
                        model = instantiate_from_config(model_config)
         | 
| 29 | 
            +
                        # model = model.cuda(gpu_id)
         | 
| 30 | 
            +
                        assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
         | 
| 31 | 
            +
                        model = load_model_checkpoint(model, ckpt_path)
         | 
| 32 | 
            +
                        model.eval()
         | 
| 33 | 
            +
                        model_list.append(model)
         | 
| 34 | 
            +
                    self.model_list = model_list
         | 
| 35 | 
            +
                    self.save_fps = 8
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
         | 
| 38 | 
            +
                    seed_everything(seed)
         | 
| 39 | 
            +
                    transform = transforms.Compose([
         | 
| 40 | 
            +
                        transforms.Resize(min(self.resolution)),
         | 
| 41 | 
            +
                        transforms.CenterCrop(self.resolution),
         | 
| 42 | 
            +
                        ])
         | 
| 43 | 
            +
                    torch.cuda.empty_cache()
         | 
| 44 | 
            +
                    print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
         | 
| 45 | 
            +
                    start = time.time()
         | 
| 46 | 
            +
                    gpu_id=0
         | 
| 47 | 
            +
                    if steps > 60:
         | 
| 48 | 
            +
                        steps = 60 
         | 
| 49 | 
            +
                    model = self.model_list[gpu_id]
         | 
| 50 | 
            +
                    model = model.cuda()
         | 
| 51 | 
            +
                    batch_size=1
         | 
| 52 | 
            +
                    channels = model.model.diffusion_model.out_channels
         | 
| 53 | 
            +
                    frames = model.temporal_length
         | 
| 54 | 
            +
                    h, w = self.resolution[0] // 8, self.resolution[1] // 8
         | 
| 55 | 
            +
                    noise_shape = [batch_size, channels, frames, h, w]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    # text cond
         | 
| 58 | 
            +
                    text_emb = model.get_learned_conditioning([prompt])
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # img cond
         | 
| 61 | 
            +
                    img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
         | 
| 62 | 
            +
                    img_tensor = (img_tensor / 255. - 0.5) * 2
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    image_tensor_resized = transform(img_tensor) #3,h,w
         | 
| 65 | 
            +
                    videos = image_tensor_resized.unsqueeze(0) # bchw
         | 
| 66 | 
            +
                    
         | 
| 67 | 
            +
                    z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                    img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
         | 
| 72 | 
            +
                    img_emb = model.image_proj_model(cond_images)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    imtext_cond = torch.cat([text_emb, img_emb], dim=1)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    fs = torch.tensor([fs], dtype=torch.long, device=model.device)
         | 
| 77 | 
            +
                    cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    ## inference
         | 
| 80 | 
            +
                    batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
         | 
| 81 | 
            +
                    ## b,samples,c,t,h,w
         | 
| 82 | 
            +
                    prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
         | 
| 83 | 
            +
                    prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
         | 
| 84 | 
            +
                    prompt_str=prompt_str[:40]
         | 
| 85 | 
            +
                    if len(prompt_str) == 0:
         | 
| 86 | 
            +
                        prompt_str = 'empty_prompt'
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
         | 
| 89 | 
            +
                    print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
         | 
| 90 | 
            +
                    model = model.cpu()
         | 
| 91 | 
            +
                    return os.path.join(self.result_dir, f"{prompt_str}.mp4")
         | 
| 92 | 
            +
                
         | 
| 93 | 
            +
                def download_model(self):
         | 
| 94 | 
            +
                    REPO_ID = 'Doubiiu/DynamiCrafter'
         | 
| 95 | 
            +
                    filename_list = ['model.ckpt']
         | 
| 96 | 
            +
                    if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/'):
         | 
| 97 | 
            +
                        os.makedirs('./dynamicrafter_'+str(self.resolution[1])+'_v1/')
         | 
| 98 | 
            +
                    for filename in filename_list:
         | 
| 99 | 
            +
                        local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', filename)
         | 
| 100 | 
            +
                        if not os.path.exists(local_file):
         | 
| 101 | 
            +
                            hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', local_dir_use_symlinks=False)
         | 
| 102 | 
            +
                
         | 
| 103 | 
            +
            if __name__ == '__main__':
         | 
| 104 | 
            +
                i2v = Image2Video()
         | 
| 105 | 
            +
                video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
         | 
| 106 | 
             
                print('done', video_path)
         | 
    	
        scripts/run.sh
    CHANGED
    
    | @@ -1,25 +1,61 @@ | |
| 1 | 
            -
             | 
|  | |
|  | |
| 2 |  | 
| 3 | 
            -
            ckpt= | 
| 4 | 
            -
            config= | 
| 5 |  | 
| 6 | 
            -
            prompt_dir= | 
| 7 | 
             
            res_dir="results"
         | 
| 8 |  | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 11 | 
             
            --ckpt_path $ckpt \
         | 
| 12 | 
             
            --config $config \
         | 
| 13 | 
             
            --savedir $res_dir/$name \
         | 
| 14 | 
             
            --n_samples 1 \
         | 
| 15 | 
            -
            --bs 1 --height  | 
| 16 | 
             
            --unconditional_guidance_scale 7.5 \
         | 
| 17 | 
             
            --ddim_steps 50 \
         | 
| 18 | 
             
            --ddim_eta 1.0 \
         | 
| 19 | 
             
            --prompt_dir $prompt_dir \
         | 
| 20 | 
             
            --text_input \
         | 
| 21 | 
             
            --video_length 16 \
         | 
| 22 | 
            -
            --frame_stride  | 
|  | |
|  | |
|  | |
| 23 |  | 
| 24 | 
             
            ## multi-cond CFG: the <unconditional_guidance_scale> is s_txt, <cfg_img> is s_img
         | 
| 25 | 
            -
            #--multiple_cond_cfg --cfg_img 7.5
         | 
|  | 
|  | |
| 1 | 
            +
            version=$1 ##1024, 512, 256
         | 
| 2 | 
            +
            seed=123
         | 
| 3 | 
            +
            name=dynamicrafter_$1_seed${seed}
         | 
| 4 |  | 
| 5 | 
            +
            ckpt=checkpoints/dynamicrafter_$1_v1/model.ckpt
         | 
| 6 | 
            +
            config=configs/inference_$1_v1.0.yaml
         | 
| 7 |  | 
| 8 | 
            +
            prompt_dir=prompts/$1/
         | 
| 9 | 
             
            res_dir="results"
         | 
| 10 |  | 
| 11 | 
            +
            if [ "$1" == "256" ]; then
         | 
| 12 | 
            +
                H=256
         | 
| 13 | 
            +
                FS=3  ## This model adopts frame stride=3, range recommended: 1-6 (larger value -> larger motion)
         | 
| 14 | 
            +
            elif [ "$1" == "512" ]; then
         | 
| 15 | 
            +
                H=320
         | 
| 16 | 
            +
                FS=24 ## This model adopts FPS=24, range recommended: 15-30 (smaller value -> larger motion)
         | 
| 17 | 
            +
            elif [ "$1" == "1024" ]; then
         | 
| 18 | 
            +
                H=576
         | 
| 19 | 
            +
                FS=10 ## This model adopts FPS=10, range recommended: 15-5 (smaller value -> larger motion)
         | 
| 20 | 
            +
            else
         | 
| 21 | 
            +
                echo "Invalid input. Please enter 256, 512, or 1024."
         | 
| 22 | 
            +
                exit 1
         | 
| 23 | 
            +
            fi
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            if [ "$1" == "256" ]; then
         | 
| 26 | 
            +
            CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
         | 
| 27 | 
            +
            --seed ${seed} \
         | 
| 28 | 
            +
            --ckpt_path $ckpt \
         | 
| 29 | 
            +
            --config $config \
         | 
| 30 | 
            +
            --savedir $res_dir/$name \
         | 
| 31 | 
            +
            --n_samples 1 \
         | 
| 32 | 
            +
            --bs 1 --height ${H} --width $1 \
         | 
| 33 | 
            +
            --unconditional_guidance_scale 7.5 \
         | 
| 34 | 
            +
            --ddim_steps 50 \
         | 
| 35 | 
            +
            --ddim_eta 1.0 \
         | 
| 36 | 
            +
            --prompt_dir $prompt_dir \
         | 
| 37 | 
            +
            --text_input \
         | 
| 38 | 
            +
            --video_length 16 \
         | 
| 39 | 
            +
            --frame_stride ${FS}
         | 
| 40 | 
            +
            else
         | 
| 41 | 
            +
            CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
         | 
| 42 | 
            +
            --seed ${seed} \
         | 
| 43 | 
             
            --ckpt_path $ckpt \
         | 
| 44 | 
             
            --config $config \
         | 
| 45 | 
             
            --savedir $res_dir/$name \
         | 
| 46 | 
             
            --n_samples 1 \
         | 
| 47 | 
            +
            --bs 1 --height ${H} --width $1 \
         | 
| 48 | 
             
            --unconditional_guidance_scale 7.5 \
         | 
| 49 | 
             
            --ddim_steps 50 \
         | 
| 50 | 
             
            --ddim_eta 1.0 \
         | 
| 51 | 
             
            --prompt_dir $prompt_dir \
         | 
| 52 | 
             
            --text_input \
         | 
| 53 | 
             
            --video_length 16 \
         | 
| 54 | 
            +
            --frame_stride ${FS} \
         | 
| 55 | 
            +
            --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae
         | 
| 56 | 
            +
            fi
         | 
| 57 | 
            +
             | 
| 58 |  | 
| 59 | 
             
            ## multi-cond CFG: the <unconditional_guidance_scale> is s_txt, <cfg_img> is s_img
         | 
| 60 | 
            +
            #--multiple_cond_cfg --cfg_img 7.5
         | 
| 61 | 
            +
            #--loop
         | 
    	
        scripts/run_mp.sh
    ADDED
    
    | @@ -0,0 +1,102 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version=$1 ##1024, 512, 256
         | 
| 2 | 
            +
            seed=123
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            name=dynamicrafter_$1_mp_seed${seed}
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            ckpt=checkpoints/dynamicrafter_$1_v1/model.ckpt
         | 
| 7 | 
            +
            config=configs/inference_$1_v1.0.yaml
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            prompt_dir=prompts/$1/
         | 
| 10 | 
            +
            res_dir="results"
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            if [ "$1" == "256" ]; then
         | 
| 13 | 
            +
                H=256
         | 
| 14 | 
            +
                FS=3  ## This model adopts frame stride=3
         | 
| 15 | 
            +
            elif [ "$1" == "512" ]; then
         | 
| 16 | 
            +
                H=320
         | 
| 17 | 
            +
                FS=24 ## This model adopts FPS=24
         | 
| 18 | 
            +
            elif [ "$1" == "1024" ]; then
         | 
| 19 | 
            +
                H=576
         | 
| 20 | 
            +
                FS=10 ## This model adopts FPS=10
         | 
| 21 | 
            +
            else
         | 
| 22 | 
            +
                echo "Invalid input. Please enter 256, 512, or 1024."
         | 
| 23 | 
            +
                exit 1
         | 
| 24 | 
            +
            fi
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # if [ "$1" == "256" ]; then
         | 
| 27 | 
            +
            # CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
         | 
| 28 | 
            +
            # --seed 123 \
         | 
| 29 | 
            +
            # --ckpt_path $ckpt \
         | 
| 30 | 
            +
            # --config $config \
         | 
| 31 | 
            +
            # --savedir $res_dir/$name \
         | 
| 32 | 
            +
            # --n_samples 1 \
         | 
| 33 | 
            +
            # --bs 1 --height ${H} --width $1 \
         | 
| 34 | 
            +
            # --unconditional_guidance_scale 7.5 \
         | 
| 35 | 
            +
            # --ddim_steps 50 \
         | 
| 36 | 
            +
            # --ddim_eta 1.0 \
         | 
| 37 | 
            +
            # --prompt_dir $prompt_dir \
         | 
| 38 | 
            +
            # --text_input \
         | 
| 39 | 
            +
            # --video_length 16 \
         | 
| 40 | 
            +
            # --frame_stride ${FS}
         | 
| 41 | 
            +
            # else
         | 
| 42 | 
            +
            # CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
         | 
| 43 | 
            +
            # --seed 123 \
         | 
| 44 | 
            +
            # --ckpt_path $ckpt \
         | 
| 45 | 
            +
            # --config $config \
         | 
| 46 | 
            +
            # --savedir $res_dir/$name \
         | 
| 47 | 
            +
            # --n_samples 1 \
         | 
| 48 | 
            +
            # --bs 1 --height ${H} --width $1 \
         | 
| 49 | 
            +
            # --unconditional_guidance_scale 7.5 \
         | 
| 50 | 
            +
            # --ddim_steps 50 \
         | 
| 51 | 
            +
            # --ddim_eta 1.0 \
         | 
| 52 | 
            +
            # --prompt_dir $prompt_dir \
         | 
| 53 | 
            +
            # --text_input \
         | 
| 54 | 
            +
            # --video_length 16 \
         | 
| 55 | 
            +
            # --frame_stride ${FS} \
         | 
| 56 | 
            +
            # --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7
         | 
| 57 | 
            +
            # fi
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            ## multi-cond CFG: the <unconditional_guidance_scale> is s_txt, <cfg_img> is s_img
         | 
| 61 | 
            +
            #--multiple_cond_cfg --cfg_img 7.5
         | 
| 62 | 
            +
            #--loop
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            ## inference using single node with multi-GPUs:
         | 
| 65 | 
            +
            if [ "$1" == "256" ]; then
         | 
| 66 | 
            +
            CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
         | 
| 67 | 
            +
            --nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
         | 
| 68 | 
            +
            scripts/evaluation/ddp_wrapper.py \
         | 
| 69 | 
            +
            --module 'inference' \
         | 
| 70 | 
            +
            --seed ${seed} \
         | 
| 71 | 
            +
            --ckpt_path $ckpt \
         | 
| 72 | 
            +
            --config $config \
         | 
| 73 | 
            +
            --savedir $res_dir/$name \
         | 
| 74 | 
            +
            --n_samples 1 \
         | 
| 75 | 
            +
            --bs 1 --height ${H} --width $1 \
         | 
| 76 | 
            +
            --unconditional_guidance_scale 7.5 \
         | 
| 77 | 
            +
            --ddim_steps 50 \
         | 
| 78 | 
            +
            --ddim_eta 1.0 \
         | 
| 79 | 
            +
            --prompt_dir $prompt_dir \
         | 
| 80 | 
            +
            --text_input \
         | 
| 81 | 
            +
            --video_length 16 \
         | 
| 82 | 
            +
            --frame_stride ${FS}
         | 
| 83 | 
            +
            else
         | 
| 84 | 
            +
            CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
         | 
| 85 | 
            +
            --nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
         | 
| 86 | 
            +
            scripts/evaluation/ddp_wrapper.py \
         | 
| 87 | 
            +
            --module 'inference' \
         | 
| 88 | 
            +
            --seed ${seed} \
         | 
| 89 | 
            +
            --ckpt_path $ckpt \
         | 
| 90 | 
            +
            --config $config \
         | 
| 91 | 
            +
            --savedir $res_dir/$name \
         | 
| 92 | 
            +
            --n_samples 1 \
         | 
| 93 | 
            +
            --bs 1 --height ${H} --width $1 \
         | 
| 94 | 
            +
            --unconditional_guidance_scale 7.5 \
         | 
| 95 | 
            +
            --ddim_steps 50 \
         | 
| 96 | 
            +
            --ddim_eta 1.0 \
         | 
| 97 | 
            +
            --prompt_dir $prompt_dir \
         | 
| 98 | 
            +
            --text_input \
         | 
| 99 | 
            +
            --video_length 16 \
         | 
| 100 | 
            +
            --frame_stride ${FS} \
         | 
| 101 | 
            +
            --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae
         | 
| 102 | 
            +
            fi
         | 
