Spaces:
Runtime error
Runtime error
| import argparse | |
| import logging | |
| import math | |
| import os | |
| import gc | |
| import copy | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torch.utils.checkpoint | |
| import diffusers | |
| import transformers | |
| from tqdm.auto import tqdm | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from models.unet.unet_3d_condition import UNet3DConditionModel | |
| from diffusers.models import AutoencoderKL | |
| from diffusers import DDIMScheduler, TextToVideoSDPipeline | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from utils.ddim_utils import inverse_video | |
| from utils.gpu_utils import handle_memory_attention, unet_and_text_g_c | |
| from utils.func_utils import * | |
| import imageio | |
| import numpy as np | |
| from dataset import * | |
| from loss import * | |
| from noise_init import * | |
| from attn_ctrl import register_attention_control | |
| import shutil | |
| logger = get_logger(__name__, log_level="INFO") | |
| def log_validation(accelerator, config, batch, global_step, text_prompt, unet, text_encoder, vae, output_dir): | |
| with accelerator.autocast(): | |
| unet.eval() | |
| text_encoder.eval() | |
| unet_and_text_g_c(unet, text_encoder, False, False) | |
| # handle spatial lora | |
| if config.loss.type =='DebiasedHybrid': | |
| loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) | |
| for lora_i in loras: | |
| lora_i.scale = 0 | |
| pipeline = TextToVideoSDPipeline.from_pretrained( | |
| config.model.pretrained_model_path, | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| unet=unet | |
| ) | |
| prompt_list = text_prompt if len(config.val.prompt) <= 0 else config.val.prompt | |
| for seed in config.val.seeds: | |
| noisy_latent = batch['inversion_noise'] | |
| shape = noisy_latent.shape | |
| noise = torch.randn( | |
| shape, | |
| device=noisy_latent.device, | |
| generator=torch.Generator(noisy_latent.device).manual_seed(seed) | |
| ).to(noisy_latent.dtype) | |
| # handle different noise initialization strategy | |
| init_func_name = f'{config.noise_init.type}' | |
| # Assuming config.dataset is a DictConfig object | |
| init_params_dict = OmegaConf.to_container(config.noise_init, resolve=True) | |
| # Remove the 'type' key | |
| init_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist | |
| init_func_to_call = globals().get(init_func_name) | |
| init_noise = init_func_to_call(noisy_latent, noise, **init_params_dict) | |
| for prompt in prompt_list: | |
| file_name = f"{prompt.replace(' ', '_')}_seed_{seed}.mp4" | |
| file_path = f"{output_dir}/samples_{global_step}/" | |
| if not os.path.exists(file_path): | |
| os.makedirs(file_path) | |
| with torch.no_grad(): | |
| video_frames = pipeline( | |
| prompt=prompt, | |
| negative_prompt=config.val.negative_prompt, | |
| width=config.val.width, | |
| height=config.val.height, | |
| num_frames=config.val.num_frames, | |
| num_inference_steps=config.val.num_inference_steps, | |
| guidance_scale=config.val.guidance_scale, | |
| latents=init_noise, | |
| ).frames[0] | |
| export_to_video(video_frames, os.path.join(file_path, file_name), config.dataset.fps) | |
| logger.info(f"Saved a new sample to {os.path.join(file_path, file_name)}") | |
| del pipeline | |
| torch.cuda.empty_cache() | |
| def create_logging(logging, logger, accelerator): | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state, main_process_only=False) | |
| def accelerate_set_verbose(accelerator): | |
| if accelerator.is_local_main_process: | |
| transformers.utils.logging.set_verbosity_warning() | |
| diffusers.utils.logging.set_verbosity_info() | |
| else: | |
| transformers.utils.logging.set_verbosity_error() | |
| diffusers.utils.logging.set_verbosity_error() | |
| def export_to_video(video_frames, output_video_path, fps): | |
| video_writer = imageio.get_writer(output_video_path, fps=fps) | |
| for img in video_frames: | |
| video_writer.append_data(np.array(img)) | |
| video_writer.close() | |
| return output_video_path | |
| def create_output_folders(output_dir, config): | |
| out_dir = os.path.join(output_dir) | |
| os.makedirs(out_dir, exist_ok=True) | |
| OmegaConf.save(config, os.path.join(out_dir, 'config.yaml')) | |
| shutil.copyfile(config.dataset.single_video_path, os.path.join(out_dir,'source.mp4')) | |
| return out_dir | |
| def load_primary_models(pretrained_model_path): | |
| noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
| unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") | |
| return noise_scheduler, tokenizer, text_encoder, vae, unet | |
| def freeze_models(models_to_freeze): | |
| for model in models_to_freeze: | |
| if model is not None: model.requires_grad_(False) | |
| def is_mixed_precision(accelerator): | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| return weight_dtype | |
| def cast_to_gpu_and_type(model_list, accelerator, weight_dtype): | |
| for model in model_list: | |
| if model is not None: model.to(accelerator.device, dtype=weight_dtype) | |
| def handle_cache_latents( | |
| should_cache, | |
| output_dir, | |
| train_dataloader, | |
| train_batch_size, | |
| vae, | |
| unet, | |
| pretrained_model_path, | |
| cached_latent_dir=None, | |
| ): | |
| # Cache latents by storing them in VRAM. | |
| # Speeds up training and saves memory by not encoding during the train loop. | |
| if not should_cache: return None | |
| vae.to('cuda', dtype=torch.float16) | |
| vae.enable_slicing() | |
| pipe = TextToVideoSDPipeline.from_pretrained( | |
| pretrained_model_path, | |
| vae=vae, | |
| unet=copy.deepcopy(unet).to('cuda', dtype=torch.float16) | |
| ) | |
| pipe.text_encoder.to('cuda', dtype=torch.float16) | |
| cached_latent_dir = ( | |
| os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None | |
| ) | |
| if cached_latent_dir is None: | |
| cache_save_dir = f"{output_dir}/cached_latents" | |
| os.makedirs(cache_save_dir, exist_ok=True) | |
| for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")): | |
| save_name = f"cached_{i}" | |
| full_out_path = f"{cache_save_dir}/{save_name}.pt" | |
| pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16) | |
| batch['latents'] = tensor_to_vae_latent(pixel_values, vae) | |
| batch['inversion_noise'] = inverse_video(pipe, batch['latents'], 50) | |
| for k, v in batch.items(): batch[k] = v[0] | |
| torch.save(batch, full_out_path) | |
| del pixel_values | |
| del batch | |
| # We do this to avoid fragmentation from casting latents between devices. | |
| torch.cuda.empty_cache() | |
| else: | |
| cache_save_dir = cached_latent_dir | |
| return torch.utils.data.DataLoader( | |
| CachedDataset(cache_dir=cache_save_dir), | |
| batch_size=train_batch_size, | |
| shuffle=True, | |
| num_workers=0 | |
| ) | |
| def should_sample(global_step, validation_steps, validation_data): | |
| return (global_step == 1 or global_step % validation_steps == 0) and validation_data.sample_preview | |
| def save_pipe( | |
| path, | |
| global_step, | |
| accelerator, | |
| unet, | |
| text_encoder, | |
| vae, | |
| output_dir, | |
| is_checkpoint=False, | |
| save_pretrained_model=False, | |
| **extra_params | |
| ): | |
| if is_checkpoint: | |
| save_path = os.path.join(output_dir, f"checkpoint-{global_step}") | |
| os.makedirs(save_path, exist_ok=True) | |
| else: | |
| save_path = output_dir | |
| # Save the dtypes so we can continue training at the same precision. | |
| u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype | |
| # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled. | |
| unet_out = copy.deepcopy(accelerator.unwrap_model(unet.cpu(), keep_fp32_wrapper=False)) | |
| text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder.cpu(), keep_fp32_wrapper=False)) | |
| pipeline = TextToVideoSDPipeline.from_pretrained( | |
| path, | |
| unet=unet_out, | |
| text_encoder=text_encoder_out, | |
| vae=vae, | |
| ).to(torch_dtype=torch.float32) | |
| lora_managers_spatial = extra_params.get('lora_managers_spatial', [None]) | |
| lora_manager_spatial = lora_managers_spatial[-1] | |
| if lora_manager_spatial is not None: | |
| lora_manager_spatial.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/spatial', step=global_step) | |
| save_motion_embeddings(unet_out, os.path.join(save_path, 'motion_embed.pt')) | |
| if save_pretrained_model: | |
| pipeline.save_pretrained(save_path) | |
| if is_checkpoint: | |
| unet, text_encoder = accelerator.prepare(unet, text_encoder) | |
| models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)] | |
| [x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back] | |
| logger.info(f"Saved model at {save_path} on step {global_step}") | |
| del pipeline | |
| del unet_out | |
| del text_encoder_out | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def main(config): | |
| # Initialize the Accelerator | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=config.train.gradient_accumulation_steps, | |
| mixed_precision=config.train.mixed_precision, | |
| log_with=config.train.logger_type, | |
| project_dir=config.train.output_dir | |
| ) | |
| video_path = config.dataset.single_video_path | |
| cap = cv2.VideoCapture(video_path) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = 8 | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| config.dataset.width = width | |
| config.dataset.height = height | |
| config.dataset.fps = fps | |
| config.dataset.n_sample_frames = frame_count | |
| config.dataset.single_video_path = video_path | |
| config.val.width = width | |
| config.val.height = height | |
| config.val.num_frames = frame_count | |
| # Create output directories and set up logging | |
| if accelerator.is_main_process: | |
| output_dir = create_output_folders(config.train.output_dir, config) | |
| create_logging(logging, logger, accelerator) | |
| accelerate_set_verbose(accelerator) | |
| # Load primary models | |
| noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(config.model.pretrained_model_path) | |
| # Load videoCrafter2 unet for better video quality, if needed | |
| if config.model.unet == 'videoCrafter2': | |
| unet = UNet3DConditionModel.from_pretrained("/hpc2hdd/home/lwang592/ziyang/cache/videocrafterv2",subfolder='unet') | |
| elif config.model.unet == 'zeroscope_v2_576w': | |
| # by default, we use zeroscope_v2_576w, thus this unet is already loaded | |
| pass | |
| else: | |
| raise ValueError("Invalid UNet model") | |
| freeze_models([vae, text_encoder]) | |
| handle_memory_attention(unet) | |
| train_dataloader, train_dataset = prepare_data(config, tokenizer) | |
| # Handle latents caching | |
| cached_data_loader = handle_cache_latents( | |
| config.train.cache_latents, | |
| output_dir, | |
| train_dataloader, | |
| config.train.train_batch_size, | |
| vae, | |
| unet, | |
| config.model.pretrained_model_path, | |
| config.train.cached_latent_dir, | |
| ) | |
| if cached_data_loader is not None: | |
| train_dataloader = cached_data_loader | |
| # Prepare parameters and optimization | |
| params, extra_params = prepare_params(unet, config, train_dataset) | |
| optimizers, lr_schedulers = prepare_optimizers(params, config, **extra_params) | |
| # Prepare models and data for training | |
| unet, optimizers, train_dataloader, lr_schedulers, text_encoder = accelerator.prepare( | |
| unet, optimizers, train_dataloader, lr_schedulers, text_encoder | |
| ) | |
| # Additional model setups | |
| unet_and_text_g_c(unet, text_encoder) | |
| vae.enable_slicing() | |
| # Setup for mixed precision training | |
| weight_dtype = is_mixed_precision(accelerator) | |
| cast_to_gpu_and_type([text_encoder, vae], accelerator, weight_dtype) | |
| # Recalculate training steps and epochs | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.train.gradient_accumulation_steps) | |
| num_train_epochs = math.ceil(config.train.max_train_steps / num_update_steps_per_epoch) | |
| # Initialize trackers and store configuration | |
| if accelerator.is_main_process: | |
| accelerator.init_trackers("motion-inversion") | |
| # Train! | |
| total_batch_size = config.train.train_batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(train_dataset)}") | |
| logger.info(f" Num Epochs = {num_train_epochs}") | |
| logger.info(f" Instantaneous batch size per device = {config.train.train_batch_size}") | |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
| logger.info(f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") | |
| logger.info(f" Total optimization steps = {config.train.max_train_steps}") | |
| global_step = 0 | |
| first_epoch = 0 | |
| # Only show the progress bar once on each machine. | |
| progress_bar = tqdm(range(global_step, config.train.max_train_steps), disable=not accelerator.is_local_main_process) | |
| progress_bar.set_description("Steps") | |
| # Register the attention control, for Motion Value Embedding(s) | |
| register_attention_control(unet, config=config) | |
| for epoch in range(first_epoch, num_train_epochs): | |
| train_loss_temporal = 0.0 | |
| for step, batch in enumerate(train_dataloader): | |
| # Skip steps until we reach the resumed step | |
| if config.train.resume_from_checkpoint and epoch == first_epoch and step < config.train.resume_step: | |
| if step % config.train.gradient_accumulation_steps == 0: | |
| progress_bar.update(1) | |
| continue | |
| with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): | |
| for optimizer in optimizers: | |
| optimizer.zero_grad(set_to_none=True) | |
| with accelerator.autocast(): | |
| if global_step == 0: | |
| unet.train() | |
| loss_func_to_call = globals().get(f'{config.loss.type}') | |
| loss_temporal, train_loss_temporal = loss_func_to_call( | |
| train_loss_temporal, | |
| accelerator, | |
| optimizers, | |
| lr_schedulers, | |
| unet, | |
| vae, | |
| text_encoder, | |
| noise_scheduler, | |
| batch, | |
| step, | |
| config | |
| ) | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if accelerator.sync_gradients: | |
| progress_bar.update(1) | |
| global_step += 1 | |
| accelerator.log({"train_loss": train_loss_temporal}, step=global_step) | |
| train_loss_temporal = 0.0 | |
| if global_step % config.train.checkpointing_steps == 0 and global_step > 0: | |
| save_pipe( | |
| config.model.pretrained_model_path, | |
| global_step, | |
| accelerator, | |
| unet, | |
| text_encoder, | |
| vae, | |
| output_dir, | |
| is_checkpoint=True, | |
| **extra_params | |
| ) | |
| if loss_temporal is not None: | |
| accelerator.log({"loss_temporal": loss_temporal.detach().item()}, step=step) | |
| if global_step >= config.train.max_train_steps: | |
| break | |
| # Create the pipeline using the trained modules and save it. | |
| accelerator.wait_for_everyone() | |
| accelerator.end_training() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default='configs/config.yaml') | |
| parser.add_argument("--single_video_path", type=str) | |
| parser.add_argument("--prompts", type=str, help="JSON string of prompts") | |
| args = parser.parse_args() | |
| # Load and merge configurations | |
| config = OmegaConf.load(args.config) | |
| # Update the config with the command-line arguments | |
| if args.single_video_path: | |
| config.dataset.single_video_path = args.single_video_path | |
| # Set the output dir | |
| config.train.output_dir = os.path.join(config.train.output_dir, os.path.basename(args.single_video_path).split('.')[0]) | |
| if args.prompts: | |
| config.val.prompt = json.loads(args.prompts) | |
| main(config) | |