diff --git a/app.py b/app.py index 719ba5efb4be785d411488fb307f1f467779a4a6..21c4bfbd3669a90bee4bf6b69ab7025343a35dd4 100644 --- a/app.py +++ b/app.py @@ -1,27 +1,50 @@ -import time +import time -from easyanimate.api.api import infer_forward_api, update_diffusion_transformer_api, update_edition_api -from easyanimate.ui.ui import ui_modelscope, ui_eas, ui +import torch + +from easyanimate.api.api import (infer_forward_api, + update_diffusion_transformer_api, + update_edition_api) +from easyanimate.ui.ui import ui, ui_eas, ui_modelscope if __name__ == "__main__": # Choose the ui mode ui_mode = "eas" + + # GPU memory mode, which can be choosen in ["model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"]. + # "model_cpu_offload" means that the entire model will be moved to the CPU after use, which can save some GPU memory. + # + # "model_cpu_offload_and_qfloat8" indicates that the entire model will be moved to the CPU after use, + # and the transformer model has been quantized to float8, which can save more GPU memory. + # + # "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use, + # resulting in slower speeds but saving a large amount of GPU memory. + GPU_memory_mode = "model_cpu_offload_and_qfloat8" + # Use torch.float16 if GPU does not support torch.bfloat16 + # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 + weight_dtype = torch.bfloat16 + # Server ip server_name = "0.0.0.0" server_port = 7860 # Params below is used when ui_mode = "modelscope" - edition = "v3" - config_path = "config/easyanimate_video_slicevae_motion_module_v3.yaml" - model_name = "models/Diffusion_Transformer/EasyAnimateV3-XL-2-InP-512x512" + edition = "v5" + # Config + config_path = "config/easyanimate_video_v5_magvit_multi_text_encoder.yaml" + # Model path of the pretrained model + model_name = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP" + # "Inpaint" or "Control" + model_type = "Inpaint" + # Save dir savedir_sample = "samples" if ui_mode == "modelscope": - demo, controller = ui_modelscope(edition, config_path, model_name, savedir_sample) + demo, controller = ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, GPU_memory_mode, weight_dtype) elif ui_mode == "eas": demo, controller = ui_eas(edition, config_path, model_name, savedir_sample) else: - demo, controller = ui() + demo, controller = ui(GPU_memory_mode, weight_dtype) # launch gradio app, _, _ = demo.queue(status_update_rate=1).launch( diff --git a/config/easyanimate_image_magvit_v2.yaml b/config/easyanimate_image_magvit_v2.yaml deleted file mode 100644 index 8a781081885f300a557b28c9eeb30afa78cc8112..0000000000000000000000000000000000000000 --- a/config/easyanimate_image_magvit_v2.yaml +++ /dev/null @@ -1,8 +0,0 @@ -noise_scheduler_kwargs: - beta_start: 0.0001 - beta_end: 0.02 - beta_schedule: "linear" - steps_offset: 1 - -vae_kwargs: - enable_magvit: true \ No newline at end of file diff --git a/config/easyanimate_image_normal_v1.yaml b/config/easyanimate_image_normal_v1.yaml deleted file mode 100644 index 8b926c1c15586e94c5458f80f8468b6651327ccb..0000000000000000000000000000000000000000 --- a/config/easyanimate_image_normal_v1.yaml +++ /dev/null @@ -1,8 +0,0 @@ -noise_scheduler_kwargs: - beta_start: 0.0001 - beta_end: 0.02 - beta_schedule: "linear" - steps_offset: 1 - -vae_kwargs: - enable_magvit: false \ No newline at end of file diff --git a/config/easyanimate_image_slicevae_v3.yaml b/config/easyanimate_image_slicevae_v3.yaml deleted file mode 100644 index e41b63d64e605a70ef6be20f309e0c383a177495..0000000000000000000000000000000000000000 --- a/config/easyanimate_image_slicevae_v3.yaml +++ /dev/null @@ -1,9 +0,0 @@ -noise_scheduler_kwargs: - beta_start: 0.0001 - beta_end: 0.02 - beta_schedule: "linear" - steps_offset: 1 - -vae_kwargs: - enable_magvit: true - slice_compression_vae: true \ No newline at end of file diff --git a/config/easyanimate_video_casual_motion_module_v1.yaml b/config/easyanimate_video_casual_motion_module_v1.yaml deleted file mode 100644 index 4ed53304ec0f78b764dfc671d2efb7e7af91cf1c..0000000000000000000000000000000000000000 --- a/config/easyanimate_video_casual_motion_module_v1.yaml +++ /dev/null @@ -1,27 +0,0 @@ -transformer_additional_kwargs: - patch_3d: false - fake_3d: false - casual_3d: true - casual_3d_upsampler_index: [16, 20] - time_patch_size: 4 - basic_block_type: "motionmodule" - time_position_encoding_before_transformer: false - motion_module_type: "VanillaGrid" - - motion_module_kwargs: - num_attention_heads: 8 - num_transformer_block: 1 - attention_block_types: [ "Temporal_Self", "Temporal_Self" ] - temporal_position_encoding: true - temporal_position_encoding_max_len: 4096 - temporal_attention_dim_div: 1 - block_size: 2 - -noise_scheduler_kwargs: - beta_start: 0.0001 - beta_end: 0.02 - beta_schedule: "linear" - steps_offset: 1 - -vae_kwargs: - enable_magvit: false \ No newline at end of file diff --git a/config/easyanimate_video_long_sequence_v1.yaml b/config/easyanimate_video_long_sequence_v1.yaml deleted file mode 100644 index 0538352aaf2a11faa1bd374b78dcf47435c8df37..0000000000000000000000000000000000000000 --- a/config/easyanimate_video_long_sequence_v1.yaml +++ /dev/null @@ -1,14 +0,0 @@ -transformer_additional_kwargs: - patch_3d: false - fake_3d: false - basic_block_type: "selfattentiontemporal" - time_position_encoding_before_transformer: true - -noise_scheduler_kwargs: - beta_start: 0.0001 - beta_end: 0.02 - beta_schedule: "linear" - steps_offset: 1 - -vae_kwargs: - enable_magvit: false \ No newline at end of file diff --git a/config/easyanimate_video_slicevae_motion_module_v3.yaml b/config/easyanimate_video_slicevae_motion_module_v3.yaml deleted file mode 100644 index e0e1bac132accd6bf5905ee59643004e2c401e9f..0000000000000000000000000000000000000000 --- a/config/easyanimate_video_slicevae_motion_module_v3.yaml +++ /dev/null @@ -1,27 +0,0 @@ -transformer_additional_kwargs: - patch_3d: false - fake_3d: false - basic_block_type: "motionmodule" - time_position_encoding_before_transformer: false - motion_module_type: "Vanilla" - enable_uvit: true - - motion_module_kwargs: - num_attention_heads: 8 - num_transformer_block: 1 - attention_block_types: [ "Temporal_Self", "Temporal_Self" ] - temporal_position_encoding: true - temporal_position_encoding_max_len: 4096 - temporal_attention_dim_div: 1 - block_size: 1 - -noise_scheduler_kwargs: - beta_start: 0.0001 - beta_end: 0.02 - beta_schedule: "linear" - steps_offset: 1 - -vae_kwargs: - enable_magvit: true - slice_compression_vae: true - mini_batch_encoder: 8 \ No newline at end of file diff --git a/config/easyanimate_video_motion_module_v1.yaml b/config/easyanimate_video_v1_motion_module.yaml similarity index 82% rename from config/easyanimate_video_motion_module_v1.yaml rename to config/easyanimate_video_v1_motion_module.yaml index add62459d6ada289f2f8c61e571876d5f99ec5ac..f7b46d963315532dab15f6a73554ac883f16ef16 100644 --- a/config/easyanimate_video_motion_module_v1.yaml +++ b/config/easyanimate_video_v1_motion_module.yaml @@ -1,4 +1,5 @@ transformer_additional_kwargs: + transformer_type: "Transformer3DModel" patch_3d: false fake_3d: false basic_block_type: "motionmodule" @@ -14,11 +15,8 @@ transformer_additional_kwargs: temporal_attention_dim_div: 1 block_size: 2 -noise_scheduler_kwargs: - beta_start: 0.0001 - beta_end: 0.02 - beta_schedule: "linear" - steps_offset: 1 - vae_kwargs: - enable_magvit: false \ No newline at end of file + vae_type: "AutoencoderKL" + +text_encoder_kwargs: + enable_multi_text_encoder: false \ No newline at end of file diff --git a/config/easyanimate_video_magvit_motion_module_v2.yaml b/config/easyanimate_video_v2_magvit_motion_module.yaml similarity index 70% rename from config/easyanimate_video_magvit_motion_module_v2.yaml rename to config/easyanimate_video_v2_magvit_motion_module.yaml index 723ad1e0aee1a8c3c6c77c3ce0549a0ea4babb90..5f68faf17944e6ae74b607264c984bf59e264438 100644 --- a/config/easyanimate_video_magvit_motion_module_v2.yaml +++ b/config/easyanimate_video_v2_magvit_motion_module.yaml @@ -1,4 +1,5 @@ transformer_additional_kwargs: + transformer_type: "Transformer3DModel" patch_3d: false fake_3d: false basic_block_type: "motionmodule" @@ -15,12 +16,14 @@ transformer_additional_kwargs: temporal_attention_dim_div: 1 block_size: 1 -noise_scheduler_kwargs: - beta_start: 0.0001 - beta_end: 0.02 - beta_schedule: "linear" - steps_offset: 1 - vae_kwargs: - enable_magvit: true - mini_batch_encoder: 9 \ No newline at end of file + vae_type: "AutoencoderKLMagvit" + mini_batch_encoder: 9 + mini_batch_decoder: 3 + slice_mag_vae: true + slice_compression_vae: false + cache_compression_vae: false + cache_mag_vae: false + +text_encoder_kwargs: + enable_multi_text_encoder: false \ No newline at end of file diff --git a/config/easyanimate_video_v3_slicevae_motion_module.yaml b/config/easyanimate_video_v3_slicevae_motion_module.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ff9b97bbf449ca0be82f843233211b19cf792ef --- /dev/null +++ b/config/easyanimate_video_v3_slicevae_motion_module.yaml @@ -0,0 +1,39 @@ +transformer_additional_kwargs: + transformer_type: "Transformer3DModel" + patch_3d: false + fake_3d: false + basic_block_type: "global_motionmodule" + time_position_encoding_before_transformer: false + motion_module_type: "Vanilla" + enable_uvit: true + + motion_module_kwargs_even: + num_attention_heads: 16 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self", "Temporal_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 4096 + temporal_attention_dim_div: 1 + block_size: 1 + remove_time_embedding_in_photo: false + motion_module_kwargs_odd: + num_attention_heads: 16 + num_transformer_block: 1 + attention_block_types: [ "Temporal_Self", "Global_Self" ] + temporal_position_encoding: true + temporal_position_encoding_max_len: 4096 + temporal_attention_dim_div: 1 + block_size: 1 + remove_time_embedding_in_photo: false + +vae_kwargs: + vae_type: "AutoencoderKLMagvit" + mini_batch_encoder: 8 + mini_batch_decoder: 2 + slice_mag_vae: false + slice_compression_vae: true + cache_compression_vae: false + cache_mag_vae: false + +text_encoder_kwargs: + enable_multi_text_encoder: false \ No newline at end of file diff --git a/config/easyanimate_video_v4_slicevae_multi_text_encoder.yaml b/config/easyanimate_video_v4_slicevae_multi_text_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..63ea9c89bbe74aab8b1d9e615838a5c28766aa5e --- /dev/null +++ b/config/easyanimate_video_v4_slicevae_multi_text_encoder.yaml @@ -0,0 +1,20 @@ +transformer_additional_kwargs: + transformer_type: "HunyuanTransformer3DModel" + basic_block_type: "basic" + after_norm: false + time_position_encoding_type: "2d_rope" + time_position_encoding: true + resize_inpaint_mask_directly: false + enable_clip_in_inpaint: true + +vae_kwargs: + vae_type: "AutoencoderKLMagvit" + mini_batch_encoder: 8 + mini_batch_decoder: 2 + slice_mag_vae: false + slice_compression_vae: false + cache_compression_vae: true + cache_mag_vae: false + +text_encoder_kwargs: + enable_multi_text_encoder: true \ No newline at end of file diff --git a/config/easyanimate_video_v5_magvit_multi_text_encoder.yaml b/config/easyanimate_video_v5_magvit_multi_text_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e9d9ad7060fa3ee0e1497758c207f5ed1b9d3078 --- /dev/null +++ b/config/easyanimate_video_v5_magvit_multi_text_encoder.yaml @@ -0,0 +1,19 @@ +transformer_additional_kwargs: + transformer_type: "EasyAnimateTransformer3DModel" + after_norm: false + time_position_encoding_type: "3d_rope" + resize_inpaint_mask_directly: true + enable_text_attention_mask: false + enable_clip_in_inpaint: false + +vae_kwargs: + vae_type: "AutoencoderKLMagvit" + mini_batch_encoder: 4 + mini_batch_decoder: 1 + slice_mag_vae: false + slice_compression_vae: false + cache_compression_vae: false + cache_mag_vae: true + +text_encoder_kwargs: + enable_multi_text_encoder: true \ No newline at end of file diff --git a/config/zero_stage2_config.json b/config/zero_stage2_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e60ea05a563827d4286955c1421e82d3b1bbe5cc --- /dev/null +++ b/config/zero_stage2_config.json @@ -0,0 +1,16 @@ +{ + "bf16": { + "enabled": true + }, + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "dump_state": true, + "zero_optimization": { + "stage": 2, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 5e8 + } +} \ No newline at end of file diff --git a/easyanimate/api/api.py b/easyanimate/api/api.py index 73cf9af9d3b0916d9563b879b36a5b5a999aef20..2168416340bf2eb0db447cb96ff639ff39a8eb9c 100644 --- a/easyanimate/api/api.py +++ b/easyanimate/api/api.py @@ -1,15 +1,17 @@ -import io -import gc import base64 -import torch -import gradio as gr -import tempfile +import gc import hashlib +import io +import os +import tempfile +from io import BytesIO +import gradio as gr +import torch from fastapi import FastAPI -from io import BytesIO from PIL import Image + # Function to encode a file to Base64 def encode_file_to_base64(file_path): with open(file_path, "rb") as file: @@ -53,6 +55,34 @@ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller): return {"message": comment} +def save_base64_video(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.mp4" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + + return file_path + +def save_base64_image(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.jpg" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + + return file_path + def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): @app.post("/easyanimate/infer_forward") def _infer_forward_api( @@ -63,7 +93,7 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): lora_model_path = datas.get('lora_model_path', 'none') lora_alpha_slider = datas.get('lora_alpha_slider', 0.55) prompt_textbox = datas.get('prompt_textbox', None) - negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion.') + negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Unclear, mutated, deformed, distorted, dark frames, fixed frames, comic book, comic book, small and indistinguishable subject.') sampler_dropdown = datas.get('sampler_dropdown', 'Euler') sample_step_slider = datas.get('sample_step_slider', 30) resize_method = datas.get('resize_method', "Generate by") @@ -72,17 +102,20 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): base_resolution = datas.get('base_resolution', 512) is_image = datas.get('is_image', False) generation_method = datas.get('generation_method', False) - length_slider = datas.get('length_slider', 144) + length_slider = datas.get('length_slider', 49) overlap_video_length = datas.get('overlap_video_length', 4) partial_video_length = datas.get('partial_video_length', 72) cfg_scale_slider = datas.get('cfg_scale_slider', 6) start_image = datas.get('start_image', None) end_image = datas.get('end_image', None) + validation_video = datas.get('validation_video', None) + validation_video_mask = datas.get('validation_video_mask', None) + control_video = datas.get('control_video', None) + denoise_strength = datas.get('denoise_strength', 0.70) seed_textbox = datas.get("seed_textbox", 43) generation_method = "Image Generation" if is_image else generation_method - temp_directory = tempfile.gettempdir() if start_image is not None: start_image = base64.b64decode(start_image) start_image = [Image.open(BytesIO(start_image))] @@ -91,6 +124,15 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): end_image = base64.b64decode(end_image) end_image = [Image.open(BytesIO(end_image))] + if validation_video is not None: + validation_video = save_base64_video(validation_video) + + if validation_video_mask is not None: + validation_video_mask = save_base64_image(validation_video_mask) + + if control_video is not None: + control_video = save_base64_video(control_video) + try: save_sample_path, comment = controller.generate( "", @@ -113,6 +155,10 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): cfg_scale_slider, start_image, end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, seed_textbox, is_api = True, ) diff --git a/easyanimate/api/post_infer.py b/easyanimate/api/post_infer.py index 11cdc6571cd8f2878966c43b57d2f06d20cb08b9..4fee31cf093519fb8c63e3f4ec01a91171823ea4 100644 --- a/easyanimate/api/post_infer.py +++ b/easyanimate/api/post_infer.py @@ -7,7 +7,6 @@ from io import BytesIO import cv2 import requests -import base64 def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'): diff --git a/easyanimate/data/dataset_image_video.py b/easyanimate/data/dataset_image_video.py index 2069c12a81b76ac3ef820316672c7d7ccb74bbaa..3065838ee9391b54bc799de902ccb92bd02e9133 100644 --- a/easyanimate/data/dataset_image_video.py +++ b/easyanimate/data/dataset_image_video.py @@ -1,24 +1,23 @@ import csv +import gc import io import json import math import os import random +from contextlib import contextmanager from threading import Thread import albumentations import cv2 -import gc import numpy as np import torch import torchvision.transforms as transforms - -from func_timeout import func_timeout, FunctionTimedOut from decord import VideoReader +from func_timeout import FunctionTimedOut, func_timeout from PIL import Image from torch.utils.data import BatchSampler, Sampler from torch.utils.data.dataset import Dataset -from contextlib import contextmanager VIDEO_READER_TIMEOUT = 20 @@ -26,9 +25,9 @@ def get_random_mask(shape): f, c, h, w = shape if f != 1: - mask_index = np.random.randint(1, 4) + mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05]) else: - mask_index = np.random.randint(1, 2) + mask_index = np.random.choice([0, 1], p = [0.2, 0.8]) mask = torch.zeros((f, 1, h, w), dtype=torch.uint8) if mask_index == 0: @@ -64,6 +63,40 @@ def get_random_mask(shape): mask_frame_before = np.random.randint(0, f // 2) mask_frame_after = np.random.randint(f // 2, f) mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1 + elif mask_index == 5: + mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8) + elif mask_index == 6: + num_frames_to_mask = random.randint(1, max(f // 2, 1)) + frames_to_mask = random.sample(range(f), num_frames_to_mask) + + for i in frames_to_mask: + block_height = random.randint(1, h // 4) + block_width = random.randint(1, w // 4) + top_left_y = random.randint(0, h - block_height) + top_left_x = random.randint(0, w - block_width) + mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1 + elif mask_index == 7: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴 + b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴 + + for i in range(h): + for j in range(w): + if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1: + mask[:, :, i, j] = 1 + elif mask_index == 8: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() + for i in range(h): + for j in range(w): + if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2: + mask[:, :, i, j] = 1 + elif mask_index == 9: + for idx in range(f): + if np.random.rand() > 0.5: + mask[idx, :, :, :] = 1 else: raise ValueError(f"The mask_index {mask_index} is not define") return mask @@ -128,19 +161,35 @@ def get_video_reader_batch(video_reader, batch_index): frames = video_reader.get_batch(batch_index).asnumpy() return frames +def resize_frame(frame, target_short_side): + h, w, _ = frame.shape + if h < w: + if target_short_side > h: + return frame + new_h = target_short_side + new_w = int(target_short_side * w / h) + else: + if target_short_side > w: + return frame + new_w = target_short_side + new_h = int(target_short_side * h / w) + + resized_frame = cv2.resize(frame, (new_w, new_h)) + return resized_frame + class ImageVideoDataset(Dataset): def __init__( - self, - ann_path, data_root=None, - video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, - image_sample_size=512, - video_repeat=0, - text_drop_ratio=-1, - enable_bucket=False, - video_length_drop_start=0.1, - video_length_drop_end=0.9, - enable_inpaint=False, - ): + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + image_sample_size=512, + video_repeat=0, + text_drop_ratio=-1, + enable_bucket=False, + video_length_drop_start=0.1, + video_length_drop_end=0.9, + enable_inpaint=False, + ): # Loading annotations from files print(f"loading annotations from {ann_path} ...") if ann_path.endswith('.csv'): @@ -176,11 +225,11 @@ class ImageVideoDataset(Dataset): # Video params self.video_sample_stride = video_sample_stride self.video_sample_n_frames = video_sample_n_frames - video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) self.video_transforms = transforms.Compose( [ - transforms.Resize(video_sample_size[0]), - transforms.CenterCrop(video_sample_size), + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) @@ -193,7 +242,9 @@ class ImageVideoDataset(Dataset): transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) ]) - + + self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) + def get_batch(self, idx): data_info = self.dataset[idx % len(self.dataset)] @@ -208,7 +259,7 @@ class ImageVideoDataset(Dataset): with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: min_sample_n_frames = min( self.video_sample_n_frames, - int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start)) + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) ) if min_sample_n_frames == 0: raise ValueError(f"No Frames in video.") @@ -223,6 +274,12 @@ class ImageVideoDataset(Dataset): pixel_values = func_timeout( VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) except FunctionTimedOut: raise ValueError(f"Read {idx} timeout.") except Exception as e: @@ -291,6 +348,238 @@ class ImageVideoDataset(Dataset): clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 sample["clip_pixel_values"] = clip_pixel_values + ref_pixel_values = sample["pixel_values"][0].unsqueeze(0) + if (mask == 1).all(): + ref_pixel_values = torch.ones_like(ref_pixel_values) * -1 + sample["ref_pixel_values"] = ref_pixel_values + + return sample + + +class ImageVideoControlDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + image_sample_size=512, + video_repeat=0, + text_drop_ratio=-1, + enable_bucket=False, + video_length_drop_start=0.1, + video_length_drop_end=0.9, + enable_inpaint=False, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + + # It's used to balance num of images and videos. + self.dataset = [] + for data in dataset: + if data.get('type', 'image') != 'video': + self.dataset.append(data) + if video_repeat > 0: + for _ in range(video_repeat): + for data in dataset: + if data.get('type', 'image') == 'video': + self.dataset.append(data) + del dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + self.enable_inpaint = enable_inpaint + + self.video_length_drop_start = video_length_drop_start + self.video_length_drop_end = video_length_drop_end + + # Video params + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_transforms = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + # Image params + self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) + self.image_transforms = transforms.Compose([ + transforms.Resize(min(self.image_sample_size)), + transforms.CenterCrop(self.image_sample_size), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) + ]) + + self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + video_id, text = data_info['file_path'], data_info['text'] + + if data_info.get('type', 'image')=='video': + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + + control_video_id = data_info['control_file_path'] + + if self.data_root is None: + control_video_id = control_video_id + else: + control_video_id = os.path.join(self.data_root, control_video_id) + + with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: + try: + sample_args = (control_video_reader, batch_index) + control_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(control_pixel_values)): + frame = control_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + control_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + del control_video_reader + else: + control_pixel_values = control_pixel_values + + if not self.enable_bucket: + control_pixel_values = self.video_transforms(control_pixel_values) + return pixel_values, control_pixel_values, text, "video" + else: + image_path, text = data_info['file_path'], data_info['text'] + if self.data_root is not None: + image_path = os.path.join(self.data_root, image_path) + image = Image.open(image_path).convert('RGB') + if not self.enable_bucket: + image = self.image_transforms(image).unsqueeze(0) + else: + image = np.expand_dims(np.array(image), 0) + + if random.random() < self.text_drop_ratio: + text = '' + + control_image_id = data_info['control_file_path'] + + if self.data_root is None: + control_image_id = control_image_id + else: + control_image_id = os.path.join(self.data_root, control_image_id) + + control_image = Image.open(control_image_id).convert('RGB') + if not self.enable_bucket: + control_image = self.image_transforms(control_image).unsqueeze(0) + else: + control_image = np.expand_dims(np.array(control_image), 0) + return image, control_image, text, 'image' + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, control_pixel_values, name, data_type = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["control_pixel_values"] = control_pixel_values + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + ref_pixel_values = sample["pixel_values"][0].unsqueeze(0) + if (mask == 1).all(): + ref_pixel_values = torch.ones_like(ref_pixel_values) * -1 + sample["ref_pixel_values"] = ref_pixel_values + return sample if __name__ == "__main__": diff --git a/easyanimate/models/__init__.py b/easyanimate/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2988ded0700ca058add72405e60ef15062eb1c --- /dev/null +++ b/easyanimate/models/__init__.py @@ -0,0 +1,16 @@ +from .autoencoder_magvit import (AutoencoderKLCogVideoX, AutoencoderKLMagvit, AutoencoderKL) +from .transformer3d import (EasyAnimateTransformer3DModel, + HunyuanTransformer3DModel, + Transformer3DModel) + + +name_to_transformer3d = { + "Transformer3DModel": Transformer3DModel, + "HunyuanTransformer3DModel": HunyuanTransformer3DModel, + "EasyAnimateTransformer3DModel": EasyAnimateTransformer3DModel, +} +name_to_autoencoder_magvit = { + "AutoencoderKL": AutoencoderKL, + "AutoencoderKLMagvit": AutoencoderKLMagvit, + "AutoencoderKLCogVideoX": AutoencoderKLCogVideoX, +} \ No newline at end of file diff --git a/easyanimate/models/attention.py b/easyanimate/models/attention.py index 2684ac2b6be33ef37e7de9874987649d65424185..9e62da307dd34d1d4cb084cc226272dc46596d09 100644 --- a/easyanimate/models/attention.py +++ b/easyanimate/models/attention.py @@ -11,34 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple, Union import diffusers import pkg_resources import torch import torch.nn.functional as F import torch.nn.init as init - -installed_version = diffusers.__version__ - -if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"): - from diffusers.models.attention_processor import (Attention, - AttnProcessor2_0, - HunyuanAttnProcessor2_0) -else: - from diffusers.models.attention_processor import Attention, AttnProcessor2_0 - -from diffusers.models.attention import AdaLayerNorm, FeedForward -from diffusers.models.embeddings import SinusoidalPositionalEmbedding -from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero -from diffusers.utils import USE_PEFT_BACKEND +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.attention_processor import (Attention, + AttentionProcessor, + AttnProcessor2_0, + HunyuanAttnProcessor2_0) +from diffusers.models.embeddings import (SinusoidalPositionalEmbedding, + TimestepEmbedding, Timesteps, + get_3d_sincos_pos_embed) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero, + CogVideoXLayerNormZero) +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import maybe_allow_in_graph from einops import rearrange, repeat from torch import nn from .motion_module import PositionalEncoding, get_motion_module -from .norm import FP32LayerNorm +from .norm import AdaLayerNormShift, FP32LayerNorm, EasyAnimateLayerNormZero +from .processor import (EasyAnimateAttnProcessor2_0, + LazyKVCompressionProcessor2_0) + + if is_xformers_available(): import xformers @@ -53,7 +57,6 @@ def zero_module(module): p.detach().zero_() return module - @maybe_allow_in_graph class GatedSelfAttentionDense(nn.Module): r""" @@ -95,267 +98,33 @@ class GatedSelfAttentionDense(nn.Module): return x - -class KVCompressionCrossAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (`int`): The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - """ - +class LazyKVCompressionAttention(Attention): def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias=False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - ): - super().__init__() - inner_dim = dim_head * heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - - self.scale = dim_head**-0.5 - - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - self._slice_size = None - self._use_memory_efficient_attention_xformers = True - self.added_kv_proj_dim = added_kv_proj_dim - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) - else: - self.group_norm = None - - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - - self.kv_compression = nn.Conv2d( - query_dim, - query_dim, - groups=query_dim, - kernel_size=2, - stride=2, + self, + sr_ratio=2, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.sr_ratio = sr_ratio + self.k_compression = nn.Conv2d( + kwargs["query_dim"], + kwargs["query_dim"], + groups=kwargs["query_dim"], + kernel_size=sr_ratio, + stride=sr_ratio, bias=True ) - self.kv_compression_norm = FP32LayerNorm(query_dim) - init.constant_(self.kv_compression.weight, 1 / 4) - if self.kv_compression.bias is not None: - init.constant_(self.kv_compression.bias, 0) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim)) - self.to_out.append(nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def set_attention_slice(self, slice_size): - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") - - self._slice_size = slice_size - - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, num_frames: int = 16, height: int = 32, width: int = 32): - batch_size, sequence_length, _ = hidden_states.shape - - encoder_hidden_states = encoder_hidden_states - - if self.group_norm is not None: - hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = self.to_q(hidden_states) - dim = query.shape[-1] - query = self.reshape_heads_to_batch_dim(query) - - if self.added_kv_proj_dim is not None: - key = self.to_k(hidden_states) - value = self.to_v(hidden_states) - - encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) - - key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width) - key = self.kv_compression(key) - key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames) - key = self.kv_compression_norm(key) - key = key.to(query.dtype) - - value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width) - value = self.kv_compression(value) - value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames) - value = self.kv_compression_norm(value) - value = value.to(query.dtype) - - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) - encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) - - key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) - else: - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) - - key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width) - key = self.kv_compression(key) - key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames) - key = self.kv_compression_norm(key) - key = key.to(query.dtype) - - value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width) - value = self.kv_compression(value) - value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames) - value = self.kv_compression_norm(value) - value = value.to(query.dtype) - - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if attention_mask is not None: - if attention_mask.shape[-1] != query.shape[1]: - target_length = query.shape[1] - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) - - # attention, what we cannot get enough of - if self._use_memory_efficient_attention_xformers: - hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) - # Some versions of xformers return output in fp32, cast it back to the dtype of the input - hidden_states = hidden_states.to(query.dtype) - else: - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value, attention_mask) - else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - - # dropout - hidden_states = self.to_out[1](hidden_states) - return hidden_states - - def _attention(self, query, key, value, attention_mask=None): - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = attention_scores.softmax(dim=-1) - - # cast back to the original dtype - attention_probs = attention_probs.to(value.dtype) - - # compute attention output - hidden_states = torch.bmm(attention_probs, value) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + self.v_compression = nn.Conv2d( + kwargs["query_dim"], + kwargs["query_dim"], + groups=kwargs["query_dim"], + kernel_size=sr_ratio, + stride=sr_ratio, + bias=True ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - - if self.upcast_attention: - query_slice = query_slice.float() - key_slice = key_slice.float() - - attn_slice = torch.baddbmm( - torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), - query_slice, - key_slice.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - - if attention_mask is not None: - attn_slice = attn_slice + attention_mask[start_idx:end_idx] - - if self.upcast_softmax: - attn_slice = attn_slice.float() - - attn_slice = attn_slice.softmax(dim=-1) - - # cast back to the original dtype - attn_slice = attn_slice.to(value.dtype) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): - # TODO attention_mask - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - + init.constant_(self.k_compression.weight, 1 / (sr_ratio * sr_ratio)) + init.constant_(self.v_compression.weight, 1 / (sr_ratio * sr_ratio)) + init.constant_(self.k_compression.bias, 0) + init.constant_(self.v_compression.bias, 0) @maybe_allow_in_graph class TemporalTransformerBlock(nn.Module): @@ -413,8 +182,6 @@ class TemporalTransformerBlock(nn.Module): attention_type: str = "default", positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, - # kv compression - kvcompression: Optional[bool] = False, # motion module kwargs motion_module_type = "VanillaGrid", motion_module_kwargs = None, @@ -454,40 +221,17 @@ class TemporalTransformerBlock(nn.Module): else: self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - self.kvcompression = kvcompression - if kvcompression: - self.attn1 = KVCompressionCrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - else: - if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"): - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - qk_norm="layer_norm" if qk_norm else None, - processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), - ) - else: - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + qk_norm="layer_norm" if qk_norm else None, + processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), + ) self.attn_temporal = get_motion_module( in_channels = dim, @@ -505,28 +249,17 @@ class TemporalTransformerBlock(nn.Module): if self.use_ada_layer_norm else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) ) - if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"): - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - qk_norm="layer_norm" if qk_norm else None, - processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), - ) # is self-attn if encoder_hidden_states is none - else: - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + qk_norm="layer_norm" if qk_norm else None, + processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None @@ -605,23 +338,12 @@ class TemporalTransformerBlock(nn.Module): gligen_kwargs = cross_attention_kwargs.pop("gligen", None) norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames) - if self.kvcompression: - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - num_frames=1, - height=height, - width=width, - **cross_attention_kwargs, - ) - else: - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output @@ -658,6 +380,9 @@ class TemporalTransformerBlock(nn.Module): if self.pos_embed is not None and self.use_ada_layer_norm_single is None: norm_hidden_states = self.pos_embed(norm_hidden_states) + if norm_hidden_states.dtype != encoder_hidden_states.dtype or norm_hidden_states.dtype != encoder_attention_mask.dtype: + norm_hidden_states = norm_hidden_states.to(encoder_hidden_states.dtype) + attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -760,7 +485,7 @@ class SelfAttentionTemporalTransformerBlock(nn.Module): double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_type: str = "layer_norm", norm_eps: float = 1e-5, final_dropout: bool = False, attention_type: str = "default", @@ -802,28 +527,17 @@ class SelfAttentionTemporalTransformerBlock(nn.Module): else: self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"): - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - qk_norm="layer_norm" if qk_norm else None, - processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), - ) - else: - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + qk_norm="layer_norm" if qk_norm else None, + processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), + ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: @@ -835,28 +549,17 @@ class SelfAttentionTemporalTransformerBlock(nn.Module): if self.use_ada_layer_norm else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) ) - if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"): - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - qk_norm="layer_norm" if qk_norm else None, - processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), - ) # is self-attn if encoder_hidden_states is none - else: - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + qk_norm="layer_norm" if qk_norm else None, + processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None @@ -1017,340 +720,415 @@ class SelfAttentionTemporalTransformerBlock(nn.Module): hidden_states = hidden_states.squeeze(1) return hidden_states - + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out, norm_elementwise_affine): + super().__init__() + self.norm = FP32LayerNorm(dim_in, dim_in, norm_elementwise_affine) + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(self.norm(x)).chunk(2, dim=-1) + return x * F.gelu(gate) @maybe_allow_in_graph -class KVCompressionTransformerBlock(nn.Module): +class HunyuanDiTBlock(nn.Module): r""" - A Temporal Transformer block. + Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and + QKNorm Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of headsto use for multi-head attention. + cross_attention_dim (`int`,*optional*): + The size of the encoder_hidden_states vector for cross attention. + dropout(`float`, *optional*, defaults to 0.0): + The dropout probability to use. + activation_fn (`str`,*optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. . norm_elementwise_affine (`bool`, *optional*, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + norm_eps (`float`, *optional*, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. final_dropout (`bool` *optional*, defaults to False): Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, *optional*, defaults to `"default"`): - The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - positional_embeddings (`str`, *optional*, defaults to `None`): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. + ff_inner_dim (`int`, *optional*): + The size of the hidden layer in the feed-forward block. Defaults to `None`. + ff_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the feed-forward block. + skip (`bool`, *optional*, defaults to `False`): + Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. + qk_norm (`bool`, *optional*, defaults to `True`): + Whether to use normalization in QK calculation. Defaults to `True`. """ def __init__( self, dim: int, num_attention_heads: int, - attention_head_dim: int, + cross_attention_dim: int = 1024, dropout=0.0, - cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' - norm_eps: float = 1e-5, + norm_eps: float = 1e-6, final_dropout: bool = False, - attention_type: str = "default", - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - kvcompression: Optional[bool] = False, - qk_norm = False, - after_norm = False, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + skip: bool = False, + qk_norm: bool = True, + time_position_encoding: bool = False, + after_norm: bool = False, + is_local_attention: bool = False, + local_attention_frames: int = 2, + enable_inpaint: bool = False, + kvcompression = False, ): super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - self.use_ada_layer_norm_single = norm_type == "ada_norm_single" - self.use_layer_norm = norm_type == "layer_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - if positional_embeddings and (num_positional_embeddings is None): - raise ValueError( - "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." - ) - - if positional_embeddings == "sinusoidal": - self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) - else: - self.pos_embed = None # Define 3 blocks. Each block has its own normalization layer. + # NOTE: when new version comes, check norm2 and norm 3 # 1. Self-Attn - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.t_embed = PositionalEncoding(dim, dropout=0., max_len=512) \ + if time_position_encoding else nn.Identity() + self.is_local_attention = is_local_attention + self.local_attention_frames = local_attention_frames self.kvcompression = kvcompression if kvcompression: - self.attn1 = KVCompressionCrossAttention( + self.attn1 = LazyKVCompressionAttention( query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=LazyKVCompressionProcessor2_0(), ) else: - if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"): - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - qk_norm="layer_norm" if qk_norm else None, - processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), - ) - else: - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor2_0(), + ) # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + if self.is_local_attention: + from mamba_ssm import Mamba2 + self.mamba_norm_in = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.in_linear = nn.Linear(dim, 1536) + self.mamba_norm_1 = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine) + self.mamba_norm_2 = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine) + + self.mamba_block_1 = Mamba2( + d_model=1536, + d_state=64, + d_conv=4, + expand=2, ) - if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"): - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - qk_norm="layer_norm" if qk_norm else None, - processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(), - ) # is self-attn if encoder_hidden_states is none - else: - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None + self.mamba_block_2 = Mamba2( + d_model=1536, + d_state=64, + d_conv=4, + expand=2, + ) + self.mamba_norm_after_mamba_block = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine) + + self.out_linear = nn.Linear(1536, dim) + self.out_linear = zero_module(self.out_linear) + self.mamba_norm_out = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor2_0(), + ) + if enable_inpaint: + self.norm_clip = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.attn_clip = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor2_0(), + ) + self.gate_clip = GEGLU(dim, dim, norm_elementwise_affine) + self.norm_clip_out = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.attn_clip = None + self.norm_clip = None + self.gate_clip = None + self.norm_clip_out = None + # 3. Feed-forward - if not self.use_ada_layer_norm_single: - self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, ### 0.0 + activation_fn=activation_fn, ### approx GeLU + final_dropout=final_dropout, ### 0.0 + inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) + bias=ff_bias, + ) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + # 4. Skip Connection + if skip: + self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True) + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None if after_norm: self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) else: self.norm4 = None - # 4. Fuser - if attention_type == "gated" or attention_type == "gated-text-image": - self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) - - # 5. Scale-shift for PixArt-Alpha. - if self.use_ada_layer_norm_single: - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - num_frames: int = 16, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb=None, + skip=None, + num_frames: int = 1, height: int = 32, width: int = 32, - use_reentrant: bool = False, - ) -> torch.FloatTensor: + clip_encoder_hidden_states: Optional[torch.Tensor] = None, + disable_image_rotary_emb_in_attn1=False, + ) -> torch.Tensor: # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - batch_size = hidden_states.shape[0] - - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.use_layer_norm: - norm_hidden_states = self.norm1(hidden_states) - elif self.use_ada_layer_norm_single: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - norm_hidden_states = norm_hidden_states.squeeze(1) - else: - raise ValueError("Incorrect norm used") - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - # 1. Retrieve lora scale. - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - - # 2. Prepare GLIGEN inputs - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - - if self.kvcompression: + # 0. Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([hidden_states, skip], dim=-1) + cat = self.skip_norm(cat) + hidden_states = self.skip_linear(cat) + + if image_rotary_emb is not None: + image_rotary_emb = (torch.cat([image_rotary_emb[0] for i in range(num_frames)], dim=0), torch.cat([image_rotary_emb[1] for i in range(num_frames)], dim=0)) + + if num_frames != 1: + # add time embedding + hidden_states = rearrange(hidden_states, "b (f d) c -> (b d) f c", f=num_frames) + if self.t_embed is not None: + hidden_states = self.t_embed(hidden_states) + hidden_states = rearrange(hidden_states, "(b d) f c -> b (f d) c", d=height * width) + + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct + if num_frames > 2 and self.is_local_attention: + if image_rotary_emb is not None: + attn1_image_rotary_emb = (image_rotary_emb[0][:int(height * width * 2)], image_rotary_emb[1][:int(height * width * 2)]) + else: + attn1_image_rotary_emb = image_rotary_emb + norm_hidden_states_1 = rearrange(norm_hidden_states, "b (f d) c -> b f d c", d=height * width) + norm_hidden_states_1 = rearrange(norm_hidden_states_1, "b (f p) d c -> (b f) (p d) c", p = 2) + attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - num_frames=num_frames, - height=height, - width=width, - **cross_attention_kwargs, + norm_hidden_states_1, + image_rotary_emb=attn1_image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None, ) - else: - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, + attn_output = rearrange(attn_output, "(b f) (p d) c -> b (f p) d c", p = 2, f = num_frames // 2) + + norm_hidden_states_2 = rearrange(norm_hidden_states, "b (f d) c -> b f d c", d = height * width)[:, 1:-1] + local_attention_frames_num = norm_hidden_states_2.size()[1] // 2 + norm_hidden_states_2 = rearrange(norm_hidden_states_2, "b (f p) d c -> (b f) (p d) c", p = 2) + attn_output_2 = self.attn1( + norm_hidden_states_2, + image_rotary_emb=attn1_image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None, ) + attn_output_2 = rearrange(attn_output_2, "(b f) (p d) c -> b (f p) d c", p = 2, f = local_attention_frames_num) + attn_output[:, 1:-1] = (attn_output[:, 1:-1] + attn_output_2) / 2 - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.use_ada_layer_norm_single: - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 2.5 GLIGEN Control - if gligen_kwargs is not None: - hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - - # 3. Cross-Attention - if self.attn2 is not None: - if self.use_ada_layer_norm: - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.use_ada_layer_norm_zero or self.use_layer_norm: - norm_hidden_states = self.norm2(hidden_states) - elif self.use_ada_layer_norm_single: - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states + attn_output = rearrange(attn_output, "b f d c -> b (f d) c") + else: + if self.kvcompression: + norm_hidden_states = rearrange(norm_hidden_states, "b (f h w) c -> b c f h w", f = num_frames, h = height, w = width) + attn_output = self.attn1( + norm_hidden_states, + image_rotary_emb=image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None, + ) else: - raise ValueError("Incorrect norm") - - if self.pos_embed is not None and self.use_ada_layer_norm_single is None: - norm_hidden_states = self.pos_embed(norm_hidden_states) + attn_output = self.attn1( + norm_hidden_states, + image_rotary_emb=image_rotary_emb if not disable_image_rotary_emb_in_attn1 else None, + ) + hidden_states = hidden_states + attn_output + + if num_frames > 2 and self.is_local_attention: + hidden_states_in = self.in_linear(self.mamba_norm_in(hidden_states)) + hidden_states = hidden_states + self.mamba_norm_out( + self.out_linear( + self.mamba_norm_after_mamba_block( + self.mamba_block_1( + self.mamba_norm_1(hidden_states_in) + ) + + self.mamba_block_2( + self.mamba_norm_2(hidden_states_in.flip(1)) + ).flip(1) + ) + ) + ) + + # 2. Cross-Attention + hidden_states = hidden_states + self.attn2( + self.norm2(hidden_states), + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, + if self.attn_clip is not None: + hidden_states = hidden_states + self.norm_clip_out( + self.gate_clip( + self.attn_clip( + self.norm_clip(hidden_states), + encoder_hidden_states=clip_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + ) ) - hidden_states = attn_output + hidden_states - # 4. Feed-forward - if not self.use_ada_layer_norm_single: - norm_hidden_states = self.norm3(hidden_states) + # FFN Layer ### TODO: switch norm2 and norm3 in the state dict + mlp_inputs = self.norm3(hidden_states) + if self.norm4 is not None: + hidden_states = hidden_states + self.norm4(self.ff(mlp_inputs)) + else: + hidden_states = hidden_states + self.ff(mlp_inputs) - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + return hidden_states - if self.use_ada_layer_norm_single: - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp +@maybe_allow_in_graph +class EasyAnimateDiTBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-6, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + qk_norm: bool = True, + after_norm: bool = False, + norm_type: str="fp32_layer_norm" + ): + super().__init__() - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) + # Attention Part + self.norm1 = EasyAnimateLayerNormZero( + time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True + ) - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [ - self.ff(hid_slice, scale=lora_scale) - for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) - ], - dim=self._chunk_dim, - ) - else: - ff_output = self.ff(norm_hidden_states, scale=lora_scale) + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=EasyAnimateAttnProcessor2_0(), + ) + self.attn2 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=EasyAnimateAttnProcessor2_0(), + ) - if self.norm4 is not None: - ff_output = self.norm4(ff_output) + # FFN Part + self.norm2 = EasyAnimateLayerNormZero( + time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True + ) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + self.txt_ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + if after_norm: + self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + else: + self.norm3 = None - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - elif self.use_ada_layer_norm_single: - ff_output = gate_mlp * ff_output + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + # Norm + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) + # Attn + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attn2=self.attn2, + ) + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states - return hidden_states \ No newline at end of file + # Norm + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # FFN + if self.norm3 is not None: + norm_hidden_states = self.norm3(self.ff(norm_hidden_states)) + norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states)) + else: + norm_hidden_states = self.ff(norm_hidden_states) + norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states) + hidden_states = hidden_states + gate_ff * norm_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states + return hidden_states, encoder_hidden_states \ No newline at end of file diff --git a/easyanimate/models/autoencoder_magvit.py b/easyanimate/models/autoencoder_magvit.py index 607fc5452d84a945cc5dc7a71e793b437b20e8fa..62ee173d01feb4f98a0b1abd26b6d84ac18edbde 100644 --- a/easyanimate/models/autoencoder_magvit.py +++ b/easyanimate/models/autoencoder_magvit.py @@ -15,8 +15,14 @@ from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.autoencoders.vae import (DecoderOutput, + DiagonalGaussianDistribution) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging +from diffusers.utils.accelerate_utils import apply_forward_hook try: from diffusers.loaders import FromOriginalVAEMixin @@ -32,10 +38,16 @@ from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.utils.accelerate_utils import apply_forward_hook from torch import nn +from diffusers import AutoencoderKL +from ..vae.ldm.models.cogvideox_enc_dec import (CogVideoXCausalConv3d, + CogVideoXDecoder3D, + CogVideoXEncoder3D, + CogVideoXSafeConv3d) from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder +logger = logging.get_logger(__name__) # pylint: disable=invalid-name def str_eval(item): if type(item) == str: @@ -97,10 +109,19 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): latent_channels: int = 4, norm_num_groups: int = 32, scaling_factor: float = 0.1825, + slice_mag_vae=True, slice_compression_vae=False, + cache_compression_vae=False, + cache_mag_vae=False, use_tiling=False, + use_tiling_encoder=False, + use_tiling_decoder=False, mini_batch_encoder=9, mini_batch_decoder=3, + upcast_vae=False, + spatial_group_norm=False, + tile_sample_min_size=384, + tile_overlap_factor=0.25, ): super().__init__() down_block_types = str_eval(down_block_types) @@ -121,8 +142,12 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): act_fn=act_fn, num_attention_heads=num_attention_heads, double_z=True, + slice_mag_vae=slice_mag_vae, slice_compression_vae=slice_compression_vae, + cache_compression_vae=cache_compression_vae, + cache_mag_vae=cache_mag_vae, mini_batch_encoder=mini_batch_encoder, + spatial_group_norm=spatial_group_norm, ) self.decoder = omnigen_Mag_Decoder( @@ -140,20 +165,30 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): norm_num_groups=norm_num_groups, act_fn=act_fn, num_attention_heads=num_attention_heads, + slice_mag_vae=slice_mag_vae, slice_compression_vae=slice_compression_vae, + cache_compression_vae=cache_compression_vae, + cache_mag_vae=cache_mag_vae, mini_batch_decoder=mini_batch_decoder, + spatial_group_norm=spatial_group_norm, ) self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) + self.slice_mag_vae = slice_mag_vae self.slice_compression_vae = slice_compression_vae + self.cache_compression_vae = cache_compression_vae + self.cache_mag_vae = cache_mag_vae self.mini_batch_encoder = mini_batch_encoder self.mini_batch_decoder = mini_batch_decoder self.use_slicing = False self.use_tiling = use_tiling - self.tile_sample_min_size = 384 - self.tile_overlap_factor = 0.25 + self.use_tiling_encoder = use_tiling_encoder + self.use_tiling_decoder = use_tiling_decoder + self.upcast_vae = upcast_vae + self.tile_sample_min_size = tile_sample_min_size + self.tile_overlap_factor = tile_overlap_factor self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1))) self.scaling_factor = scaling_factor @@ -253,8 +288,16 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ + if self.upcast_vae: + x = x.float() + self.encoder = self.encoder.float() + self.quant_conv = self.quant_conv.float() if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): - return self.tiled_encode(x, return_dict=return_dict) + x = self.tiled_encode(x, return_dict=return_dict) + return x + if self.use_tiling_encoder and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + x = self.tiled_encode(x, return_dict=return_dict) + return x if self.use_slicing and x.shape[0] > 1: encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] @@ -271,8 +314,15 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.upcast_vae: + z = z.float() + self.decoder = self.decoder.float() + self.post_quant_conv = self.post_quant_conv.float() if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) + if self.use_tiling_decoder and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + z = self.post_quant_conv(z) dec = self.decoder(z) @@ -408,6 +458,34 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): result_rows.append(torch.cat(result_row, dim=4)) dec = torch.cat(result_rows, dim=3) + + # Handle the lower right corner tile separately + lower_right_original = z[ + :, + :, + :, + -self.tile_latent_min_size:, + -self.tile_latent_min_size: + ] + quantized_lower_right = self.decoder(self.post_quant_conv(lower_right_original)) + + # Combine + H, W = quantized_lower_right.size(-2), quantized_lower_right.size(-1) + x_weights = torch.linspace(0, 1, W).unsqueeze(0).repeat(H, 1) + y_weights = torch.linspace(0, 1, H).unsqueeze(1).repeat(1, W) + weights = torch.min(x_weights, y_weights) + + if len(dec.size()) == 4: + weights = weights.unsqueeze(0).unsqueeze(0) + elif len(dec.size()) == 5: + weights = weights.unsqueeze(0).unsqueeze(0).unsqueeze(0) + + weights = weights.to(dec.device) + quantized_area = dec[:, :, :, -H:, -W:] + combined = weights * quantized_lower_right + (1 - weights) * quantized_area + + dec[:, :, :, -H:, -W:] = combined + if not return_dict: return (dec,) @@ -507,3 +585,441 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") print(m, u) return model + + +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [CogVideoX](https://github.com/THUDM/CogVideo). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to `1.15258426`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["CogVideoXResnetBlock3D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types: Tuple[str] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: Tuple[int] = (128, 256, 256, 512), + latent_channels: int = 16, + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + sample_height: int = 480, + sample_width: int = 720, + scaling_factor: float = 1.15258426, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + slice_mag_vae=False, + slice_compression_vae=False, + cache_compression_vae=False, + cache_mag_vae=True, + use_tiling=False, + mini_batch_encoder=4, + mini_batch_decoder=1, + ): + super().__init__() + + self.encoder = CogVideoXEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = CogVideoXDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None + self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None + + self.use_slicing = False + self.use_tiling = use_tiling + + # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not + # recommended because the temporal parts of the VAE, here, are tricky to understand. + # If you decode X latent frames together, the number of output frames is: + # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames + # + # Example with num_latent_frames_batch_size = 2: + # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together + # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + # => 6 * 8 = 48 frames + # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together + # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + + # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + # => 1 * 9 + 5 * 8 = 49 frames + # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that + # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different + # number of temporal frames. + self.num_latent_frames_batch_size = 2 + + # We make the minimum height and width of sample for tiling half that of the generally supported + self.tile_sample_min_height = sample_height // 2 + self.tile_sample_min_width = sample_width // 2 + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + + # These are experimental overlap factors that were chosen based on experimentation and seem to work best for + # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX + # and so the tiling implementation has only been tested on those specific resolutions. + self.tile_overlap_factor_height = 1 / 6 + self.tile_overlap_factor_width = 1 / 5 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): + module.gradient_checkpointing = value + + def _clear_fake_context_parallel_cache(self): + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") + module._clear_fake_context_parallel_cache() + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_overlap_factor_height: Optional[float] = None, + tile_overlap_factor_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`int`, *optional*): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height + self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + batch_size, num_channels, num_frames, height, width = x.shape + if num_frames == 1: + h = self.encoder(x) + if self.quant_conv is not None: + h = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(h) + else: + frame_batch_size = 4 + h = [] + for i in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + z_intermediate = x[:, :, start_frame:end_frame] + z_intermediate = self.encoder(z_intermediate) + if self.quant_conv is not None: + z_intermediate = self.quant_conv(z_intermediate) + h.append(z_intermediate) + self._clear_fake_context_parallel_cache() + h = torch.cat(h, dim=2) + posterior = DiagonalGaussianDistribution(h) + self._clear_fake_context_parallel_cache() + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + if num_frames == 1: + dec = [] + z_intermediate = z + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate = self.decoder(z_intermediate) + dec.append(z_intermediate) + else: + frame_batch_size = self.num_latent_frames_batch_size + dec = [] + for i in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + z_intermediate = z[:, :, start_frame:end_frame] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate = self.decoder(z_intermediate) + dec.append(z_intermediate) + + self._clear_fake_context_parallel_cache() + dec = torch.cat(dec, dim=2) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + time = [] + for k in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile = self.decoder(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec diff --git a/easyanimate/models/embeddings.py b/easyanimate/models/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..dddef56856bc11dd7ea4803b264224f9585d473f --- /dev/null +++ b/easyanimate/models/embeddings.py @@ -0,0 +1,107 @@ +import math +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.models.embeddings import (PixArtAlphaTextProjection, get_timestep_embedding, + TimestepEmbedding, Timesteps) +from einops import rearrange +from torch import nn + + +class HunyuanDiTAttentionPool(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) + + query = self.q_proj(x[:, :1]) + key = self.k_proj(x) + value = self.v_proj(x) + batch_size, _, _ = query.size() + + query = query.reshape(batch_size, -1, self.num_heads, query.size(-1) // self.num_heads).transpose(1, 2) # (1, H, N, E/H) + key = key.reshape(batch_size, -1, self.num_heads, key.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H) + value = value.reshape(batch_size, -1, self.num_heads, value.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H) + + x = F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=None, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2).reshape(batch_size, 1, -1) + x = x.to(query.dtype) + x = self.c_proj(x) + + return x.squeeze(1) + + +class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.pooler = HunyuanDiTAttentionPool( + seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim + ) + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + self.extra_embedder = PixArtAlphaTextProjection( + in_features=extra_in_dim, + hidden_size=embedding_dim * 4, + out_features=embedding_dim, + act_fn="silu_fp32", + ) + + def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) + + # extra condition1: text + pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) + + # extra condition2: image meta size embdding + image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] + + return conditioning + + +class TimePositionalEncoding(nn.Module): + def __init__( + self, + d_model, + dropout = 0., + max_len = 24 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + b, c, f, h, w = x.size() + x = rearrange(x, "b c f h w -> (b h w) f c") + x = x + self.pe[:, :x.size(1)] + x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w) + return self.dropout(x) \ No newline at end of file diff --git a/easyanimate/models/norm.py b/easyanimate/models/norm.py index d0e61b2afbdd484e325c6d2d2f6b50f7790576ba..9bb6dc0a149286a7170220fead19619cbc4cac6f 100644 --- a/easyanimate/models/norm.py +++ b/easyanimate/models/norm.py @@ -2,7 +2,8 @@ from typing import Any, Dict, Optional, Tuple import torch import torch.nn.functional as F -from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.embeddings import (CombinedTimestepLabelEmbeddings, + TimestepEmbedding, Timesteps) from torch import nn @@ -12,7 +13,6 @@ def zero_module(module): p.detach().zero_() return module - class FP32LayerNorm(nn.LayerNorm): def forward(self, inputs: torch.Tensor) -> torch.Tensor: origin_dtype = inputs.dtype @@ -95,3 +95,56 @@ class AdaLayerNormSingle(nn.Module): # No modulation happening here. embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep + +class AdaLayerNormShift(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim) + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype)) + x = self.norm(x) + shift.unsqueeze(dim=1) + return x + +class EasyAnimateLayerNormZero(nn.Module): + # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py + # Add fp32 layer norm + def __init__( + self, + conditioning_dim: int, + embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "fp32_layer_norm", + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] \ No newline at end of file diff --git a/easyanimate/models/patch.py b/easyanimate/models/patch.py index 4391bed8ac1432a03dfdd7f9bee6431ac3836780..f075acf5d13d4b87925b721576582e324c4b12b8 100644 --- a/easyanimate/models/patch.py +++ b/easyanimate/models/patch.py @@ -153,15 +153,6 @@ class TemporalUpsampler3D(Upsampler): x = torch.cat([first_frame, x], dim=2) return x -def cast_tuple(t, length = 1): - return t if isinstance(t, tuple) else ((t,) * length) - -def divisible_by(num, den): - return (num % den) == 0 - -def is_odd(n): - return not divisible_by(n, 2) - class CausalConv3d(nn.Conv3d): def __init__( self, diff --git a/easyanimate/models/processor.py b/easyanimate/models/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9224085976b068aa96f194d02363024532c097 --- /dev/null +++ b/easyanimate/models/processor.py @@ -0,0 +1,312 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers.models.attention import Attention +from diffusers.models.embeddings import apply_rotary_emb +from einops import rearrange, repeat + + +class HunyuanAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class LazyKVCompressionProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the KVCompression model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + batch_size, channel, num_frames, height, width = hidden_states.shape + hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c", f=num_frames, h=height, w=width) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width) + key = attn.k_compression(key) + key_shape = key.size() + key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames) + + value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width) + value = attn.v_compression(value) + value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + compression_image_rotary_emb = ( + rearrange(image_rotary_emb[0], "(f h w) c -> f c h w", f=num_frames, h=height, w=width), + rearrange(image_rotary_emb[1], "(f h w) c -> f c h w", f=num_frames, h=height, w=width), + ) + compression_image_rotary_emb = ( + F.interpolate(compression_image_rotary_emb[0], size=key_shape[-2:], mode='bilinear'), + F.interpolate(compression_image_rotary_emb[1], size=key_shape[-2:], mode='bilinear') + ) + compression_image_rotary_emb = ( + rearrange(compression_image_rotary_emb[0], "f c h w -> (f h w) c"), + rearrange(compression_image_rotary_emb[1], "f c h w -> (f h w) c"), + ) + + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, compression_image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class EasyAnimateAttnProcessor2_0: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + attn2: Attention = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn2 is None: + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if attn2 is not None: + query_txt = attn2.to_q(encoder_hidden_states) + key_txt = attn2.to_k(encoder_hidden_states) + value_txt = attn2.to_v(encoder_hidden_states) + + inner_dim = key_txt.shape[-1] + head_dim = inner_dim // attn.heads + + query_txt = query_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_txt = key_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_txt = value_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn2.norm_q is not None: + query_txt = attn2.norm_q(query_txt) + if attn2.norm_k is not None: + key_txt = attn2.norm_k(key_txt) + + query = torch.cat([query_txt, query], dim=2) + key = torch.cat([key_txt, key], dim=2) + value = torch.cat([value_txt, value], dim=2) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + if attn2 is None: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + else: + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + encoder_hidden_states = attn2.to_out[0](encoder_hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn2.to_out[1](encoder_hidden_states) + return hidden_states, encoder_hidden_states diff --git a/easyanimate/models/resampler.py b/easyanimate/models/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..43a0612131b8914d4e4c96a0ef3f3d14f01d6bfb --- /dev/null +++ b/easyanimate/models/resampler.py @@ -0,0 +1,146 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.init import normal_ + + +def get_abs_pos(abs_pos, tgt_size): + # abs_pos: L, C + # tgt_size: M + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + return F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size, tgt_size), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + else: + return abs_pos + +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + +class Resampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + def __init__( + self, + grid_size, + embed_dim, + num_heads, + kv_dim=None, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.num_queries = grid_size ** 2 + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.pos_embed = nn.Parameter( + torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() + ).requires_grad_(False) + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + normal_(self.query, std=.02) + + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) + else: + self.kv_proj = nn.Identity() + + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x, key_padding_mask=None): + pos_embed = get_abs_pos(self.pos_embed, x.size(1)) + + x = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + key_padding_mask=key_padding_mask)[0] + return out.permute(1, 0, 2) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) \ No newline at end of file diff --git a/easyanimate/models/transformer2d.py b/easyanimate/models/transformer2d.py index 836d8dad22541bc8b5d4f6fbec6b1dea2648c755..971a02b7c2a37ef8dd2380dc017977676c321aab 100644 --- a/easyanimate/models/transformer2d.py +++ b/easyanimate/models/transformer2d.py @@ -37,10 +37,6 @@ except: from diffusers.models.embeddings import \ CaptionProjection as PixArtAlphaTextProjection -from .attention import (KVCompressionTransformerBlock, - SelfAttentionTemporalTransformerBlock, - TemporalTransformerBlock) - @dataclass class Transformer2DModelOutput(BaseOutput): @@ -196,58 +192,29 @@ class Transformer2DModel(ModelMixin, ConfigMixin): interpolation_scale=interpolation_scale, ) - basic_block = { - "basic": BasicTransformerBlock, - "kvcompression": KVCompressionTransformerBlock, - }[self.basic_block_type] - if self.basic_block_type == "kvcompression": - self.transformer_blocks = nn.ModuleList( - [ - basic_block( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - double_self_attention=double_self_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - attention_type=attention_type, - kvcompression=False if d < 14 else True, - ) - for d in range(num_layers) - ] - ) - else: - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - double_self_attention=double_self_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - attention_type=attention_type, - ) - for d in range(num_layers) - ] - ) + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels @@ -413,7 +380,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin): if self.training and self.gradient_checkpointing: args = { "basic": [], - "kvcompression": [1, height, width], }[self.basic_block_type] hidden_states = torch.utils.checkpoint.checkpoint( block, @@ -430,7 +396,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin): else: kwargs = { "basic": {}, - "kvcompression": {"num_frames":1, "height":height, "width":width}, }[self.basic_block_type] hidden_states = block( hidden_states, diff --git a/easyanimate/models/transformer3d.py b/easyanimate/models/transformer3d.py index 47f8e02570e714dca9828bc70628207447989b9a..eff2760667b2e8b672a25d6891bdf20428e342e5 100644 --- a/easyanimate/models/transformer3d.py +++ b/easyanimate/models/transformer3d.py @@ -11,34 +11,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import glob import json import math import os from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import numpy as np import torch import torch.nn.functional as F -import torch.nn.init as init from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.attention import BasicTransformerBlock, FeedForward +from diffusers.models.attention import BasicTransformerBlock from diffusers.models.embeddings import (PatchEmbed, PixArtAlphaTextProjection, - TimestepEmbedding, Timesteps) -from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear + TimestepEmbedding, Timesteps, + get_2d_sincos_pos_embed) +from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import AdaLayerNormContinuous +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging) from diffusers.utils.torch_utils import maybe_allow_in_graph from einops import rearrange from torch import nn -from .attention import (SelfAttentionTemporalTransformerBlock, - TemporalTransformerBlock) +from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock, + SelfAttentionTemporalTransformerBlock, + TemporalTransformerBlock, zero_module) +from .embeddings import HunyuanCombinedTimestepTextSizeStyleEmbedding, TimePositionalEncoding from .norm import AdaLayerNormSingle -from .patch import (CasualPatchEmbed3D, Patch1D, PatchEmbed3D, PatchEmbedF3D, +from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D, TemporalUpsampler3D, UnPatch1D) +from .resampler import Resampler try: from diffusers.models.embeddings import PixArtAlphaTextProjection @@ -46,12 +51,6 @@ except: from diffusers.models.embeddings import \ CaptionProjection as PixArtAlphaTextProjection -def zero_module(module): - # Zero out the parameters of a module and return it. - for p in module.parameters(): - p.detach().zero_() - return module - class CLIPProjection(nn.Module): """ @@ -72,28 +71,6 @@ class CLIPProjection(nn.Module): hidden_states = self.linear_2(hidden_states) return hidden_states -class TimePositionalEncoding(nn.Module): - def __init__( - self, - d_model, - dropout = 0., - max_len = 24 - ): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - pe = torch.zeros(1, max_len, d_model) - pe[0, :, 0::2] = torch.sin(position * div_term) - pe[0, :, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) - - def forward(self, x): - b, c, f, h, w = x.size() - x = rearrange(x, "b c f h w -> (b h w) f c") - x = x + self.pe[:, :x.size(1)] - x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w) - return self.dropout(x) @dataclass class Transformer3DModelOutput(BaseOutput): @@ -189,6 +166,10 @@ class Transformer3DModel(ModelMixin, ConfigMixin): qk_norm = False, after_norm = False, + resize_inpaint_mask_directly: bool = False, + enable_clip_in_inpaint: bool = True, + enable_text_attention_mask: bool = True, + add_noise_in_inpaint_model: bool = False, ): super().__init__() self.use_linear_projection = use_linear_projection @@ -202,9 +183,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin): self.casual_3d = casual_3d self.casual_3d_upsampler_index = casual_3d_upsampler_index - conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv - linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear - assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size" self.height = sample_size @@ -310,34 +288,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin): for d in range(num_layers) ] ) - elif self.basic_block_type == "kvcompression_motionmodule": - self.transformer_blocks = nn.ModuleList( - [ - TemporalTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - double_self_attention=double_self_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - attention_type=attention_type, - kvcompression=False if d < 14 else True, - motion_module_type=motion_module_type, - motion_module_kwargs=motion_module_kwargs, - qk_norm=qk_norm, - after_norm=after_norm, - ) - for d in range(num_layers) - ] - ) elif self.basic_block_type == "selfattentiontemporal": self.transformer_blocks = nn.ModuleList( [ @@ -448,6 +398,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin): self, hidden_states: torch.Tensor, inpaint_latents: torch.Tensor = None, + control_latents: torch.Tensor = None, encoder_hidden_states: Optional[torch.Tensor] = None, clip_encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, @@ -524,6 +475,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin): if inpaint_latents is not None: hidden_states = torch.concat([hidden_states, inpaint_latents], 1) + if control_latents is not None: + hidden_states = torch.concat([hidden_states, control_latents], 1) # 1. Input if self.casual_3d: video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size @@ -596,7 +549,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin): "motionmodule": [video_length, height, width], "global_motionmodule": [video_length, height, width], "selfattentiontemporal": [], - "kvcompression_motionmodule": [video_length, height, width], }[self.basic_block_type] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), @@ -616,7 +568,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin): "motionmodule": {"num_frames":video_length, "height":height, "width":width}, "global_motionmodule": {"num_frames":video_length, "height":height, "width":width}, "selfattentiontemporal": {}, - "kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width}, }[self.basic_block_type] hidden_states = block( hidden_states, @@ -741,4 +692,745 @@ class Transformer3DModel(ModelMixin, ConfigMixin): params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()] print(f"### Attn temporal Parameters: {sum(params) / 1e6} M") + return model + +class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): + The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + patch_size (`int`, *optional*): + The size of the patch to use for the input. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. + sample_size (`int`, *optional*): + The width of the latent images. This is fixed during training since it is used to learn a number of + position embeddings. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The number of dimension in the clip text embedding. + hidden_size (`int`, *optional*): + The size of hidden layer in the conditioning embedding layers. + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden layer size to the input size. + learn_sigma (`bool`, *optional*, defaults to `True`): + Whether to predict variance. + cross_attention_dim_t5 (`int`, *optional*): + The number dimensions in t5 text embedding. + pooled_projection_dim (`int`, *optional*): + The size of the pooled projection. + text_len (`int`, *optional*): + The length of the clip text embedding. + text_len_t5 (`int`, *optional*): + The length of the T5 text embedding. + """ + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + patch_size: Optional[int] = None, + + n_query=16, + projection_dim=768, + activation_fn: str = "gelu-approximate", + sample_size=32, + hidden_size=1152, + num_layers: int = 28, + mlp_ratio: float = 4.0, + learn_sigma: bool = True, + cross_attention_dim: int = 1024, + norm_type: str = "layer_norm", + cross_attention_dim_t5: int = 2048, + pooled_projection_dim: int = 1024, + text_len: int = 77, + text_len_t5: int = 256, + + # block type + basic_block_type: str = "basic", + + time_position_encoding = False, + time_position_encoding_type: str = "2d_rope", + after_norm = False, + resize_inpaint_mask_directly: bool = False, + enable_clip_in_inpaint: bool = True, + enable_text_attention_mask: bool = True, + add_noise_in_inpaint_model: bool = False, + ): + super().__init__() + # 4. Define output layers + if learn_sigma: + self.out_channels = in_channels * 2 if out_channels is None else out_channels + else: + self.out_channels = in_channels if out_channels is None else out_channels + self.enable_inpaint = in_channels * 2 != self.out_channels if learn_sigma else in_channels != self.out_channels + self.num_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.basic_block_type = basic_block_type + self.resize_inpaint_mask_directly = resize_inpaint_mask_directly + self.text_embedder = PixArtAlphaTextProjection( + in_features=cross_attention_dim_t5, + hidden_size=cross_attention_dim_t5 * 4, + out_features=cross_attention_dim, + act_fn="silu_fp32", + ) + + self.text_embedding_padding = nn.Parameter( + torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) + ) + + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + in_channels=in_channels, + embed_dim=hidden_size, + patch_size=patch_size, + pos_embed_type=None, + ) + + self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( + hidden_size, + pooled_projection_dim=pooled_projection_dim, + seq_len=text_len_t5, + cross_attention_dim=cross_attention_dim_t5, + ) + + # 3. Define transformers blocks + if self.basic_block_type == "hybrid_attention": + self.blocks = nn.ModuleList( + [ + HunyuanDiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + activation_fn=activation_fn, + ff_inner_dim=int(self.inner_dim * mlp_ratio), + cross_attention_dim=cross_attention_dim, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + skip=layer > num_layers // 2, + after_norm=after_norm, + time_position_encoding=time_position_encoding, + is_local_attention=False if layer % 2 == 0 else True, + local_attention_frames=2, + enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint, + ) + for layer in range(num_layers) + ] + ) + elif self.basic_block_type == "kvcompression_basic": + self.blocks = nn.ModuleList( + [ + HunyuanDiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + activation_fn=activation_fn, + ff_inner_dim=int(self.inner_dim * mlp_ratio), + cross_attention_dim=cross_attention_dim, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + skip=layer > num_layers // 2, + after_norm=after_norm, + time_position_encoding=time_position_encoding, + kvcompression=False if layer < num_layers // 2 else True, + enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint, + ) + for layer in range(num_layers) + ] + ) + else: + self.blocks = nn.ModuleList( + [ + HunyuanDiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + activation_fn=activation_fn, + ff_inner_dim=int(self.inner_dim * mlp_ratio), + cross_attention_dim=cross_attention_dim, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + skip=layer > num_layers // 2, + after_norm=after_norm, + time_position_encoding=time_position_encoding, + enable_inpaint=self.enable_inpaint and enable_clip_in_inpaint, + ) + for layer in range(num_layers) + ] + ) + + self.n_query = n_query + if self.enable_inpaint and enable_clip_in_inpaint: + self.clip_padding = nn.Parameter( + torch.randn((self.n_query, cross_attention_dim)) * 0.02 + ) + self.clip_projection = Resampler( + int(math.sqrt(n_query)), + embed_dim=cross_attention_dim, + num_heads=self.config.num_attention_heads, + kv_dim=projection_dim, + norm_layer=nn.LayerNorm, + ) + else: + self.clip_padding = None + self.clip_projection = None + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states, + timestep, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + image_rotary_emb=None, + inpaint_latents=None, + control_latents: torch.Tensor = None, + clip_encoder_hidden_states: Optional[torch.Tensor]=None, + clip_attention_mask: Optional[torch.Tensor]=None, + return_dict=True, + ): + """ + The [`HunyuanDiT2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): + The input tensor. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of `BertModel`. + text_embedding_mask: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of `BertModel`. + encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. + text_embedding_mask_t5: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of T5 Text Encoder. + image_meta_size (torch.Tensor): + Conditional embedding indicate the image sizes + style: torch.Tensor: + Conditional embedding indicate the style + image_rotary_emb (`torch.Tensor`): + The image rotary embeddings to apply on query and key tensors during attention calculation. + return_dict: bool + Whether to return a dictionary. + """ + if inpaint_latents is not None: + hidden_states = torch.concat([hidden_states, inpaint_latents], 1) + if control_latents is not None: + hidden_states = torch.concat([hidden_states, control_latents], 1) + + # unpatchify: (N, out_channels, H, W) + patch_size = self.pos_embed.patch_size + video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // patch_size, hidden_states.shape[-1] // patch_size + hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w") + hidden_states = self.pos_embed(hidden_states) + hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb = self.time_extra_emb( + timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype + ) # [B, D] + + # text projection + batch_size, sequence_length, _ = encoder_hidden_states_t5.shape + encoder_hidden_states_t5 = self.text_embedder( + encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) + ) + encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) + + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) + text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) + text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() + + encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) + + if clip_encoder_hidden_states is not None: + batch_size = encoder_hidden_states.shape[0] + + clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states) + clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) + + clip_attention_mask = clip_attention_mask.unsqueeze(2).bool() + clip_encoder_hidden_states = torch.where(clip_attention_mask, clip_encoder_hidden_states, self.clip_padding) + + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.config.num_layers // 2: + skip = skips.pop() + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + args = { + "kvcompression_basic": [video_length, height, width, clip_encoder_hidden_states], + "basic": [video_length, height, width, clip_encoder_hidden_states], + "hybrid_attention": [video_length, height, width, clip_encoder_hidden_states], + }[self.basic_block_type] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + skip, + *args, + **ckpt_kwargs, + ) + else: + kwargs = { + "kvcompression_basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states}, + "basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states}, + "hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states}, + }[self.basic_block_type] + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + skip=skip, + **kwargs + ) # (N, L, D) + else: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + args = { + "kvcompression_basic": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False], + "basic": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False], + "hybrid_attention": [None, video_length, height, width, clip_encoder_hidden_states, True if layer==0 else False], + }[self.basic_block_type] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + *args, + **ckpt_kwargs, + ) + else: + kwargs = { + "kvcompression_basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states}, + "basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states}, + "hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states}, + }[self.basic_block_type] + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + disable_image_rotary_emb_in_attn1=True if layer==0 else False, + **kwargs + ) # (N, L, D) + + if layer < (self.config.num_layers // 2 - 1): + skips.append(hidden_states) + + # final layer + hidden_states = self.norm_out(hidden_states, temb.to(torch.float32)) + hidden_states = self.proj_out(hidden_states) + # (N, L, patch_size ** 2 * out_channels) + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], video_length, height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, video_length, height * patch_size, width * patch_size) + ) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config, **transformer_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + if os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size(): + new_shape = model.state_dict()['pos_embed.proj.weight'].size() + if len(new_shape) == 5: + state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone() + state_dict['pos_embed.proj.weight'][:, :, :-1] = 0 + else: + if model.state_dict()['pos_embed.proj.weight'].size()[1] > state_dict['pos_embed.proj.weight'].size()[1]: + model.state_dict()['pos_embed.proj.weight'][:, :state_dict['pos_embed.proj.weight'].size()[1], :, :] = state_dict['pos_embed.proj.weight'] + model.state_dict()['pos_embed.proj.weight'][:, state_dict['pos_embed.proj.weight'].size()[1]:, :, :] = 0 + state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight'] + else: + model.state_dict()['pos_embed.proj.weight'][:, :, :, :] = state_dict['pos_embed.proj.weight'][:, :model.state_dict()['pos_embed.proj.weight'].size()[1], :, :] + state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight'] + + if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size(): + if model.state_dict()['proj_out.weight'].size()[0] > state_dict['proj_out.weight'].size()[0]: + model.state_dict()['proj_out.weight'][:state_dict['proj_out.weight'].size()[0], :] = state_dict['proj_out.weight'] + state_dict['proj_out.weight'] = model.state_dict()['proj_out.weight'] + else: + model.state_dict()['proj_out.weight'][:, :] = state_dict['proj_out.weight'][:model.state_dict()['proj_out.weight'].size()[0], :] + state_dict['proj_out.weight'] = model.state_dict()['proj_out.weight'] + + if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size(): + if model.state_dict()['proj_out.bias'].size()[0] > state_dict['proj_out.bias'].size()[0]: + model.state_dict()['proj_out.bias'][:state_dict['proj_out.bias'].size()[0]] = state_dict['proj_out.bias'] + state_dict['proj_out.bias'] = model.state_dict()['proj_out.bias'] + else: + model.state_dict()['proj_out.bias'][:, :] = state_dict['proj_out.bias'][:model.state_dict()['proj_out.bias'].size()[0], :] + state_dict['proj_out.bias'] = model.state_dict()['proj_out.bias'] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()] + print(f"### Mamba Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + + return model + +class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + patch_size: Optional[int] = None, + sample_width: int = 90, + sample_height: int = 60, + ref_channels: int = None, + clip_channels: int = None, + + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + freq_shift: int = 0, + num_layers: int = 30, + dropout: float = 0.0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + text_embed_dim_t5: int = 4096, + norm_eps: float = 1e-5, + + norm_elementwise_affine: bool = True, + flip_sin_to_cos: bool = True, + + time_position_encoding_type: str = "3d_rope", + after_norm = False, + resize_inpaint_mask_directly: bool = False, + enable_clip_in_inpaint: bool = True, + enable_text_attention_mask: bool = True, + add_noise_in_inpaint_model: bool = False, + ): + super().__init__() + self.num_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.resize_inpaint_mask_directly = resize_inpaint_mask_directly + self.patch_size = patch_size + + post_patch_height = sample_height // patch_size + post_patch_width = sample_width // patch_size + self.post_patch_height = post_patch_height + self.post_patch_width = post_patch_width + + self.time_proj = Timesteps(self.inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(self.inner_dim, time_embed_dim, timestep_activation_fn) + + self.proj = nn.Conv2d( + in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True + ) + self.text_proj = nn.Linear(text_embed_dim, self.inner_dim) + self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim) + + if ref_channels is not None: + self.ref_proj = nn.Conv2d( + ref_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True + ) + ref_pos_embedding = get_2d_sincos_pos_embed(self.inner_dim, (post_patch_height, post_patch_width)) + ref_pos_embedding = torch.from_numpy(ref_pos_embedding) + self.register_buffer("ref_pos_embedding", ref_pos_embedding, persistent=False) + + if clip_channels is not None: + self.clip_proj = nn.Linear(clip_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + EasyAnimateDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + after_norm=after_norm + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine) + + # 5. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * self.inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states, + timestep, + timestep_cond = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + text_embedding_mask: Optional[torch.Tensor] = None, + encoder_hidden_states_t5: Optional[torch.Tensor] = None, + text_embedding_mask_t5: Optional[torch.Tensor] = None, + image_meta_size = None, + style = None, + image_rotary_emb: Optional[torch.Tensor] = None, + inpaint_latents: Optional[torch.Tensor] = None, + control_latents: Optional[torch.Tensor] = None, + ref_latents: Optional[torch.Tensor] = None, + clip_encoder_hidden_states: Optional[torch.Tensor] = None, + clip_attention_mask: Optional[torch.Tensor] = None, + return_dict=True, + ): + batch_size, channels, video_length, height, width = hidden_states.size() + + # 1. Time embedding + temb = self.time_proj(timestep).to(dtype=hidden_states.dtype) + temb = self.time_embedding(temb, timestep_cond) + + # 2. Patch embedding + if inpaint_latents is not None: + hidden_states = torch.concat([hidden_states, inpaint_latents], 1) + if control_latents is not None: + hidden_states = torch.concat([hidden_states, control_latents], 1) + + hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w") + hidden_states = self.proj(hidden_states) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length, h=height // self.patch_size, w=width // self.patch_size) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + encoder_hidden_states = self.text_proj(encoder_hidden_states) + if encoder_hidden_states_t5 is not None: + encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5) + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous() + + if ref_latents is not None: + ref_batch, ref_channels, ref_video_length, ref_height, ref_width = ref_latents.shape + ref_latents = rearrange(ref_latents, "b c f h w ->(b f) c h w") + ref_latents = self.ref_proj(ref_latents) + ref_latents = rearrange(ref_latents, "(b f) c h w -> b c f h w", f=ref_video_length, h=ref_height // self.patch_size, w=ref_width // self.patch_size) + ref_latents = ref_latents.flatten(2).transpose(1, 2) + + emb_size = hidden_states.size()[-1] + ref_pos_embedding = self.ref_pos_embedding + ref_pos_embedding_interpolate = ref_pos_embedding.view(1, 1, self.post_patch_height, self.post_patch_width, emb_size).permute([0, 4, 1, 2, 3]) + ref_pos_embedding_interpolate = F.interpolate( + ref_pos_embedding_interpolate, + size=[1, height // self.config.patch_size, width // self.config.patch_size], + mode='trilinear', align_corners=False + ) + ref_pos_embedding_interpolate = ref_pos_embedding_interpolate.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size) + ref_latents = ref_latents + ref_pos_embedding_interpolate + + encoder_hidden_states = ref_latents + + if clip_encoder_hidden_states is not None: + clip_encoder_hidden_states = self.clip_proj(clip_encoder_hidden_states) + + encoder_hidden_states = torch.concat([clip_encoder_hidden_states, ref_latents], dim=1) + + # 4. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, encoder_hidden_states.size()[1]:] + + # 5. Final block + hidden_states = self.norm_out(hidden_states, temb=temb) + hidden_states = self.proj_out(hidden_states) + + # 6. Unpatchify + p = self.config.patch_size + output = hidden_states.reshape(batch_size, video_length, height // p, width // p, channels, p, p) + output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config, **transformer_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file_safetensors = model_file.replace(".bin", ".safetensors") + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + else: + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for model_file_safetensors in model_files_safetensors: + _state_dict = load_file(model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] + + if model.state_dict()['proj.weight'].size() != state_dict['proj.weight'].size(): + new_shape = model.state_dict()['proj.weight'].size() + if len(new_shape) == 5: + state_dict['proj.weight'] = state_dict['proj.weight'].unsqueeze(2).expand(new_shape).clone() + state_dict['proj.weight'][:, :, :-1] = 0 + else: + if model.state_dict()['proj.weight'].size()[1] > state_dict['proj.weight'].size()[1]: + model.state_dict()['proj.weight'][:, :state_dict['proj.weight'].size()[1], :, :] = state_dict['proj.weight'] + model.state_dict()['proj.weight'][:, state_dict['proj.weight'].size()[1]:, :, :] = 0 + state_dict['proj.weight'] = model.state_dict()['proj.weight'] + else: + model.state_dict()['proj.weight'][:, :, :, :] = state_dict['proj.weight'][:, :model.state_dict()['proj.weight'].size()[1], :, :] + state_dict['proj.weight'] = model.state_dict()['proj.weight'] + + tmp_state_dict = {} + for key in state_dict: + if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): + tmp_state_dict[key] = state_dict[key] + else: + print(key, "Size don't match, skip") + + state_dict = tmp_state_dict + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + print(m) + + params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] + print(f"### All Parameters: {sum(params) / 1e6} M") + + params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] + print(f"### attn1 Parameters: {sum(params) / 1e6} M") + return model \ No newline at end of file diff --git a/easyanimate/pipeline/pipeline_easyanimate.py b/easyanimate/pipeline/pipeline_easyanimate.py index ef53107776a5bafdcbd98dce1b1b30909d22ef57..9d0ff0091a436b96f116c073176bfc0d15b5c987 100644 --- a/easyanimate/pipeline/pipeline_easyanimate.py +++ b/easyanimate/pipeline/pipeline_easyanimate.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import html import inspect -import copy import re import urllib.parse as ul from dataclasses import dataclass @@ -154,7 +154,8 @@ class EasyAnimatePipeline(DiffusionPipeline): self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - + self.enable_autocast_float8_transformer_flag = False + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py def mask_text_embeddings(self, emb, mask): if emb.shape[0] == 1: @@ -548,31 +549,13 @@ class EasyAnimatePipeline(DiffusionPipeline): prefix_index_before = mini_batch_encoder // 2 prefix_index_after = mini_batch_encoder - prefix_index_before pixel_values = video[:, :, prefix_index_before:-prefix_index_after] - - if self.vae.slice_compression_vae: - latents = self.vae.encode(pixel_values)[0] - latents = latents.sample() - else: - new_pixel_values = [] - for i in range(0, pixel_values.shape[2], mini_batch_encoder): - with torch.no_grad(): - pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :] - pixel_values_bs = self.vae.encode(pixel_values_bs)[0] - pixel_values_bs = pixel_values_bs.sample() - new_pixel_values.append(pixel_values_bs) - latents = torch.cat(new_pixel_values, dim = 2) - - if self.vae.slice_compression_vae: - middle_video = self.vae.decode(latents)[0] - else: - middle_video = [] - for i in range(0, latents.shape[2], mini_batch_decoder): - with torch.no_grad(): - start_index = i - end_index = i + mini_batch_decoder - latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0] - middle_video.append(latents_bs) - middle_video = torch.cat(middle_video, 2) + + # Encode middle videos + latents = self.vae.encode(pixel_values)[0] + latents = latents.mode() + # Decode middle videos + middle_video = self.vae.decode(latents)[0] + video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2 return video @@ -582,17 +565,7 @@ class EasyAnimatePipeline(DiffusionPipeline): if self.vae.quant_conv.weight.ndim==5: mini_batch_encoder = self.vae.mini_batch_encoder mini_batch_decoder = self.vae.mini_batch_decoder - if self.vae.slice_compression_vae: - video = self.vae.decode(latents)[0] - else: - video = [] - for i in range(0, latents.shape[2], mini_batch_decoder): - with torch.no_grad(): - start_index = i - end_index = i + mini_batch_decoder - latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0] - video.append(latents_bs) - video = torch.cat(video, 2) + video = self.vae.decode(latents)[0] video = video.clamp(-1, 1) video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) else: @@ -607,6 +580,9 @@ class EasyAnimatePipeline(DiffusionPipeline): video = video.cpu().float().numpy() return video + def enable_autocast_float8_transformer(self): + self.enable_autocast_float8_transformer_flag = True + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -633,6 +609,7 @@ class EasyAnimatePipeline(DiffusionPipeline): callback_steps: int = 1, clean_caption: bool = True, max_sequence_length: int = 120, + comfyui_progressbar: bool = False, **kwargs, ) -> Union[EasyAnimatePipelineOutput, Tuple]: """ @@ -780,9 +757,16 @@ class EasyAnimatePipeline(DiffusionPipeline): added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + torch.cuda.empty_cache() + if self.enable_autocast_float8_transformer_flag: + origin_weight_dtype = self.transformer.dtype + self.transformer = self.transformer.to(torch.float8_e4m3fn) + # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -834,6 +818,12 @@ class EasyAnimatePipeline(DiffusionPipeline): step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if comfyui_progressbar: + pbar.update(1) + + if self.enable_autocast_float8_transformer_flag: + self.transformer = self.transformer.to("cpu", origin_weight_dtype) + # Post-processing video = self.decode_latents(latents) diff --git a/easyanimate/pipeline/pipeline_easyanimate_inpaint.py b/easyanimate/pipeline/pipeline_easyanimate_inpaint.py index cf80f918888fd4b3c10222c8e307b107a33a94cd..340b977479d0b42f1b3980188b2ae7dbb8208c02 100644 --- a/easyanimate/pipeline/pipeline_easyanimate_inpaint.py +++ b/easyanimate/pipeline/pipeline_easyanimate_inpaint.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import gc import html import inspect import re -import gc -import copy import urllib.parse as ul from dataclasses import dataclass -from PIL import Image from typing import Callable, List, Optional, Tuple, Union import numpy as np @@ -34,9 +33,10 @@ from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate, replace_example_docstring) from diffusers.utils.torch_utils import randn_tensor from einops import rearrange +from PIL import Image from tqdm import tqdm -from transformers import T5EncoderModel, T5Tokenizer -from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor +from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection, + T5EncoderModel, T5Tokenizer) from ..models.transformer3d import Transformer3DModel @@ -129,6 +129,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) + self.enable_autocast_float8_transformer_flag = False # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py def mask_text_embeddings(self, emb, mask): @@ -493,6 +494,60 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): return caption.strip() + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + video_length = mask.shape[2] + + mask = mask.to(device=device, dtype=self.vae.dtype) + if self.vae.quant_conv.weight.ndim==5: + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.sample() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + mask = mask * self.vae.config.scaling_factor + + else: + if mask.shape[1] == 4: + mask = mask + else: + video_length = mask.shape[2] + mask = rearrange(mask, "b c f h w -> (b f) c h w") + mask = self._encode_vae_image(mask, generator=generator) + mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length) + + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + if self.vae.quant_conv.weight.ndim==5: + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.sample() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + + else: + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + video_length = mask.shape[2] + masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w") + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + def prepare_latents( self, batch_size, @@ -529,22 +584,11 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): bs = 1 mini_batch_encoder = self.vae.mini_batch_encoder new_video = [] - if self.vae.slice_compression_vae: - for i in range(0, video.shape[0], bs): - video_bs = video[i : i + bs] - video_bs = self.vae.encode(video_bs)[0] - video_bs = video_bs.sample() - new_video.append(video_bs) - else: - for i in range(0, video.shape[0], bs): - new_video_mini_batch = [] - for j in range(0, video.shape[2], mini_batch_encoder): - video_bs = video[i : i + bs, :, j: j + mini_batch_encoder, :, :] - video_bs = self.vae.encode(video_bs)[0] - video_bs = video_bs.sample() - new_video_mini_batch.append(video_bs) - new_video_mini_batch = torch.cat(new_video_mini_batch, dim = 2) - new_video.append(new_video_mini_batch) + for i in range(0, video.shape[0], bs): + video_bs = video[i : i + bs] + video_bs = self.vae.encode(video_bs)[0] + video_bs = video_bs.sample() + new_video.append(video_bs) video = torch.cat(new_video, dim = 0) video = video * self.vae.config.scaling_factor @@ -585,31 +629,13 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): prefix_index_before = mini_batch_encoder // 2 prefix_index_after = mini_batch_encoder - prefix_index_before pixel_values = video[:, :, prefix_index_before:-prefix_index_after] - - if self.vae.slice_compression_vae: - latents = self.vae.encode(pixel_values)[0] - latents = latents.sample() - else: - new_pixel_values = [] - for i in range(0, pixel_values.shape[2], mini_batch_encoder): - with torch.no_grad(): - pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :] - pixel_values_bs = self.vae.encode(pixel_values_bs)[0] - pixel_values_bs = pixel_values_bs.sample() - new_pixel_values.append(pixel_values_bs) - latents = torch.cat(new_pixel_values, dim = 2) - - if self.vae.slice_compression_vae: - middle_video = self.vae.decode(latents)[0] - else: - middle_video = [] - for i in range(0, latents.shape[2], mini_batch_decoder): - with torch.no_grad(): - start_index = i - end_index = i + mini_batch_decoder - latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0] - middle_video.append(latents_bs) - middle_video = torch.cat(middle_video, 2) + + # Encode middle videos + latents = self.vae.encode(pixel_values)[0] + latents = latents.sample() + # Decode middle videos + middle_video = self.vae.decode(latents)[0] + video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2 return video @@ -619,17 +645,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): if self.vae.quant_conv.weight.ndim==5: mini_batch_encoder = self.vae.mini_batch_encoder mini_batch_decoder = self.vae.mini_batch_decoder - if self.vae.slice_compression_vae: - video = self.vae.decode(latents)[0] - else: - video = [] - for i in range(0, latents.shape[2], mini_batch_decoder): - with torch.no_grad(): - start_index = i - end_index = i + mini_batch_decoder - latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0] - video.append(latents_bs) - video = torch.cat(video, 2) + video = self.vae.decode(latents)[0] video = video.clamp(-1, 1) video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) else: @@ -668,84 +684,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): return timesteps, num_inference_steps - t_start - def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - video_length = mask.shape[2] - - mask = mask.to(device=device, dtype=self.vae.dtype) - if self.vae.quant_conv.weight.ndim==5: - bs = 1 - mini_batch_encoder = self.vae.mini_batch_encoder - new_mask = [] - if self.vae.slice_compression_vae: - for i in range(0, mask.shape[0], bs): - mask_bs = mask[i : i + bs] - mask_bs = self.vae.encode(mask_bs)[0] - mask_bs = mask_bs.sample() - new_mask.append(mask_bs) - else: - for i in range(0, mask.shape[0], bs): - new_mask_mini_batch = [] - for j in range(0, mask.shape[2], mini_batch_encoder): - mask_bs = mask[i : i + bs, :, j: j + mini_batch_encoder, :, :] - mask_bs = self.vae.encode(mask_bs)[0] - mask_bs = mask_bs.sample() - new_mask_mini_batch.append(mask_bs) - new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2) - new_mask.append(new_mask_mini_batch) - mask = torch.cat(new_mask, dim = 0) - mask = mask * self.vae.config.scaling_factor - - else: - if mask.shape[1] == 4: - mask = mask - else: - video_length = mask.shape[2] - mask = rearrange(mask, "b c f h w -> (b f) c h w") - mask = self._encode_vae_image(mask, generator=generator) - mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length) - - masked_image = masked_image.to(device=device, dtype=self.vae.dtype) - if self.vae.quant_conv.weight.ndim==5: - bs = 1 - mini_batch_encoder = self.vae.mini_batch_encoder - new_mask_pixel_values = [] - if self.vae.slice_compression_vae: - for i in range(0, masked_image.shape[0], bs): - mask_pixel_values_bs = masked_image[i : i + bs] - mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] - mask_pixel_values_bs = mask_pixel_values_bs.sample() - new_mask_pixel_values.append(mask_pixel_values_bs) - else: - for i in range(0, masked_image.shape[0], bs): - new_mask_pixel_values_mini_batch = [] - for j in range(0, masked_image.shape[2], mini_batch_encoder): - mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch_encoder, :, :] - mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] - mask_pixel_values_bs = mask_pixel_values_bs.sample() - new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs) - new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2) - new_mask_pixel_values.append(new_mask_pixel_values_mini_batch) - masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) - masked_image_latents = masked_image_latents * self.vae.config.scaling_factor - - else: - if masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - video_length = mask.shape[2] - masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w") - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) - masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length) + def enable_autocast_float8_transformer(self): + self.enable_autocast_float8_transformer_flag = True - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - return mask, masked_image_latents - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -779,6 +720,8 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): max_sequence_length: int = 120, clip_image: Image = None, clip_apply_ratio: float = 0.50, + comfyui_progressbar: bool = False, + **kwargs, ) -> Union[EasyAnimatePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -1057,10 +1000,16 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() + if self.enable_autocast_float8_transformer_flag: + origin_weight_dtype = self.transformer.dtype + self.transformer = self.transformer.to(torch.float8_e4m3fn) # 10. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -1130,16 +1079,19 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if comfyui_progressbar: + pbar.update(1) + + if self.enable_autocast_float8_transformer_flag: + self.transformer = self.transformer.to("cpu", origin_weight_dtype) + gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() # Post-processing video = self.decode_latents(latents) - - gc.collect() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + # Convert to tensor if output_type == "latent": video = torch.from_numpy(video) diff --git a/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py b/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8adde44bfb5b01b12b9e08c0ef567253e89ce5d5 --- /dev/null +++ b/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py @@ -0,0 +1,925 @@ +# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import (get_2d_rotary_pos_embed, + get_3d_rotary_pos_embed) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import \ + StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import (is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange +from tqdm import tqdm +from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, + T5Tokenizer, T5EncoderModel) + +from .pipeline_easyanimate import EasyAnimatePipelineOutput +from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> pass + ``` +""" + + +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class EasyAnimatePipeline_Multi_Text_Encoder(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using EasyAnimate. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + HunyuanDiT team) + + Args: + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + EasyAnimate uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKLMagvit, + text_encoder: BertModel, + tokenizer: BertTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5Tokenizer, + transformer: EasyAnimateTransformer3DModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.enable_autocast_float8_transformer_flag = False + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_sequential_cpu_offload(self, *args, **kwargs): + super().enable_sequential_cpu_offload(*args, **kwargs) + if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None: + import accelerate + accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True) + self.transformer.clip_projection = self.transformer.clip_projection.to("cuda") + + def encode_prompt( + self, + prompt: str, + device: torch.device, + dtype: torch.dtype, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + actual_max_sequence_length: int = 256 + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length) + if text_encoder_index == 1: + max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length) + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + if text_input_ids.shape[-1] > actual_max_sequence_length: + reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + text_inputs = tokenizer( + reprompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length) + removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {_actual_max_sequence_length} tokens: {removed_text}" + ) + prompt_attention_mask = text_inputs.attention_mask.to(device) + + if self.transformer.config.enable_text_attention_mask: + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + else: + prompt_embeds = text_encoder( + text_input_ids.to(device) + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + if uncond_input_ids.shape[-1] > actual_max_sequence_length: + reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + uncond_input = tokenizer( + reuncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + if self.transformer.config.enable_text_attention_mask: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + else: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device) + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + if self.vae.cache_mag_vae: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + else: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + else: + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder): + if video.size()[2] <= mini_batch_encoder: + return video + prefix_index_before = mini_batch_encoder // 2 + prefix_index_after = mini_batch_encoder - prefix_index_before + pixel_values = video[:, :, prefix_index_before:-prefix_index_after] + + # Encode middle videos + latents = self.vae.encode(pixel_values)[0] + latents = latents.mode() + # Decode middle videos + middle_video = self.vae.decode(latents)[0] + + video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2 + return video + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / self.vae.config.scaling_factor * latents + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + video = self.vae.decode(latents)[0] + video = video.clamp(-1, 1) + if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae: + video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) + else: + latents = rearrange(latents, "b c f h w -> (b f) c h w") + video = [] + for frame_idx in tqdm(range(latents.shape[0])): + video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def enable_autocast_float8_transformer(self): + self.enable_autocast_float8_transformer_flag = True + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + video_length: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "latent", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + comfyui_progressbar: bool = False, + ): + r""" + Generates images or video using the EasyAnimate pipeline based on the provided prompts. + + Examples: + prompt (`str` or `List[str]`, *optional*): + Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead. + video_length (`int`, *optional*): + Length of the generated video (in frames). + height (`int`, *optional*): + Height of the generated image in pixels. + width (`int`, *optional*): + Width of the generated image in pixels. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + Encourages the model to align outputs with prompts. A higher value may decrease image quality. + negative_prompt (`str` or `List[str]`, *optional*): + Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate for each prompt. + eta (`float`, *optional*, defaults to 0.0): + Applies to DDIM scheduling. Controlled by the eta parameter from the related literature. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A generator to ensure reproducibility in image generation. + latents (`torch.Tensor`, *optional*): + Predefined latent tensors to condition generation. + prompt_embeds (`torch.Tensor`, *optional*): + Text embeddings for the prompts. Overrides prompt string inputs for more flexibility. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary text embeddings to supplement or replace the initial prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Embeddings for negative prompts. Overrides string inputs if defined. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the primary prompt embeddings. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the secondary prompt embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for negative prompt embeddings. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for secondary negative prompt embeddings. + output_type (`str`, *optional*, defaults to "latent"): + Format of the generated output, either as a PIL image or as a NumPy array. + return_dict (`bool`, *optional*, defaults to `True`): + If `True`, returns a structured output. Otherwise returns a simple tuple. + callback_on_step_end (`Callable`, *optional*): + Functions called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor names to be included in callback function calls. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Adjusts noise levels based on guidance scale. + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + Original dimensions of the output. + target_size (`Tuple[int, int]`, *optional*): + Desired output dimensions for calculations. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + Coordinates for cropping. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + text_encoder_index=1, + ) + torch.cuda.empty_cache() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + video_length, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope": + base_size_width = 720 // 8 // self.transformer.config.patch_size + base_size_height = 480 // 8 // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + image_rotary_emb = get_3d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width), + temporal_size=latents.size(2), use_real=True, + ) + else: + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size, base_size + ) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width) + ) + + # Get other hunyuan params + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + # To latents.device + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + torch.cuda.empty_cache() + if self.enable_autocast_float8_transformer_flag: + origin_weight_dtype = self.transformer.dtype + self.transformer = self.transformer.to(torch.float8_e4m3fn) + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if comfyui_progressbar: + pbar.update(1) + + if self.enable_autocast_float8_transformer_flag: + self.transformer = self.transformer.to("cpu", origin_weight_dtype) + + torch.cuda.empty_cache() + # Post-processing + video = self.decode_latents(latents) + + # Convert to tensor + if output_type == "latent": + video = torch.from_numpy(video) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return video + + return EasyAnimatePipelineOutput(videos=video) \ No newline at end of file diff --git a/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py b/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py new file mode 100644 index 0000000000000000000000000000000000000000..b23502e8db80260d6824ba7d98e606cb62d252c0 --- /dev/null +++ b/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py @@ -0,0 +1,996 @@ +# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +import urllib.parse as ul +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers import DiffusionPipeline, ImagePipelineOutput +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, HunyuanDiT2DModel +from diffusers.models.embeddings import (get_2d_rotary_pos_embed, + get_3d_rotary_pos_embed) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import \ + StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler +from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate, + is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange +from PIL import Image +from tqdm import tqdm +from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, + CLIPVisionModelWithProjection, + T5EncoderModel, T5Tokenizer) + +from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel +from .pipeline_easyanimate import EasyAnimatePipelineOutput + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> pass + ``` +""" + +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using EasyAnimate. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + HunyuanDiT team) + + Args: + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + EasyAnimate uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKLMagvit, + text_encoder: BertModel, + tokenizer: BertTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5Tokenizer, + transformer: EasyAnimateTransformer3DModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2 + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.enable_autocast_float8_transformer_flag = False + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_sequential_cpu_offload(self, *args, **kwargs): + super().enable_sequential_cpu_offload(*args, **kwargs) + if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None: + import accelerate + accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True) + self.transformer.clip_projection = self.transformer.clip_projection.to("cuda") + + def encode_prompt( + self, + prompt: str, + device: torch.device, + dtype: torch.dtype, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + actual_max_sequence_length: int = 256 + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length) + if text_encoder_index == 1: + max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length) + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + if text_input_ids.shape[-1] > actual_max_sequence_length: + reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + text_inputs = tokenizer( + reprompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length) + removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {_actual_max_sequence_length} tokens: {removed_text}" + ) + prompt_attention_mask = text_inputs.attention_mask.to(device) + if self.transformer.config.enable_text_attention_mask: + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + else: + prompt_embeds = text_encoder( + text_input_ids.to(device) + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + if uncond_input_ids.shape[-1] > actual_max_sequence_length: + reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + uncond_input = tokenizer( + reuncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + if self.transformer.config.enable_text_attention_mask: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + else: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device) + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + if self.vae.cache_mag_vae: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + else: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + else: + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder): + if video.size()[2] <= mini_batch_encoder: + return video + prefix_index_before = mini_batch_encoder // 2 + prefix_index_after = mini_batch_encoder - prefix_index_before + pixel_values = video[:, :, prefix_index_before:-prefix_index_after] + + # Encode middle videos + latents = self.vae.encode(pixel_values)[0] + latents = latents.mode() + # Decode middle videos + middle_video = self.vae.decode(latents)[0] + + video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2 + return video + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / self.vae.config.scaling_factor * latents + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + video = self.vae.decode(latents)[0] + video = video.clamp(-1, 1) + if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae: + video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) + else: + latents = rearrange(latents, "b c f h w -> (b f) c h w") + video = [] + for frame_idx in tqdm(range(latents.shape[0])): + video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def enable_autocast_float8_transformer(self): + self.enable_autocast_float8_transformer_flag = True + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + video_length: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + control_video: Union[torch.FloatTensor] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "latent", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + comfyui_progressbar: bool = False, + ): + r""" + Generates images or video using the EasyAnimate pipeline based on the provided prompts. + + Examples: + prompt (`str` or `List[str]`, *optional*): + Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead. + video_length (`int`, *optional*): + Length of the generated video (in frames). + height (`int`, *optional*): + Height of the generated image in pixels. + width (`int`, *optional*): + Width of the generated image in pixels. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + Encourages the model to align outputs with prompts. A higher value may decrease image quality. + negative_prompt (`str` or `List[str]`, *optional*): + Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate for each prompt. + eta (`float`, *optional*, defaults to 0.0): + Applies to DDIM scheduling. Controlled by the eta parameter from the related literature. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A generator to ensure reproducibility in image generation. + latents (`torch.Tensor`, *optional*): + Predefined latent tensors to condition generation. + prompt_embeds (`torch.Tensor`, *optional*): + Text embeddings for the prompts. Overrides prompt string inputs for more flexibility. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary text embeddings to supplement or replace the initial prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Embeddings for negative prompts. Overrides string inputs if defined. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the primary prompt embeddings. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the secondary prompt embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for negative prompt embeddings. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for secondary negative prompt embeddings. + output_type (`str`, *optional*, defaults to "latent"): + Format of the generated output, either as a PIL image or as a NumPy array. + return_dict (`bool`, *optional*, defaults to `True`): + If `True`, returns a structured output. Otherwise returns a simple tuple. + callback_on_step_end (`Callable`, *optional*): + Functions called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor names to be included in callback function calls. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Adjusts noise levels based on guidance scale. + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + Original dimensions of the output. + target_size (`Tuple[int, int]`, *optional*): + Desired output dimensions for calculations. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + Coordinates for cropping. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + text_encoder_index=1, + ) + torch.cuda.empty_cache() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + video_length, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + if control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + else: + control_video = None + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance + )[1] + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope": + base_size_width = 720 // 8 // self.transformer.config.patch_size + base_size_height = 480 // 8 // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + image_rotary_emb = get_3d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width), + temporal_size=latents.size(2), use_real=True, + ) + else: + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size, base_size + ) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width) + ) + + # Get other hunyuan params + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + # To latents.device + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + torch.cuda.empty_cache() + if self.enable_autocast_float8_transformer_flag: + origin_weight_dtype = self.transformer.dtype + self.transformer = self.transformer.to(torch.float8_e4m3fn) + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + control_latents=control_latents, + )[0] + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if comfyui_progressbar: + pbar.update(1) + + if self.enable_autocast_float8_transformer_flag: + self.transformer = self.transformer.to("cpu", origin_weight_dtype) + + torch.cuda.empty_cache() + # Post-processing + video = self.decode_latents(latents) + + # Convert to tensor + if output_type == "latent": + video = torch.from_numpy(video) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return video + + return EasyAnimatePipelineOutput(videos=video) \ No newline at end of file diff --git a/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py b/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..2c45241bb5be766164ff5460500480d12d5c399c --- /dev/null +++ b/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py @@ -0,0 +1,1334 @@ +# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers import DiffusionPipeline +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, HunyuanDiT2DModel +from diffusers.models.embeddings import (get_2d_rotary_pos_embed, + get_3d_rotary_pos_embed) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import \ + StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import (is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange +from PIL import Image +from tqdm import tqdm +from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, + CLIPVisionModelWithProjection, T5Tokenizer, + T5EncoderModel) + +from .pipeline_easyanimate import EasyAnimatePipelineOutput +from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> pass + ``` +""" + + +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +def add_noise_to_reference_video(image, ratio=None): + if ratio is None: + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) + sigma = torch.exp(sigma).to(image.dtype) + else: + sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio + + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) + image = image + image_noise + return image + + +class EasyAnimatePipeline_Multi_Text_Encoder_Inpaint(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using EasyAnimate. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + HunyuanDiT team) + + Args: + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + EasyAnimate uses a fine-tuned [bilingual CLIP]. + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + clip_image_processor (`CLIPImageProcessor`): + The CLIP image embedder. + clip_image_encoder (`CLIPVisionModelWithProjection`): + The image processor for the CLIP image embedder. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->clip_image_encoder->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + "clip_image_encoder", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKLMagvit, + text_encoder: BertModel, + tokenizer: BertTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5Tokenizer, + transformer: EasyAnimateTransformer3DModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + clip_image_processor: CLIPImageProcessor = None, + clip_image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + clip_image_processor=clip_image_processor, + clip_image_encoder=clip_image_encoder, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.enable_autocast_float8_transformer_flag = False + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_sequential_cpu_offload(self, *args, **kwargs): + super().enable_sequential_cpu_offload(*args, **kwargs) + if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None: + import accelerate + accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True) + self.transformer.clip_projection = self.transformer.clip_projection.to("cuda") + + def encode_prompt( + self, + prompt: str, + device: torch.device, + dtype: torch.dtype, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + actual_max_sequence_length: int = 256 + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length) + if text_encoder_index == 1: + max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length) + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + if text_input_ids.shape[-1] > actual_max_sequence_length: + reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + text_inputs = tokenizer( + reprompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length) + removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {_actual_max_sequence_length} tokens: {removed_text}" + ) + prompt_attention_mask = text_inputs.attention_mask.to(device) + if self.transformer.config.enable_text_attention_mask: + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + else: + prompt_embeds = text_encoder( + text_input_ids.to(device) + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + if uncond_input_ids.shape[-1] > actual_max_sequence_length: + reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + uncond_input = tokenizer( + reuncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + if self.transformer.config.enable_text_attention_mask: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + else: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device) + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + mask = mask * self.vae.config.scaling_factor + + else: + if mask.shape[1] == 4: + mask = mask + else: + video_length = mask.shape[2] + mask = rearrange(mask, "b c f h w -> (b f) c h w") + mask = self._encode_vae_image(mask, generator=generator) + mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length) + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + if self.transformer.config.add_noise_in_inpaint_model: + masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength) + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + + else: + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + video_length = masked_image.shape[2] + masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w") + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + else: + masked_image_latents = None + + return mask, masked_image_latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + video_length, + dtype, + device, + generator, + latents=None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + if self.vae.cache_mag_vae: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + else: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + else: + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=self.vae.dtype) + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + bs = 1 + new_video = [] + for i in range(0, video.shape[0], bs): + video_bs = video[i : i + bs] + video_bs = self.vae.encode(video_bs)[0] + video_bs = video_bs.sample() + new_video.append(video_bs) + video = torch.cat(new_video, dim = 0) + video = video * self.vae.config.scaling_factor + + else: + if video.shape[1] == 4: + video = video + else: + video_length = video.shape[2] + video = rearrange(video, "b c f h w -> (b f) c h w") + video = self._encode_vae_image(video, generator=generator) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1) + video_latents = video_latents.to(device=device, dtype=dtype) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + # scale the initial noise by the standard deviation required by the scheduler + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder): + if video.size()[2] <= mini_batch_encoder: + return video + prefix_index_before = mini_batch_encoder // 2 + prefix_index_after = mini_batch_encoder - prefix_index_before + pixel_values = video[:, :, prefix_index_before:-prefix_index_after] + + # Encode middle videos + latents = self.vae.encode(pixel_values)[0] + latents = latents.mode() + # Decode middle videos + middle_video = self.vae.decode(latents)[0] + + video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2 + return video + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / self.vae.config.scaling_factor * latents + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + video = self.vae.decode(latents)[0] + video = video.clamp(-1, 1) + if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae: + video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) + else: + latents = rearrange(latents, "b c f h w -> (b f) c h w") + video = [] + for frame_idx in tqdm(range(latents.shape[0])): + video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def enable_autocast_float8_transformer(self): + self.enable_autocast_float8_transformer_flag = True + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + video_length: Optional[int] = None, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + masked_video_latents: Union[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "latent", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + clip_image: Image = None, + clip_apply_ratio: float = 0.40, + strength: float = 1.0, + noise_aug_strength: float = 0.0563, + comfyui_progressbar: bool = False, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Examples: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + video_length (`int`, *optional*): + Length of the video to be generated in seconds. This parameter influences the number of frames and + continuity of generated content. + video (`torch.FloatTensor`, *optional*): + A tensor representing an input video, which can be modified depending on the prompts provided. + mask_video (`torch.FloatTensor`, *optional*): + A tensor to specify areas of the video to be masked (omitted from generation). + masked_video_latents (`torch.FloatTensor`, *optional*): + Latents from masked portions of the video, utilized during image generation. + height (`int`, *optional*): + The height in pixels of the generated image or video frames. + width (`int`, *optional*): + The width in pixels of the generated image or video frames. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image but slower + inference time. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to exclude in image generation. If not defined, you need to + provide `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the + [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the + inference process. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting + random seeds which helps in making generation deterministic. + latents (`torch.Tensor`, *optional*): + A pre-computed latent representation which can be used to guide the generation process. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary set of pre-generated text embeddings, useful for advanced prompt weighting. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the outputs. + If not provided, embeddings are generated from the `negative_prompt` argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary set of pre-generated negative text embeddings for further control. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using + `prompt_embeds`. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the secondary prompt embedding. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the secondary negative prompt embedding. + output_type (`str`, *optional*, defaults to `"latent"`): + The output format of the generated image. Choose between `PIL.Image` and `np.array` to define + how you want the results to be formatted. + return_dict (`bool`, *optional*, defaults to `True`): + If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned; + otherwise, a tuple containing the generated images and safety flags will be returned. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function (or a list of them) that will be executed at the end of each denoising step, + allowing for custom processing during generation. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Specifies which tensor inputs should be included in the callback function. If not defined, all tensor + inputs will be passed, facilitating enhanced logging or monitoring of the generation process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original dimensions of the image. Used to compute time ids during the generation process. + target_size (`Tuple[int, int]`, *optional*): + The targeted dimensions of the generated image, also utilized in the time id calculations. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + Coordinates defining the top left corner of any cropping, utilized while calculating the time ids. + clip_image (`Image`, *optional*): + An optional image to assist in the generation process. It may be used as an additional visual cue. + clip_apply_ratio (`float`, *optional*, defaults to 0.40): + Ratio indicating how much influence the clip image should exert over the generated content. + strength (`float`, *optional*, defaults to 1.0): + Affects the overall styling or quality of the generated output. Values closer to 1 usually provide direct + adherence to prompts. + comfyui_progressbar (`bool`, *optional*, defaults to `False`): + Enables a progress bar in ComfyUI, providing visual feedback during the generation process. + + Examples: + # Example usage of the function for generating images based on prompts. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + Returns either a structured output containing generated images and their metadata when `return_dict` is + `True`, or a simpler tuple, where the first element is a list of generated images and the second + element indicates if any of them contain "not-safe-for-work" (NSFW) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int(height // 16 * 16) + width = int(width // 16 * 16) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + text_encoder_index=1, + ) + torch.cuda.empty_cache() + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 3) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + # Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_transformer = self.transformer.config.in_channels + return_image_latents = num_channels_transformer == num_channels_latents + + # 5. Prepare latents. + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + video_length, + prompt_embeds.dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_image_latents, + ) + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare clip latents if it needs. + if clip_image is not None and self.transformer.enable_clip_in_inpaint: + inputs = self.clip_image_processor(images=clip_image, return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype) + clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:] + clip_encoder_hidden_states_neg = torch.zeros( + [ + batch_size, + int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, + int(self.clip_image_encoder.config.hidden_size) + ] + ).to(latents.device, dtype=latents.dtype) + + clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype) + clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype) + + clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states + clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask + + elif clip_image is None and num_channels_transformer != num_channels_latents and self.transformer.enable_clip_in_inpaint: + clip_encoder_hidden_states = torch.zeros( + [ + batch_size, + int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, + int(self.clip_image_encoder.config.hidden_size) + ] + ).to(latents.device, dtype=latents.dtype) + + clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query]) + clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype) + + clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states + clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask + + else: + clip_encoder_hidden_states_input = None + clip_attention_mask_input = None + if comfyui_progressbar: + pbar.update(1) + + # 7. Prepare inpaint latents if it needs. + if mask_video is not None: + if (mask_video == 255).all(): + # Use zero latents if we want to t2v. + if self.transformer.resize_inpaint_mask_directly: + mask_latents = torch.zeros_like(latents)[:, :1].to(latents.device, latents.dtype) + else: + mask_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) + masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) + else: + # Prepare mask latent variables + video_length = video.shape[2] + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + if num_channels_transformer != num_channels_latents: + mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1]) + if masked_video_latents is None: + masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1 + else: + masked_video = masked_video_latents + + if self.transformer.resize_inpaint_mask_directly: + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, self.vae.cache_mag_vae) + mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor + else: + mask_latents, masked_video_latents = self.prepare_mask_latents( + mask_condition_tile, + masked_video, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) + else: + inpaint_latents = None + + mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) + else: + if num_channels_transformer != num_channels_latents: + mask = torch.zeros_like(latents).to(latents.device, latents.dtype) + masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) + + mask_input = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) + else: + mask = torch.zeros_like(init_video[:, :1]) + mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) + + inpaint_latents = None + if comfyui_progressbar: + pbar.update(1) + + # Check that sizes of mask, masked image and latents match + if num_channels_transformer != num_channels_latents: + num_channels_mask = mask_latents.shape[1] + num_channels_masked_image = masked_video_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" + f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.transformer` or your `mask_image` or `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope": + base_size_width = 720 // 8 // self.transformer.config.patch_size + base_size_height = 480 // 8 // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + image_rotary_emb = get_3d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width), + temporal_size=latents.size(2), use_real=True, + ) + else: + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size, base_size + ) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width) + ) + + # Get other hunyuan params + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + torch.cuda.empty_cache() + if self.enable_autocast_float8_transformer_flag: + origin_weight_dtype = self.transformer.dtype + self.transformer = self.transformer.to(torch.float8_e4m3fn) + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None: + clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input) + clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input) + else: + clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input + clip_attention_mask_actual_input = clip_attention_mask_input + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + inpaint_latents=inpaint_latents, + clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input, + clip_attention_mask=clip_attention_mask_actual_input, + return_dict=False, + )[0] + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_transformer == 4: + init_latents_proper = image_latents + init_mask = mask + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if comfyui_progressbar: + pbar.update(1) + + if self.enable_autocast_float8_transformer_flag: + self.transformer = self.transformer.to("cpu", origin_weight_dtype) + + torch.cuda.empty_cache() + # Post-processing + video = self.decode_latents(latents) + + # Convert to tensor + if output_type == "latent": + video = torch.from_numpy(video) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return video + + return EasyAnimatePipelineOutput(videos=video) \ No newline at end of file diff --git a/easyanimate/ui/ui.py b/easyanimate/ui/ui.py index 748b252c3be77f9e2eba3ef17589685bd5297388..2c96e052acfc5650ea062263113ce09bf8ba7698 100644 --- a/easyanimate/ui/ui.py +++ b/easyanimate/ui/ui.py @@ -8,6 +8,7 @@ import random from datetime import datetime from glob import glob +import cv2 import gradio as gr import numpy as np import pkg_resources @@ -21,19 +22,28 @@ from diffusers.utils.import_utils import is_xformers_available from omegaconf import OmegaConf from PIL import Image from safetensors import safe_open -from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection, +from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, + CLIPVisionModelWithProjection, T5Tokenizer, T5EncoderModel, T5Tokenizer) from easyanimate.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from easyanimate.models import (name_to_autoencoder_magvit, + name_to_transformer3d) from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit -from easyanimate.models.transformer3d import Transformer3DModel +from easyanimate.models.transformer3d import (HunyuanTransformer3DModel, + Transformer3DModel) from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline from easyanimate.pipeline.pipeline_easyanimate_inpaint import \ EasyAnimateInpaintPipeline +from easyanimate.pipeline.pipeline_easyanimate_multi_text_encoder import \ + EasyAnimatePipeline_Multi_Text_Encoder +from easyanimate.pipeline.pipeline_easyanimate_multi_text_encoder_inpaint import \ + EasyAnimatePipeline_Multi_Text_Encoder_Inpaint from easyanimate.utils.lora_utils import merge_lora, unmerge_lora from easyanimate.utils.utils import ( - get_image_to_video_latent, + get_image_to_video_latent, get_video_to_video_latent, get_width_and_height_from_image_and_base_resolution, save_videos_grid) +from easyanimate.utils.fp8_optimization import convert_weight_dtype_wrapper scheduler_dict = { "Euler": EulerDiscreteScheduler, @@ -56,7 +66,7 @@ css = """ """ class EasyAnimateController: - def __init__(self): + def __init__(self, GPU_memory_mode, weight_dtype): # config dirs self.basedir = os.getcwd() self.config_dir = os.path.join(self.basedir, "config") @@ -65,8 +75,7 @@ class EasyAnimateController: self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model") self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) self.savedir_sample = os.path.join(self.savedir, "sample") - self.edition = "v3" - self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_slicevae_motion_module_v3.yaml")) + self.model_type = "Inpaint" os.makedirs(self.savedir, exist_ok=True) self.diffusion_transformer_list = [] @@ -86,8 +95,11 @@ class EasyAnimateController: self.motion_module_path = "none" self.base_model_path = "none" self.lora_model_path = "none" + self.GPU_memory_mode = GPU_memory_mode - self.weight_dtype = torch.bfloat16 + self.weight_dtype = weight_dtype + self.edition = "v5" + self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml")) def refresh_diffusion_transformer(self): self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/"))) @@ -99,72 +111,179 @@ class EasyAnimateController: def refresh_personalized_model(self): personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors"))) self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] + + def update_model_type(self, model_type): + self.model_type = model_type def update_edition(self, edition): print("Update edition of EasyAnimate") self.edition = edition if edition == "v1": - self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml")) + self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v1_motion_module.yaml")) return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \ gr.update(value=512, minimum=384, maximum=704, step=32), \ gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1) elif edition == "v2": - self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml")) + self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v2_magvit_motion_module.yaml")) return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \ - gr.update(value=672, minimum=128, maximum=1280, step=16), \ - gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9) - else: - self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_slicevae_motion_module_v3.yaml")) + gr.update(value=672, minimum=128, maximum=1344, step=16), \ + gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=9, maximum=144, step=9) + elif edition == "v3": + self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v3_slicevae_motion_module.yaml")) + return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \ + gr.update(value=672, minimum=128, maximum=1344, step=16), \ + gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8) + elif edition == "v4": + self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v4_slicevae_multi_text_encoder.yaml")) return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \ - gr.update(value=672, minimum=128, maximum=1280, step=16), \ - gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=8, maximum=144, step=8) + gr.update(value=672, minimum=128, maximum=1344, step=16), \ + gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8) + elif edition == "v5": + self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml")) + return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \ + gr.update(value=672, minimum=128, maximum=1344, step=16), \ + gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4) def update_diffusion_transformer(self, diffusion_transformer_dropdown): print("Update diffusion transformer") if diffusion_transformer_dropdown == "none": return gr.update() - if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']: - Choosen_AutoencoderKL = AutoencoderKLMagvit - else: - Choosen_AutoencoderKL = AutoencoderKL + Choosen_AutoencoderKL = name_to_autoencoder_magvit[ + self.inference_config['vae_kwargs'].get('vae_type', 'AutoencoderKL') + ] self.vae = Choosen_AutoencoderKL.from_pretrained( diffusion_transformer_dropdown, subfolder="vae", ).to(self.weight_dtype) - self.transformer = Transformer3DModel.from_pretrained_2d( + if self.inference_config['vae_kwargs'].get('vae_type', 'AutoencoderKL') == 'AutoencoderKLMagvit' and self.weight_dtype == torch.float16: + self.vae.upcast_vae = True + + transformer_additional_kwargs = OmegaConf.to_container(self.inference_config['transformer_additional_kwargs']) + if self.weight_dtype == torch.float16: + transformer_additional_kwargs["upcast_attention"] = True + + # Get Transformer + Choosen_Transformer3DModel = name_to_transformer3d[ + self.inference_config['transformer_additional_kwargs'].get('transformer_type', 'Transformer3DModel') + ] + + self.transformer = Choosen_Transformer3DModel.from_pretrained_2d( diffusion_transformer_dropdown, subfolder="transformer", - transformer_additional_kwargs=OmegaConf.to_container(self.inference_config.transformer_additional_kwargs) + transformer_additional_kwargs=transformer_additional_kwargs ).to(self.weight_dtype) - self.tokenizer = T5Tokenizer.from_pretrained(diffusion_transformer_dropdown, subfolder="tokenizer") - self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype) + + if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + tokenizer = BertTokenizer.from_pretrained( + diffusion_transformer_dropdown, subfolder="tokenizer" + ) + tokenizer_2 = T5Tokenizer.from_pretrained( + diffusion_transformer_dropdown, subfolder="tokenizer_2" + ) + else: + tokenizer = T5Tokenizer.from_pretrained( + diffusion_transformer_dropdown, subfolder="tokenizer" + ) + tokenizer_2 = None - # Get pipeline - if self.transformer.config.in_channels != 12: - self.pipeline = EasyAnimatePipeline( - vae=self.vae, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, - transformer=self.transformer, - scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) + if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + text_encoder = BertModel.from_pretrained( + diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + diffusion_transformer_dropdown, subfolder="text_encoder_2", torch_dtype=self.weight_dtype ) else: + text_encoder = T5EncoderModel.from_pretrained( + diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype + ) + text_encoder_2 = None + + # Get pipeline + if self.transformer.config.in_channels != self.vae.config.latent_channels and self.inference_config['transformer_additional_kwargs'].get('enable_clip_in_inpaint', True): clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained( diffusion_transformer_dropdown, subfolder="image_encoder" - ).to("cuda", self.weight_dtype) + ).to(self.weight_dtype) clip_image_processor = CLIPImageProcessor.from_pretrained( diffusion_transformer_dropdown, subfolder="image_encoder" ) - self.pipeline = EasyAnimateInpaintPipeline( - vae=self.vae, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, - transformer=self.transformer, - scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)), - clip_image_encoder=clip_image_encoder, - clip_image_processor=clip_image_processor, - ) - + else: + clip_image_encoder = None + clip_image_processor = None + + # Get Scheduler + Choosen_Scheduler = scheduler_dict = { + "Euler": EulerDiscreteScheduler, + "Euler A": EulerAncestralDiscreteScheduler, + "DPM++": DPMSolverMultistepScheduler, + "PNDM": PNDMScheduler, + "DDIM": DDIMScheduler, + }["Euler"] + + scheduler = Choosen_Scheduler.from_pretrained( + diffusion_transformer_dropdown, + subfolder="scheduler" + ) + + if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder_Inpaint.from_pretrained( + diffusion_transformer_dropdown, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + torch_dtype=self.weight_dtype, + clip_image_encoder=clip_image_encoder, + clip_image_processor=clip_image_processor, + ) + else: + self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder.from_pretrained( + diffusion_transformer_dropdown, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + torch_dtype=self.weight_dtype + ) + else: + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = EasyAnimateInpaintPipeline( + diffusion_transformer_dropdown, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + torch_dtype=self.weight_dtype, + clip_image_encoder=clip_image_encoder, + clip_image_processor=clip_image_processor, + ) + else: + self.pipeline = EasyAnimatePipeline( + diffusion_transformer_dropdown, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + torch_dtype=self.weight_dtype + ) + + if self.GPU_memory_mode == "sequential_cpu_offload": + self.pipeline.enable_sequential_cpu_offload() + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + self.pipeline.enable_model_cpu_offload() + self.pipeline.enable_autocast_float8_transformer() + convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype) + else: + self.GPU_memory_mode.enable_model_cpu_offload() print("Update diffusion transformer done") return gr.update() @@ -238,6 +357,10 @@ class EasyAnimateController: cfg_scale_slider, start_image, end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, seed_textbox, is_api = False, ): @@ -251,33 +374,47 @@ class EasyAnimateController: if self.base_model_path != base_model_dropdown: self.update_base_model(base_model_dropdown) - if self.motion_module_path != motion_module_dropdown: - self.update_motion_module(motion_module_dropdown) - if self.lora_model_path != lora_model_dropdown: print("Update lora model") self.update_lora_model(lora_model_dropdown) - - if resize_method == "Resize to the Start Image": - if start_image is None: + + if control_video is not None and self.model_type == "Inpaint": + if is_api: + return "", f"If specifying the control video, please set the model_type == \"Control\". " + else: + raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ") + + if control_video is None and self.model_type == "Control": + if is_api: + return "", f"If set the model_type == \"Control\", please specifying the control video. " + else: + raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ") + + if resize_method == "Resize according to Reference": + if start_image is None and validation_video is None and control_video is None: if is_api: - return "", f"Please upload an image when using \"Resize to the Start Image\"." + return "", f"Please upload an image when using \"Resize according to Reference\"." else: - raise gr.Error(f"Please upload an image when using \"Resize to the Start Image\".") + raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".") aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} - - original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size + if self.model_type == "Inpaint": + if validation_video is not None: + original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size + else: + original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size + else: + original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) height_slider, width_slider = [int(x / 16) * 16 for x in closest_size] - if self.transformer.config.in_channels != 12 and start_image is not None: + if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None: if is_api: return "", f"Please select an image to video pretrained model while using image to video." else: raise gr.Error(f"Please select an image to video pretrained model while using image to video.") - if self.transformer.config.in_channels != 12 and generation_method == "Long Video Generation": + if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation": if is_api: return "", f"Please select an image to video pretrained model while using long video generation." else: @@ -289,88 +426,118 @@ class EasyAnimateController: else: raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") + fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition] is_image = True if generation_method == "Image Generation" else False - if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention() + if is_xformers_available() and not self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): self.transformer.enable_xformers_memory_efficient_attention() - self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) + self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config) if self.lora_model_path != "none": # lora part self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) - self.pipeline.to("cuda") if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) else: seed_textbox = np.random.randint(0, 1e10) generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox)) try: - if self.transformer.config.in_channels == 12: - if generation_method == "Long Video Generation": - init_frames = 0 - last_frames = init_frames + partial_video_length - while init_frames < length_slider: - if last_frames >= length_slider: - if self.pipeline.vae.quant_conv.weight.ndim==5: - mini_batch_encoder = self.pipeline.vae.mini_batch_encoder + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if generation_method == "Long Video Generation": + if validation_video is not None: + raise gr.Error(f"Video to Video is not Support Long Video Generation now.") + init_frames = 0 + last_frames = init_frames + partial_video_length + while init_frames < length_slider: + if last_frames >= length_slider: _partial_video_length = length_slider - init_frames - _partial_video_length = int(_partial_video_length // mini_batch_encoder * mini_batch_encoder) + if self.vae.cache_mag_vae: + _partial_video_length = int((_partial_video_length - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1 + else: + _partial_video_length = int(_partial_video_length // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + + if _partial_video_length <= 0: + break else: - _partial_video_length = length_slider - init_frames - - if _partial_video_length <= 0: - break - else: - _partial_video_length = partial_video_length + _partial_video_length = partial_video_length - if last_frames >= length_slider: - input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider)) - else: - input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider)) - - with torch.no_grad(): - sample = self.pipeline( - prompt_textbox, - negative_prompt = negative_prompt_textbox, - num_inference_steps = sample_step_slider, - guidance_scale = cfg_scale_slider, - width = width_slider, - height = height_slider, - video_length = _partial_video_length, - generator = generator, - - video = input_video, - mask_video = input_video_mask, - clip_image = clip_image, - strength = 1, - ).videos - - if init_frames != 0: - mix_ratio = torch.from_numpy( - np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32) - ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + if last_frames >= length_slider: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider)) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider)) + + with torch.no_grad(): + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + video_length = _partial_video_length, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + strength = 1, + ).videos - new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \ - sample[:, :, :overlap_video_length] * mix_ratio - new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2) + if init_frames != 0: + mix_ratio = torch.from_numpy( + np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32) + ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + + new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \ + sample[:, :, :overlap_video_length] * mix_ratio + new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2) - sample = new_sample - else: - new_sample = sample + sample = new_sample + else: + new_sample = sample - if last_frames >= length_slider: - break + if last_frames >= length_slider: + break - start_image = [ - Image.fromarray( - (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8) - ) for _index in range(-overlap_video_length, 0) - ] + start_image = [ + Image.fromarray( + (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8) + ) for _index in range(-overlap_video_length, 0) + ] - init_frames = init_frames + _partial_video_length - overlap_video_length - last_frames = init_frames + _partial_video_length + init_frames = init_frames + _partial_video_length - overlap_video_length + last_frames = init_frames + _partial_video_length + else: + if self.vae.cache_mag_vae: + length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1 + else: + length_slider = int(length_slider // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + if validation_video is not None: + input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + strength = denoise_strength + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + strength = 1 + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + video_length = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + strength = strength, + ).videos else: - input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) - + if self.vae.cache_mag_vae: + length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1 + else: + length_slider = int(length_slider // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + sample = self.pipeline( prompt_textbox, negative_prompt = negative_prompt_textbox, @@ -379,13 +546,15 @@ class EasyAnimateController: width = width_slider, height = height_slider, video_length = length_slider if not is_image else 1, - generator = generator, - - video = input_video, - mask_video = input_video_mask, - clip_image = clip_image, + generator = generator ).videos else: + if self.vae.cache_mag_vae: + length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1 + else: + length_slider = int(length_slider // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps) + sample = self.pipeline( prompt_textbox, negative_prompt = negative_prompt_textbox, @@ -394,7 +563,9 @@ class EasyAnimateController: width = width_slider, height = height_slider, video_length = length_slider if not is_image else 1, - generator = generator + generator = generator, + + control_video = input_video, ).videos except Exception as e: gc.collect() @@ -457,7 +628,7 @@ class EasyAnimateController: return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" else: save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4") - save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24) + save_videos_grid(sample, save_sample_path, fps=fps) if is_api: return save_sample_path, "Success" @@ -468,8 +639,8 @@ class EasyAnimateController: return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" -def ui(): - controller = EasyAnimateController() +def ui(GPU_memory_mode, weight_dtype): + controller = EasyAnimateController(GPU_memory_mode, weight_dtype) with gr.Blocks(css=css) as demo: gr.Markdown( @@ -486,19 +657,32 @@ def ui(): with gr.Column(variant="panel"): gr.Markdown( """ - ### 1. EasyAnimate Edition (EasyAnimate版本). + ### 1. EasyAnimate Model Type (EasyAnimate模型的种类,正常模型还是控制模型). + """ + ) + with gr.Row(): + model_type = gr.Dropdown( + label="The model type of EasyAnimate (EasyAnimate模型的种类,正常模型还是控制模型)", + choices=["Inpaint", "Control"], + value="Inpaint", + interactive=True, + ) + with gr.Column(variant="panel"): + gr.Markdown( + """ + ### 2. EasyAnimate Edition (EasyAnimate版本). """ ) with gr.Row(): easyanimate_edition_dropdown = gr.Dropdown( label="The config of EasyAnimate Edition (EasyAnimate版本配置)", - choices=["v1", "v2", "v3"], - value="v3", + choices=["v1", "v2", "v3", "v4", "v5"], + value="v5", interactive=True, ) gr.Markdown( """ - ### 2. Model checkpoints (模型路径). + ### 3. Model checkpoints (模型路径). """ ) with gr.Row(): @@ -568,21 +752,21 @@ def ui(): ) prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.") - negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion." ) + negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Unclear, mutated, deformed, distorted, dark frames, fixed frames, comic book, comic book, small and indistinguishable subject." ) with gr.Row(): with gr.Column(): with gr.Row(): sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) - sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=30, minimum=10, maximum=100, step=1) + sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1) resize_method = gr.Radio( - ["Generate by", "Resize to the Start Image"], + ["Generate by", "Resize according to Reference"], value="Generate by", show_label=False, ) - width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16) - height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16) + width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1344, step=16) + height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1344, step=16) base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False) with gr.Group(): @@ -592,17 +776,25 @@ def ui(): show_label=False, ) with gr.Row(): - length_slider = gr.Slider(label="Animation length (视频帧数)", value=144, minimum=8, maximum=144, step=8) + length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=1, maximum=49, step=4) overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False) - partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=72, minimum=8, maximum=144, step=8, visible=False) + partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=49, step=4, visible=False) - with gr.Accordion("Image to Video (图片到视频)", open=False): - start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath") + source_method = gr.Radio( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], + value="Text to Video (文本到视频)", + show_label=False, + ) + with gr.Column(visible = False) as image_to_video_col: + start_image = gr.Image( + label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, + elem_id="i2v_start", sources="upload", type="filepath", + ) template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"] def select_template(evt: gr.SelectData): text = { - "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", + "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", @@ -623,7 +815,37 @@ def ui(): with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False): end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath") - cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20) + with gr.Column(visible = False) as video_to_video_col: + with gr.Row(): + validation_video = gr.Video( + label="The video to convert (视频转视频的参考视频)", show_label=True, + elem_id="v2v", sources="upload", + ) + with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False): + gr.Markdown( + """ + - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70 + - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70) + """ + ) + validation_video_mask = gr.Image( + label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", + show_label=False, elem_id="v2v_mask", sources="upload", type="filepath" + ) + denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01) + + with gr.Column(visible = False) as control_video_col: + gr.Markdown( + """ + Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4). + """ + ) + control_video = gr.Video( + label="The control video (用于提供控制信号的video)", show_label=True, + elem_id="v2v_control", sources="upload", + ) + + cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20) with gr.Row(): seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43) @@ -645,17 +867,49 @@ def ui(): interactive=False ) - def upload_generation_method(generation_method): + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method, easyanimate_edition_dropdown): + if easyanimate_edition_dropdown == "v1": + f_maximum = 80 + f_value = 80 + elif easyanimate_edition_dropdown in ["v2", "v3", "v4"]: + f_maximum = 144 + f_value = 144 + else: + f_maximum = 49 + f_value = 49 + if generation_method == "Video Generation": - return [gr.update(visible=True, maximum=144, value=144), gr.update(visible=False), gr.update(visible=False)] + return [gr.update(visible=True, maximum=f_maximum, value=f_value), gr.update(visible=False), gr.update(visible=False)] elif generation_method == "Image Generation": return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)] else: - return [gr.update(visible=True, maximum=1440), gr.update(visible=True), gr.update(visible=True)] + return [gr.update(visible=True, maximum=1200), gr.update(visible=True), gr.update(visible=True)] generation_method.change( upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] ) + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + def upload_resize_method(resize_method): if resize_method == "Generate by": return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] @@ -701,6 +955,10 @@ def ui(): cfg_scale_slider, start_image, end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, seed_textbox, ], outputs=[result_image, result_video, infer_progress] @@ -709,10 +967,7 @@ def ui(): class EasyAnimateController_Modelscope: - def __init__(self, edition, config_path, model_name, savedir_sample): - # Weight Dtype - weight_dtype = torch.bfloat16 - + def __init__(self, model_type, edition, config_path, model_name, savedir_sample, GPU_memory_mode, weight_dtype): # Basic dir self.basedir = os.getcwd() self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model") @@ -722,65 +977,152 @@ class EasyAnimateController_Modelscope: os.makedirs(self.savedir_sample, exist_ok=True) # Config and model path + self.model_type = model_type self.edition = edition + self.weight_dtype = weight_dtype self.inference_config = OmegaConf.load(config_path) - # Get Transformer - self.transformer = Transformer3DModel.from_pretrained_2d( - model_name, - subfolder="transformer", - transformer_additional_kwargs=OmegaConf.to_container(self.inference_config['transformer_additional_kwargs']) - ).to(weight_dtype) - if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']: - Choosen_AutoencoderKL = AutoencoderKLMagvit - else: - Choosen_AutoencoderKL = AutoencoderKL + Choosen_AutoencoderKL = name_to_autoencoder_magvit[ + self.inference_config['vae_kwargs'].get('vae_type', 'AutoencoderKL') + ] self.vae = Choosen_AutoencoderKL.from_pretrained( model_name, - subfolder="vae" - ).to(weight_dtype) - self.tokenizer = T5Tokenizer.from_pretrained( - model_name, - subfolder="tokenizer" - ) - self.text_encoder = T5EncoderModel.from_pretrained( + subfolder="vae", + ).to(self.weight_dtype) + if self.inference_config['vae_kwargs'].get('vae_type', 'AutoencoderKL') == 'AutoencoderKLMagvit' and weight_dtype == torch.float16: + self.vae.upcast_vae = True + + transformer_additional_kwargs = OmegaConf.to_container(self.inference_config['transformer_additional_kwargs']) + if self.weight_dtype == torch.float16: + transformer_additional_kwargs["upcast_attention"] = True + + # Get Transformer + Choosen_Transformer3DModel = name_to_transformer3d[ + self.inference_config['transformer_additional_kwargs'].get('transformer_type', 'Transformer3DModel') + ] + + self.transformer = Choosen_Transformer3DModel.from_pretrained_2d( model_name, - subfolder="text_encoder", - torch_dtype=weight_dtype - ) - # Get pipeline - if self.transformer.config.in_channels != 12: - self.pipeline = EasyAnimatePipeline( - vae=self.vae, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, - transformer=self.transformer, - scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) + subfolder="transformer", + transformer_additional_kwargs=transformer_additional_kwargs + ).to(self.weight_dtype) + + if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + tokenizer = BertTokenizer.from_pretrained( + model_name, subfolder="tokenizer" + ) + tokenizer_2 = T5Tokenizer.from_pretrained( + model_name, subfolder="tokenizer_2" + ) + else: + tokenizer = T5Tokenizer.from_pretrained( + model_name, subfolder="tokenizer" + ) + tokenizer_2 = None + + if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + text_encoder = BertModel.from_pretrained( + model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + model_name, subfolder="text_encoder_2", torch_dtype=self.weight_dtype ) else: + text_encoder = T5EncoderModel.from_pretrained( + model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype + ) + text_encoder_2 = None + + # Get pipeline + if self.transformer.config.in_channels != self.vae.config.latent_channels and self.inference_config['transformer_additional_kwargs'].get('enable_clip_in_inpaint', True): clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained( model_name, subfolder="image_encoder" - ).to("cuda", weight_dtype) + ).to(self.weight_dtype) clip_image_processor = CLIPImageProcessor.from_pretrained( model_name, subfolder="image_encoder" ) - self.pipeline = EasyAnimateInpaintPipeline( - vae=self.vae, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, - transformer=self.transformer, - scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)), - clip_image_encoder=clip_image_encoder, - clip_image_processor=clip_image_processor, - ) - - print("Update diffusion transformer done") + else: + clip_image_encoder = None + clip_image_processor = None + + # Get Scheduler + Choosen_Scheduler = scheduler_dict = { + "Euler": EulerDiscreteScheduler, + "Euler A": EulerAncestralDiscreteScheduler, + "DPM++": DPMSolverMultistepScheduler, + "PNDM": PNDMScheduler, + "DDIM": DDIMScheduler, + }["Euler"] + + scheduler = Choosen_Scheduler.from_pretrained( + model_name, + subfolder="scheduler" + ) + if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder_Inpaint.from_pretrained( + model_name, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + torch_dtype=self.weight_dtype, + clip_image_encoder=clip_image_encoder, + clip_image_processor=clip_image_processor, + ) + else: + self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder.from_pretrained( + model_name, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + torch_dtype=self.weight_dtype + ) + else: + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = EasyAnimateInpaintPipeline( + model_name, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + torch_dtype=self.weight_dtype, + clip_image_encoder=clip_image_encoder, + clip_image_processor=clip_image_processor, + ) + else: + self.pipeline = EasyAnimatePipeline( + model_name, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + torch_dtype=self.weight_dtype + ) + + if GPU_memory_mode == "sequential_cpu_offload": + self.pipeline.enable_sequential_cpu_offload() + elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + self.pipeline.enable_model_cpu_offload() + self.pipeline.enable_autocast_float8_transformer() + convert_weight_dtype_wrapper(self.pipeline.transformer, weight_dtype) + else: + GPU_memory_mode.enable_model_cpu_offload() + print("Update diffusion transformer done") def refresh_personalized_model(self): personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors"))) self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] - def update_lora_model(self, lora_model_dropdown): print("Update lora model") if lora_model_dropdown == "none": @@ -789,7 +1131,6 @@ class EasyAnimateController_Modelscope: lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) self.lora_model_path = lora_model_dropdown return gr.update() - def generate( self, @@ -808,9 +1149,15 @@ class EasyAnimateController_Modelscope: base_resolution, generation_method, length_slider, + overlap_video_length, + partial_video_length, cfg_scale_slider, start_image, end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, seed_textbox, is_api = False, ): @@ -824,55 +1171,109 @@ class EasyAnimateController_Modelscope: if self.lora_model_path != lora_model_dropdown: print("Update lora model") self.update_lora_model(lora_model_dropdown) + + if control_video is not None and self.model_type == "Inpaint": + if is_api: + return "", f"If specifying the control video, please set the model_type == \"Control\". " + else: + raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ") + + if control_video is None and self.model_type == "Control": + if is_api: + return "", f"If set the model_type == \"Control\", please specifying the control video. " + else: + raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ") - if resize_method == "Resize to the Start Image": - if start_image is None: - raise gr.Error(f"Please upload an image when using \"Resize to the Start Image\".") + if resize_method == "Resize according to Reference": + if start_image is None and validation_video is None and control_video is None: + if is_api: + return "", f"Please upload an image when using \"Resize according to Reference\"." + else: + raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".") aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} - original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size + if self.model_type == "Inpaint": + if validation_video is not None: + original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size + else: + original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size + else: + original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) height_slider, width_slider = [int(x / 16) * 16 for x in closest_size] - if self.transformer.config.in_channels != 12 and start_image is not None: - raise gr.Error(f"Please select an image to video pretrained model while using image to video.") - + if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None: + if is_api: + return "", f"Please select an image to video pretrained model while using image to video." + else: + raise gr.Error(f"Please select an image to video pretrained model while using image to video.") + if start_image is None and end_image is not None: - raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") + if is_api: + return "", f"If specifying the ending image of the video, please specify a starting image of the video." + else: + raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") + fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition] is_image = True if generation_method == "Image Generation" else False - if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention() - - self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) + self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config) if self.lora_model_path != "none": # lora part self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) - self.pipeline.to("cuda") if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) else: seed_textbox = np.random.randint(0, 1e10) generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox)) try: - if self.transformer.config.in_channels == 12: - input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + if self.model_type == "Inpaint": + if self.vae.cache_mag_vae: + length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1 + else: + length_slider = int(length_slider // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if validation_video is not None: + input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + strength = denoise_strength + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + strength = 1 - sample = self.pipeline( - prompt_textbox, - negative_prompt = negative_prompt_textbox, - num_inference_steps = sample_step_slider, - guidance_scale = cfg_scale_slider, - width = width_slider, - height = height_slider, - video_length = length_slider if not is_image else 1, - generator = generator, + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + video_length = length_slider if not is_image else 1, + generator = generator, - video = input_video, - mask_video = input_video_mask, - clip_image = clip_image, - ).videos + video = input_video, + mask_video = input_video_mask, + strength = strength, + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + video_length = length_slider if not is_image else 1, + generator = generator + ).videos else: + if self.vae.cache_mag_vae: + length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1 + else: + length_slider = int(length_slider // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + + input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps) + sample = self.pipeline( prompt_textbox, negative_prompt = negative_prompt_textbox, @@ -881,7 +1282,9 @@ class EasyAnimateController_Modelscope: width = width_slider, height = height_slider, video_length = length_slider if not is_image else 1, - generator = generator + generator = generator, + + control_video = input_video, ).videos except Exception as e: gc.collect() @@ -927,7 +1330,7 @@ class EasyAnimateController_Modelscope: return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" else: save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4") - save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24) + save_videos_grid(sample, save_sample_path, fps=fps) if is_api: return save_sample_path, "Success" else: @@ -937,8 +1340,8 @@ class EasyAnimateController_Modelscope: return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" -def ui_modelscope(edition, config_path, model_name, savedir_sample): - controller = EasyAnimateController_Modelscope(edition, config_path, model_name, savedir_sample) +def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, GPU_memory_mode, weight_dtype): + controller = EasyAnimateController_Modelscope(model_type, edition, config_path, model_name, savedir_sample, GPU_memory_mode, weight_dtype) with gr.Blocks(css=css) as demo: gr.Markdown( @@ -989,9 +1392,9 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample): with gr.Row(): lora_model_dropdown = gr.Dropdown( label="Select LoRA model", - choices=["none", "easyanimatev2_minimalism_lora.safetensors"], + choices=["none"], value="none", - interactive=True, + interactive=False, ) lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True) @@ -1004,13 +1407,13 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample): ) prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.") - negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion." ) + negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Unclear, mutated, deformed, distorted, dark frames, fixed frames, comic book, comic book, small and indistinguishable subject." ) with gr.Row(): with gr.Column(): with gr.Row(): sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) - sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=20, minimum=10, maximum=30, step=1, interactive=False) + sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False) if edition == "v1": width_slider = gr.Slider(label="Width (视频宽度)", value=512, minimum=384, maximum=704, step=32) @@ -1024,25 +1427,17 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample): visible=False, ) length_slider = gr.Slider(label="Animation length (视频帧数)", value=80, minimum=40, maximum=96, step=1) + overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False) + partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=72, minimum=8, maximum=144, step=8, visible=False) cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20) else: resize_method = gr.Radio( - ["Generate by", "Resize to the Start Image"], + ["Generate by", "Resize according to Reference"], value="Generate by", show_label=False, - ) - with gr.Column(): - gr.Markdown( - """ - We support video generation up to 720p with 144 frames, but for the trial experience, we have set certain limitations. We fix the max resolution of video to 384x672x48 (2s). - - If the start image you uploaded does not match this resolution, you can use the "Resize to the Start Image" option above. - - If you want to experience longer and larger video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/). - """ - ) - width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False) - height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False) + ) + width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1344, step=16, interactive=False) + height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1344, step=16, interactive=False) base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False) with gr.Group(): @@ -1052,16 +1447,26 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample): show_label=False, visible=True, ) - length_slider = gr.Slider(label="Animation length (视频帧数)", value=48, minimum=8, maximum=48, step=8) + if edition in ["v2", "v3", "v4"]: + length_slider = gr.Slider(label="Animation length (视频帧数)", value=144, minimum=8, maximum=144, step=8) + else: + length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=5, maximum=49, step=4) + overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False) + partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=72, minimum=8, maximum=144, step=8, visible=False) - with gr.Accordion("Image to Video (图片到视频)", open=True): + source_method = gr.Radio( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], + value="Text to Video (文本到视频)", + show_label=False, + ) + with gr.Column(visible = False) as image_to_video_col: with gr.Row(): start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath") template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"] def select_template(evt: gr.SelectData): text = { - "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", + "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", @@ -1082,8 +1487,37 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample): with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False): end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath") + with gr.Column(visible = False) as video_to_video_col: + with gr.Row(): + validation_video = gr.Video( + label="The video to convert (视频转视频的参考视频)", show_label=True, + elem_id="v2v", sources="upload", + ) + with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False): + gr.Markdown( + """ + - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70 + - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70) + """ + ) + validation_video_mask = gr.Image( + label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", + show_label=False, elem_id="v2v_mask", sources="upload", type="filepath" + ) + denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01) + + with gr.Column(visible = False) as control_video_col: + gr.Markdown( + """ + Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4). + """ + ) + control_video = gr.Video( + label="The control video (用于提供控制信号的video)", show_label=True, + elem_id="v2v_control", sources="upload", + ) - cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20) + cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20) with gr.Row(): seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43) @@ -1106,14 +1540,40 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample): ) def upload_generation_method(generation_method): + if edition == "v1": + f_maximum = 80 + f_value = 80 + elif edition in ["v2", "v3", "v4"]: + f_maximum = 144 + f_value = 144 + else: + f_maximum = 49 + f_value = 49 + if generation_method == "Video Generation": - return gr.update(visible=True, minimum=8, maximum=48, value=48, interactive=True) + return gr.update(visible=True, maximum=f_maximum, value=f_value) elif generation_method == "Image Generation": - return gr.update(minimum=1, maximum=1, value=1, interactive=False) + return gr.update(visible=False) generation_method.change( upload_generation_method, generation_method, [length_slider] ) + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + def upload_resize_method(resize_method): if resize_method == "Generate by": return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] @@ -1141,9 +1601,15 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample): base_resolution, generation_method, length_slider, + overlap_video_length, + partial_video_length, cfg_scale_slider, start_image, end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, seed_textbox, ], outputs=[result_image, result_video, infer_progress] @@ -1157,7 +1623,7 @@ def post_eas( prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, base_resolution, generation_method, length_slider, cfg_scale_slider, - start_image, end_image, seed_textbox, + start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox, ): if start_image is not None: with open(start_image, 'rb') as file: @@ -1171,6 +1637,19 @@ def post_eas( end_image_encoded_content = base64.b64encode(file_content) end_image = end_image_encoded_content.decode('utf-8') + if validation_video is not None: + with open(validation_video, 'rb') as file: + file_content = file.read() + validation_video_encoded_content = base64.b64encode(file_content) + validation_video = validation_video_encoded_content.decode('utf-8') + + if validation_video_mask is not None: + with open(validation_video_mask, 'rb') as file: + file_content = file.read() + validation_video_mask_encoded_content = base64.b64encode(file_content) + validation_video_mask = validation_video_mask_encoded_content.decode('utf-8') + + datas = { "base_model_path": base_model_dropdown, "motion_module_path": motion_module_dropdown, @@ -1189,6 +1668,9 @@ def post_eas( "cfg_scale_slider": cfg_scale_slider, "start_image": start_image, "end_image": end_image, + "validation_video": validation_video, + "validation_video_mask": validation_video_mask, + "denoise_strength": denoise_strength, "seed_textbox": seed_textbox, } @@ -1226,6 +1708,9 @@ class EasyAnimateController_EAS: cfg_scale_slider, start_image, end_image, + validation_video, + validation_video_mask, + denoise_strength, seed_textbox ): is_image = True if generation_method == "Image Generation" else False @@ -1236,7 +1721,7 @@ class EasyAnimateController_EAS: prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, base_resolution, generation_method, length_slider, cfg_scale_slider, - start_image, end_image, + start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox ) try: @@ -1321,9 +1806,9 @@ def ui_eas(edition, config_path, model_name, savedir_sample): with gr.Row(): lora_model_dropdown = gr.Dropdown( label="Select LoRA model", - choices=["none", "easyanimatev2_minimalism_lora.safetensors"], + choices=["none"], value="none", - interactive=True, + interactive=False, ) lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True) @@ -1336,13 +1821,13 @@ def ui_eas(edition, config_path, model_name, savedir_sample): ) prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.") - negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " ) + negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="Unclear, mutated, deformed, distorted, dark frames, fixed frames, comic book, comic book, small and indistinguishable subject." ) with gr.Row(): with gr.Column(): with gr.Row(): sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) - sample_step_slider = gr.Slider(label="Sampling steps", value=20, minimum=10, maximum=30, step=1, interactive=False) + sample_step_slider = gr.Slider(label="Sampling steps", value=40, minimum=10, maximum=40, step=1, interactive=False) if edition == "v1": width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32) @@ -1359,22 +1844,12 @@ def ui_eas(edition, config_path, model_name, savedir_sample): cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20) else: resize_method = gr.Radio( - ["Generate by", "Resize to the Start Image"], + ["Generate by", "Resize according to Reference"], value="Generate by", show_label=False, - ) - with gr.Column(): - gr.Markdown( - """ - We support video generation up to 720p with 144 frames, but for the trial experience, we have set certain limitations. We fix the max resolution of video to 384x672x48 (2s). - - If the start image you uploaded does not match this resolution, you can use the "Resize to the Start Image" option above. - - If you want to experience longer and larger video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/). - """ - ) - width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False) - height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False) + ) + width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1344, step=16, interactive=False) + height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1344, step=16, interactive=False) base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False) with gr.Group(): @@ -1384,15 +1859,23 @@ def ui_eas(edition, config_path, model_name, savedir_sample): show_label=False, visible=True, ) - length_slider = gr.Slider(label="Animation length (视频帧数)", value=48, minimum=8, maximum=48, step=8) + if edition in ["v2", "v3", "v4"]: + length_slider = gr.Slider(label="Animation length (视频帧数)", value=144, minimum=8, maximum=144, step=8) + else: + length_slider = gr.Slider(label="Animation length (视频帧数)", value=21, minimum=5, maximum=21, step=4) - with gr.Accordion("Image to Video", open=True): + source_method = gr.Radio( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"], + value="Text to Video (文本到视频)", + show_label=False, + ) + with gr.Column(visible = False) as image_to_video_col: start_image = gr.Image(label="The image at the beginning of the video", show_label=True, elem_id="i2v_start", sources="upload", type="filepath") template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"] def select_template(evt: gr.SelectData): text = { - "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", + "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", @@ -1413,8 +1896,27 @@ def ui_eas(edition, config_path, model_name, savedir_sample): with gr.Accordion("The image at the ending of the video (Optional)", open=False): end_image = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath") - cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20) - + with gr.Column(visible = False) as video_to_video_col: + with gr.Row(): + validation_video = gr.Video( + label="The video to convert (视频转视频的参考视频)", show_label=True, + elem_id="v2v", sources="upload", + ) + with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False): + gr.Markdown( + """ + - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70 + - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70) + """ + ) + validation_video_mask = gr.Image( + label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", + show_label=False, elem_id="v2v_mask", sources="upload", type="filepath" + ) + denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01) + + cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20) + with gr.Row(): seed_textbox = gr.Textbox(label="Seed", value=43) seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") @@ -1436,14 +1938,35 @@ def ui_eas(edition, config_path, model_name, savedir_sample): ) def upload_generation_method(generation_method): + if edition == "v1": + f_maximum = 80 + f_value = 80 + elif edition in ["v2", "v3", "v4"]: + f_maximum = 144 + f_value = 144 + else: + f_maximum = 49 + f_value = 49 + if generation_method == "Video Generation": - return gr.update(visible=True, minimum=8, maximum=48, value=48, interactive=True) + return gr.update(visible=True, maximum=f_maximum, value=f_value) elif generation_method == "Image Generation": - return gr.update(minimum=1, maximum=1, value=1, interactive=False) + return gr.update(visible=False) generation_method.change( upload_generation_method, generation_method, [length_slider] ) + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + def upload_resize_method(resize_method): if resize_method == "Generate by": return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] @@ -1474,6 +1997,9 @@ def ui_eas(edition, config_path, model_name, savedir_sample): cfg_scale_slider, start_image, end_image, + validation_video, + validation_video_mask, + denoise_strength, seed_textbox, ], outputs=[result_image, result_video, infer_progress] diff --git a/easyanimate/utils/discrete_sampler.py b/easyanimate/utils/discrete_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..149dbe7beb94dfea2e6fe0ca3b5acf9437be60f7 --- /dev/null +++ b/easyanimate/utils/discrete_sampler.py @@ -0,0 +1,46 @@ +"""Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py +""" +import torch + +class DiscreteSampling: + def __init__(self, num_idx, uniform_sampling=False): + self.num_idx = num_idx + self.uniform_sampling = uniform_sampling + self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + + if self.is_distributed and self.uniform_sampling: + world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + + i = 1 + while True: + if world_size % i != 0 or num_idx % (world_size // i) != 0: + i += 1 + else: + self.group_num = world_size // i + break + assert self.group_num > 0 + assert world_size % self.group_num == 0 + # the number of rank in one group + self.group_width = world_size // self.group_num + self.sigma_interval = self.num_idx // self.group_num + print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % ( + self.rank, world_size, self.group_num, + self.group_width, self.sigma_interval)) + + def __call__(self, n_samples, generator=None, device=None): + if self.is_distributed and self.uniform_sampling: + group_index = self.rank // self.group_width + idx = torch.randint( + group_index * self.sigma_interval, + (group_index + 1) * self.sigma_interval, + (n_samples,), + generator=generator, device=device, + ) + print('proc[%d] idx=%s' % (self.rank, idx)) + else: + idx = torch.randint( + 0, self.num_idx, (n_samples,), + generator=generator, device=device, + ) + return idx \ No newline at end of file diff --git a/easyanimate/utils/fp8_optimization.py b/easyanimate/utils/fp8_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..d270605c8aabde645da8c42f9daccd9057cb3197 --- /dev/null +++ b/easyanimate/utils/fp8_optimization.py @@ -0,0 +1,28 @@ +"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper +""" +import torch +import torch.nn as nn + +def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): + weight_dtype = cls.weight.dtype + cls.to(origin_dtype) + + # Convert all inputs to the original dtype + inputs = [input.to(origin_dtype) for input in inputs] + out = cls.original_forward(*inputs, **kwargs) + + cls.to(weight_dtype) + return out + +def convert_weight_dtype_wrapper(module, origin_dtype): + for name, module in module.named_modules(): + if name == "": + continue + original_forward = module.forward + if hasattr(module, "weight"): + setattr(module, "original_forward", original_forward) + setattr( + module, + "forward", + lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) + ) diff --git a/easyanimate/utils/lora_utils.py b/easyanimate/utils/lora_utils.py index 9673a7d51d9949f49a10ea50382d9d0117cb2aa6..b50d11b33c60d40e1b73b9da4b7e4ff39cfb3f93 100644 --- a/easyanimate/utils/lora_utils.py +++ b/easyanimate/utils/lora_utils.py @@ -156,8 +156,8 @@ def precalculate_safetensors_hashes(tensors, metadata): class LoRANetwork(torch.nn.Module): - TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF"] + TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel", "HunyuanTransformer3DModel", "EasyAnimateTransformer3DModel"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder"] LORA_PREFIX_TRANSFORMER = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" def __init__( @@ -238,9 +238,10 @@ class LoRANetwork(torch.nn.Module): self.text_encoder_loras = [] skipped_te = [] for i, text_encoder in enumerate(text_encoders): - text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - self.text_encoder_loras.extend(text_encoder_loras) - skipped_te += skipped + if text_encoder is not None: + text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE) @@ -368,6 +369,7 @@ def create_network( def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False): LORA_PREFIX_TRANSFORMER = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" + SPECIAL_LAYER_NAME = ["text_proj_t5"] if state_dict is None: state_dict = load_file(lora_path, device=device) else: @@ -389,21 +391,24 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3 layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") curr_layer = pipeline.transformer - temp_name = layer_infos.pop(0) - while len(layer_infos) > -1: - try: - curr_layer = curr_layer.__getattr__(temp_name) - if len(layer_infos) > 0: - temp_name = layer_infos.pop(0) - elif len(layer_infos) == 0: - break - except Exception: - if len(layer_infos) == 0: - print('Error loading layer') - if len(temp_name) > 0: - temp_name += "_" + layer_infos.pop(0) - else: - temp_name = layer_infos.pop(0) + try: + curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) + except Exception: + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print('Error loading layer') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) weight_up = elems['lora_up.weight'].to(dtype) weight_down = elems['lora_down.weight'].to(dtype) @@ -444,6 +449,7 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl curr_layer = pipeline.transformer temp_name = layer_infos.pop(0) + print(layer, curr_layer) while len(layer_infos) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) diff --git a/easyanimate/utils/utils.py b/easyanimate/utils/utils.py index 9b8bfed5eb811c1890269dee360338aa4dfbba34..c1e2083a456053fdd6a50590d1c378572183f94b 100644 --- a/easyanimate/utils/utils.py +++ b/easyanimate/utils/utils.py @@ -1,13 +1,15 @@ +import gc import os +import cv2 import imageio import numpy as np import torch import torchvision -import cv2 from einops import rearrange from PIL import Image + def get_width_and_height_from_image_and_base_resolution(image, base_resolution): target_pixels = int(base_resolution) * int(base_resolution) original_width, original_height = Image.open(image).size @@ -73,13 +75,20 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size): if validation_image_start is not None and validation_image_end is not None: if type(validation_image_start) is str and os.path.isfile(validation_image_start): - image_start = clip_image = Image.open(validation_image_start) + image_start = clip_image = Image.open(validation_image_start).convert("RGB") + image_start = image_start.resize([sample_size[1], sample_size[0]]) + clip_image = clip_image.resize([sample_size[1], sample_size[0]]) else: image_start = clip_image = validation_image_start + image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] + clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] + if type(validation_image_end) is str and os.path.isfile(validation_image_end): - image_end = Image.open(validation_image_end) + image_end = Image.open(validation_image_end).convert("RGB") + image_end = image_end.resize([sample_size[1], sample_size[0]]) else: image_end = validation_image_end + image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end] if type(image_start) is list: clip_image = clip_image[0] @@ -119,8 +128,13 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide elif validation_image_start is not None: if type(validation_image_start) is str and os.path.isfile(validation_image_start): image_start = clip_image = Image.open(validation_image_start).convert("RGB") + image_start = image_start.resize([sample_size[1], sample_size[0]]) + clip_image = clip_image.resize([sample_size[1], sample_size[0]]) else: image_start = clip_image = validation_image_start + image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] + clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] + image_end = None if type(image_start) is list: clip_image = clip_image[0] @@ -142,30 +156,60 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide input_video_mask = torch.zeros_like(input_video[:, :1]) input_video_mask[:, :, 1:, ] = 255 else: + image_start = None + image_end = None input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]]) input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255 clip_image = None + del image_start + del image_end + gc.collect() + return input_video, input_video_mask, clip_image -def video_frames(input_video_path): - cap = cv2.VideoCapture(input_video_path) - frames = [] - while True: - ret, frame = cap.read() - if not ret: - break - frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - cap.release() - cv2.destroyAllWindows() - return frames - -def get_video_to_video_latent(validation_videos, video_length): - input_video = video_frames(validation_videos) +def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None): + if isinstance(input_video_path, str): + cap = cv2.VideoCapture(input_video_path) + input_video = [] + + original_fps = cap.get(cv2.CAP_PROP_FPS) + frame_skip = 1 if fps is None else int(original_fps // fps) + + frame_count = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + if frame_count % frame_skip == 0: + frame = cv2.resize(frame, (sample_size[1], sample_size[0])) + input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + + frame_count += 1 + + cap.release() + else: + input_video = input_video_path + input_video = torch.from_numpy(np.array(input_video))[:video_length] input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 - input_video_mask = torch.zeros_like(input_video[:, :1]) - input_video_mask[:, :, :] = 255 + if ref_image is not None: + ref_image = Image.open(ref_image) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + + if validation_video_mask is not None: + validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0])) + input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255) + + input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) + input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) + input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) + else: + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, :] = 255 - return input_video, input_video_mask, None \ No newline at end of file + return input_video, input_video_mask, ref_image \ No newline at end of file diff --git a/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_cogvideox.yaml b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_cogvideox.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b777c76910fd494e9ff9d9e0c4b50b55a4c921df --- /dev/null +++ b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_cogvideox.yaml @@ -0,0 +1,64 @@ +model: + base_learning_rate: 1.0e-04 + target: easyanimate.vae.ldm.models.cogvideox_casual3dcnn.AutoencoderKLMagvit_CogVideoX + params: + latent_channels: 16 + temporal_compression_ratio: 4 + monitor: train/rec_loss + ckpt_path: vae/diffusion_pytorch_model.safetensors + down_block_types: ("CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D",) + up_block_types: ("CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D",) + lossconfig: + target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + l2_loss_weight: 0.1 + l1_loss_weight: 1.0 + perceptual_weight: 1.0 + +data: + target: train_vae.DataModuleFromConfig + + params: + batch_size: 1 + wrap: true + num_workers: 8 + train: + target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain + params: + data_json_path: pretrain.json + data_root: /your_data_root # This is used in relative path + size: 256 + degradation: pil_nearest + video_size: 256 + video_len: 49 + slice_interval: 1 + validation: + target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation + params: + data_json_path: pretrain.json + data_root: /your_data_root # This is used in relative path + size: 256 + degradation: pil_nearest + video_size: 256 + video_len: 49 + slice_interval: 1 + +lightning: + callbacks: + image_logger: + target: train_vae.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 1 + gpus: "0" + num_nodes: 1 \ No newline at end of file diff --git a/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag_v2.yaml b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag_v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e651ae24fb5577c3a643e18477795630ad6f9c50 --- /dev/null +++ b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag_v2.yaml @@ -0,0 +1,65 @@ +model: + base_learning_rate: 1.0e-04 + target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen + params: + spatial_group_norm: true + mid_block_attention_type: "spatial" + latent_channels: 16 + monitor: train/rec_loss + ckpt_path: vae/diffusion_pytorch_model.safetensors + down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D", + "SpatialTemporalDownBlock3D",) + up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D", + "SpatialTemporalUpBlock3D",) + lossconfig: + target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + l2_loss_weight: 0.1 + l1_loss_weight: 1.0 + perceptual_weight: 1.0 + +data: + target: train_vae.DataModuleFromConfig + + params: + batch_size: 1 + wrap: true + num_workers: 8 + train: + target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain + params: + data_json_path: pretrain.json + data_root: /your_data_root # This is used in relative path + size: 256 + degradation: pil_nearest + video_size: 256 + video_len: 49 + slice_interval: 1 + validation: + target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation + params: + data_json_path: pretrain.json + data_root: /your_data_root # This is used in relative path + size: 256 + degradation: pil_nearest + video_size: 256 + video_len: 49 + slice_interval: 1 + +lightning: + callbacks: + image_logger: + target: train_vae.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 1 + gpus: "0" + num_nodes: 1 \ No newline at end of file diff --git a/easyanimate/vae/ldm/data/dataset_callback.py b/easyanimate/vae/ldm/data/dataset_callback.py index d0d59435644faca534f24d7619d7b9cf3071c22c..f98a5528836135ccc0e4114b1c5228977b57fd61 100644 --- a/easyanimate/vae/ldm/data/dataset_callback.py +++ b/easyanimate/vae/ldm/data/dataset_callback.py @@ -1,6 +1,7 @@ #-*- encoding:utf-8 -*- from pytorch_lightning.callbacks import Callback + class DatasetCallback(Callback): def __init__(self): self.sampler_pos_start = 0 diff --git a/easyanimate/vae/ldm/data/dataset_image_video.py b/easyanimate/vae/ldm/data/dataset_image_video.py index bd0c4f6572a34659c81ca1b0d44e1acb5a58ee05..3eb85fbe31caea5f5b0fca9fff8af4c5be1ee657 100644 --- a/easyanimate/vae/ldm/data/dataset_image_video.py +++ b/easyanimate/vae/ldm/data/dataset_image_video.py @@ -17,7 +17,7 @@ from decord import VideoReader from func_timeout import FunctionTimedOut, func_set_timeout from omegaconf import OmegaConf from PIL import Image -from torch.utils.data import (BatchSampler, Dataset, Sampler) +from torch.utils.data import BatchSampler, Dataset, Sampler from tqdm import tqdm from ..modules.image_degradation import (degradation_fn_bsr, @@ -164,15 +164,18 @@ class ImageVideoDataset(Dataset): return self.base[index].get('type', 'image') def __getitem__(self, i): - @func_set_timeout(3) # time wait 3 seconds + @func_set_timeout(15) # time wait 3 seconds def get_video_item(example): if self.data_root is not None: video_reader = VideoReader(os.path.join(self.data_root, example['file_path'])) else: video_reader = VideoReader(example['file_path']) video_length = len(video_reader) - - clip_length = min(video_length, (self.video_len - 1) * self.slice_interval + 1) + if self.slice_interval == "rand": + slice_interval = np.random.choice([1, 2, 3]) + else: + slice_interval = int(self.slice_interval) + clip_length = min(video_length, (self.video_len - 1) * slice_interval + 1) start_idx = random.randint(0, video_length - clip_length) batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_len, dtype=int) diff --git a/easyanimate/vae/ldm/models/casual3dcnn.py b/easyanimate/vae/ldm/models/casual3dcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e4a60ef24d9dc7c91697ee73252dae04c4b8f9 --- /dev/null +++ b/easyanimate/vae/ldm/models/casual3dcnn.py @@ -0,0 +1,337 @@ +import time +from contextlib import contextmanager + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + +from ..modules.diffusionmodules.model import Decoder, Encoder +from ..modules.distributions.distributions import DiagonalGaussianDistribution +from ..util import instantiate_from_config +from .enc_dec import Decoder as Mag_Decoder +from .enc_dec import Encoder as Mag_Encoder + + +class AutoencoderKLMagvit(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Mag_Encoder() + self.decoder = Mag_Decoder() + self.loss = instantiate_from_config(lossconfig) + self.quant_conv = torch.nn.Conv3d(16, 16, 1) + self.post_quant_conv = torch.nn.Conv3d(8, 8, 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + if input.ndim==4: + input = input.unsqueeze(2) + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if x.ndim==5: + x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float() + return x + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # tic = time.time() + inputs = self.get_input(batch, self.image_key) + # print(f"get_input time {time.time() - tic}") + # tic = time.time() + reconstructions, posterior = self(inputs) + # print(f"model forward time {time.time() - tic}") + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + # print(f"cal loss time {time.time() - tic}") + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + # print(f"cal loss time {time.time() - tic}") + return discloss + + def validation_step(self, batch, batch_idx): + with torch.no_grad(): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # tic = time.time() + inputs = self.get_input(batch, self.image_key) + # print(f"get_input time {time.time() - tic}") + # tic = time.time() + reconstructions, posterior = self(inputs) + # print(f"model forward time {time.time() - tic}") + tic = time.time() + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + # print(f"cal loss time {time.time() - tic}") + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + # print(f"cal loss time {time.time() - tic}") + return discloss + + def validation_step(self, batch, batch_idx): + tic = time.time() + inputs = self.get_input(batch, self.image_key) + print(f"get_input time {time.time() - tic}") + tic = time.time() + reconstructions, posterior = self(inputs) + print(f"val forward time {time.time() - tic}") + tic = time.time() + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + print(f"val end time {time.time() - tic}") + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py b/easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..a64c6160f52d6db2a7b9f60bef2ffe1fda615266 --- /dev/null +++ b/easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py @@ -0,0 +1,326 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ..util import instantiate_from_config +from .cogvideox_enc_dec import (CogVideoXDecoder3D, CogVideoXEncoder3D, + CogVideoXSafeConv3d) + + +class DiagonalGaussianDistribution: + def __init__( + self, + mean: torch.Tensor, + logvar: torch.Tensor, + deterministic: bool = False, + ): + self.mean = mean + self.logvar = torch.clamp(logvar, -30.0, 20.0) + self.deterministic = deterministic + + if deterministic: + self.var = self.std = torch.zeros_like(self.mean) + else: + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + + def sample(self, generator = None) -> torch.FloatTensor: + x = torch.randn( + self.mean.shape, + generator=generator, + device=self.mean.device, + dtype=self.mean.dtype, + ) + return self.mean + self.std * x + + def mode(self): + return self.mean + + def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor: + dims = list(range(1, self.mean.ndim)) + + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=dims, + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=dims, + ) + + def nll(self, sample: torch.Tensor) -> torch.Tensor: + dims = list(range(1, self.mean.ndim)) + + if self.deterministic: + return torch.Tensor([0.0]) + + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + +@dataclass +class EncoderOutput: + latent_dist: DiagonalGaussianDistribution + +@dataclass +class DecoderOutput: + sample: torch.Tensor + +def str_eval(item): + if type(item) == str: + return eval(item) + else: + return item + +class AutoencoderKLMagvit_CogVideoX(pl.LightningModule): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types: Tuple[str] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: Tuple[int] = (128, 256, 256, 512), + latent_channels: int = 16, + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + + mini_batch_encoder=4, + mini_batch_decoder=1, + + image_key="image", + train_decoder_only=False, + train_encoder_only=False, + monitor=None, + ckpt_path=None, + lossconfig=None, + ): + super().__init__() + self.image_key = image_key + down_block_types = str_eval(down_block_types) + up_block_types = str_eval(up_block_types) + + self.encoder = CogVideoXEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + + self.decoder = CogVideoXDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None + self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None + + self.mini_batch_encoder = mini_batch_encoder + self.mini_batch_decoder = mini_batch_decoder + self.train_decoder_only = train_decoder_only + self.train_encoder_only = train_encoder_only + if train_decoder_only: + self.encoder.requires_grad_(False) + if self.quant_conv is not None: + self.quant_conv.requires_grad_(False) + if train_encoder_only: + self.decoder.requires_grad_(False) + if self.post_quant_conv is not None: + self.post_quant_conv.requires_grad_(False) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys="loss") + if lossconfig is not None: + self.loss = instantiate_from_config(lossconfig) + + def init_from_ckpt(self, path, ignore_keys=list()): + if path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + sd = load_file(path) + else: + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + m, u = self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully + print(f"Restored from {path}") + print(f"missing keys: {str(m)}, unexpected keys: {str(u)}") + + def encode(self, x: torch.Tensor) -> EncoderOutput: + h = self.encoder(x) + self.encoder._clear_fake_context_parallel_cache() + + if self.quant_conv is not None: + moments: torch.Tensor = self.quant_conv(h) + else: + moments: torch.Tensor = h + mean, logvar = moments.chunk(2, dim=1) + posterior = DiagonalGaussianDistribution(mean, logvar) + + return posterior + + def decode(self, z: torch.Tensor) -> DecoderOutput: + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + decoded = self.decoder(z) + self.decoder._clear_fake_context_parallel_cache() + return decoded + + def forward(self, input, sample_posterior=True): + if input.ndim==4: + input = input.unsqueeze(2) + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + # print("stt latent shape", z.shape) + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if x.ndim==5: + x = x.permute(0, 4, 1, 2, 3).to(memory_format=torch.contiguous_format).float() + return x + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + with torch.no_grad(): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + if self.train_decoder_only: + if self.post_quant_conv is not None: + training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters()) + else: + training_list = list(self.decoder.parameters()) + opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) + elif self.train_encoder_only: + if self.quant_conv is not None: + training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters()) + else: + training_list = list(self.encoder.parameters()) + opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) + else: + training_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + if self.quant_conv is not None: + training_list = training_list + list(self.quant_conv.parameters()) + if self.post_quant_conv is not None: + training_list = training_list + list(self.post_quant_conv.parameters()) + opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam( + list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()), + lr=lr, betas=(0.5, 0.9) + ) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.conv.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x diff --git a/easyanimate/vae/ldm/models/cogvideox_enc_dec.py b/easyanimate/vae/ldm/models/cogvideox_enc_dec.py new file mode 100644 index 0000000000000000000000000000000000000000..36d9c0302843a87ecc4ff1a8f6c111c752a66fe5 --- /dev/null +++ b/easyanimate/vae/ldm/models/cogvideox_enc_dec.py @@ -0,0 +1,312 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from diffusers.models.autoencoders.autoencoder_kl_cogvideox import ( + CogVideoXCausalConv3d, CogVideoXDownBlock3D, CogVideoXMidBlock3D, + CogVideoXSafeConv3d, CogVideoXSpatialNorm3D, CogVideoXUpBlock3D) +from diffusers.utils import logging + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogVideoXEncoder3D(nn.Module): + r""" + The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + down_block_types: Tuple[str, ...] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + # log2 of temporal_compress_times + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + self.down_blocks = nn.ModuleList([]) + + # down blocks + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if down_block_type == "CogVideoXDownBlock3D": + down_block = CogVideoXDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + compress_time=compress_time, + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=block_out_channels[-1], + temb_channels=0, + dropout=dropout, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + pad_mode=pad_mode, + ) + + self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def _clear_fake_context_parallel_cache(self): + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") + module._clear_fake_context_parallel_cache() + + def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + r"""The forward method of the `CogVideoXEncoder3D` class.""" + hidden_states = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # 1. Down + for down_block in self.down_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, temb, None + ) + + # 2. Mid + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb, None + ) + else: + # 1. Down + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states, temb, None) + + # 2. Mid + hidden_states = self.mid_block(hidden_states, temb, None) + + # 3. Post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class CogVideoXDecoder3D(nn.Module): + r""" + The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + reversed_block_out_channels = list(reversed(block_out_channels)) + + self.conv_in = CogVideoXCausalConv3d( + in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode + ) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=reversed_block_out_channels[0], + temb_channels=0, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + pad_mode=pad_mode, + ) + + # up blocks + self.up_blocks = nn.ModuleList([]) + + output_channel = reversed_block_out_channels[0] + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if up_block_type == "CogVideoXUpBlock3D": + up_block = CogVideoXUpBlock3D( + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final_block, + compress_time=compress_time, + pad_mode=pad_mode, + ) + prev_output_channel = output_channel + else: + raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") + + self.up_blocks.append(up_block) + + self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def _clear_fake_context_parallel_cache(self): + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") + module._clear_fake_context_parallel_cache() + + def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + r"""The forward method of the `CogVideoXDecoder3D` class.""" + hidden_states = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # 1. Mid + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb, sample + ) + + # 2. Up + for up_block in self.up_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), hidden_states, temb, sample + ) + else: + # 1. Mid + hidden_states = self.mid_block(hidden_states, temb, sample) + + # 2. Up + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb, sample) + + # 3. Post-process + hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states \ No newline at end of file diff --git a/easyanimate/vae/ldm/models/enc_dec_pytorch.py b/easyanimate/vae/ldm/models/enc_dec.py similarity index 100% rename from easyanimate/vae/ldm/models/enc_dec_pytorch.py rename to easyanimate/vae/ldm/models/enc_dec.py diff --git a/easyanimate/vae/ldm/models/omnigen_casual3dcnn.py b/easyanimate/vae/ldm/models/omnigen_casual3dcnn.py index 651dc66069dfaf202955292d4822dd482e814a17..12692dadfc4a6aae9cc744f4cbba05096a8ac65d 100644 --- a/easyanimate/vae/ldm/models/omnigen_casual3dcnn.py +++ b/easyanimate/vae/ldm/models/omnigen_casual3dcnn.py @@ -1,4 +1,3 @@ -import itertools from dataclasses import dataclass from typing import Optional @@ -112,10 +111,15 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): monitor=None, ckpt_path=None, lossconfig=None, + slice_mag_vae=False, slice_compression_vae=False, + cache_compression_vae=False, + cache_mag_vae=False, + spatial_group_norm=False, mini_batch_encoder=9, mini_batch_decoder=3, train_decoder_only=False, + train_encoder_only=False, ): super().__init__() self.image_key = image_key @@ -137,7 +141,10 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): act_fn=act_fn, num_attention_heads=num_attention_heads, double_z=True, + slice_mag_vae=slice_mag_vae, slice_compression_vae=slice_compression_vae, + cache_compression_vae=cache_compression_vae, + spatial_group_norm=spatial_group_norm, mini_batch_encoder=mini_batch_encoder, ) @@ -156,7 +163,11 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): norm_num_groups=norm_num_groups, act_fn=act_fn, num_attention_heads=num_attention_heads, + slice_mag_vae=slice_mag_vae, slice_compression_vae=slice_compression_vae, + cache_compression_vae=cache_compression_vae, + cache_mag_vae=cache_mag_vae, + spatial_group_norm=spatial_group_norm, mini_batch_decoder=mini_batch_decoder, ) @@ -166,9 +177,15 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): self.mini_batch_encoder = mini_batch_encoder self.mini_batch_decoder = mini_batch_decoder self.train_decoder_only = train_decoder_only + self.train_encoder_only = train_encoder_only if train_decoder_only: self.encoder.requires_grad_(False) - self.quant_conv.requires_grad_(False) + if self.quant_conv is not None: + self.quant_conv.requires_grad_(False) + if train_encoder_only: + self.decoder.requires_grad_(False) + if self.post_quant_conv is not None: + self.post_quant_conv.requires_grad_(False) if monitor is not None: self.monitor = monitor if ckpt_path is not None: @@ -190,28 +207,28 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] - self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully + m, u = self.load_state_dict(sd, strict=False) # loss.item can be ignored successfully print(f"Restored from {path}") + print(f"missing keys: {str(m)}, unexpected keys: {str(u)}") def encode(self, x: torch.Tensor) -> EncoderOutput: h = self.encoder(x) - moments: torch.Tensor = self.quant_conv(h) + if self.quant_conv is not None: + moments: torch.Tensor = self.quant_conv(h) + else: + moments: torch.Tensor = h mean, logvar = moments.chunk(2, dim=1) posterior = DiagonalGaussianDistribution(mean, logvar) - # return EncoderOutput(latent_dist=posterior) return posterior def decode(self, z: torch.Tensor) -> DecoderOutput: - z = self.post_quant_conv(z) - + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) decoded = self.decoder(z) - - # return DecoderOutput(sample=decoded) return decoded - def forward(self, input, sample_posterior=True): if input.ndim==4: input = input.unsqueeze(2) @@ -235,30 +252,22 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): return x def training_step(self, batch, batch_idx, optimizer_idx): - # tic = time.time() inputs = self.get_input(batch, self.image_key) - # print(f"get_input time {time.time() - tic}") - # tic = time.time() reconstructions, posterior = self(inputs) - # print(f"model forward time {time.time() - tic}") if optimizer_idx == 0: - # train encoder+decoder+logvar aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) - # print(f"cal loss time {time.time() - tic}") return aeloss if optimizer_idx == 1: - # train the discriminator discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) - # print(f"cal loss time {time.time() - tic}") return discloss def validation_step(self, batch, batch_idx): @@ -279,17 +288,28 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): def configure_optimizers(self): lr = self.learning_rate if self.train_decoder_only: - opt_ae = torch.optim.Adam(list(self.decoder.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) + if self.post_quant_conv is not None: + training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters()) + else: + training_list = list(self.decoder.parameters()) + opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) + elif self.train_encoder_only: + if self.quant_conv is not None: + training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters()) + else: + training_list = list(self.encoder.parameters()) + opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) else: - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()), - lr=lr, betas=(0.5, 0.9)) + training_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + if self.quant_conv is not None: + training_list = training_list + list(self.quant_conv.parameters()) + if self.post_quant_conv is not None: + training_list = training_list + list(self.post_quant_conv.parameters()) + opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam( + list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()), + lr=lr, betas=(0.5, 0.9) + ) return [opt_ae, opt_disc], [] def get_last_layer(self): diff --git a/easyanimate/vae/ldm/models/omnigen_enc_dec.py b/easyanimate/vae/ldm/models/omnigen_enc_dec.py index 99f44b4d1c016871027d443bae24d96ce2a6e0bb..ff501f14de342f0541d1023a870777bdd50703f5 100644 --- a/easyanimate/vae/ldm/models/omnigen_enc_dec.py +++ b/easyanimate/vae/ldm/models/omnigen_enc_dec.py @@ -1,6 +1,10 @@ +from typing import Any, Dict + import torch import torch.nn as nn -import numpy as np +from diffusers.utils import is_torch_version +from einops import rearrange + from ..modules.vaemodules.activations import get_activation from ..modules.vaemodules.common import CausalConv3d from ..modules.vaemodules.down_blocks import get_down_block @@ -8,6 +12,16 @@ from ..modules.vaemodules.mid_blocks import get_mid_block from ..modules.vaemodules.up_blocks import get_up_block +def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + class Encoder(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. @@ -54,7 +68,11 @@ class Encoder(nn.Module): act_fn: str = "silu", num_attention_heads: int = 1, double_z: bool = True, + slice_mag_vae: bool = False, slice_compression_vae: bool = False, + cache_compression_vae: bool = False, + cache_mag_vae: bool = False, + spatial_group_norm: bool = False, mini_batch_encoder: int = 9, verbose = False, ): @@ -118,9 +136,12 @@ class Encoder(nn.Module): conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) + self.slice_mag_vae = slice_mag_vae self.slice_compression_vae = slice_compression_vae + self.cache_compression_vae = cache_compression_vae + self.cache_mag_vae = cache_mag_vae self.mini_batch_encoder = mini_batch_encoder - self.features_share = False + self.spatial_group_norm = spatial_group_norm self.verbose = verbose def set_padding_one_frame(self): @@ -145,36 +166,142 @@ class Encoder(nn.Module): for name, module in self.named_children(): _set_padding_more_frame(name, module) + def set_magvit_padding_one_frame(self): + def _set_magvit_padding_one_frame(name, module): + if hasattr(module, 'padding_flag'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.padding_flag = 3 + for sub_name, sub_mod in module.named_children(): + _set_magvit_padding_one_frame(sub_name, sub_mod) + for name, module in self.named_children(): + _set_magvit_padding_one_frame(name, module) + + def set_magvit_padding_more_frame(self): + def _set_magvit_padding_more_frame(name, module): + if hasattr(module, 'padding_flag'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.padding_flag = 4 + for sub_name, sub_mod in module.named_children(): + _set_magvit_padding_more_frame(sub_name, sub_mod) + for name, module in self.named_children(): + _set_magvit_padding_more_frame(name, module) + + def set_cache_slice_vae_padding_one_frame(self): + def _set_cache_slice_vae_padding_one_frame(name, module): + if hasattr(module, 'padding_flag'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.padding_flag = 5 + for sub_name, sub_mod in module.named_children(): + _set_cache_slice_vae_padding_one_frame(sub_name, sub_mod) + for name, module in self.named_children(): + _set_cache_slice_vae_padding_one_frame(name, module) + + def set_cache_slice_vae_padding_more_frame(self): + def _set_cache_slice_vae_padding_more_frame(name, module): + if hasattr(module, 'padding_flag'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.padding_flag = 6 + for sub_name, sub_mod in module.named_children(): + _set_cache_slice_vae_padding_more_frame(sub_name, sub_mod) + for name, module in self.named_children(): + _set_cache_slice_vae_padding_more_frame(name, module) + + def set_3dgroupnorm_for_submodule(self): + def _set_3dgroupnorm_for_submodule(name, module): + if hasattr(module, 'set_3dgroupnorm'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.set_3dgroupnorm = True + for sub_name, sub_mod in module.named_children(): + _set_3dgroupnorm_for_submodule(sub_name, sub_mod) + for name, module in self.named_children(): + _set_3dgroupnorm_for_submodule(name, module) + def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor: # x: (B, C, T, H, W) - if self.features_share and previous_features is not None and after_features is None: + if self.training: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if previous_features is not None and after_features is None: x = torch.concat([previous_features, x], 2) - elif self.features_share and previous_features is None and after_features is not None: + elif previous_features is None and after_features is not None: x = torch.concat([x, after_features], 2) - elif self.features_share and previous_features is not None and after_features is not None: + elif previous_features is not None and after_features is not None: x = torch.concat([previous_features, x, after_features], 2) - x = self.conv_in(x) - + if self.training: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.conv_in), + x, + **ckpt_kwargs, + ) + else: + x = self.conv_in(x) for down_block in self.down_blocks: - x = down_block(x) + if self.training: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), + x, + **ckpt_kwargs, + ) + else: + x = down_block(x) x = self.mid_block(x) - x = self.conv_norm_out(x) + if self.spatial_group_norm: + batch_size = x.shape[0] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv_norm_out(x) + x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + else: + x = self.conv_norm_out(x) x = self.conv_act(x) x = self.conv_out(x) - if self.features_share and previous_features is not None and after_features is None: + if previous_features is not None and after_features is None: x = x[:, :, 1:] - elif self.features_share and previous_features is None and after_features is not None: + elif previous_features is None and after_features is not None: x = x[:, :, :2] - elif self.features_share and previous_features is not None and after_features is not None: + elif previous_features is not None and after_features is not None: x = x[:, :, 1:3] return x def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.slice_compression_vae: + if self.spatial_group_norm: + self.set_3dgroupnorm_for_submodule() + + if self.cache_mag_vae: + self.set_magvit_padding_one_frame() + first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) + self.set_magvit_padding_more_frame() + new_pixel_values = [first_frames] + for i in range(1, x.shape[2], self.mini_batch_encoder): + next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None) + new_pixel_values.append(next_frames) + new_pixel_values = torch.cat(new_pixel_values, dim=2) + elif self.cache_compression_vae: + _, _, f, _, _ = x.size() + if f % 2 != 0: + self.set_padding_one_frame() + first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) + self.set_padding_more_frame() + + new_pixel_values = [first_frames] + start_index = 1 + else: + self.set_padding_more_frame() + new_pixel_values = [] + start_index = 0 + + for i in range(start_index, x.shape[2], self.mini_batch_encoder): + next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None) + new_pixel_values.append(next_frames) + new_pixel_values = torch.cat(new_pixel_values, dim=2) + elif self.slice_compression_vae: _, _, f, _, _ = x.size() if f % 2 != 0: self.set_padding_one_frame() @@ -188,11 +315,15 @@ class Encoder(nn.Module): new_pixel_values = [] start_index = 0 - previous_features = None for i in range(start_index, x.shape[2], self.mini_batch_encoder): - after_features = x[:, :, i + self.mini_batch_encoder: i + self.mini_batch_encoder + 4, :, :] if i + self.mini_batch_encoder < x.shape[2] else None - next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], previous_features, after_features) - previous_features = x[:, :, i + self.mini_batch_encoder - 4: i + self.mini_batch_encoder, :, :] + next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None) + new_pixel_values.append(next_frames) + new_pixel_values = torch.cat(new_pixel_values, dim=2) + elif self.slice_mag_vae: + _, _, f, _, _ = x.size() + new_pixel_values = [] + for i in range(0, x.shape[2], self.mini_batch_encoder): + next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], None, None) new_pixel_values.append(next_frames) new_pixel_values = torch.cat(new_pixel_values, dim=2) else: @@ -242,7 +373,11 @@ class Decoder(nn.Module): norm_num_groups: int = 32, act_fn: str = "silu", num_attention_heads: int = 1, + slice_mag_vae: bool = False, slice_compression_vae: bool = False, + cache_compression_vae: bool = False, + cache_mag_vae: bool = False, + spatial_group_norm: bool = False, mini_batch_decoder: int = 3, verbose = False, ): @@ -309,9 +444,12 @@ class Decoder(nn.Module): self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) + self.slice_mag_vae = slice_mag_vae self.slice_compression_vae = slice_compression_vae + self.cache_compression_vae = cache_compression_vae + self.cache_mag_vae = cache_mag_vae self.mini_batch_decoder = mini_batch_decoder - self.features_share = True + self.spatial_group_norm = spatial_group_norm self.verbose = verbose def set_padding_one_frame(self): @@ -335,22 +473,90 @@ class Decoder(nn.Module): _set_padding_more_frame(sub_name, sub_mod) for name, module in self.named_children(): _set_padding_more_frame(name, module) + + def set_magvit_padding_one_frame(self): + def _set_magvit_padding_one_frame(name, module): + if hasattr(module, 'padding_flag'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.padding_flag = 3 + for sub_name, sub_mod in module.named_children(): + _set_magvit_padding_one_frame(sub_name, sub_mod) + for name, module in self.named_children(): + _set_magvit_padding_one_frame(name, module) + + def set_magvit_padding_more_frame(self): + def _set_magvit_padding_more_frame(name, module): + if hasattr(module, 'padding_flag'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.padding_flag = 4 + for sub_name, sub_mod in module.named_children(): + _set_magvit_padding_more_frame(sub_name, sub_mod) + for name, module in self.named_children(): + _set_magvit_padding_more_frame(name, module) + + def set_cache_slice_vae_padding_one_frame(self): + def _set_cache_slice_vae_padding_one_frame(name, module): + if hasattr(module, 'padding_flag'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.padding_flag = 5 + for sub_name, sub_mod in module.named_children(): + _set_cache_slice_vae_padding_one_frame(sub_name, sub_mod) + for name, module in self.named_children(): + _set_cache_slice_vae_padding_one_frame(name, module) + + def set_cache_slice_vae_padding_more_frame(self): + def _set_cache_slice_vae_padding_more_frame(name, module): + if hasattr(module, 'padding_flag'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.padding_flag = 6 + for sub_name, sub_mod in module.named_children(): + _set_cache_slice_vae_padding_more_frame(sub_name, sub_mod) + for name, module in self.named_children(): + _set_cache_slice_vae_padding_more_frame(name, module) + + def set_3dgroupnorm_for_submodule(self): + def _set_3dgroupnorm_for_submodule(name, module): + if hasattr(module, 'set_3dgroupnorm'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.set_3dgroupnorm = True + for sub_name, sub_mod in module.named_children(): + _set_3dgroupnorm_for_submodule(sub_name, sub_mod) + for name, module in self.named_children(): + _set_3dgroupnorm_for_submodule(name, module) + + def clear_cache(self): + def _clear_cache(name, module): + if hasattr(module, 'prev_features'): + if self.verbose: + print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) + module.prev_features = None + for sub_name, sub_mod in module.named_children(): + _clear_cache(sub_name, sub_mod) + for name, module in self.named_children(): + _clear_cache(name, module) def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor: # x: (B, C, T, H, W) - if self.features_share and previous_features is not None and after_features is None: + if self.training: + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + if previous_features is not None and after_features is None: b, c, t, h, w = x.size() x = torch.concat([previous_features, x], 2) x = self.conv_in(x) x = self.mid_block(x) x = x[:, :, -t:] - elif self.features_share and previous_features is None and after_features is not None: + elif previous_features is None and after_features is not None: b, c, t, h, w = x.size() x = torch.concat([x, after_features], 2) x = self.conv_in(x) x = self.mid_block(x) x = x[:, :, :t] - elif self.features_share and previous_features is not None and after_features is not None: + elif previous_features is not None and after_features is not None: _, _, t_1, _, _ = previous_features.size() _, _, t_2, _, _ = x.size() x = torch.concat([previous_features, x, after_features], 2) @@ -358,20 +564,76 @@ class Decoder(nn.Module): x = self.mid_block(x) x = x[:, :, t_1:(t_1 + t_2)] else: - x = self.conv_in(x) - x = self.mid_block(x) - + if self.training: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.conv_in), + x, + **ckpt_kwargs, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + x, + **ckpt_kwargs, + ) + else: + x = self.conv_in(x) + x = self.mid_block(x) + for up_block in self.up_blocks: - x = up_block(x) + if self.training: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + x, + **ckpt_kwargs, + ) + else: + x = up_block(x) + + if self.spatial_group_norm: + batch_size = x.shape[0] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv_norm_out(x) + x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + else: + x = self.conv_norm_out(x) - x = self.conv_norm_out(x) x = self.conv_act(x) x = self.conv_out(x) return x def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.slice_compression_vae: + if self.spatial_group_norm: + self.set_3dgroupnorm_for_submodule() + + if self.cache_mag_vae: + self.set_magvit_padding_one_frame() + first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) + self.set_magvit_padding_more_frame() + new_pixel_values = [first_frames] + for i in range(1, x.shape[2], self.mini_batch_decoder): + next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None) + new_pixel_values.append(next_frames) + new_pixel_values = torch.cat(new_pixel_values, dim=2) + elif self.cache_compression_vae: + _, _, f, _, _ = x.size() + if f == 1: + self.set_padding_one_frame() + first_frames = self.single_forward(x[:, :, :1, :, :], None, None) + new_pixel_values = [first_frames] + start_index = 1 + else: + self.set_cache_slice_vae_padding_one_frame() + first_frames = self.single_forward(x[:, :, :self.mini_batch_decoder, :, :], None, None) + new_pixel_values = [first_frames] + start_index = self.mini_batch_decoder + + for i in range(start_index, x.shape[2], self.mini_batch_decoder): + self.set_cache_slice_vae_padding_more_frame() + next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None) + new_pixel_values.append(next_frames) + new_pixel_values = torch.cat(new_pixel_values, dim=2) + elif self.slice_compression_vae: _, _, f, _, _ = x.size() if f % 2 != 0: self.set_padding_one_frame() @@ -391,6 +653,13 @@ class Decoder(nn.Module): previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :] new_pixel_values.append(next_frames) new_pixel_values = torch.cat(new_pixel_values, dim=2) + elif self.slice_mag_vae: + _, _, f, _, _ = x.size() + new_pixel_values = [] + for i in range(0, x.shape[2], self.mini_batch_decoder): + next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], None, None) + new_pixel_values.append(next_frames) + new_pixel_values = torch.cat(new_pixel_values, dim=2) else: new_pixel_values = self.single_forward(x, None, None) return new_pixel_values diff --git a/easyanimate/vae/ldm/modules/ema.py b/easyanimate/vae/ldm/modules/ema.py index 3657208473f12355e44a2d3a8b114dae79d215bc..1da8e7728dedef2e7f0c9b3c48152e89df37364c 100644 --- a/easyanimate/vae/ldm/modules/ema.py +++ b/easyanimate/vae/ldm/modules/ema.py @@ -1,7 +1,8 @@ #-*- encoding:utf-8 -*- import torch -from torch import nn from pytorch_lightning.callbacks import Callback +from torch import nn + class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): diff --git a/easyanimate/vae/ldm/modules/losses/contperceptual.py b/easyanimate/vae/ldm/modules/losses/contperceptual.py index f2dfcda26f92ca356cdbc26c74e5333b29eba502..c344005ad243e885030a5f8b25f6b8f8bee46ec2 100644 --- a/easyanimate/vae/ldm/modules/losses/contperceptual.py +++ b/easyanimate/vae/ldm/modules/losses/contperceptual.py @@ -2,8 +2,10 @@ import torch import torch.nn as nn import torch.nn.functional as F from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + from ..vaemodules.discriminator import Discriminator3D + class LPIPSWithDiscriminator(nn.Module): def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, @@ -62,15 +64,6 @@ class LPIPSWithDiscriminator(nn.Module): # get new loss_weight loss_weights = 1 - # b, _ ,f, _, _ = reconstructions.size() - # loss_weights = torch.ones([b, f]).view(b, 1, f, 1, 1) - # loss_weights[:, :, 0] = 3 - # for i in range(1, f, 8): - # loss_weights[:, :, i - 1] = 3 - # loss_weights[:, :, i] = 3 - # loss_weights[:, :, -1] = 3 - # loss_weights = loss_weights.permute(0, 2, 1, 3, 4).flatten(0, 1).to(reconstructions.device) - inputs = inputs.permute(0, 2, 1, 3, 4).flatten(0, 1) reconstructions = reconstructions.permute(0, 2, 1, 3, 4).flatten(0, 1) diff --git a/easyanimate/vae/ldm/modules/vaemodules/common.py b/easyanimate/vae/ldm/modules/vaemodules/common.py index a49999dd518e4439cb1b184972f8ac1d66abd3cd..f85bf39652f0ff95585df31906eaef2bd215e479 100755 --- a/easyanimate/vae/ldm/modules/vaemodules/common.py +++ b/easyanimate/vae/ldm/modules/vaemodules/common.py @@ -38,7 +38,7 @@ class CausalConv3d(nn.Conv3d): assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead." t_ks, h_ks, w_ks = kernel_size - _, h_stride, w_stride = stride + self.t_stride, h_stride, w_stride = stride t_dilation, h_dilation, w_dilation = dilation t_pad = (t_ks - 1) * t_dilation @@ -54,6 +54,7 @@ class CausalConv3d(nn.Conv3d): self.temporal_padding = t_pad self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2) self.padding_flag = 0 + self.prev_features = None super().__init__( in_channels=in_channels, @@ -67,38 +68,81 @@ class CausalConv3d(nn.Conv3d): def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, C, T, H, W) + dtype = x.dtype + x = x.float() if self.padding_flag == 0: x = F.pad( x, pad=(0, 0, 0, 0, self.temporal_padding, 0), mode="replicate", # TODO: check if this is necessary ) + x = x.to(dtype=dtype) + return super().forward(x) + elif self.padding_flag == 3: + x = F.pad( + x, + pad=(0, 0, 0, 0, self.temporal_padding, 0), + mode="replicate", # TODO: check if this is necessary + ) + x = x.to(dtype=dtype) + self.prev_features = x[:, :, -self.temporal_padding:] + + b, c, f, h, w = x.size() + outputs = [] + i = 0 + while i + self.temporal_padding + 1 <= f: + out = super().forward(x[:, :, i:i + self.temporal_padding + 1]) + i += self.t_stride + outputs.append(out) + return torch.concat(outputs, 2) + elif self.padding_flag == 4: + if self.t_stride == 2: + x = torch.concat( + [self.prev_features[:, :, -(self.temporal_padding - 1):], x], dim = 2 + ) + else: + x = torch.concat( + [self.prev_features, x], dim = 2 + ) + x = x.to(dtype=dtype) + self.prev_features = x[:, :, -self.temporal_padding:] + + b, c, f, h, w = x.size() + outputs = [] + i = 0 + while i + self.temporal_padding + 1 <= f: + out = super().forward(x[:, :, i:i + self.temporal_padding + 1]) + i += self.t_stride + outputs.append(out) + return torch.concat(outputs, 2) + elif self.padding_flag == 5: + x = F.pad( + x, + pad=(0, 0, 0, 0, self.temporal_padding, 0), + mode="replicate", # TODO: check if this is necessary + ) + x = x.to(dtype=dtype) + self.prev_features = x[:, :, -self.temporal_padding:] + return super().forward(x) + elif self.padding_flag == 6: + if self.t_stride == 2: + x = torch.concat( + [self.prev_features[:, :, -(self.temporal_padding - 1):], x], dim = 2 + ) + else: + x = torch.concat( + [self.prev_features, x], dim = 2 + ) + self.prev_features = x[:, :, -self.temporal_padding:] + x = x.to(dtype=dtype) + return super().forward(x) else: x = F.pad( x, pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin), ) - return super().forward(x) - - def set_padding_one_frame(self): - def _set_padding_one_frame(name, module): - if hasattr(module, 'padding_flag'): - print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) - module.padding_flag = 1 - for sub_name, sub_mod in module.named_children(): - _set_padding_one_frame(sub_name, sub_mod) - for name, module in self.named_children(): - _set_padding_one_frame(name, module) - - def set_padding_more_frame(self): - def _set_padding_more_frame(name, module): - if hasattr(module, 'padding_flag'): - print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) - module.padding_flag = 2 - for sub_name, sub_mod in module.named_children(): - _set_padding_more_frame(sub_name, sub_mod) - for name, module in self.named_children(): - _set_padding_more_frame(name, module) + x = x.to(dtype=dtype) + return super().forward(x) class ResidualBlock2D(nn.Module): def __init__( @@ -142,15 +186,29 @@ class ResidualBlock2D(nn.Module): else: self.shortcut = nn.Identity() + self.set_3dgroupnorm = False + def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = self.shortcut(x) - x = self.norm1(x) + if self.set_3dgroupnorm: + batch_size = x.shape[0] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm1(x) + x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + else: + x = self.norm1(x) x = self.nonlinearity(x) x = self.conv1(x) - x = self.norm2(x) + if self.set_3dgroupnorm: + batch_size = x.shape[0] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm2(x) + x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + else: + x = self.norm2(x) x = self.nonlinearity(x) x = self.dropout(x) @@ -201,15 +259,29 @@ class ResidualBlock3D(nn.Module): else: self.shortcut = nn.Identity() + self.set_3dgroupnorm = False + def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = self.shortcut(x) - x = self.norm1(x) + if self.set_3dgroupnorm: + batch_size = x.shape[0] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm1(x) + x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + else: + x = self.norm1(x) x = self.nonlinearity(x) x = self.conv1(x) - x = self.norm2(x) + if self.set_3dgroupnorm: + batch_size = x.shape[0] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm2(x) + x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + else: + x = self.norm2(x) x = self.nonlinearity(x) x = self.dropout(x) @@ -238,11 +310,18 @@ class SpatialNorm2D(nn.Module): self.norm = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.set_3dgroupnorm = False def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor: f_size = f.shape[-2:] zq = F.interpolate(zq, size=f_size, mode="nearest") - norm_f = self.norm(f) + if self.set_3dgroupnorm: + batch_size = f.shape[0] + f = rearrange(f, "b c t h w -> (b t) c h w") + norm_f = self.norm(f) + norm_f = rearrange(norm_f, "(b t) c h w -> b c t h w", b=batch_size) + else: + norm_f = self.norm(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f diff --git a/easyanimate/vae/ldm/modules/vaemodules/upsamplers.py b/easyanimate/vae/ldm/modules/vaemodules/upsamplers.py index 16288f13ec1cd70b3ad3e15a27a6fe676884e25e..13694832df4781d5c3dc661d272296469e47f358 100644 --- a/easyanimate/vae/ldm/modules/vaemodules/upsamplers.py +++ b/easyanimate/vae/ldm/modules/vaemodules/upsamplers.py @@ -137,6 +137,7 @@ class SpatialTemporalUpsampler3D(Upsampler): ) self.padding_flag = 0 + self.set_3dgroupnorm = False def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest") @@ -145,32 +146,12 @@ class SpatialTemporalUpsampler3D(Upsampler): if self.padding_flag == 0: if x.shape[2] > 1: first_frame, x = x[:, :, :1], x[:, :, 1:] - x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear") + x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest") x = torch.cat([first_frame, x], dim=2) - elif self.padding_flag == 2: - x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear") + elif self.padding_flag == 2 or self.padding_flag == 4 or self.padding_flag == 5 or self.padding_flag == 6: + x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear" if not self.set_3dgroupnorm else "nearest") return x - def set_padding_one_frame(self): - def _set_padding_one_frame(name, module): - if hasattr(module, 'padding_flag'): - print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) - module.padding_flag = 1 - for sub_name, sub_mod in module.named_children(): - _set_padding_one_frame(sub_name, sub_mod) - for name, module in self.named_children(): - _set_padding_one_frame(name, module) - - def set_padding_more_frame(self): - def _set_padding_more_frame(name, module): - if hasattr(module, 'padding_flag'): - print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) - module.padding_flag = 2 - for sub_name, sub_mod in module.named_children(): - _set_padding_more_frame(sub_name, sub_mod) - for name, module in self.named_children(): - _set_padding_more_frame(name, module) - class SpatialTemporalUpsamplerD2S3D(Upsampler): def __init__(self, in_channels: int, out_channels: int): super().__init__( diff --git a/easyanimate/video_caption/README.md b/easyanimate/video_caption/README.md deleted file mode 100644 index 92c6f9abbac2e39c12996a849299457a2c1dbfbe..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/README.md +++ /dev/null @@ -1,90 +0,0 @@ -# Video Caption -EasyAnimate uses multi-modal LLMs to generate captions for frames extracted from the video firstly, and then employs LLMs to summarize and refine the generated frame captions into the final video caption. By leveraging [sglang](https://github.com/sgl-project/sglang)/[vLLM](https://github.com/vllm-project/vllm) and [accelerate distributed inference](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference), the entire processing could be very fast. - -English | [简体中文](./README_zh-CN.md) - -## Quick Start -1. Cloud usage: AliyunDSW/Docker - - Check [README.md](../../README.md#quick-start) for details. - -2. Local usage - - ```shell - # Install EasyAnimate requirements firstly. - cd EasyAnimate && pip install -r requirements.txt - - # Install additional requirements for video caption. - cd easyanimate/video_caption && pip install -r requirements.txt --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ - - # Use DDP instead of DP in EasyOCR detection. - site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])') - cp -v easyocr_detection_patched.py $site_pkg_path/easyocr/detection.py - - # We strongly recommend using Docker unless you can properly handle the dependency between vllm with torch(cuda). - ``` - -## Data preprocessing -Data preprocessing can be divided into three parts: - -- Video cut. -- Video cleaning. -- Video caption. - -The input for data preprocessing can be a video folder or a metadata file (txt/csv/jsonl) containing the video path column. Please check `get_video_path_list` function in [utils/video_utils.py](utils/video_utils.py) for details. - -For easier understanding, we use one data from Panda70m as an example for data preprocessing, [Download here](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/v2/--C66yU3LjM_2.mp4). Please download the video and push it in "datasets/panda_70m/before_vcut/" - -``` -📦 datasets/ -├── 📂 panda_70m/ -│ └── 📂 before_vcut/ -│ └── 📄 --C66yU3LjM_2.mp4 -``` - -1. Video cut - - For long video cut, EasyAnimate utilizes PySceneDetect to identify scene changes within the video and performs scene cutting based on certain threshold values to ensure consistency in the themes of the video segments. After cutting, we only keep segments with lengths ranging from 3 to 10 seconds for model training. - - We have completed the parameters for ```stage_1_video_cut.sh```, so I can run it directly using the command sh ```stage_1_video_cut.sh```. After executing ```stage_1_video_cut.sh```, we obtained short videos in ```easyanimate/video_caption/datasets/panda_70m/train```. - - ```shell - sh stage_1_video_cut.sh - ``` -2. Video cleaning - - Following SVD's data preparation process, EasyAnimate provides a simple yet effective data processing pipeline for high-quality data filtering and labeling. It also supports distributed processing to accelerate the speed of data preprocessing. The overall process is as follows: - - - Duration filtering: Analyze the basic information of the video to filter out low-quality videos that are short in duration or low in resolution. This filtering result is corresponding to the video cut (3s ~ 10s videos). - - Aesthetic filtering: Filter out videos with poor content (blurry, dim, etc.) by calculating the average aesthetic score of uniformly distributed 4 frames. - - Text filtering: Use easyocr to calculate the text proportion of middle frames to filter out videos with a large proportion of text. - - Motion filtering: Calculate interframe optical flow differences to filter out videos that move too slowly or too quickly. - - The process file of **Aesthetic filtering** is ```compute_video_frame_quality.py```. After executing ```compute_video_frame_quality.py```, we obtained the file ```datasets/panda_70m/aesthetic_score.jsonl```, where each line corresponds to the aesthetic score of each video. - - The process file of **Text filtering** is ```compute_text_score.py```. After executing ```compute_text_score.py```, we obtained the file ```datasets/panda_70m/text_score.jsonl```, where each line corresponds to the text score of each video. - - The process file of **Motion filtering** is ```compute_motion_score.py```. Motion filtering is based on Aesthetic filtering and Text filtering; only samples that meet certain aesthetic scores and text scores will undergo calculation for the Motion score. After executing ```compute_motion_score.py```, we obtained the file ```datasets/panda_70m/motion_score.jsonl```, where each line corresponds to the motion score of each video. - - Then we need to filter videos by motion scores. After executing ```filter_videos_by_motion_score.py```, we get the file ```datasets/panda_70m/train.jsonl```, which includes the video we need to caption. - - We have completed the parameters for stage_2_filter_data.sh, so I can run it directly using the command sh stage_2_filter_data.sh. - - ```shell - sh stage_2_filter_data.sh - ``` -3. Video caption - - Video captioning is carried out in two stages. The first stage involves extracting frames from a video and generating descriptions for them. Subsequently, a large language model is used to summarize these descriptions into a caption. - - We have conducted a detailed and manual comparison of open sourced multi-modal LLMs such as [Qwen-VL](https://huggingface.co/Qwen/Qwen-VL), [ShareGPT4V-7B](https://huggingface.co/Lin-Chen/ShareGPT4V-7B), [deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) and etc. And we found that [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) is capable of generating more detailed captions with fewer hallucinations. Additionally, it is supported by serving engines like [sglang](https://github.com/sgl-project/sglang) and [lmdepoly](https://github.com/InternLM/lmdeploy), enabling faster inference. - - Firstly, we use ```caption_video_frame.py``` to generate frame captions. Then, we use ```caption_summary.py``` to generate summary captions. - - We have completed the parameters for stage_3_video_caption.sh, so I can run it directly using the command sh stage_3_video_caption.sh. After executing ```stage_3_video_cut.sh```, we obtained last json ```train_panda_70m.json``` for easyanimate training. - - ```shell - sh stage_3_video_caption.sh - ``` - - If you cannot access to Huggingface, you can run `export HF_ENDPOINT=https://hf-mirror.com` before the above command to download the summary caption model automatically. \ No newline at end of file diff --git a/easyanimate/video_caption/README_zh-CN.md b/easyanimate/video_caption/README_zh-CN.md deleted file mode 100644 index b1bd34f717dd23b77a4d9eb2890f2349b0b6d958..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/README_zh-CN.md +++ /dev/null @@ -1,90 +0,0 @@ -# 数据预处理 - - -EasyAnimate 对数据进行了场景切分、视频过滤和视频打标来得到高质量的有标注视频训练使用。使用多模态大型语言模型(LLMs)为从视频中提取的帧生成字幕,然后利用LLMs将生成的帧字幕总结并细化为最终的视频字幕。通过利用sglang/vLLM和加速分布式推理,高效完成视频的打标。 - -[English](./README.md) | 简体中文 - -## 快速开始 -1. 云上使用: 阿里云DSW/Docker - 参考 [README.md](../../README_zh-CN.md#quick-start) 查看更多细节。 - -2. 本地安装 - - ```shell - # Install EasyAnimate requirements firstly. - cd EasyAnimate && pip install -r requirements.txt - - # Install additional requirements for video caption. - cd easyanimate/video_caption && pip install -r requirements.txt --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ - - # Use DDP instead of DP in EasyOCR detection. - site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])') - cp -v easyocr_detection_patched.py $site_pkg_path/easyocr/detection.py - - # We strongly recommend using Docker unless you can properly handle the dependency between vllm with torch(cuda). - ``` - -## 数据预处理 -数据预处理可以分为一下三步: - -- 视频切分 -- 视频过滤 -- 视频打标 - -数据预处理的输入可以是视频文件夹或包含视频路径列的元数据文件(txt/csv/jsonl格式)。详情请查看[utils/video_utils.py](utils/video_utils.py) 文件中的 `get_video_path_list` 函数。 - -为了便于理解,我们以Panda70m的一个数据为例进行数据预处理,点击[这里](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/v2/--C66yU3LjM_2.mp4)下载视频。请下载视频并放在下面的路径:"datasets/panda_70m/before_vcut/" - -``` -📦 datasets/ -├── 📂 panda_70m/ -│ └── 📂 before_vcut/ -│ └── 📄 --C66yU3LjM_2.mp4 -``` - -1. 视频切分 - - 对于长视频剪辑,EasyAnimate 利用 PySceneDetect 来识别视频中的场景变化,并根据特定的阈值进行场景切割,以确保视频片段主题的一致性。切割后,我们只保留长度在3到10秒之间的片段,用于模型训练。 - - 我们整理了完整的方案在 ```stage_1_video_cut.sh``` 文件中, 您可以直接运行```stage_1_video_cut.sh```. 执行完成后可以在 ```easyanimate/video_caption/datasets/panda_70m/train``` 文件夹中查看结果。 - - ```shell - sh stage_1_video_cut.sh - ``` -2. 视频过滤 - - 遵循SVD([Stable Video Diffusion](https://github.com/Stability-AI/generative-models))的数据准备流程,EasyAnimate 提供了一个简单而有效的数据处理管道,用于高质量数据的过滤和标记。我们还支持分布式处理来加快数据预处理的速度。整个过程如下:: - - - 时长过滤: 分析视频的基本信息,筛选出时长过短或分辨率过低的低质量视频。我们保留3秒至10秒的视频。 - - 美学过滤: 通过计算均匀分布的4帧的平均审美分数,过滤掉内容质量差的视频(模糊、暗淡等)。 - - 文本过滤: 使用 [easyocr](https://github.com/JaidedAI/EasyOCR) 来计算中间帧的文本比例,以筛选出含有大量文本的视频。 - - 运动过滤: 计算帧间光流差异,以筛选出移动过慢或过快的视频。 - - **美学过滤** 的代码在 ```compute_video_frame_quality.py```. 执行 ```compute_video_frame_quality.py```,我们可以生成 ```datasets/panda_70m/aesthetic_score.jsonl```文件, 计算每条视频的美学得分。 - - **文本过滤** 的代码在 ```compute_text_score.py```. 执行```compute_text_score.py```, 我们可以生成 ```datasets/panda_70m/text_score.jsonl```文件, 计算每个视频的文字占比。 - - **运动过滤** 的代码在 ```compute_motion_score.py```. 运动过滤基于审美过滤和文本过滤;只有达到一定审美分数和文本分数的样本才会进行运动分数的计算。 执行 ```compute_motion_score.py```, 我们可以生成 ```datasets/panda_70m/motion_score.jsonl```, 计算每条视频的运动得分。 - - 接着执行 ```filter_videos_by_motion_score.py```来得过滤视频。我们最终得到筛选后需要打标的 ```datasets/panda_70m/train.jsonl```文件。 - - 我们将视频过滤的流程整理为 ```stage_2_filter_data.sh```,直接执行该脚本来完成视频数据的过滤。 - - ```shell - sh stage_2_filter_data.sh - ``` -3. 视频打标 - - - 视频打标生成分为两个阶段。第一阶段涉及从视频中提取帧并为它们生成描述。随后,使用大型语言模型将这些描述汇总成一条字幕。 - - 我们详细对比了现有的多模态大语言模型(诸如[Qwen-VL](https://huggingface.co/Qwen/Qwen-VL), [ShareGPT4V-7B](https://huggingface.co/Lin-Chen/ShareGPT4V-7B), [deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat))生成文本描述的效果。 最终选择 [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) 来进行视频文本描述的生成,它能生成详细的描述并有更少的幻觉。此外,我们引入 [sglang](https://github.com/sgl-project/sglang),[lmdepoly](https://github.com/InternLM/lmdeploy), 来加速推理的过程。 - - 首先,我们用 ```caption_video_frame.py``` 来生成文本描述,并用 ```caption_summary.py``` 来总结描述信息。我们将上述过程整理在 ```stage_3_video_caption.sh```, 直接运行它来生成视频的文本描述。我们最终得到 ```train_panda_70m.json``` 用于EasyAnmate 的训练。 - - ```shell - sh stage_3_video_caption.sh - ``` - - 请注意,如遇网络问题,您可以设置 `export HF_ENDPOINT=https://hf-mirror.com` 来自动下载视频打标模型。 \ No newline at end of file diff --git a/easyanimate/video_caption/caption_summary.py b/easyanimate/video_caption/caption_summary.py deleted file mode 100644 index d0c99f44cb44cb286cc86f047471fcb51fc1478c..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/caption_summary.py +++ /dev/null @@ -1,134 +0,0 @@ -import argparse -import os -import re -from tqdm import tqdm - -import pandas as pd -from vllm import LLM, SamplingParams - -from utils.logger import logger - - -def parse_args(): - parser = argparse.ArgumentParser(description="Recaption the video frame.") - parser.add_argument( - "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)." - ) - parser.add_argument( - "--video_path_column", - type=str, - default="video_path", - help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", - ) - parser.add_argument( - "--caption_column", - type=str, - default="sampled_frame_caption", - help="The column contains the sampled_frame_caption.", - ) - parser.add_argument( - "--remove_quotes", - action="store_true", - help="Whether to remove quotes from caption.", - ) - parser.add_argument( - "--batch_size", - type=int, - default=10, - required=False, - help="The batch size for the video caption.", - ) - parser.add_argument( - "--summary_model_name", - type=str, - default="mistralai/Mistral-7B-Instruct-v0.2", - ) - parser.add_argument( - "--summary_prompt", - type=str, - default=( - "You are a helpful video description generator. I'll give you a description of the middle frame of the video clip, " - "which you need to summarize it into a description of the video clip." - "Please provide your video description following these requirements: " - "1. Describe the basic and necessary information of the video in the third person, be as concise as possible. " - "2. Output the video description directly. Begin with 'In this video'. " - "3. Limit the video description within 100 words. " - "Here is the mid-frame description: " - ), - ) - parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") - parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.") - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - - if args.video_metadata_path.endswith(".csv"): - video_metadata_df = pd.read_csv(args.video_metadata_path) - elif args.video_metadata_path.endswith(".jsonl"): - video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) - else: - raise ValueError("The video_metadata_path must end with .csv or .jsonl.") - video_path_list = video_metadata_df[args.video_path_column].tolist() - sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist() - - if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): - raise ValueError("The saved_path must end with .csv or .jsonl.") - - if os.path.exists(args.saved_path): - if args.saved_path.endswith(".csv"): - saved_metadata_df = pd.read_csv(args.saved_path) - elif args.saved_path.endswith(".jsonl"): - saved_metadata_df = pd.read_json(args.saved_path, lines=True) - saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() - video_path_list = list(set(video_path_list) - set(saved_video_path_list)) - video_metadata_df.set_index(args.video_path_column, inplace=True) - video_metadata_df = video_metadata_df.loc[video_path_list] - sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist() - logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") - - sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256) - summary_model = LLM(model=args.summary_model_name, trust_remote_code=True) - - result_dict = {"video_path": [], "summary_model": [], "summary_caption": []} - - for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)): - batch_video_path = video_path_list[i: i + args.batch_size] - batch_caption = sampled_frame_caption_list[i : i + args.batch_size] - batch_prompt = [] - for caption in batch_caption: - if args.remove_quotes: - caption = re.sub(r'(["\']).*?\1', "", caption) - batch_prompt.append("user:" + args.summary_prompt + str(caption) + "\n assistant:") - batch_output = summary_model.generate(batch_prompt, sampling_params) - - result_dict["video_path"].extend(batch_video_path) - result_dict["summary_model"].extend([args.summary_model_name] * len(batch_caption)) - result_dict["summary_caption"].extend([output.outputs[0].text.rstrip() for output in batch_output]) - - # Save the metadata every args.saved_freq. - if i != 0 and ((i // args.batch_size) % args.saved_freq) == 0: - result_df = pd.DataFrame(result_dict) - if args.saved_path.endswith(".csv"): - header = True if not os.path.exists(args.saved_path) else False - result_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save result to {args.saved_path}.") - - result_dict = {"video_path": [], "summary_model": [], "summary_caption": []} - - result_df = pd.DataFrame(result_dict) - if args.saved_path.endswith(".csv"): - header = True if not os.path.exists(args.saved_path) else False - result_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save the final result to {args.saved_path}.") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/easyanimate/video_caption/caption_video_frame.py b/easyanimate/video_caption/caption_video_frame.py deleted file mode 100644 index 09ce26831b3ffddf04e05ef929c5b9fa62dd9f49..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/caption_video_frame.py +++ /dev/null @@ -1,267 +0,0 @@ -import argparse -import copy -import os - -import pandas as pd -from accelerate import PartialState -from accelerate.utils import gather_object -from natsort import natsorted -from tqdm import tqdm -from torch.utils.data import DataLoader - -from utils.logger import logger -from utils.video_dataset import VideoDataset, collate_fn -from utils.video_utils import get_video_path_list, extract_frames - - -ACCELERATE_SUPPORTED_MODELS = ["Qwen-VL-Chat", "internlm-xcomposer2-vl-7b"] -SGLANG_SUPPORTED_MODELS = ["llava-v1.6-vicuna-7b"] - - -def parse_args(): - parser = argparse.ArgumentParser(description="Recaption the video frame.") - parser.add_argument("--video_folder", type=str, default="", help="The video folder.") - parser.add_argument( - "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl/txt)." - ) - parser.add_argument( - "--video_path_column", - type=str, - default="video_path", - help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", - ) - parser.add_argument( - "--batch_size", - type=int, - default=10, - required=False, - help="The batch size for the video dataset.", - ) - parser.add_argument( - "--frame_sample_method", - type=str, - choices=["mid", "uniform"], - default="mid", - ) - parser.add_argument( - "--num_sampled_frames", - type=int, - default=1, - help="num_sampled_frames", - ) - parser.add_argument( - "--image_caption_model_name", - type=str, - choices=ACCELERATE_SUPPORTED_MODELS + SGLANG_SUPPORTED_MODELS, - default="internlm-xcomposer2-vl-7b", - ) - parser.add_argument( - "--image_caption_model_quantized", type=bool, default=True, help="Whether to use the quantized image caption model." - ) - parser.add_argument( - "--image_caption_prompt", - type=str, - default="Describe this image and its style in a very detailed manner.", - ) - parser.add_argument( - "--output_dir", - type=str, - required=True, - help="The directory to create the subfolder (named with the video name) to indicate the video has been processed.", - ) - parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") - parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.") - - args = parser.parse_args() - return args - - -def accelerate_inference(args, video_path_list): - from utils.image_captioner_awq import QwenVLChat, InternLMXComposer2 - - state = PartialState() - device = state.device - if state.num_processes == 1: - device = "cuda:0" - if args.image_caption_model_name == "internlm-xcomposer2-vl-7b": - image_caption_model = InternLMXComposer2(device=device, quantized=args.image_caption_model_quantized) - elif args.image_caption_model_name == "Qwen-VL-Chat": - image_caption_model = QwenVLChat(device=device, quantized=args.image_caption_model_quantized) - - # The workaround can be removed after https://github.com/huggingface/accelerate/pull/2781 is released. - index = len(video_path_list) - len(video_path_list) % state.num_processes - logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to avoid duplicates in state.split_between_processes.") - video_path_list = video_path_list[:index] - - if state.is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - result_list = [] - with state.split_between_processes(video_path_list) as splitted_video_path_list: - for i, video_path in enumerate(tqdm(splitted_video_path_list, desc=f"{state.device}")): - video_id = os.path.splitext(os.path.basename(video_path))[0] - try: - if not os.path.exists(video_path): - print(f"Video {video_id} does not exist. Pass it.") - continue - sampled_frame_list, sampled_frame_idx_list = extract_frames(video_path, num_sample_frames=args.num_sample_frames) - except Exception as e: - print(f"Failed to extract frames from video {video_id}. Error is {e}.") - - video_recaption_output_dir = os.path.join(args.output_dir, video_id) - if os.path.exists(video_recaption_output_dir): - print(f"Video {video_id} has been processed. Pass it.") - continue - else: - os.makedirs(video_recaption_output_dir) - - caption_list = [] - for frame, frame_idx in zip(sampled_frame_list, sampled_frame_idx_list): - frame_path = f"{args.output_dir}/{video_id}_{frame_idx}.png" - frame.save(frame_path) - try: - response, _ = image_caption_model(args.image_caption_prompt, frame_path) - except Exception as e: - print(f"Failed to caption video {video_id}. Error is {e}.") - finally: - os.remove(frame_path) - caption_list.append(response) - - result_meta = {} - if args.video_folder == "": - result_meta[args.video_path_column] = video_path - else: - result_meta[args.video_path_column] = os.path.basename(video_path) - result_meta["image_caption_model"] = args.image_caption_model_name - result_meta["prompt"] = args.image_caption_prompt - result_meta["sampled_frame_idx"] = sampled_frame_idx_list - result_meta["sampled_frame_caption"] = caption_list - result_list.append(copy.deepcopy(result_meta)) - - # Save the metadata in the main process. - if i != 0 and i % args.saved_freq == 0: - state.wait_for_everyone() - gathered_result_list = gather_object(result_list) - if state.is_main_process: - result_df = pd.DataFrame(gathered_result_list) - if args.saved_path.endswith(".csv"): - result_df.to_csv(args.saved_path, index=False) - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True) - print(f"Save result to {args.saved_path}.") - - # Wait for all processes to finish and gather the final result. - state.wait_for_everyone() - gathered_result_list = gather_object(result_list) - # Save the metadata in the main process. - if state.is_main_process: - result_df = pd.DataFrame(gathered_result_list) - if args.saved_path.endswith(".csv"): - result_df.to_csv(args.saved_path, index=False) - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True) - print(f"Save the final result to {args.saved_path}.") - - -def sglang_inference(args, video_path_list): - from utils.image_captioner_sglang import LLaVASRT - - if args.image_caption_model_name == "llava-v1.6-vicuna-7b": - image_caption_model = LLaVASRT() - - result_dict = { - "video_path": [], - "image_caption_model": [], - "prompt": [], - 'sampled_frame_idx': [], - "sampled_frame_caption": [] - } - - video_dataset = VideoDataset( - video_path_list=video_path_list, - sample_method=args.frame_sample_method, - num_sampled_frames=args.num_sampled_frames - ) - video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=16, collate_fn=collate_fn) - for idx, batch in enumerate(tqdm(video_loader)): - if len(batch) == 0: - continue - batch_video_path, batch_frame_idx = batch["video_path"], batch["sampled_frame_idx"] - # [batch_size, num_sampled_frames, H, W, C] => [batch_size * num_sampled_frames, H, W, C]. - batch_frame = [] - for item_sampled_frame in batch["sampled_frame"]: - batch_frame.extend([frame for frame in item_sampled_frame]) - - try: - response_list, _ = image_caption_model([args.image_caption_prompt] * len(batch_frame), batch_frame) - response_list = [response_list[i:i + args.num_sampled_frames] for i in range(0, len(response_list), args.num_sampled_frames)] - except Exception as e: - logger.error(f"Failed to caption video {batch_video_path}. Error is {e}.") - - result_dict["video_path"].extend(batch_video_path) - result_dict["image_caption_model"].extend([args.image_caption_model_name] * len(batch_video_path)) - result_dict["prompt"].extend([args.image_caption_prompt] * len(batch_video_path)) - result_dict["sampled_frame_idx"].extend(batch_frame_idx) - result_dict["sampled_frame_caption"].extend(response_list) - - # Save the metadata in the main process. - if idx != 0 and idx % args.saved_freq == 0: - result_df = pd.DataFrame(result_dict) - if args.saved_path.endswith(".csv"): - header = True if not os.path.exists(args.saved_path) else False - result_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save result to {args.saved_path}.") - - result_dict = { - "video_path": [], - "image_caption_model": [], - "prompt": [], - 'sampled_frame_idx': [], - "sampled_frame_caption": [] - } - - if len(result_dict["video_path"]) != 0: - result_df = pd.DataFrame(result_dict) - if args.saved_path.endswith(".csv"): - header = True if not os.path.exists(args.saved_path) else False - result_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save the final result to {args.saved_path}.") - - -def main(): - args = parse_args() - - video_path_list = get_video_path_list( - video_folder=args.video_folder, - video_metadata_path=args.video_metadata_path, - video_path_column=args.video_path_column - ) - - if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): - raise ValueError("The saved_path must end with .csv or .jsonl.") - - if os.path.exists(args.saved_path): - if args.saved_path.endswith(".csv"): - saved_metadata_df = pd.read_csv(args.saved_path) - elif args.saved_path.endswith(".jsonl"): - saved_metadata_df = pd.read_json(args.saved_path, lines=True) - saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() - saved_video_path_list = [os.path.join(args.video_folder, path) for path in saved_video_path_list] - video_path_list = list(set(video_path_list) - set(saved_video_path_list)) - # Sorting to guarantee the same result for each process. - video_path_list = natsorted(video_path_list) - logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") - - if args.image_caption_model_name in SGLANG_SUPPORTED_MODELS: - sglang_inference(args, video_path_list) - elif args.image_caption_model_name in ACCELERATE_SUPPORTED_MODELS: - accelerate_inference(args, video_path_list) - else: - raise ValueError(f"The {args.image_caption_model_name} is not supported.") - - -if __name__ == "__main__": - main() diff --git a/easyanimate/video_caption/compute_motion_score.py b/easyanimate/video_caption/compute_motion_score.py deleted file mode 100644 index 4f8afbac2ef979fcaa15fa094ebcf5109e85d25c..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/compute_motion_score.py +++ /dev/null @@ -1,196 +0,0 @@ -import ast -import argparse -import gc -import os -from contextlib import contextmanager -from pathlib import Path - -import cv2 -import numpy as np -import pandas as pd -from joblib import Parallel, delayed -from natsort import natsorted -from tqdm import tqdm - -from utils.logger import logger -from utils.video_utils import get_video_path_list - - -@contextmanager -def VideoCapture(video_path): - cap = cv2.VideoCapture(video_path) - try: - yield cap - finally: - cap.release() - del cap - gc.collect() - - -def compute_motion_score(video_path): - video_motion_scores = [] - sampling_fps = 2 - - try: - with VideoCapture(video_path) as cap: - fps = cap.get(cv2.CAP_PROP_FPS) - valid_fps = min(max(sampling_fps, 1), fps) - frame_interval = int(fps / valid_fps) - total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - - # if cannot get the second frame, use the last one - frame_interval = min(frame_interval, total_frames - 1) - - prev_frame = None - frame_count = -1 - while cap.isOpened(): - ret, frame = cap.read() - frame_count += 1 - - if not ret: - break - - # skip middle frames - if frame_count % frame_interval != 0: - continue - - gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - if prev_frame is None: - prev_frame = gray_frame - continue - - flow = cv2.calcOpticalFlowFarneback( - prev_frame, - gray_frame, - None, - pyr_scale=0.5, - levels=3, - winsize=15, - iterations=3, - poly_n=5, - poly_sigma=1.2, - flags=0, - ) - mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) - frame_motion_score = np.mean(mag) - video_motion_scores.append(frame_motion_score) - prev_frame = gray_frame - - video_meta_info = { - "video_path": Path(video_path).name, - "motion_score": round(float(np.mean(video_motion_scores)), 5), - } - return video_meta_info - - except Exception as e: - print(f"Compute motion score for video {video_path} with error: {e}.") - - -def parse_args(): - parser = argparse.ArgumentParser(description="Compute the motion score of the videos.") - parser.add_argument("--video_folder", type=str, default="", help="The video folder.") - parser.add_argument( - "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)." - ) - parser.add_argument( - "--video_path_column", - type=str, - default="video_path", - help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", - ) - parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") - parser.add_argument("--saved_freq", type=int, default=100, help="The frequency to save the output results.") - parser.add_argument("--n_jobs", type=int, default=1, help="The number of concurrent processes.") - - parser.add_argument( - "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." - ) - parser.add_argument("--asethetic_score_threshold", type=float, default=4.0, help="The asethetic score threshold.") - parser.add_argument( - "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)." - ) - parser.add_argument("--text_score_threshold", type=float, default=0.02, help="The text threshold.") - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - - video_path_list = get_video_path_list( - video_folder=args.video_folder, - video_metadata_path=args.video_metadata_path, - video_path_column=args.video_path_column - ) - - if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): - raise ValueError("The saved_path must end with .csv or .jsonl.") - - if os.path.exists(args.saved_path): - if args.saved_path.endswith(".csv"): - saved_metadata_df = pd.read_csv(args.saved_path) - elif args.saved_path.endswith(".jsonl"): - saved_metadata_df = pd.read_json(args.saved_path, lines=True) - saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() - saved_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in saved_video_path_list] - - video_path_list = list(set(video_path_list).difference(set(saved_video_path_list))) - # Sorting to guarantee the same result for each process. - video_path_list = natsorted(video_path_list) - logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") - - if args.asethetic_score_metadata_path is not None: - if args.asethetic_score_metadata_path.endswith(".csv"): - asethetic_score_df = pd.read_csv(args.asethetic_score_metadata_path) - elif args.asethetic_score_metadata_path.endswith(".jsonl"): - asethetic_score_df = pd.read_json(args.asethetic_score_metadata_path, lines=True) - - # In pandas, csv will save lists as strings, whereas jsonl will not. - asethetic_score_df["aesthetic_score"] = asethetic_score_df["aesthetic_score"].apply( - lambda x: ast.literal_eval(x) if isinstance(x, str) else x - ) - asethetic_score_df["aesthetic_score_mean"] = asethetic_score_df["aesthetic_score"].apply(lambda x: sum(x) / len(x)) - filtered_asethetic_score_df = asethetic_score_df[asethetic_score_df["aesthetic_score_mean"] < args.asethetic_score_threshold] - filtered_video_path_list = filtered_asethetic_score_df[args.video_path_column].tolist() - filtered_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in filtered_video_path_list] - - video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) - # Sorting to guarantee the same result for each process. - video_path_list = natsorted(video_path_list) - logger.info(f"Load {args.asethetic_score_metadata_path} and filter {len(filtered_video_path_list)} videos.") - - if args.text_score_metadata_path is not None: - if args.text_score_metadata_path.endswith(".csv"): - text_score_df = pd.read_csv(args.text_score_metadata_path) - elif args.text_score_metadata_path.endswith(".jsonl"): - text_score_df = pd.read_json(args.text_score_metadata_path, lines=True) - - filtered_text_score_df = text_score_df[text_score_df["text_score"] > args.text_score_threshold] - filtered_video_path_list = filtered_text_score_df[args.video_path_column].tolist() - filtered_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in filtered_video_path_list] - - video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) - # Sorting to guarantee the same result for each process. - video_path_list = natsorted(video_path_list) - logger.info(f"Load {args.text_score_metadata_path} and filter {len(filtered_video_path_list)} videos.") - - for i in tqdm(range(0, len(video_path_list), args.saved_freq)): - result_list = Parallel(n_jobs=args.n_jobs, backend="threading")( - delayed(compute_motion_score)(video_path) for video_path in tqdm(video_path_list[i: i + args.saved_freq]) - ) - result_list = [result for result in result_list if result is not None] - if len(result_list) == 0: - continue - - result_df = pd.DataFrame(result_list) - if args.saved_path.endswith(".csv"): - header = False if os.path.exists(args.saved_path) else True - result_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save result to {args.saved_path}.") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/easyanimate/video_caption/compute_text_score.py b/easyanimate/video_caption/compute_text_score.py deleted file mode 100644 index f1e8ec5302747289c050a975ff63fb8d5b242f29..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/compute_text_score.py +++ /dev/null @@ -1,198 +0,0 @@ -import ast -import argparse -import os -from pathlib import Path - -import easyocr -import numpy as np -import pandas as pd -from accelerate import PartialState -from accelerate.utils import gather_object -from natsort import natsorted -from tqdm import tqdm -from torchvision.datasets.utils import download_url - -from utils.logger import logger -from utils.video_utils import extract_frames, get_video_path_list - - -def init_ocr_reader(root: str = "~/.cache/easyocr", device: str = "gpu"): - root = os.path.expanduser(root) - if not os.path.exists(root): - os.makedirs(root) - download_url( - "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/easyocr/craft_mlt_25k.pth", - root, - filename="craft_mlt_25k.pth", - md5="2f8227d2def4037cdb3b34389dcf9ec1", - ) - ocr_reader = easyocr.Reader( - lang_list=["en", "ch_sim"], - gpu=device, - recognizer=False, - verbose=False, - model_storage_directory=root, - ) - - return ocr_reader - - -def triangle_area(p1, p2, p3): - """Compute the triangle area according to its coordinates. - """ - x1, y1 = p1 - x2, y2 = p2 - x3, y3 = p3 - tri_area = 0.5 * np.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - x1 * y3) - return tri_area - - -def compute_text_score(video_path, ocr_reader): - _, images = extract_frames(video_path, sample_method="mid") - images = [np.array(image) for image in images] - - frame_ocr_area_ratios = [] - for image in images: - # horizontal detected results and free-form detected - horizontal_list, free_list = ocr_reader.detect(np.asarray(image)) - width, height = image.shape[0], image.shape[1] - - total_area = width * height - # rectangles - rect_area = 0 - for xmin, xmax, ymin, ymax in horizontal_list[0]: - if xmax < xmin or ymax < ymin: - continue - rect_area += (xmax - xmin) * (ymax - ymin) - # free-form - quad_area = 0 - try: - for points in free_list[0]: - triangle1 = points[:3] - quad_area += triangle_area(*triangle1) - triangle2 = points[3:] + [points[0]] - quad_area += triangle_area(*triangle2) - except: - quad_area = 0 - text_area = rect_area + quad_area - - frame_ocr_area_ratios.append(text_area / total_area) - - video_meta_info = { - "video_path": Path(video_path).name, - "text_score": round(np.mean(frame_ocr_area_ratios), 5), - } - - return video_meta_info - - -def parse_args(): - parser = argparse.ArgumentParser(description="Compute the text score of the middle frame in the videos.") - parser.add_argument("--video_folder", type=str, default="", help="The video folder.") - parser.add_argument( - "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)." - ) - parser.add_argument( - "--video_path_column", - type=str, - default="video_path", - help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", - ) - parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") - parser.add_argument("--saved_freq", type=int, default=100, help="The frequency to save the output results.") - parser.add_argument( - "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." - ) - parser.add_argument("--asethetic_score_threshold", type=float, default=4.0, help="The asethetic score threshold.") - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - - video_path_list = get_video_path_list( - video_folder=args.video_folder, - video_metadata_path=args.video_metadata_path, - video_path_column=args.video_path_column - ) - - if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): - raise ValueError("The saved_path must end with .csv or .jsonl.") - - if os.path.exists(args.saved_path): - if args.saved_path.endswith(".csv"): - saved_metadata_df = pd.read_csv(args.saved_path) - elif args.saved_path.endswith(".jsonl"): - saved_metadata_df = pd.read_json(args.saved_path, lines=True) - saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() - saved_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in saved_video_path_list] - - video_path_list = list(set(video_path_list).difference(set(saved_video_path_list))) - # Sorting to guarantee the same result for each process. - video_path_list = natsorted(video_path_list) - logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") - - if args.asethetic_score_metadata_path is not None: - if args.asethetic_score_metadata_path.endswith(".csv"): - asethetic_score_df = pd.read_csv(args.asethetic_score_metadata_path) - elif args.asethetic_score_metadata_path.endswith(".jsonl"): - asethetic_score_df = pd.read_json(args.asethetic_score_metadata_path, lines=True) - - # In pandas, csv will save lists as strings, whereas jsonl will not. - asethetic_score_df["aesthetic_score"] = asethetic_score_df["aesthetic_score"].apply( - lambda x: ast.literal_eval(x) if isinstance(x, str) else x - ) - asethetic_score_df["aesthetic_score_mean"] = asethetic_score_df["aesthetic_score"].apply(lambda x: sum(x) / len(x)) - filtered_asethetic_score_df = asethetic_score_df[asethetic_score_df["aesthetic_score_mean"] < args.asethetic_score_threshold] - filtered_video_path_list = filtered_asethetic_score_df[args.video_path_column].tolist() - filtered_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in filtered_video_path_list] - - video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list))) - # Sorting to guarantee the same result for each process. - video_path_list = natsorted(video_path_list) - logger.info(f"Load {args.asethetic_score_metadata_path} and filter {len(filtered_video_path_list)} videos.") - - state = PartialState() - ocr_reader = init_ocr_reader(device=state.device) - - # The workaround can be removed after https://github.com/huggingface/accelerate/pull/2781 is released. - index = len(video_path_list) - len(video_path_list) % state.num_processes - logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to avoid duplicates in state.split_between_processes.") - video_path_list = video_path_list[:index] - - result_list = [] - with state.split_between_processes(video_path_list) as splitted_video_path_list: - for i, video_path in enumerate(tqdm(splitted_video_path_list)): - video_meta_info = compute_text_score(video_path, ocr_reader) - result_list.append(video_meta_info) - if i != 0 and i % args.saved_freq == 0: - state.wait_for_everyone() - gathered_result_list = gather_object(result_list) - if state.is_main_process: - result_df = pd.DataFrame(gathered_result_list) - if args.saved_path.endswith(".csv"): - header = False if os.path.exists(args.saved_path) else True - result_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save result to {args.saved_path}.") - result_list = [] - - state.wait_for_everyone() - gathered_result_list = gather_object(result_list) - if state.is_main_process: - logger.info(len(gathered_result_list)) - if len(gathered_result_list) != 0: - result_df = pd.DataFrame(gathered_result_list) - if args.saved_path.endswith(".csv"): - header = False if os.path.exists(args.saved_path) else True - result_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save the final result to {args.saved_path}.") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/easyanimate/video_caption/compute_video_frame_quality.py b/easyanimate/video_caption/compute_video_frame_quality.py deleted file mode 100644 index 8ebc9997c465c955d05eb3ec5c32af3862c2015b..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/compute_video_frame_quality.py +++ /dev/null @@ -1,176 +0,0 @@ -import argparse -import re -import os - -import pandas as pd -from accelerate import PartialState -from accelerate.utils import gather_object -from natsort import natsorted -from tqdm import tqdm -from torch.utils.data import DataLoader - -import utils.image_evaluator as image_evaluator -from utils.logger import logger -from utils.video_dataset import VideoDataset, collate_fn -from utils.video_utils import get_video_path_list - - -def camel2snake(s: str) -> str: - """Convert camel case to snake case.""" - if not re.match("^[A-Z]+$", s): - pattern = re.compile(r"(? 1 - - video_path_list = get_video_path_list( - video_folder=args.video_folder, - video_metadata_path=args.video_metadata_path, - video_path_column=args.video_path_column - ) - - if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): - raise ValueError("The saved_path must end with .csv or .jsonl.") - - caption_list = None - if args.video_metadata_path is not None and args.caption_column is not None: - if args.video_metadata_path.endswith(".csv"): - video_metadata_df = pd.read_csv(args.video_metadata_path) - elif args.video_metadata_path.endswith(".jsonl"): - video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) - else: - raise ValueError("The video_metadata_path must end with .csv or .jsonl.") - caption_list = video_metadata_df[args.caption_column].tolist() - - if os.path.exists(args.saved_path): - if args.saved_path.endswith(".csv"): - saved_metadata_df = pd.read_csv(args.saved_path) - elif args.saved_path.endswith(".jsonl"): - saved_metadata_df = pd.read_json(args.saved_path, lines=True) - saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() - saved_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in saved_video_path_list] - - video_path_list = list(set(video_path_list).difference(set(saved_video_path_list))) - # Sorting to guarantee the same result for each process. - video_path_list = natsorted(video_path_list) - logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") - - logger.info("Initializing evaluator metrics...") - state = PartialState() - metric_fns = [getattr(image_evaluator, metric)(device=state.device) for metric in args.metrics] - - # The workaround can be removed after https://github.com/huggingface/accelerate/pull/2781 is released. - index = len(video_path_list) - len(video_path_list) % state.num_processes - logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to avoid duplicates in state.split_between_processes.") - video_path_list = video_path_list[:index] - - result_dict = {args.video_path_column: [], "sample_frame_idx": []} - for metric in args.metrics: - result_dict[camel2snake(metric)] = [] - - with state.split_between_processes(video_path_list) as splitted_video_path_list: - video_dataset = VideoDataset( - video_path_list=splitted_video_path_list, - sample_method="uniform", - num_sampled_frames=args.num_sampled_frames - ) - video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=4, collate_fn=collate_fn) - for idx, batch in enumerate(tqdm(video_loader)): - if len(batch) == 0: - continue - batch_video_path = batch[args.video_path_column] - result_dict["sample_frame_idx"].extend(batch["sampled_frame_idx"]) - # [batch_size, num_sampled_frames, H, W, C] => [batch_size * num_sampled_frames, H, W, C]. - batch_frame = [] - for item_sampled_frame in batch["sampled_frame"]: - batch_frame.extend([frame for frame in item_sampled_frame]) - batch_caption = None - if caption_list is not None: - batch_caption = caption_list[i : i + args.batch_size] - # Compute the frame quality. - for i, metric in enumerate(args.metrics): - # [batch_size * num_sampled_frames] => [batch_size, num_sampled_frames] - quality_scores = metric_fns[i](batch_frame, batch_caption) - quality_scores = [round(score, 5) for score in quality_scores] - quality_scores = [quality_scores[j:j + args.num_sampled_frames] for j in range(0, len(quality_scores), args.num_sampled_frames)] - result_dict[camel2snake(metric)].extend(quality_scores) - - saved_video_path_list = [os.path.basename(video_path) for video_path in batch_video_path] - result_dict[args.video_path_column].extend(saved_video_path_list) - - # Save the metadata in the main process every saved_freq. - if (idx != 0) and (idx % args.saved_freq == 0): - state.wait_for_everyone() - gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()} - if state.is_main_process: - result_df = pd.DataFrame(gathered_result_dict) - if args.saved_path.endswith(".csv"): - header = False if os.path.exists(args.saved_path) else True - result_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save result to {args.saved_path}.") - for k in result_dict.keys(): - result_dict[k] = [] - - # Wait for all processes to finish and gather the final result. - state.wait_for_everyone() - gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()} - # Save the metadata in the main process. - if state.is_main_process: - result_df = pd.DataFrame(gathered_result_dict) - if len(gathered_result_dict[args.video_path_column]) != 0: - result_df = pd.DataFrame(gathered_result_dict) - if args.saved_path.endswith(".csv"): - header = False if os.path.exists(args.saved_path) else True - result_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - result_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save the final result to {args.saved_path}.") - - -if __name__ == "__main__": - main() diff --git a/easyanimate/video_caption/convert_jsonl_to_json.py b/easyanimate/video_caption/convert_jsonl_to_json.py deleted file mode 100644 index 78b7ad99ce168fb0902237d5b44b0e89da1060c3..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/convert_jsonl_to_json.py +++ /dev/null @@ -1,40 +0,0 @@ -import argparse -import json -import os - -def parse_args(): - parser = argparse.ArgumentParser(description="Convert jsonl to json.") - parser.add_argument("--video_folder", type=str, default="", help="The video folder.") - parser.add_argument( - "--jsonl_load_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)." - ) - parser.add_argument("--save_path", type=str, default=None, help="The save path to the output results.") - args = parser.parse_args() - return args - -def main(): - args = parse_args() - - with open(args.jsonl_load_path, "r") as read: - _lines = read.readlines() - - output = [] - for line in _lines: - try: - line = json.loads(line.strip()) - videoid, name = line['video_path'], line['summary_caption'] - output.append( - { - "file_path": os.path.join(args.video_folder, videoid), - "text": name, - "type": "video", - } - ) - except: - pass - - with open(args.save_path, mode="w", encoding="utf-8") as f: - json.dump(output, f, indent=2) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/easyanimate/video_caption/datasets/put preprocess datasets here.txt b/easyanimate/video_caption/datasets/put preprocess datasets here.txt deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/easyanimate/video_caption/easyocr_detection_patched.py b/easyanimate/video_caption/easyocr_detection_patched.py deleted file mode 100644 index e2cffa2b00c7c90aafcde307ce27307ed6e71dbf..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/easyocr_detection_patched.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py. -1. Disable DataParallel. -""" -import torch -import torch.backends.cudnn as cudnn -from torch.autograd import Variable -from PIL import Image -from collections import OrderedDict - -import cv2 -import numpy as np -from .craft_utils import getDetBoxes, adjustResultCoordinates -from .imgproc import resize_aspect_ratio, normalizeMeanVariance -from .craft import CRAFT - -def copyStateDict(state_dict): - if list(state_dict.keys())[0].startswith("module"): - start_idx = 1 - else: - start_idx = 0 - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = ".".join(k.split(".")[start_idx:]) - new_state_dict[name] = v - return new_state_dict - -def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False): - if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays - image_arrs = image - else: # image is single numpy array - image_arrs = [image] - - img_resized_list = [] - # resize - for img in image_arrs: - img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size, - interpolation=cv2.INTER_LINEAR, - mag_ratio=mag_ratio) - img_resized_list.append(img_resized) - ratio_h = ratio_w = 1 / target_ratio - # preprocessing - x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1)) - for n_img in img_resized_list] - x = torch.from_numpy(np.array(x)) - x = x.to(device) - - # forward pass - with torch.no_grad(): - y, feature = net(x) - - boxes_list, polys_list = [], [] - for out in y: - # make score and link map - score_text = out[:, :, 0].cpu().data.numpy() - score_link = out[:, :, 1].cpu().data.numpy() - - # Post-processing - boxes, polys, mapper = getDetBoxes( - score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars) - - # coordinate adjustment - boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) - polys = adjustResultCoordinates(polys, ratio_w, ratio_h) - if estimate_num_chars: - boxes = list(boxes) - polys = list(polys) - for k in range(len(polys)): - if estimate_num_chars: - boxes[k] = (boxes[k], mapper[k]) - if polys[k] is None: - polys[k] = boxes[k] - boxes_list.append(boxes) - polys_list.append(polys) - - return boxes_list, polys_list - -def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False): - net = CRAFT() - - if device == 'cpu': - net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) - if quantize: - try: - torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True) - except: - pass - else: - net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) - # net = torch.nn.DataParallel(net).to(device) - net = net.to(device) - cudnn.benchmark = cudnn_benchmark - - net.eval() - return net - -def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs): - result = [] - estimate_num_chars = optimal_num_chars is not None - bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector, - image, text_threshold, - link_threshold, low_text, poly, - device, estimate_num_chars) - if estimate_num_chars: - polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))] - for polys in polys_list] - - for polys in polys_list: - single_img_result = [] - for i, box in enumerate(polys): - poly = np.array(box).astype(np.int32).reshape((-1)) - single_img_result.append(poly) - result.append(single_img_result) - - return result diff --git a/easyanimate/video_caption/filter_videos_by_motion_score.py b/easyanimate/video_caption/filter_videos_by_motion_score.py deleted file mode 100644 index e622aaa405958e10bafbe30bf0f35d6b7b3063a4..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/filter_videos_by_motion_score.py +++ /dev/null @@ -1,55 +0,0 @@ -import ast -import argparse -import gc -import os -from contextlib import contextmanager -from pathlib import Path - -import cv2 -import numpy as np -import pandas as pd -from joblib import Parallel, delayed -from natsort import natsorted -from tqdm import tqdm - -from utils.logger import logger -from utils.video_utils import get_video_path_list - -def parse_args(): - parser = argparse.ArgumentParser(description="Filter the motion score of the videos.") - parser.add_argument( - "--motion_score_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)." - ) - parser.add_argument("--low_motion_score_threshold", type=float, default=3.0, help="The low motion score threshold.") - parser.add_argument("--high_motion_score_threshold", type=float, default=8.0, help="The high motion score threshold.") - parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - - if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): - raise ValueError("The saved_path must end with .csv or .jsonl.") - - if args.motion_score_metadata_path is not None: - if args.motion_score_metadata_path.endswith(".csv"): - motion_score_df = pd.read_csv(args.motion_score_metadata_path) - elif args.motion_score_metadata_path.endswith(".jsonl"): - motion_score_df = pd.read_json(args.motion_score_metadata_path, lines=True) - - filtered_motion_score_df = motion_score_df[motion_score_df["motion_score"] > args.low_motion_score_threshold] - filtered_motion_score_df = filtered_motion_score_df[motion_score_df["motion_score"] < args.high_motion_score_threshold] - - if args.saved_path.endswith(".csv"): - header = False if os.path.exists(args.saved_path) else True - filtered_motion_score_df.to_csv(args.saved_path, header=header, index=False, mode="a") - elif args.saved_path.endswith(".jsonl"): - filtered_motion_score_df.to_json(args.saved_path, orient="records", lines=True, mode="a") - logger.info(f"Save result to {args.saved_path}.") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/easyanimate/video_caption/requirements.txt b/easyanimate/video_caption/requirements.txt deleted file mode 100644 index b64e18452c0f6de98276e45b47559454de63f100..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ -auto_gptq==0.6.0 -pandas>=2.0.0 -vllm==0.3.3 -sglang[srt]==0.1.13 -func_timeout -easyocr==1.7.1 -git+https://github.com/openai/CLIP.git -natsort -joblib -scenedetect -av diff --git a/easyanimate/video_caption/scenedetect_vcut.py b/easyanimate/video_caption/scenedetect_vcut.py deleted file mode 100644 index b49c80b514c09038db3fa2a98c8ebbe4aa6c13a1..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/scenedetect_vcut.py +++ /dev/null @@ -1,235 +0,0 @@ -import argparse -import copy -import json -import os -import shutil -from multiprocessing import Pool - -from scenedetect import SceneManager, open_video -from scenedetect.detectors import ContentDetector -from scenedetect.video_splitter import split_video_ffmpeg -from tqdm import tqdm - -from utils.video_utils import download_video, get_video_path_list - -tmp_file_dir = "./tmp" -DEFAULT_FFMPEG_ARGS = '-c:v libx264 -preset veryfast -crf 22 -c:a aac' - -def parse_args(): - parser = argparse.ArgumentParser( - description = '''Cut video by PySceneDetect''') - parser.add_argument( - 'video', - type = str, - help = '''Input format: - 1. Local video file path. - 2. Video URL. - 3. Local root dir path of videos. - 4. Local txt file of video urls/local file path, line by line. - ''') - parser.add_argument( - '--threshold', - type = float, - nargs='+', - default = [10, 20, 30], - help = 'Threshold list the average change in pixel intensity must exceed to trigger a cut, one-to-one with frame_skip.') - parser.add_argument( - '--frame_skip', - type = int, - nargs='+', - default = [0, 1, 2], - help = 'Number list of frames to skip, coordinate with threshold \ - (i.e. process every 1 in N+1 frames, where N is frame_skip, \ - processing only 1/N+1 percent of the video, \ - speeding up the detection time at the expense of accuracy). One-to-one with threshold.') - parser.add_argument( - '--min_seconds', - type = int, - default = 3, - help = 'Video cut must be longer then min_seconds.') - parser.add_argument( - '--max_seconds', - type = int, - default = 12, - help = 'Video cut must be longer then min_seconds.') - parser.add_argument( - '--save_dir', - type = str, - default = "", - help = 'Video scene cuts save dir, default value means reusing input video dir.') - parser.add_argument( - '--name_template', - type = str, - default = "$VIDEO_NAME-Scene-$SCENE_NUMBER.mp4", - help = 'Video scene cuts save name template.') - parser.add_argument( - '--num_processes', - type = int, - default = os.cpu_count() // 2, - help = 'Number of CPU cores to process the video scene cut.') - parser.add_argument( - "--save_json", action="store_true", help="Whether save json in datasets." - ) - args = parser.parse_args() - return args - - -def split_video_into_scenes( - video_path: str, - threshold: list[float] = [27.0], - frame_skip: list[int] = [0], - min_seconds: int = 3, - max_seconds: int = 8, - save_dir: str = "", - name_template: str = "$VIDEO_NAME-Scene-$SCENE_NUMBER.mp4", - save_json: bool = False ): - # SceneDetect video through casceded (threshold, FPS) - frame_points = [] - frame_timecode = {} - fps = 25.0 - for thre, f_skip in zip(threshold, frame_skip): - # Open our video, create a scene manager, and add a detector. - video = open_video(video_path, backend='pyav') - scene_manager = SceneManager() - scene_manager.add_detector( - # [ContentDetector, ThresholdDetector, AdaptiveDetector] - ContentDetector(threshold=thre, min_scene_len=10) - ) - scene_manager.detect_scenes(video, frame_skip=f_skip, show_progress=False) - scene_list = scene_manager.get_scene_list() - for scene in scene_list: - for frame_time_code in scene: - frame_index = frame_time_code.get_frames() - if frame_index not in frame_points: - frame_points.append(frame_index) - frame_timecode[frame_index] = frame_time_code - fps = frame_time_code.get_framerate() - del video, scene_manager - frame_points = sorted(frame_points) - output_scene_list = [] - - # Detect No Scene Change - if len(frame_points) == 0: - video = open_video(video_path, backend='pyav') - frame_points = [0, video.duration.get_frames() - 1] - frame_timecode = { - frame_points[0]: video.base_timecode, - frame_points[-1]: video.base_timecode + video.base_timecode + video.duration - } - del video - - for idx in range(len(frame_points) - 1): - # Limit save out min seconds - if frame_points[idx+1] - frame_points[idx] < fps * min_seconds: - continue - # Limit save out max seconds - elif frame_points[idx+1] - frame_points[idx] > fps * max_seconds: - tmp_start_timecode = copy.deepcopy(frame_timecode[frame_points[idx]]) - tmp_end_timecode = copy.deepcopy(frame_timecode[frame_points[idx]]) + int(max_seconds * fps) - # Average cut by max seconds - while tmp_end_timecode.get_frames() <= frame_points[idx+1]: - output_scene_list.append(( - copy.deepcopy(tmp_start_timecode), - copy.deepcopy(tmp_end_timecode))) - tmp_start_timecode += int(max_seconds * fps) - tmp_end_timecode += int(max_seconds * fps) - if tmp_end_timecode.get_frames() > frame_points[idx+1] and frame_points[idx+1] - tmp_start_timecode.get_frames() > fps * min_seconds: - output_scene_list.append(( - copy.deepcopy(tmp_start_timecode), - frame_timecode[frame_points[idx+1]])) - del tmp_start_timecode, tmp_end_timecode - continue - output_scene_list.append(( - frame_timecode[frame_points[idx]], - frame_timecode[frame_points[idx+1]])) - - # Reuse video dir - if save_dir == "": - save_dir = os.path.dirname(video_path) - # Ensure save dir exists - elif not os.path.isdir(save_dir): - os.makedirs(save_dir) - - clip_info_path = os.path.join(save_dir, os.path.splitext(os.path.basename(video_path))[0] + '.json') - - output_file_template = os.path.join(save_dir, name_template) - split_video_ffmpeg( - video_path, - output_scene_list, - arg_override=DEFAULT_FFMPEG_ARGS, - output_file_template=output_file_template, - show_progress=False, - show_output=False) # ffmpeg print - - if save_json: - # Save clip info - json.dump( - [(frame_timecode_tuple[0].get_timecode(), frame_timecode_tuple[1].get_timecode()) for frame_timecode_tuple in output_scene_list], - open(clip_info_path, 'w'), - indent=2 - ) - - return clip_info_path - - -def process_single_video(args): - video, threshold, frame_skip, min_seconds, max_seconds, save_dir, name_template, save_json = args - basename = os.path.splitext(os.path.basename(video))[0] - # Video URL - if video.startswith("http"): - save_path = os.path.join(tmp_file_dir, f"{basename}.mp4") - download_success = download_video(video, save_path) - if not download_success: - return - video = save_path - # Local video path - else: - if not os.path.isfile(video): - print(f"Video not exists: {video}") - return - # SceneDetect video cut - try: - split_video_into_scenes( - video_path=video, - threshold=threshold, - frame_skip=frame_skip, - min_seconds=min_seconds, - max_seconds=max_seconds, - save_dir=save_dir, - name_template=name_template, - save_json=save_json - ) - except Exception as e: - print(e, video) - - -def main(): - # Args - args = parse_args() - video_input = args.video - threshold = args.threshold - frame_skip = args.frame_skip - min_seconds = args.min_seconds - max_seconds = args.max_seconds - save_dir = args.save_dir - name_template = args.name_template - num_processes = args.num_processes - save_json = args.save_json - - assert len(threshold) == len(frame_skip), \ - "Threshold must one-to-one match frame_skip." - - video_list = get_video_path_list(video_input) - args_list = [ - (video, threshold, frame_skip, min_seconds, max_seconds, save_dir, name_template, save_json) - for video in video_list - ] - - with Pool(processes=num_processes) as pool: - with tqdm(total=len(video_list)) as progress_bar: - for _ in pool.imap_unordered(process_single_video, args_list): - progress_bar.update(1) - - -if __name__ == "__main__": - main() diff --git a/easyanimate/video_caption/stage_1_video_cut.sh b/easyanimate/video_caption/stage_1_video_cut.sh deleted file mode 100644 index 817f3142ac61df896b04d2d0e04ef57cd020f115..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/stage_1_video_cut.sh +++ /dev/null @@ -1,11 +0,0 @@ -export VIDEO_FOLDER="datasets/panda_70m/before_vcut/" -export OUTPUT_FOLDER="datasets/panda_70m/train/" - -# Cut raw videos -python scenedetect_vcut.py \ - $VIDEO_FOLDER \ - --threshold 10 20 30 \ - --frame_skip 0 1 2 \ - --min_seconds 3 \ - --max_seconds 10 \ - --save_dir $OUTPUT_FOLDER \ No newline at end of file diff --git a/easyanimate/video_caption/stage_2_filter_data.sh b/easyanimate/video_caption/stage_2_filter_data.sh deleted file mode 100644 index b3d9dc41d6b6d3ded0053c45d2cce10bc949ea90..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/stage_2_filter_data.sh +++ /dev/null @@ -1,39 +0,0 @@ -export VIDEO_FOLDER="datasets/panda_70m/train" -export FRAME_QUALITY_SAVE_PATH="datasets/panda_70m/aesthetic_score.jsonl" -export TEXT_SCORE_SAVE_PATH="datasets/panda_70m/text_score.jsonl" -export MOTION_SCORE_SAVE_PATH="datasets/panda_70m/motion_score.jsonl" -export FILTER_BY_MOTION_SCORE_SAVE_PATH="datasets/panda_70m/train.jsonl" - -# Get asethetic score of all videos -CUDA_VISIBLE_DEVICES="0" accelerate launch compute_video_frame_quality.py \ - --video_folder=$VIDEO_FOLDER \ - --video_path_column="video_path" \ - --metrics="AestheticScore" \ - --saved_freq=10 \ - --saved_path=$FRAME_QUALITY_SAVE_PATH \ - --batch_size=8 - -# Get text score of all videos -CUDA_VISIBLE_DEVICES="0" accelerate launch compute_text_score.py \ - --video_folder=$VIDEO_FOLDER \ - --video_path_column="video_path" \ - --saved_freq=10 \ - --saved_path=$TEXT_SCORE_SAVE_PATH \ - --asethetic_score_metadata_path $FRAME_QUALITY_SAVE_PATH - -# Get motion score after filter videos by asethetic score and text score -python compute_motion_score.py \ - --video_folder=$VIDEO_FOLDER \ - --video_path_column="video_path" \ - --saved_freq=10 \ - --saved_path=$MOTION_SCORE_SAVE_PATH \ - --n_jobs=8 \ - --asethetic_score_metadata_path $FRAME_QUALITY_SAVE_PATH \ - --text_score_metadata_path $TEXT_SCORE_SAVE_PATH - -# Filter videos by motion score -python filter_videos_by_motion_score.py \ - --motion_score_metadata_path $MOTION_SCORE_SAVE_PATH \ - --low_motion_score_threshold=3 \ - --high_motion_score_threshold=8 \ - --saved_path $FILTER_BY_MOTION_SCORE_SAVE_PATH diff --git a/easyanimate/video_caption/stage_3_video_caption.sh b/easyanimate/video_caption/stage_3_video_caption.sh deleted file mode 100644 index 68bb0a870cf3934cd544b92eb14ffc2205a44d63..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/stage_3_video_caption.sh +++ /dev/null @@ -1,35 +0,0 @@ -export CUDA_VISIBLE_DEVICES=0 -export VIDEO_FOLDER="datasets/panda_70m/train/" -export MOTION_SCORE_META_PATH="datasets/panda_70m/train.jsonl" -export VIDEO_FRAME_CAPTION_PATH="datasets/panda_70m/frame_caption.jsonl" -export VIDEO_CAPTION_PATH="datasets/panda_70m/summary_caption.jsonl" -export LAST_JSON_PATH="datasets/panda_70m/train_panda_70m.json" - -CUDA_VISIBLE_DEVICES="0" python caption_video_frame.py \ - --video_metadata_path=$MOTION_SCORE_META_PATH \ - --video_folder=$VIDEO_FOLDER \ - --frame_sample_method="mid" \ - --num_sampled_frames=1 \ - --image_caption_model_name="llava-v1.6-vicuna-7b" \ - --image_caption_prompt="Please describe this image in detail." \ - --saved_path=$VIDEO_FRAME_CAPTION_PATH \ - --output_dir="tmp" - -CUDA_VISIBLE_DEVICES="0" python caption_summary.py \ - --video_metadata_path=$VIDEO_FRAME_CAPTION_PATH \ - --video_path_column="video_path" \ - --caption_column="sampled_frame_caption" \ - --summary_model_name="Qwen/Qwen1.5-7B-Chat" \ - --summary_prompt="You are a helpful video description generator. I'll give you a description of the middle frame of the video clip, \ - which you need to summarize it into a description of the video clip. \ - Please provide your video description following these requirements: \ - 1. Describe the basic and necessary information of the video in the third person, be as concise as possible. \ - 2. Output the video description directly. Begin with 'In this video'. \ - 3. Limit the video description within 100 words. \ - Here is the mid-frame description: " \ - --saved_path=$VIDEO_CAPTION_PATH - -python convert_jsonl_to_json.py \ - --video_folder=$VIDEO_FOLDER \ - --jsonl_load_path=$VIDEO_CAPTION_PATH \ - --save_path=$LAST_JSON_PATH \ No newline at end of file diff --git a/easyanimate/video_caption/utils/image_captioner_awq.py b/easyanimate/video_caption/utils/image_captioner_awq.py deleted file mode 100644 index b65b5f45984b099e1f35c34ee5bc18606bb8f4e2..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/utils/image_captioner_awq.py +++ /dev/null @@ -1,90 +0,0 @@ -from pathlib import Path -from typing import Tuple - -import auto_gptq -import torch -from auto_gptq.modeling import BaseGPTQForCausalLM -from transformers import AutoModelForCausalLM, AutoTokenizer - - -class QwenVLChat: - def __init__(self, device: str = "cuda:0", quantized: bool = False) -> None: - if quantized: - self.model = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen-VL-Chat-Int4", device_map=device, trust_remote_code=True - ).eval() - self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True) - else: - self.model = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen-VL-Chat", device_map=device, trust_remote_code=True, fp16=True - ).eval() - self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True) - - def __call__(self, prompt: str, image: str) -> Tuple[str, str]: - query = self.tokenizer.from_list_format([{"image": image}, {"text": prompt}]) - response, history = self.model.chat(self.tokenizer, query=query, history=[]) - - return response, history - - -class InternLMXComposer2QForCausalLM(BaseGPTQForCausalLM): - layers_block_name = "model.layers" - outside_layer_modules = [ - "vit", - "vision_proj", - "model.tok_embeddings", - "model.norm", - "output", - ] - inside_layer_modules = [ - ["attention.wqkv.linear"], - ["attention.wo.linear"], - ["feed_forward.w1.linear", "feed_forward.w3.linear"], - ["feed_forward.w2.linear"], - ] - - -class InternLMXComposer2: - def __init__(self, device: str = "cuda:0", quantized: bool = True): - if quantized: - auto_gptq.modeling._base.SUPPORTED_MODELS = ["internlm"] - self.model = InternLMXComposer2QForCausalLM.from_quantized( - "internlm/internlm-xcomposer2-vl-7b-4bit", trust_remote_code=True, device=device - ).eval() - self.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-xcomposer2-vl-7b-4bit", trust_remote_code=True) - else: - # Setting fp16=True does not work. See https://huggingface.co/internlm/internlm-xcomposer2-vl-7b/discussions/1. - self.model = ( - AutoModelForCausalLM.from_pretrained( - "internlm/internlm-xcomposer2-vl-7b", device_map=device, trust_remote_code=True - ) - .eval() - .to(torch.float16) - ) - self.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-xcomposer2-vl-7b", trust_remote_code=True) - - def __call__(self, prompt: str, image: str): - if not prompt.startswith(""): - prompt = "" + prompt - with torch.cuda.amp.autocast(), torch.no_grad(): - response, history = self.model.chat(self.tokenizer, query=prompt, image=image, history=[], do_sample=False) - return response, history - - -if __name__ == "__main__": - image_folder = "demo/" - wildcard_list = ["*.jpg", "*.png"] - image_list = [] - for wildcard in wildcard_list: - image_list.extend([str(image_path) for image_path in Path(image_folder).glob(wildcard)]) - qwen_vl_chat = QwenVLChat(device="cuda:0", quantized=True) - qwen_vl_prompt = "Please describe this image in detail." - for image in image_list: - response, _ = qwen_vl_chat(qwen_vl_prompt, image) - print(image, response) - - internlm2_vl = InternLMXComposer2(device="cuda:0", quantized=False) - internlm2_vl_prompt = "Please describe this image in detail." - for image in image_list: - response, _ = internlm2_vl(internlm2_vl_prompt, image) - print(image, response) diff --git a/easyanimate/video_caption/utils/image_captioner_sglang.py b/easyanimate/video_caption/utils/image_captioner_sglang.py deleted file mode 100644 index 050d787e2c13ecae246e4614485f9e6cd5c82646..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/utils/image_captioner_sglang.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import time -from datetime import datetime -from typing import List, Union -from pathlib import Path - -import sglang as sgl -from PIL import Image - -from utils.logger import logger - -TMP_DIR = "./tmp" - - -def get_timestamp(): - timestamp_ns = int(time.time_ns()) - milliseconds = timestamp_ns // 1000000 - formatted_time = datetime.fromtimestamp(milliseconds / 1000).strftime("%Y-%m-%d_%H-%M-%S-%f")[:-3] - - return formatted_time - - -class LLaVASRT: - def __init__(self, device: str = "cuda:0", quantized: bool = True): - self.runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.6-vicuna-7b", tokenizer_path="llava-hf/llava-1.5-7b-hf") - sgl.set_default_backend(self.runtime) - logger.info( - f"Start the SGLang runtime for llava-v1.6-vicuna-7b with chat template: {self.runtime.endpoint.chat_template.name}. " - "Input parameter device and quantized do not take effect." - ) - if not os.path.exists(TMP_DIR): - os.makedirs(TMP_DIR, exist_ok=True) - - @sgl.function - def image_qa(s, prompt: str, image: str): - s += sgl.user(sgl.image(image) + prompt) - s += sgl.assistant(sgl.gen("answer")) - - def __call__(self, prompt: Union[str, List[str]], image: Union[str, Image.Image, List[str]]): - pil_input_flag = False - if isinstance(prompt, str) and (isinstance(image, str) or isinstance(image, Image.Image)): - if isinstance(image, Image.Image): - pil_input_flag = True - image_path = os.path.join(TMP_DIR, get_timestamp() + ".jpg") - image.save(image_path) - state = self.image_qa.run(prompt=prompt, image=image, max_new_tokens=256) - # Post-process. - if pil_input_flag: - os.remove(image) - - return state["answer"], state - elif isinstance(prompt, list) and isinstance(image, list): - assert len(prompt) == len(image) - if isinstance(image[0], Image.Image): - pil_input_flag = True - image_path = [os.path.join(TMP_DIR, get_timestamp() + f"-{i}" + ".jpg") for i in range(len(image))] - for i in range(len(image)): - image[i].save(image_path[i]) - image = image_path - batch_query = [{"prompt": p, "image": img} for p, img in zip(prompt, image)] - state = self.image_qa.run_batch(batch_query, max_new_tokens=256) - # Post-process. - if pil_input_flag: - for i in range(len(image)): - os.remove(image[i]) - - return [s["answer"] for s in state], state - else: - raise ValueError("Input prompt and image must be both strings or list of strings with the same length.") - - def __del__(self): - self.runtime.shutdown() - - -if __name__ == "__main__": - image_folder = "demo/" - wildcard_list = ["*.jpg", "*.png"] - image_list = [] - for wildcard in wildcard_list: - image_list.extend([str(image_path) for image_path in Path(image_folder).glob(wildcard)]) - # SGLang need the exclusive GPU and cannot re-initialize CUDA in forked subprocess. - llava_srt = LLaVASRT() - # Batch inference. - llava_srt_prompt = ["Please describe this image in detail."] * len(image_list) - response, _ = llava_srt(llava_srt_prompt, image_list) - print(response) - llava_srt_prompt = "Please describe this image in detail." - for image in image_list: - response, _ = llava_srt(llava_srt_prompt, image) - print(image, response) \ No newline at end of file diff --git a/easyanimate/video_caption/utils/image_evaluator.py b/easyanimate/video_caption/utils/image_evaluator.py deleted file mode 100644 index 3db62f7a12d36b0945bef63ef5b9a09cc1dec8e6..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/utils/image_evaluator.py +++ /dev/null @@ -1,130 +0,0 @@ -import os -from typing import List - -import clip -import torch -import torch.nn as nn -import torch.nn.functional as F -from PIL import Image -from torchvision.datasets.utils import download_url -from transformers import AutoModel, AutoProcessor - -# All metrics. -__all__ = ["AestheticScore", "CLIPScore"] - -_MODELS = { - "CLIP_ViT-L/14": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViT-L-14.pt", - "Aesthetics_V2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/sac%2Blogos%2Bava1-l14-linearMSE.pth", -} -_MD5 = { - "CLIP_ViT-L/14": "096db1af569b284eb76b3881534822d9", - "Aesthetics_V2": "b1047fd767a00134b8fd6529bf19521a", -} - - -# if you changed the MLP architecture during training, change it also here: -class _MLP(nn.Module): - def __init__(self, input_size): - super().__init__() - self.input_size = input_size - self.layers = nn.Sequential( - nn.Linear(self.input_size, 1024), - # nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(1024, 128), - # nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(128, 64), - # nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(64, 16), - # nn.ReLU(), - nn.Linear(16, 1), - ) - - def forward(self, x): - return self.layers(x) - - -class AestheticScore: - """Compute LAION Aesthetics Score V2 based on openai/clip. Note that the default - inference dtype with GPUs is fp16 in openai/clip. - - Ref: - 1. https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py. - 2. https://github.com/openai/CLIP/issues/30. - """ - - def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"): - # The CLIP model is loaded in the evaluation mode. - self.root = os.path.expanduser(root) - if not os.path.exists(self.root): - os.makedirs(self.root) - filename = "ViT-L-14.pt" - download_url(_MODELS["CLIP_ViT-L/14"], self.root, filename=filename, md5=_MD5["CLIP_ViT-L/14"]) - self.clip_model, self.preprocess = clip.load(os.path.join(self.root, filename), device=device) - self.device = device - self._load_mlp() - - def _load_mlp(self): - filename = "sac+logos+ava1-l14-linearMSE.pth" - download_url(_MODELS["Aesthetics_V2"], self.root, filename=filename, md5=_MD5["Aesthetics_V2"]) - state_dict = torch.load(os.path.join(self.root, filename)) - self.mlp = _MLP(768) - self.mlp.load_state_dict(state_dict) - self.mlp.to(self.device) - self.mlp.eval() - - def __call__(self, images: List[Image.Image], texts=None) -> List[float]: - with torch.no_grad(): - images = torch.stack([self.preprocess(image) for image in images]).to(self.device) - image_embs = F.normalize(self.clip_model.encode_image(images)) - scores = self.mlp(image_embs.float()) # torch.float16 -> torch.float32, [N, 1] - return scores.squeeze().tolist() - - def __repr__(self) -> str: - return "aesthetic_score" - - -class CLIPScore: - """Compute CLIP scores for image-text pairs based on huggingface/transformers.""" - - def __init__( - self, - model_name_or_path: str = "openai/clip-vit-large-patch14", - torch_dtype=torch.float16, - device: str = "cpu", - ): - self.model = AutoModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).eval().to(device) - self.processor = AutoProcessor.from_pretrained(model_name_or_path) - self.torch_dtype = torch_dtype - self.device = device - - def __call__(self, images: List[Image.Image], texts: List[str]) -> List[float]: - assert len(images) == len(texts) - image_inputs = self.processor(images=images, return_tensors="pt") # {"pixel_values": } - if self.torch_dtype == torch.float16: - image_inputs["pixel_values"] = image_inputs["pixel_values"].half() - text_inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True) # {"inputs_id": } - image_inputs, text_inputs = image_inputs.to(self.device), text_inputs.to(self.device) - with torch.no_grad(): - image_embs = F.normalize(self.model.get_image_features(**image_inputs)) - text_embs = F.normalize(self.model.get_text_features(**text_inputs)) - scores = text_embs @ image_embs.T # [N, N] - - return scores.diagonal().tolist() - - def __repr__(self) -> str: - return "clip_score" - - -if __name__ == "__main__": - aesthetic_score = AestheticScore(device="cuda") - clip_score = CLIPScore(device="cuda") - - paths = ["demo/splash_cl2_midframe.jpg"] * 3 - texts = ["a joker", "a woman", "a man"] - images = [Image.open(p).convert("RGB") for p in paths] - - print(aesthetic_score(images)) - print(clip_score(images, texts)) \ No newline at end of file diff --git a/easyanimate/video_caption/utils/logger.py b/easyanimate/video_caption/utils/logger.py deleted file mode 100644 index 754eaf6b379aa39e8b9469c95e17c8ec8128e30d..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/utils/logger.py +++ /dev/null @@ -1,36 +0,0 @@ -# Borrowed from sd-webui-controlnet/scripts/logging.py -import copy -import logging -import sys - - -class ColoredFormatter(logging.Formatter): - COLORS = { - "DEBUG": "\033[0;36m", # CYAN - "INFO": "\033[0;32m", # GREEN - "WARNING": "\033[0;33m", # YELLOW - "ERROR": "\033[0;31m", # RED - "CRITICAL": "\033[0;37;41m", # WHITE ON RED - "RESET": "\033[0m", # RESET COLOR - } - - def format(self, record): - colored_record = copy.copy(record) - levelname = colored_record.levelname - seq = self.COLORS.get(levelname, self.COLORS["RESET"]) - colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" - return super().format(colored_record) - - -# Create a new logger -logger = logging.getLogger("VideoCaption") -logger.propagate = False - -# Add handler if we don't have one. -if not logger.handlers: - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) - logger.addHandler(handler) - -# Configure logger -logger.setLevel("INFO") diff --git a/easyanimate/video_caption/utils/video_dataset.py b/easyanimate/video_caption/utils/video_dataset.py deleted file mode 100644 index 537c4110627d9edb267f96df591c90351b7db0fb..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/utils/video_dataset.py +++ /dev/null @@ -1,83 +0,0 @@ -from pathlib import Path - -import pandas as pd -from func_timeout import FunctionTimedOut, func_timeout -from torch.utils.data import DataLoader, Dataset - -from utils.logger import logger -from utils.video_utils import get_video_path_list, extract_frames - -ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov"]) -VIDEO_READER_TIMEOUT = 10 - - -def collate_fn(batch): - batch = list(filter(lambda x: x is not None, batch)) - if len(batch) != 0: - return {k: [item[k] for item in batch] for k in batch[0].keys()} - return {} - - -class VideoDataset(Dataset): - def __init__( - self, - video_path_list=None, - video_folder=None, - video_metadata_path=None, - video_path_column=None, - sample_method="mid", - num_sampled_frames=1, - num_sample_stride=None, - ): - self.video_path_column = video_path_column - self.video_folder = video_folder - self.sample_method = sample_method - self.num_sampled_frames = num_sampled_frames - self.num_sample_stride = num_sample_stride - - if video_path_list is not None: - self.video_path_list = video_path_list - self.metadata_df = pd.DataFrame({video_path_column: self.video_path_list}) - else: - self.video_path_list = get_video_path_list( - video_folder=video_folder, - video_metadata_path=video_metadata_path, - video_path_column=video_path_column - ) - - def __getitem__(self, index): - # video_path = os.path.join(self.video_folder, str(self.video_path_list[index])) - video_path = self.video_path_list[index] - try: - sample_args = (video_path, self.sample_method, self.num_sampled_frames, self.num_sample_stride) - sampled_frame_idx_list, sampled_frame_list = func_timeout( - VIDEO_READER_TIMEOUT, extract_frames, args=sample_args - ) - except FunctionTimedOut: - logger.warning(f"Read {video_path} timeout.") - return None - except Exception as e: - logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.") - return None - item = { - "video_path": Path(video_path).name, - "sampled_frame_idx": sampled_frame_idx_list, - "sampled_frame": sampled_frame_list, - } - - return item - - def __len__(self): - return len(self.video_path_list) - - -if __name__ == "__main__": - video_folder = "your_video_folder" - video_dataset = VideoDataset(video_folder=video_folder) - - video_dataloader = DataLoader( - video_dataset, batch_size=16, num_workers=16, collate_fn=collate_fn - ) - for idx, batch in enumerate(video_dataloader): - if len(batch) != 0: - print(batch["video_path"], batch["sampled_frame_idx"], len(batch["video_path"])) \ No newline at end of file diff --git a/easyanimate/video_caption/utils/video_utils.py b/easyanimate/video_caption/utils/video_utils.py deleted file mode 100644 index 4219a208fd2bf9dfedba4daa77d9dc9dae373bdb..0000000000000000000000000000000000000000 --- a/easyanimate/video_caption/utils/video_utils.py +++ /dev/null @@ -1,108 +0,0 @@ -import gc -import os -import random -import urllib.request as request -from contextlib import contextmanager -from pathlib import Path -from typing import List, Tuple, Optional - -import numpy as np -import pandas as pd -from decord import VideoReader -from PIL import Image - -ALL_VIDEO_EXT = set([".mp4", ".webm", ".mkv", ".avi", ".flv", ".mov"]) - - -def get_video_path_list( - video_folder: Optional[str]=None, - video_metadata_path: Optional[str]=None, - video_path_column: Optional[str]=None -) -> List[str]: - """Get all video (absolute) path list from the video folder or the video metadata file. - - Args: - video_folder (str): The absolute path of the folder (including sub-folders) containing all the required video files. - video_metadata_path (str): The absolute path of the video metadata file containing video path list. - video_path_column (str): The column/key for the corresponding video path in the video metadata file (csv/jsonl). - """ - if video_folder is None and video_metadata_path is None: - raise ValueError("Either the video_input or the video_metadata_path should be specified.") - if video_metadata_path is not None: - if video_metadata_path.endswith(".csv"): - if video_path_column is None: - raise ValueError("The video_path_column can not be None if provided a csv file.") - metadata_df = pd.read_csv(video_metadata_path) - video_path_list = metadata_df[video_path_column].tolist() - elif video_metadata_path.endswith(".jsonl"): - if video_path_column is None: - raise ValueError("The video_path_column can not be None if provided a jsonl file.") - metadata_df = pd.read_json(video_metadata_path, lines=True) - video_path_list = metadata_df[video_path_column].tolist() - elif video_metadata_path.endswith(".txt"): - with open(video_metadata_path, "r", encoding="utf-8") as f: - video_path_list = [line.strip() for line in f] - else: - raise ValueError("The video_metadata_path must end with `.csv`, `.jsonl` or `.txt`.") - if video_folder is not None: - video_path_list = [os.path.join(video_folder, video_path) for video_path in video_path_list] - return video_path_list - - if os.path.isfile(video_folder): - video_path_list = [] - if video_folder.endswith("mp4"): - video_path_list.append(video_folder) - elif video_folder.endswith("txt"): - with open(video_folder, "r") as file: - video_path_list += [line.strip() for line in file.readlines()] - return video_path_list - - elif video_folder is not None: - video_path_list = [] - for ext in ALL_VIDEO_EXT: - video_path_list.extend(Path(video_folder).rglob(f"*{ext}")) - video_path_list = [str(video_path) for video_path in video_path_list] - return video_path_list - - -@contextmanager -def video_reader(*args, **kwargs): - """A context manager to solve the memory leak of decord. - """ - vr = VideoReader(*args, **kwargs) - try: - yield vr - finally: - del vr - gc.collect() - - -def extract_frames( - video_path: str, sample_method: str = "mid", num_sampled_frames: int = -1, sample_stride: int = -1 -) -> Optional[Tuple[List[int], List[Image.Image]]]: - with video_reader(video_path, num_threads=2) as vr: - if sample_method == "mid": - sampled_frame_idx_list = [len(vr) // 2] - elif sample_method == "uniform": - sampled_frame_idx_list = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int) - elif sample_method == "random": - clip_length = min(len(vr), (num_sampled_frames - 1) * sample_stride + 1) - start_idx = random.randint(0, len(vr) - clip_length) - sampled_frame_idx_list = np.linspace(start_idx, start_idx + clip_length - 1, num_sampled_frames, dtype=int) - else: - raise ValueError("The sample_method must be mid, uniform or random.") - sampled_frame_list = vr.get_batch(sampled_frame_idx_list).asnumpy() - sampled_frame_list = [Image.fromarray(frame) for frame in sampled_frame_list] - - return list(sampled_frame_idx_list), sampled_frame_list - - -def download_video( - video_url: str, - save_path: str) -> bool: - try: - request.urlretrieve(video_url, save_path) - return os.path.isfile(save_path) - except Exception as e: - print(e, video_url) - return False \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9f19350a36ade007d1d5875823b35e64af91887b..4ee48530662b177c0877c7f281b01ba199f794f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,10 +3,10 @@ einops safetensors timm tomesd -accelerate -torch>=2.2.0 +torch>=2.1.2 torchdiffeq torchsde +xformers decord datasets numpy @@ -18,6 +18,11 @@ albumentations imageio[ffmpeg] imageio[pyav] tensorboard -gradio==3.41.2 -diffusers==0.27.0 -transformers==4.37.2 +beautifulsoup4 +ftfy +func_timeout +deepspeed +accelerate>=0.25.0 +gradio>=3.41.2 +diffusers>=0.30.1 +transformers>=4.37.2 \ No newline at end of file