Spaces:
Running
Running
bubbliiiing
commited on
Commit
·
c2a6cd2
1
Parent(s):
b0f1243
Update V5.1
Browse files- app.py +9 -18
- config/easyanimate_video_v5.1_magvit_qwen.yaml +21 -0
- easyanimate/api/api.py +1 -1
- easyanimate/api/post_infer.py +2 -2
- easyanimate/data/dataset_image_video.py +220 -32
- easyanimate/models/__init__.py +3 -4
- easyanimate/models/attention.py +60 -31
- easyanimate/models/autoencoder_magvit.py +15 -117
- easyanimate/models/embeddings.py +3 -2
- easyanimate/models/norm.py +16 -0
- easyanimate/models/processor.py +146 -0
- easyanimate/models/transformer3d.py +280 -43
- easyanimate/pipeline/pipeline_easyanimate.py +730 -486
- easyanimate/pipeline/{pipeline_easyanimate_multi_text_encoder_control.py → pipeline_easyanimate_control.py} +448 -229
- easyanimate/pipeline/pipeline_easyanimate_inpaint.py +0 -0
- easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py +0 -925
- easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py +0 -1334
- easyanimate/ui/ui.py +237 -179
- easyanimate/utils/lora_utils.py +42 -30
- easyanimate/utils/utils.py +53 -33
- easyanimate/vae/ldm/models/autoencoder.py +4 -4
- easyanimate/vae/ldm/models/casual3dcnn.py +5 -5
- easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py +5 -5
- easyanimate/vae/ldm/models/omnigen_casual3dcnn.py +13 -9
- easyanimate/vae/ldm/models/omnigen_enc_dec.py +6 -2
- easyanimate/vae/ldm/modules/losses/contperceptual.py +20 -3
- easyanimate/vae/ldm/modules/vaemodules/__init__.py +0 -0
- easyanimate/vae/ldm/modules/vaemodules/activations.py +0 -0
- easyanimate/vae/ldm/modules/vaemodules/common.py +39 -5
- easyanimate/vae/ldm/modules/vaemodules/down_blocks.py +0 -0
- easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py +0 -0
- easyanimate/vae/ldm/modules/vaemodules/up_blocks.py +0 -0
- requirements.txt +2 -5
app.py
CHANGED
@@ -19,6 +19,9 @@ if __name__ == "__main__":
|
|
19 |
#
|
20 |
# "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use,
|
21 |
# resulting in slower speeds but saving a large amount of GPU memory.
|
|
|
|
|
|
|
22 |
GPU_memory_mode = "model_cpu_offload_and_qfloat8"
|
23 |
# Use torch.float16 if GPU does not support torch.bfloat16
|
24 |
# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
|
@@ -29,11 +32,11 @@ if __name__ == "__main__":
|
|
29 |
server_port = 7860
|
30 |
|
31 |
# Params below is used when ui_mode = "modelscope"
|
32 |
-
edition = "v5"
|
33 |
# Config
|
34 |
-
config_path = "config/
|
35 |
# Model path of the pretrained model
|
36 |
-
model_name = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP"
|
37 |
# "Inpaint" or "Control"
|
38 |
model_type = "Inpaint"
|
39 |
# Save dir
|
@@ -46,18 +49,6 @@ if __name__ == "__main__":
|
|
46 |
else:
|
47 |
demo, controller = ui(GPU_memory_mode, weight_dtype)
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
server_port=server_port,
|
53 |
-
prevent_thread_lock=True
|
54 |
-
)
|
55 |
-
|
56 |
-
# launch api
|
57 |
-
infer_forward_api(None, app, controller)
|
58 |
-
update_diffusion_transformer_api(None, app, controller)
|
59 |
-
update_edition_api(None, app, controller)
|
60 |
-
|
61 |
-
# not close the python
|
62 |
-
while True:
|
63 |
-
time.sleep(5)
|
|
|
19 |
#
|
20 |
# "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use,
|
21 |
# resulting in slower speeds but saving a large amount of GPU memory.
|
22 |
+
#
|
23 |
+
# EasyAnimateV1, V2 and V3 support "model_cpu_offload" "sequential_cpu_offload"
|
24 |
+
# EasyAnimateV4, V5 and V5.1 support "model_cpu_offload" "model_cpu_offload_and_qfloat8" "sequential_cpu_offload"
|
25 |
GPU_memory_mode = "model_cpu_offload_and_qfloat8"
|
26 |
# Use torch.float16 if GPU does not support torch.bfloat16
|
27 |
# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
|
|
|
32 |
server_port = 7860
|
33 |
|
34 |
# Params below is used when ui_mode = "modelscope"
|
35 |
+
edition = "v5.1"
|
36 |
# Config
|
37 |
+
config_path = "config/easyanimate_video_v5.1_magvit_qwen.yaml"
|
38 |
# Model path of the pretrained model
|
39 |
+
model_name = "models/Diffusion_Transformer/EasyAnimateV5.1-12b-zh-InP"
|
40 |
# "Inpaint" or "Control"
|
41 |
model_type = "Inpaint"
|
42 |
# Save dir
|
|
|
49 |
else:
|
50 |
demo, controller = ui(GPU_memory_mode, weight_dtype)
|
51 |
|
52 |
+
demo.launch(
|
53 |
+
server_name=server_name, server_port=server_port
|
54 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/easyanimate_video_v5.1_magvit_qwen.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformer_additional_kwargs:
|
2 |
+
transformer_type: "EasyAnimateTransformer3DModel"
|
3 |
+
after_norm: false
|
4 |
+
time_position_encoding_type: "3d_rope"
|
5 |
+
resize_inpaint_mask_directly: true
|
6 |
+
enable_text_attention_mask: true
|
7 |
+
enable_clip_in_inpaint: false
|
8 |
+
add_ref_latent_in_control_model: true
|
9 |
+
|
10 |
+
vae_kwargs:
|
11 |
+
vae_type: "AutoencoderKLMagvit"
|
12 |
+
mini_batch_encoder: 4
|
13 |
+
mini_batch_decoder: 1
|
14 |
+
slice_mag_vae: false
|
15 |
+
slice_compression_vae: false
|
16 |
+
cache_compression_vae: false
|
17 |
+
cache_mag_vae: true
|
18 |
+
|
19 |
+
text_encoder_kwargs:
|
20 |
+
enable_multi_text_encoder: false
|
21 |
+
replace_t5_to_llm: true
|
easyanimate/api/api.py
CHANGED
@@ -93,7 +93,7 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
|
93 |
lora_model_path = datas.get('lora_model_path', 'none')
|
94 |
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
95 |
prompt_textbox = datas.get('prompt_textbox', None)
|
96 |
-
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Blurring, mutation, deformation, distortion, dark and solid, comics.')
|
97 |
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
98 |
sample_step_slider = datas.get('sample_step_slider', 30)
|
99 |
resize_method = datas.get('resize_method', "Generate by")
|
|
|
93 |
lora_model_path = datas.get('lora_model_path', 'none')
|
94 |
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
95 |
prompt_textbox = datas.get('prompt_textbox', None)
|
96 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art.')
|
97 |
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
98 |
sample_step_slider = datas.get('sample_step_slider', 30)
|
99 |
resize_method = datas.get('resize_method', "Generate by")
|
easyanimate/api/post_infer.py
CHANGED
@@ -54,14 +54,14 @@ if __name__ == '__main__':
|
|
54 |
# -------------------------- #
|
55 |
# Step 1: update edition
|
56 |
# -------------------------- #
|
57 |
-
edition = "v5"
|
58 |
outputs = post_update_edition(edition)
|
59 |
print('Output update edition: ', outputs)
|
60 |
|
61 |
# -------------------------- #
|
62 |
# Step 2: update edition
|
63 |
# -------------------------- #
|
64 |
-
diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP"
|
65 |
outputs = post_diffusion_transformer(diffusion_transformer_path)
|
66 |
print('Output update edition: ', outputs)
|
67 |
|
|
|
54 |
# -------------------------- #
|
55 |
# Step 1: update edition
|
56 |
# -------------------------- #
|
57 |
+
edition = "v5.1"
|
58 |
outputs = post_update_edition(edition)
|
59 |
print('Output update edition: ', outputs)
|
60 |
|
61 |
# -------------------------- #
|
62 |
# Step 2: update edition
|
63 |
# -------------------------- #
|
64 |
+
diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV5.1-12b-zh-InP"
|
65 |
outputs = post_diffusion_transformer(diffusion_transformer_path)
|
66 |
print('Output update edition: ', outputs)
|
67 |
|
easyanimate/data/dataset_image_video.py
CHANGED
@@ -12,9 +12,12 @@ import albumentations
|
|
12 |
import cv2
|
13 |
import numpy as np
|
14 |
import torch
|
|
|
15 |
import torchvision.transforms as transforms
|
16 |
from decord import VideoReader
|
|
|
17 |
from func_timeout import FunctionTimedOut, func_timeout
|
|
|
18 |
from PIL import Image
|
19 |
from torch.utils.data import BatchSampler, Sampler
|
20 |
from torch.utils.data.dataset import Dataset
|
@@ -100,6 +103,152 @@ def get_random_mask(shape):
|
|
100 |
else:
|
101 |
raise ValueError(f"The mask_index {mask_index} is not define")
|
102 |
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
class ImageVideoSampler(BatchSampler):
|
105 |
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
@@ -184,7 +333,7 @@ class ImageVideoDataset(Dataset):
|
|
184 |
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
185 |
image_sample_size=512,
|
186 |
video_repeat=0,
|
187 |
-
text_drop_ratio
|
188 |
enable_bucket=False,
|
189 |
video_length_drop_start=0.1,
|
190 |
video_length_drop_end=0.9,
|
@@ -355,7 +504,6 @@ class ImageVideoDataset(Dataset):
|
|
355 |
|
356 |
return sample
|
357 |
|
358 |
-
|
359 |
class ImageVideoControlDataset(Dataset):
|
360 |
def __init__(
|
361 |
self,
|
@@ -363,11 +511,12 @@ class ImageVideoControlDataset(Dataset):
|
|
363 |
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
364 |
image_sample_size=512,
|
365 |
video_repeat=0,
|
366 |
-
text_drop_ratio
|
367 |
enable_bucket=False,
|
368 |
video_length_drop_start=0.1,
|
369 |
video_length_drop_end=0.9,
|
370 |
enable_inpaint=False,
|
|
|
371 |
):
|
372 |
# Loading annotations from files
|
373 |
print(f"loading annotations from {ann_path} ...")
|
@@ -397,6 +546,7 @@ class ImageVideoControlDataset(Dataset):
|
|
397 |
self.enable_bucket = enable_bucket
|
398 |
self.text_drop_ratio = text_drop_ratio
|
399 |
self.enable_inpaint = enable_inpaint
|
|
|
400 |
|
401 |
self.video_length_drop_start = video_length_drop_start
|
402 |
self.video_length_drop_end = video_length_drop_end
|
@@ -412,6 +562,13 @@ class ImageVideoControlDataset(Dataset):
|
|
412 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
413 |
]
|
414 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
|
416 |
# Image params
|
417 |
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
@@ -484,33 +641,59 @@ class ImageVideoControlDataset(Dataset):
|
|
484 |
else:
|
485 |
control_video_id = os.path.join(self.data_root, control_video_id)
|
486 |
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
if not self.enable_bucket:
|
505 |
-
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
506 |
-
control_pixel_values = control_pixel_values / 255.
|
507 |
-
del control_video_reader
|
508 |
else:
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
else:
|
515 |
image_path, text = data_info['file_path'], data_info['text']
|
516 |
if self.data_root is not None:
|
@@ -536,7 +719,8 @@ class ImageVideoControlDataset(Dataset):
|
|
536 |
control_image = self.image_transforms(control_image).unsqueeze(0)
|
537 |
else:
|
538 |
control_image = np.expand_dims(np.array(control_image), 0)
|
539 |
-
|
|
|
540 |
|
541 |
def __len__(self):
|
542 |
return self.length
|
@@ -552,13 +736,17 @@ class ImageVideoControlDataset(Dataset):
|
|
552 |
if data_type_local != data_type:
|
553 |
raise ValueError("data_type_local != data_type")
|
554 |
|
555 |
-
pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
|
|
|
556 |
sample["pixel_values"] = pixel_values
|
557 |
sample["control_pixel_values"] = control_pixel_values
|
558 |
sample["text"] = name
|
559 |
sample["data_type"] = data_type
|
560 |
sample["idx"] = idx
|
561 |
-
|
|
|
|
|
|
|
562 |
if len(sample) > 0:
|
563 |
break
|
564 |
except Exception as e:
|
|
|
12 |
import cv2
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
import torchvision.transforms as transforms
|
17 |
from decord import VideoReader
|
18 |
+
from einops import rearrange
|
19 |
from func_timeout import FunctionTimedOut, func_timeout
|
20 |
+
from packaging import version as pver
|
21 |
from PIL import Image
|
22 |
from torch.utils.data import BatchSampler, Sampler
|
23 |
from torch.utils.data.dataset import Dataset
|
|
|
103 |
else:
|
104 |
raise ValueError(f"The mask_index {mask_index} is not define")
|
105 |
return mask
|
106 |
+
|
107 |
+
class Camera(object):
|
108 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
109 |
+
"""
|
110 |
+
def __init__(self, entry):
|
111 |
+
fx, fy, cx, cy = entry[1:5]
|
112 |
+
self.fx = fx
|
113 |
+
self.fy = fy
|
114 |
+
self.cx = cx
|
115 |
+
self.cy = cy
|
116 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
117 |
+
w2c_mat_4x4 = np.eye(4)
|
118 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
119 |
+
self.w2c_mat = w2c_mat_4x4
|
120 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
121 |
+
|
122 |
+
def custom_meshgrid(*args):
|
123 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
124 |
+
"""
|
125 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
126 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
127 |
+
return torch.meshgrid(*args)
|
128 |
+
else:
|
129 |
+
return torch.meshgrid(*args, indexing='ij')
|
130 |
+
|
131 |
+
def get_relative_pose(cam_params):
|
132 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
133 |
+
"""
|
134 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
135 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
136 |
+
cam_to_origin = 0
|
137 |
+
target_cam_c2w = np.array([
|
138 |
+
[1, 0, 0, 0],
|
139 |
+
[0, 1, 0, -cam_to_origin],
|
140 |
+
[0, 0, 1, 0],
|
141 |
+
[0, 0, 0, 1]
|
142 |
+
])
|
143 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
144 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
145 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
146 |
+
return ret_poses
|
147 |
+
|
148 |
+
def ray_condition(K, c2w, H, W, device):
|
149 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
150 |
+
"""
|
151 |
+
# c2w: B, V, 4, 4
|
152 |
+
# K: B, V, 4
|
153 |
+
|
154 |
+
B = K.shape[0]
|
155 |
+
|
156 |
+
j, i = custom_meshgrid(
|
157 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
158 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
159 |
+
)
|
160 |
+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
161 |
+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
162 |
+
|
163 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
164 |
+
|
165 |
+
zs = torch.ones_like(i) # [B, HxW]
|
166 |
+
xs = (i - cx) / fx * zs
|
167 |
+
ys = (j - cy) / fy * zs
|
168 |
+
zs = zs.expand_as(ys)
|
169 |
+
|
170 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
171 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
172 |
+
|
173 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
174 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
175 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
176 |
+
# c2w @ dirctions
|
177 |
+
rays_dxo = torch.cross(rays_o, rays_d)
|
178 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
179 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
180 |
+
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
181 |
+
return plucker
|
182 |
+
|
183 |
+
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
184 |
+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
185 |
+
"""
|
186 |
+
with open(pose_file_path, 'r') as f:
|
187 |
+
poses = f.readlines()
|
188 |
+
|
189 |
+
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
190 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
191 |
+
if return_poses:
|
192 |
+
return cam_params
|
193 |
+
else:
|
194 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
195 |
+
|
196 |
+
sample_wh_ratio = width / height
|
197 |
+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
198 |
+
|
199 |
+
if pose_wh_ratio > sample_wh_ratio:
|
200 |
+
resized_ori_w = height * pose_wh_ratio
|
201 |
+
for cam_param in cam_params:
|
202 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
203 |
+
else:
|
204 |
+
resized_ori_h = width / pose_wh_ratio
|
205 |
+
for cam_param in cam_params:
|
206 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
207 |
+
|
208 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
209 |
+
cam_param.fy * height,
|
210 |
+
cam_param.cx * width,
|
211 |
+
cam_param.cy * height]
|
212 |
+
for cam_param in cam_params], dtype=np.float32)
|
213 |
+
|
214 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
215 |
+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
216 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
217 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
218 |
+
plucker_embedding = plucker_embedding[None]
|
219 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
220 |
+
return plucker_embedding
|
221 |
+
|
222 |
+
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
|
223 |
+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
224 |
+
"""
|
225 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
226 |
+
|
227 |
+
sample_wh_ratio = width / height
|
228 |
+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
229 |
+
|
230 |
+
if pose_wh_ratio > sample_wh_ratio:
|
231 |
+
resized_ori_w = height * pose_wh_ratio
|
232 |
+
for cam_param in cam_params:
|
233 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
234 |
+
else:
|
235 |
+
resized_ori_h = width / pose_wh_ratio
|
236 |
+
for cam_param in cam_params:
|
237 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
238 |
+
|
239 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
240 |
+
cam_param.fy * height,
|
241 |
+
cam_param.cx * width,
|
242 |
+
cam_param.cy * height]
|
243 |
+
for cam_param in cam_params], dtype=np.float32)
|
244 |
+
|
245 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
246 |
+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
247 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
248 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
249 |
+
plucker_embedding = plucker_embedding[None]
|
250 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
251 |
+
return plucker_embedding
|
252 |
|
253 |
class ImageVideoSampler(BatchSampler):
|
254 |
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
|
|
333 |
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
334 |
image_sample_size=512,
|
335 |
video_repeat=0,
|
336 |
+
text_drop_ratio=0.1,
|
337 |
enable_bucket=False,
|
338 |
video_length_drop_start=0.1,
|
339 |
video_length_drop_end=0.9,
|
|
|
504 |
|
505 |
return sample
|
506 |
|
|
|
507 |
class ImageVideoControlDataset(Dataset):
|
508 |
def __init__(
|
509 |
self,
|
|
|
511 |
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
512 |
image_sample_size=512,
|
513 |
video_repeat=0,
|
514 |
+
text_drop_ratio=0.1,
|
515 |
enable_bucket=False,
|
516 |
video_length_drop_start=0.1,
|
517 |
video_length_drop_end=0.9,
|
518 |
enable_inpaint=False,
|
519 |
+
enable_camera_info=False,
|
520 |
):
|
521 |
# Loading annotations from files
|
522 |
print(f"loading annotations from {ann_path} ...")
|
|
|
546 |
self.enable_bucket = enable_bucket
|
547 |
self.text_drop_ratio = text_drop_ratio
|
548 |
self.enable_inpaint = enable_inpaint
|
549 |
+
self.enable_camera_info = enable_camera_info
|
550 |
|
551 |
self.video_length_drop_start = video_length_drop_start
|
552 |
self.video_length_drop_end = video_length_drop_end
|
|
|
562 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
563 |
]
|
564 |
)
|
565 |
+
if self.enable_camera_info:
|
566 |
+
self.video_transforms_camera = transforms.Compose(
|
567 |
+
[
|
568 |
+
transforms.Resize(min(self.video_sample_size)),
|
569 |
+
transforms.CenterCrop(self.video_sample_size)
|
570 |
+
]
|
571 |
+
)
|
572 |
|
573 |
# Image params
|
574 |
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
|
|
641 |
else:
|
642 |
control_video_id = os.path.join(self.data_root, control_video_id)
|
643 |
|
644 |
+
if self.enable_camera_info:
|
645 |
+
if control_video_id.lower().endswith('.txt'):
|
646 |
+
if not self.enable_bucket:
|
647 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
648 |
+
|
649 |
+
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
|
650 |
+
control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
|
651 |
+
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
|
652 |
+
control_camera_values = self.video_transforms_camera(control_camera_values)
|
653 |
+
else:
|
654 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
655 |
+
|
656 |
+
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
|
657 |
+
control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
|
658 |
+
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
|
659 |
+
control_camera_values = np.array([control_camera_values[index] for index in batch_index])
|
|
|
|
|
|
|
|
|
|
|
660 |
else:
|
661 |
+
if not self.enable_bucket:
|
662 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
663 |
+
control_camera_values = None
|
664 |
+
else:
|
665 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
666 |
+
control_camera_values = None
|
667 |
+
else:
|
668 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
669 |
+
try:
|
670 |
+
sample_args = (control_video_reader, batch_index)
|
671 |
+
control_pixel_values = func_timeout(
|
672 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
673 |
+
)
|
674 |
+
resized_frames = []
|
675 |
+
for i in range(len(control_pixel_values)):
|
676 |
+
frame = control_pixel_values[i]
|
677 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
678 |
+
resized_frames.append(resized_frame)
|
679 |
+
control_pixel_values = np.array(resized_frames)
|
680 |
+
except FunctionTimedOut:
|
681 |
+
raise ValueError(f"Read {idx} timeout.")
|
682 |
+
except Exception as e:
|
683 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
684 |
+
|
685 |
+
if not self.enable_bucket:
|
686 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
687 |
+
control_pixel_values = control_pixel_values / 255.
|
688 |
+
del control_video_reader
|
689 |
+
else:
|
690 |
+
control_pixel_values = control_pixel_values
|
691 |
+
|
692 |
+
if not self.enable_bucket:
|
693 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
694 |
+
control_camera_values = None
|
695 |
+
|
696 |
+
return pixel_values, control_pixel_values, control_camera_values, text, "video"
|
697 |
else:
|
698 |
image_path, text = data_info['file_path'], data_info['text']
|
699 |
if self.data_root is not None:
|
|
|
719 |
control_image = self.image_transforms(control_image).unsqueeze(0)
|
720 |
else:
|
721 |
control_image = np.expand_dims(np.array(control_image), 0)
|
722 |
+
|
723 |
+
return image, control_image, None, text, 'image'
|
724 |
|
725 |
def __len__(self):
|
726 |
return self.length
|
|
|
736 |
if data_type_local != data_type:
|
737 |
raise ValueError("data_type_local != data_type")
|
738 |
|
739 |
+
pixel_values, control_pixel_values, control_camera_values, name, data_type = self.get_batch(idx)
|
740 |
+
|
741 |
sample["pixel_values"] = pixel_values
|
742 |
sample["control_pixel_values"] = control_pixel_values
|
743 |
sample["text"] = name
|
744 |
sample["data_type"] = data_type
|
745 |
sample["idx"] = idx
|
746 |
+
|
747 |
+
if self.enable_camera_info:
|
748 |
+
sample["control_camera_values"] = control_camera_values
|
749 |
+
|
750 |
if len(sample) > 0:
|
751 |
break
|
752 |
except Exception as e:
|
easyanimate/models/__init__.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
-
from .autoencoder_magvit import (
|
|
|
2 |
from .transformer3d import (EasyAnimateTransformer3DModel,
|
3 |
-
|
4 |
-
Transformer3DModel)
|
5 |
-
|
6 |
|
7 |
name_to_transformer3d = {
|
8 |
"Transformer3DModel": Transformer3DModel,
|
|
|
1 |
+
from .autoencoder_magvit import (AutoencoderKL, AutoencoderKLCogVideoX,
|
2 |
+
AutoencoderKLMagvit)
|
3 |
from .transformer3d import (EasyAnimateTransformer3DModel,
|
4 |
+
HunyuanTransformer3DModel, Transformer3DModel)
|
|
|
|
|
5 |
|
6 |
name_to_transformer3d = {
|
7 |
"Transformer3DModel": Transformer3DModel,
|
easyanimate/models/attention.py
CHANGED
@@ -29,7 +29,7 @@ from diffusers.models.embeddings import (SinusoidalPositionalEmbedding,
|
|
29 |
get_3d_sincos_pos_embed)
|
30 |
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
31 |
from diffusers.models.modeling_utils import ModelMixin
|
32 |
-
from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero,
|
33 |
CogVideoXLayerNormZero)
|
34 |
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging
|
35 |
from diffusers.utils.import_utils import is_xformers_available
|
@@ -38,12 +38,11 @@ from einops import rearrange, repeat
|
|
38 |
from torch import nn
|
39 |
|
40 |
from .motion_module import PositionalEncoding, get_motion_module
|
41 |
-
from .norm import AdaLayerNormShift,
|
42 |
from .processor import (EasyAnimateAttnProcessor2_0,
|
|
|
43 |
LazyKVCompressionProcessor2_0)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
if is_xformers_available():
|
48 |
import xformers
|
49 |
import xformers.ops
|
@@ -1042,7 +1041,9 @@ class EasyAnimateDiTBlock(nn.Module):
|
|
1042 |
ff_bias: bool = True,
|
1043 |
qk_norm: bool = True,
|
1044 |
after_norm: bool = False,
|
1045 |
-
norm_type: str="fp32_layer_norm"
|
|
|
|
|
1046 |
):
|
1047 |
super().__init__()
|
1048 |
|
@@ -1051,6 +1052,7 @@ class EasyAnimateDiTBlock(nn.Module):
|
|
1051 |
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
|
1052 |
)
|
1053 |
|
|
|
1054 |
self.attn1 = Attention(
|
1055 |
query_dim=dim,
|
1056 |
dim_head=attention_head_dim,
|
@@ -1058,17 +1060,20 @@ class EasyAnimateDiTBlock(nn.Module):
|
|
1058 |
qk_norm="layer_norm" if qk_norm else None,
|
1059 |
eps=1e-6,
|
1060 |
bias=True,
|
1061 |
-
processor=EasyAnimateAttnProcessor2_0(),
|
1062 |
-
)
|
1063 |
-
self.attn2 = Attention(
|
1064 |
-
query_dim=dim,
|
1065 |
-
dim_head=attention_head_dim,
|
1066 |
-
heads=num_attention_heads,
|
1067 |
-
qk_norm="layer_norm" if qk_norm else None,
|
1068 |
-
eps=1e-6,
|
1069 |
-
bias=True,
|
1070 |
-
processor=EasyAnimateAttnProcessor2_0(),
|
1071 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1072 |
|
1073 |
# FFN Part
|
1074 |
self.norm2 = EasyAnimateLayerNormZero(
|
@@ -1082,14 +1087,18 @@ class EasyAnimateDiTBlock(nn.Module):
|
|
1082 |
inner_dim=ff_inner_dim,
|
1083 |
bias=ff_bias,
|
1084 |
)
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
|
|
|
|
|
|
|
|
1093 |
if after_norm:
|
1094 |
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1095 |
else:
|
@@ -1101,6 +1110,9 @@ class EasyAnimateDiTBlock(nn.Module):
|
|
1101 |
encoder_hidden_states: torch.Tensor,
|
1102 |
temb: torch.Tensor,
|
1103 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
|
|
|
|
1104 |
) -> torch.Tensor:
|
1105 |
# Norm
|
1106 |
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
@@ -1108,12 +1120,23 @@ class EasyAnimateDiTBlock(nn.Module):
|
|
1108 |
)
|
1109 |
|
1110 |
# Attn
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1117 |
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
1118 |
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
1119 |
|
@@ -1125,10 +1148,16 @@ class EasyAnimateDiTBlock(nn.Module):
|
|
1125 |
# FFN
|
1126 |
if self.norm3 is not None:
|
1127 |
norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
|
1128 |
-
|
|
|
|
|
|
|
1129 |
else:
|
1130 |
norm_hidden_states = self.ff(norm_hidden_states)
|
1131 |
-
|
|
|
|
|
|
|
1132 |
hidden_states = hidden_states + gate_ff * norm_hidden_states
|
1133 |
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states
|
1134 |
return hidden_states, encoder_hidden_states
|
|
|
29 |
get_3d_sincos_pos_embed)
|
30 |
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
31 |
from diffusers.models.modeling_utils import ModelMixin
|
32 |
+
from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero,
|
33 |
CogVideoXLayerNormZero)
|
34 |
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging
|
35 |
from diffusers.utils.import_utils import is_xformers_available
|
|
|
38 |
from torch import nn
|
39 |
|
40 |
from .motion_module import PositionalEncoding, get_motion_module
|
41 |
+
from .norm import AdaLayerNormShift, EasyAnimateLayerNormZero, FP32LayerNorm
|
42 |
from .processor import (EasyAnimateAttnProcessor2_0,
|
43 |
+
EasyAnimateSWAttnProcessor2_0,
|
44 |
LazyKVCompressionProcessor2_0)
|
45 |
|
|
|
|
|
46 |
if is_xformers_available():
|
47 |
import xformers
|
48 |
import xformers.ops
|
|
|
1041 |
ff_bias: bool = True,
|
1042 |
qk_norm: bool = True,
|
1043 |
after_norm: bool = False,
|
1044 |
+
norm_type: str="fp32_layer_norm",
|
1045 |
+
is_mmdit_block: bool = True,
|
1046 |
+
is_swa: bool = False,
|
1047 |
):
|
1048 |
super().__init__()
|
1049 |
|
|
|
1052 |
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
|
1053 |
)
|
1054 |
|
1055 |
+
self.is_swa = is_swa
|
1056 |
self.attn1 = Attention(
|
1057 |
query_dim=dim,
|
1058 |
dim_head=attention_head_dim,
|
|
|
1060 |
qk_norm="layer_norm" if qk_norm else None,
|
1061 |
eps=1e-6,
|
1062 |
bias=True,
|
1063 |
+
processor=EasyAnimateAttnProcessor2_0() if not is_swa else EasyAnimateSWAttnProcessor2_0(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1064 |
)
|
1065 |
+
if is_mmdit_block:
|
1066 |
+
self.attn2 = Attention(
|
1067 |
+
query_dim=dim,
|
1068 |
+
dim_head=attention_head_dim,
|
1069 |
+
heads=num_attention_heads,
|
1070 |
+
qk_norm="layer_norm" if qk_norm else None,
|
1071 |
+
eps=1e-6,
|
1072 |
+
bias=True,
|
1073 |
+
processor=EasyAnimateAttnProcessor2_0() if not is_swa else EasyAnimateSWAttnProcessor2_0(),
|
1074 |
+
)
|
1075 |
+
else:
|
1076 |
+
self.attn2 = None
|
1077 |
|
1078 |
# FFN Part
|
1079 |
self.norm2 = EasyAnimateLayerNormZero(
|
|
|
1087 |
inner_dim=ff_inner_dim,
|
1088 |
bias=ff_bias,
|
1089 |
)
|
1090 |
+
if is_mmdit_block:
|
1091 |
+
self.txt_ff = FeedForward(
|
1092 |
+
dim,
|
1093 |
+
dropout=dropout,
|
1094 |
+
activation_fn=activation_fn,
|
1095 |
+
final_dropout=final_dropout,
|
1096 |
+
inner_dim=ff_inner_dim,
|
1097 |
+
bias=ff_bias,
|
1098 |
+
)
|
1099 |
+
else:
|
1100 |
+
self.txt_ff = None
|
1101 |
+
|
1102 |
if after_norm:
|
1103 |
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1104 |
else:
|
|
|
1110 |
encoder_hidden_states: torch.Tensor,
|
1111 |
temb: torch.Tensor,
|
1112 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
1113 |
+
num_frames = None,
|
1114 |
+
height = None,
|
1115 |
+
width = None
|
1116 |
) -> torch.Tensor:
|
1117 |
# Norm
|
1118 |
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
|
|
1120 |
)
|
1121 |
|
1122 |
# Attn
|
1123 |
+
if self.is_swa:
|
1124 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
1125 |
+
hidden_states=norm_hidden_states,
|
1126 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
1127 |
+
image_rotary_emb=image_rotary_emb,
|
1128 |
+
attn2=self.attn2,
|
1129 |
+
num_frames=num_frames,
|
1130 |
+
height=height,
|
1131 |
+
width=width,
|
1132 |
+
)
|
1133 |
+
else:
|
1134 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
1135 |
+
hidden_states=norm_hidden_states,
|
1136 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
1137 |
+
image_rotary_emb=image_rotary_emb,
|
1138 |
+
attn2=self.attn2
|
1139 |
+
)
|
1140 |
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
1141 |
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
1142 |
|
|
|
1148 |
# FFN
|
1149 |
if self.norm3 is not None:
|
1150 |
norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
|
1151 |
+
if self.txt_ff is not None:
|
1152 |
+
norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
|
1153 |
+
else:
|
1154 |
+
norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states))
|
1155 |
else:
|
1156 |
norm_hidden_states = self.ff(norm_hidden_states)
|
1157 |
+
if self.txt_ff is not None:
|
1158 |
+
norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
|
1159 |
+
else:
|
1160 |
+
norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
|
1161 |
hidden_states = hidden_states + gate_ff * norm_hidden_states
|
1162 |
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states
|
1163 |
return hidden_states, encoder_hidden_states
|
easyanimate/models/autoencoder_magvit.py
CHANGED
@@ -44,6 +44,7 @@ from ..vae.ldm.models.cogvideox_enc_dec import (CogVideoXCausalConv3d,
|
|
44 |
CogVideoXDecoder3D,
|
45 |
CogVideoXEncoder3D,
|
46 |
CogVideoXSafeConv3d)
|
|
|
47 |
from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
|
48 |
from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
|
49 |
|
@@ -96,6 +97,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
96 |
out_channels: int = 3,
|
97 |
ch = 128,
|
98 |
ch_mult = [ 1,2,4,4 ],
|
|
|
99 |
use_gc_blocks = None,
|
100 |
down_block_types: tuple = None,
|
101 |
up_block_types: tuple = None,
|
@@ -109,6 +111,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
109 |
latent_channels: int = 4,
|
110 |
norm_num_groups: int = 32,
|
111 |
scaling_factor: float = 0.1825,
|
|
|
112 |
slice_mag_vae=True,
|
113 |
slice_compression_vae=False,
|
114 |
cache_compression_vae=False,
|
@@ -130,8 +133,9 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
130 |
in_channels=in_channels,
|
131 |
out_channels=latent_channels,
|
132 |
down_block_types=down_block_types,
|
133 |
-
ch
|
134 |
-
ch_mult
|
|
|
135 |
use_gc_blocks=use_gc_blocks,
|
136 |
mid_block_type=mid_block_type,
|
137 |
mid_block_use_attention=mid_block_use_attention,
|
@@ -154,8 +158,9 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
154 |
in_channels=latent_channels,
|
155 |
out_channels=out_channels,
|
156 |
up_block_types=up_block_types,
|
157 |
-
ch
|
158 |
-
ch_mult
|
|
|
159 |
use_gc_blocks=use_gc_blocks,
|
160 |
mid_block_type=mid_block_type,
|
161 |
mid_block_use_attention=mid_block_use_attention,
|
@@ -196,81 +201,10 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
196 |
if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
|
197 |
module.gradient_checkpointing = value
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
Returns:
|
204 |
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
205 |
-
indexed by its weight name.
|
206 |
-
"""
|
207 |
-
# set recursively
|
208 |
-
processors = {}
|
209 |
-
|
210 |
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
211 |
-
if hasattr(module, "get_processor"):
|
212 |
-
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
213 |
-
|
214 |
-
for sub_name, child in module.named_children():
|
215 |
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
216 |
-
|
217 |
-
return processors
|
218 |
-
|
219 |
-
for name, module in self.named_children():
|
220 |
-
fn_recursive_add_processors(name, module, processors)
|
221 |
-
|
222 |
-
return processors
|
223 |
-
|
224 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
225 |
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
226 |
-
r"""
|
227 |
-
Sets the attention processor to use to compute attention.
|
228 |
-
|
229 |
-
Parameters:
|
230 |
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
231 |
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
232 |
-
for **all** `Attention` layers.
|
233 |
-
|
234 |
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
235 |
-
processor. This is strongly recommended when setting trainable attention processors.
|
236 |
-
|
237 |
-
"""
|
238 |
-
count = len(self.attn_processors.keys())
|
239 |
-
|
240 |
-
if isinstance(processor, dict) and len(processor) != count:
|
241 |
-
raise ValueError(
|
242 |
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
243 |
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
244 |
-
)
|
245 |
-
|
246 |
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
247 |
-
if hasattr(module, "set_processor"):
|
248 |
-
if not isinstance(processor, dict):
|
249 |
-
module.set_processor(processor)
|
250 |
-
else:
|
251 |
-
module.set_processor(processor.pop(f"{name}.processor"))
|
252 |
-
|
253 |
-
for sub_name, child in module.named_children():
|
254 |
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
255 |
-
|
256 |
-
for name, module in self.named_children():
|
257 |
-
fn_recursive_attn_processor(name, module, processor)
|
258 |
-
|
259 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
260 |
-
def set_default_attn_processor(self):
|
261 |
-
"""
|
262 |
-
Disables custom attention processors and sets the default attention implementation.
|
263 |
-
"""
|
264 |
-
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
265 |
-
processor = AttnAddedKVProcessor()
|
266 |
-
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
267 |
-
processor = AttnProcessor()
|
268 |
-
else:
|
269 |
-
raise ValueError(
|
270 |
-
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
271 |
-
)
|
272 |
-
|
273 |
-
self.set_attn_processor(processor)
|
274 |
|
275 |
@apply_forward_hook
|
276 |
def encode(
|
@@ -308,6 +242,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
308 |
moments = self.quant_conv(h)
|
309 |
posterior = DiagonalGaussianDistribution(moments)
|
310 |
|
|
|
311 |
if not return_dict:
|
312 |
return (posterior,)
|
313 |
|
@@ -355,6 +290,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
355 |
else:
|
356 |
decoded = self._decode(z).sample
|
357 |
|
|
|
358 |
if not return_dict:
|
359 |
return (decoded,)
|
360 |
|
@@ -519,44 +455,6 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
519 |
|
520 |
return DecoderOutput(sample=dec)
|
521 |
|
522 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
523 |
-
def fuse_qkv_projections(self):
|
524 |
-
"""
|
525 |
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
526 |
-
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
527 |
-
|
528 |
-
<Tip warning={true}>
|
529 |
-
|
530 |
-
This API is 🧪 experimental.
|
531 |
-
|
532 |
-
</Tip>
|
533 |
-
"""
|
534 |
-
self.original_attn_processors = None
|
535 |
-
|
536 |
-
for _, attn_processor in self.attn_processors.items():
|
537 |
-
if "Added" in str(attn_processor.__class__.__name__):
|
538 |
-
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
539 |
-
|
540 |
-
self.original_attn_processors = self.attn_processors
|
541 |
-
|
542 |
-
for module in self.modules():
|
543 |
-
if isinstance(module, Attention):
|
544 |
-
module.fuse_projections(fuse=True)
|
545 |
-
|
546 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
547 |
-
def unfuse_qkv_projections(self):
|
548 |
-
"""Disables the fused QKV projection if enabled.
|
549 |
-
|
550 |
-
<Tip warning={true}>
|
551 |
-
|
552 |
-
This API is 🧪 experimental.
|
553 |
-
|
554 |
-
</Tip>
|
555 |
-
|
556 |
-
"""
|
557 |
-
if self.original_attn_processors is not None:
|
558 |
-
self.set_attn_processor(self.original_attn_processors)
|
559 |
-
|
560 |
@classmethod
|
561 |
def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
|
562 |
import json
|
|
|
44 |
CogVideoXDecoder3D,
|
45 |
CogVideoXEncoder3D,
|
46 |
CogVideoXSafeConv3d)
|
47 |
+
from ..vae.ldm.models.omnigen_enc_dec import CausalConv3d
|
48 |
from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
|
49 |
from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
|
50 |
|
|
|
97 |
out_channels: int = 3,
|
98 |
ch = 128,
|
99 |
ch_mult = [ 1,2,4,4 ],
|
100 |
+
block_out_channels = [128, 256, 512, 512],
|
101 |
use_gc_blocks = None,
|
102 |
down_block_types: tuple = None,
|
103 |
up_block_types: tuple = None,
|
|
|
111 |
latent_channels: int = 4,
|
112 |
norm_num_groups: int = 32,
|
113 |
scaling_factor: float = 0.1825,
|
114 |
+
force_upcast: float = True,
|
115 |
slice_mag_vae=True,
|
116 |
slice_compression_vae=False,
|
117 |
cache_compression_vae=False,
|
|
|
133 |
in_channels=in_channels,
|
134 |
out_channels=latent_channels,
|
135 |
down_block_types=down_block_types,
|
136 |
+
ch=ch,
|
137 |
+
ch_mult=ch_mult,
|
138 |
+
block_out_channels=block_out_channels,
|
139 |
use_gc_blocks=use_gc_blocks,
|
140 |
mid_block_type=mid_block_type,
|
141 |
mid_block_use_attention=mid_block_use_attention,
|
|
|
158 |
in_channels=latent_channels,
|
159 |
out_channels=out_channels,
|
160 |
up_block_types=up_block_types,
|
161 |
+
ch=ch,
|
162 |
+
ch_mult=ch_mult,
|
163 |
+
block_out_channels=block_out_channels,
|
164 |
use_gc_blocks=use_gc_blocks,
|
165 |
mid_block_type=mid_block_type,
|
166 |
mid_block_use_attention=mid_block_use_attention,
|
|
|
201 |
if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
|
202 |
module.gradient_checkpointing = value
|
203 |
|
204 |
+
def _clear_conv_cache(self):
|
205 |
+
for name, module in self.named_modules():
|
206 |
+
if isinstance(module, CausalConv3d):
|
207 |
+
module._clear_conv_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
@apply_forward_hook
|
210 |
def encode(
|
|
|
242 |
moments = self.quant_conv(h)
|
243 |
posterior = DiagonalGaussianDistribution(moments)
|
244 |
|
245 |
+
self._clear_conv_cache()
|
246 |
if not return_dict:
|
247 |
return (posterior,)
|
248 |
|
|
|
290 |
else:
|
291 |
decoded = self._decode(z).sample
|
292 |
|
293 |
+
self._clear_conv_cache()
|
294 |
if not return_dict:
|
295 |
return (decoded,)
|
296 |
|
|
|
455 |
|
456 |
return DecoderOutput(sample=dec)
|
457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
@classmethod
|
459 |
def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
|
460 |
import json
|
easyanimate/models/embeddings.py
CHANGED
@@ -4,8 +4,9 @@ from typing import Optional
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
-
from diffusers.models.embeddings import (PixArtAlphaTextProjection,
|
8 |
-
TimestepEmbedding, Timesteps
|
|
|
9 |
from einops import rearrange
|
10 |
from torch import nn
|
11 |
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
+
from diffusers.models.embeddings import (PixArtAlphaTextProjection,
|
8 |
+
TimestepEmbedding, Timesteps,
|
9 |
+
get_timestep_embedding)
|
10 |
from einops import rearrange
|
11 |
from torch import nn
|
12 |
|
easyanimate/models/norm.py
CHANGED
@@ -25,6 +25,22 @@ class FP32LayerNorm(nn.LayerNorm):
|
|
25 |
inputs.float(), self.normalized_shape, None, None, self.eps
|
26 |
).to(origin_dtype)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
29 |
"""
|
30 |
For PixArt-Alpha.
|
|
|
25 |
inputs.float(), self.normalized_shape, None, None, self.eps
|
26 |
).to(origin_dtype)
|
27 |
|
28 |
+
class EasyAnimateRMSNorm(nn.Module):
|
29 |
+
def __init__(self, hidden_size, eps=1e-6):
|
30 |
+
super().__init__()
|
31 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
32 |
+
self.variance_epsilon = eps
|
33 |
+
|
34 |
+
def forward(self, hidden_states):
|
35 |
+
input_dtype = hidden_states.dtype
|
36 |
+
hidden_states = hidden_states.to(torch.float32)
|
37 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
38 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
39 |
+
return self.weight * hidden_states.to(input_dtype)
|
40 |
+
|
41 |
+
def extra_repr(self):
|
42 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
43 |
+
|
44 |
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
45 |
"""
|
46 |
For PixArt-Alpha.
|
easyanimate/models/processor.py
CHANGED
@@ -310,3 +310,149 @@ class EasyAnimateAttnProcessor2_0:
|
|
310 |
hidden_states = attn.to_out[1](hidden_states)
|
311 |
encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
|
312 |
return hidden_states, encoder_hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
hidden_states = attn.to_out[1](hidden_states)
|
311 |
encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
|
312 |
return hidden_states, encoder_hidden_states
|
313 |
+
|
314 |
+
try:
|
315 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
316 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
317 |
+
except:
|
318 |
+
print("Flash Attention is not installed. Please install with `pip install flash-attn`, if you want to use SWA.")
|
319 |
+
|
320 |
+
class EasyAnimateSWAttnProcessor2_0:
|
321 |
+
def __init__(self, window_size=1024):
|
322 |
+
self.window_size = window_size
|
323 |
+
|
324 |
+
def __call__(
|
325 |
+
self,
|
326 |
+
attn: Attention,
|
327 |
+
hidden_states: torch.Tensor,
|
328 |
+
encoder_hidden_states: torch.Tensor,
|
329 |
+
attention_mask: Optional[torch.Tensor] = None,
|
330 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
331 |
+
num_frames: int = None,
|
332 |
+
height: int = None,
|
333 |
+
width: int = None,
|
334 |
+
attn2: Attention = None,
|
335 |
+
) -> torch.Tensor:
|
336 |
+
text_seq_length = encoder_hidden_states.size(1)
|
337 |
+
|
338 |
+
batch_size, sequence_length, _ = (
|
339 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
340 |
+
)
|
341 |
+
|
342 |
+
if attn2 is None:
|
343 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
344 |
+
|
345 |
+
query = attn.to_q(hidden_states)
|
346 |
+
key = attn.to_k(hidden_states)
|
347 |
+
value = attn.to_v(hidden_states)
|
348 |
+
|
349 |
+
inner_dim = key.shape[-1]
|
350 |
+
head_dim = inner_dim // attn.heads
|
351 |
+
|
352 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
353 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
354 |
+
value = value.view(batch_size, -1, attn.heads, head_dim)
|
355 |
+
|
356 |
+
if attn.norm_q is not None:
|
357 |
+
query = attn.norm_q(query)
|
358 |
+
if attn.norm_k is not None:
|
359 |
+
key = attn.norm_k(key)
|
360 |
+
|
361 |
+
if attn2 is not None:
|
362 |
+
query_txt = attn2.to_q(encoder_hidden_states)
|
363 |
+
key_txt = attn2.to_k(encoder_hidden_states)
|
364 |
+
value_txt = attn2.to_v(encoder_hidden_states)
|
365 |
+
|
366 |
+
inner_dim = key_txt.shape[-1]
|
367 |
+
head_dim = inner_dim // attn.heads
|
368 |
+
|
369 |
+
query_txt = query_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
370 |
+
key_txt = key_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
371 |
+
value_txt = value_txt.view(batch_size, -1, attn.heads, head_dim)
|
372 |
+
|
373 |
+
if attn2.norm_q is not None:
|
374 |
+
query_txt = attn2.norm_q(query_txt)
|
375 |
+
if attn2.norm_k is not None:
|
376 |
+
key_txt = attn2.norm_k(key_txt)
|
377 |
+
|
378 |
+
query = torch.cat([query_txt, query], dim=2)
|
379 |
+
key = torch.cat([key_txt, key], dim=2)
|
380 |
+
value = torch.cat([value_txt, value], dim=1)
|
381 |
+
|
382 |
+
# Apply RoPE if needed
|
383 |
+
if image_rotary_emb is not None:
|
384 |
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
385 |
+
if not attn.is_cross_attention:
|
386 |
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
387 |
+
|
388 |
+
query = query.transpose(1, 2).to(value)
|
389 |
+
key = key.transpose(1, 2).to(value)
|
390 |
+
interval = max((query.size(1) - text_seq_length) // (self.window_size - text_seq_length), 1)
|
391 |
+
|
392 |
+
cross_key = torch.cat([key[:, :text_seq_length], key[:, text_seq_length::interval]], dim=1)
|
393 |
+
cross_val = torch.cat([value[:, :text_seq_length], value[:, text_seq_length::interval]], dim=1)
|
394 |
+
cross_hidden_states = flash_attn_func(query, cross_key, cross_val, dropout_p=0.0, causal=False)
|
395 |
+
|
396 |
+
# Split and rearrange to six directions
|
397 |
+
querys = torch.tensor_split(query[:, text_seq_length:], 6, 2)
|
398 |
+
keys = torch.tensor_split(key[:, text_seq_length:], 6, 2)
|
399 |
+
values = torch.tensor_split(value[:, text_seq_length:], 6, 2)
|
400 |
+
|
401 |
+
new_querys = [querys[0]]
|
402 |
+
new_keys = [keys[0]]
|
403 |
+
new_values = [values[0]]
|
404 |
+
for index, mode in enumerate(
|
405 |
+
[
|
406 |
+
"bs (f h w) hn hd -> bs (f w h) hn hd",
|
407 |
+
"bs (f h w) hn hd -> bs (h f w) hn hd",
|
408 |
+
"bs (f h w) hn hd -> bs (h w f) hn hd",
|
409 |
+
"bs (f h w) hn hd -> bs (w f h) hn hd",
|
410 |
+
"bs (f h w) hn hd -> bs (w h f) hn hd"
|
411 |
+
]
|
412 |
+
):
|
413 |
+
new_querys.append(rearrange(querys[index + 1], mode, f=num_frames, h=height, w=width))
|
414 |
+
new_keys.append(rearrange(keys[index + 1], mode, f=num_frames, h=height, w=width))
|
415 |
+
new_values.append(rearrange(values[index + 1], mode, f=num_frames, h=height, w=width))
|
416 |
+
query = torch.cat(new_querys, dim=2)
|
417 |
+
key = torch.cat(new_keys, dim=2)
|
418 |
+
value = torch.cat(new_values, dim=2)
|
419 |
+
|
420 |
+
# apply attention
|
421 |
+
hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False, window_size=(self.window_size, self.window_size))
|
422 |
+
|
423 |
+
hidden_states = torch.tensor_split(hidden_states, 6, 2)
|
424 |
+
new_hidden_states = [hidden_states[0]]
|
425 |
+
for index, mode in enumerate(
|
426 |
+
[
|
427 |
+
"bs (f w h) hn hd -> bs (f h w) hn hd",
|
428 |
+
"bs (h f w) hn hd -> bs (f h w) hn hd",
|
429 |
+
"bs (h w f) hn hd -> bs (f h w) hn hd",
|
430 |
+
"bs (w f h) hn hd -> bs (f h w) hn hd",
|
431 |
+
"bs (w h f) hn hd -> bs (f h w) hn hd"
|
432 |
+
]
|
433 |
+
):
|
434 |
+
new_hidden_states.append(rearrange(hidden_states[index + 1], mode, f=num_frames, h=height, w=width))
|
435 |
+
hidden_states = torch.cat([cross_hidden_states[:, :text_seq_length], torch.cat(new_hidden_states, dim=2)], dim=1) + cross_hidden_states
|
436 |
+
|
437 |
+
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
|
438 |
+
|
439 |
+
if attn2 is None:
|
440 |
+
# linear proj
|
441 |
+
hidden_states = attn.to_out[0](hidden_states)
|
442 |
+
# dropout
|
443 |
+
hidden_states = attn.to_out[1](hidden_states)
|
444 |
+
|
445 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
446 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
447 |
+
)
|
448 |
+
else:
|
449 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
450 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
451 |
+
)
|
452 |
+
# linear proj
|
453 |
+
hidden_states = attn.to_out[0](hidden_states)
|
454 |
+
encoder_hidden_states = attn2.to_out[0](encoder_hidden_states)
|
455 |
+
# dropout
|
456 |
+
hidden_states = attn.to_out[1](hidden_states)
|
457 |
+
encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
|
458 |
+
return hidden_states, encoder_hidden_states
|
easyanimate/models/transformer3d.py
CHANGED
@@ -39,8 +39,9 @@ from torch import nn
|
|
39 |
from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock,
|
40 |
SelfAttentionTemporalTransformerBlock,
|
41 |
TemporalTransformerBlock, zero_module)
|
42 |
-
from .embeddings import HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
43 |
-
|
|
|
44 |
from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D,
|
45 |
TemporalUpsampler3D, UnPatch1D)
|
46 |
from .resampler import Resampler
|
@@ -142,6 +143,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
142 |
norm_eps: float = 1e-5,
|
143 |
attention_type: str = "default",
|
144 |
caption_channels: int = None,
|
|
|
145 |
# block type
|
146 |
basic_block_type: str = "motionmodule",
|
147 |
# enable_uvit
|
@@ -168,6 +170,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
168 |
after_norm = False,
|
169 |
resize_inpaint_mask_directly: bool = False,
|
170 |
enable_clip_in_inpaint: bool = True,
|
|
|
|
|
171 |
enable_text_attention_mask: bool = True,
|
172 |
add_noise_in_inpaint_model: bool = False,
|
173 |
):
|
@@ -192,6 +196,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
192 |
self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
|
193 |
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
194 |
interpolation_scale = max(interpolation_scale, 1)
|
|
|
195 |
|
196 |
if self.casual_3d:
|
197 |
self.pos_embed = CasualPatchEmbed3D(
|
@@ -397,16 +402,22 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
397 |
def forward(
|
398 |
self,
|
399 |
hidden_states: torch.Tensor,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
inpaint_latents: torch.Tensor = None,
|
401 |
control_latents: torch.Tensor = None,
|
402 |
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
403 |
-
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
|
404 |
-
timestep: Optional[torch.LongTensor] = None,
|
405 |
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
406 |
class_labels: Optional[torch.LongTensor] = None,
|
407 |
cross_attention_kwargs: Dict[str, Any] = None,
|
408 |
attention_mask: Optional[torch.Tensor] = None,
|
409 |
-
|
410 |
clip_attention_mask: Optional[torch.Tensor] = None,
|
411 |
return_dict: bool = True,
|
412 |
):
|
@@ -432,7 +443,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
432 |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
433 |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
434 |
negative values to the attention scores corresponding to "discard" tokens.
|
435 |
-
|
436 |
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
437 |
|
438 |
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
@@ -466,11 +477,12 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
466 |
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
467 |
attention_mask = attention_mask.unsqueeze(1)
|
468 |
|
|
|
469 |
if clip_attention_mask is not None:
|
470 |
-
|
471 |
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
472 |
-
if
|
473 |
-
encoder_attention_mask = (1 -
|
474 |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
475 |
|
476 |
if inpaint_latents is not None:
|
@@ -637,7 +649,10 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
637 |
return Transformer3DModelOutput(sample=output)
|
638 |
|
639 |
@classmethod
|
640 |
-
def from_pretrained_2d(
|
|
|
|
|
|
|
641 |
if subfolder is not None:
|
642 |
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
643 |
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
@@ -649,16 +664,73 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
649 |
config = json.load(f)
|
650 |
|
651 |
from diffusers.utils import WEIGHTS_NAME
|
652 |
-
model = cls.from_config(config, **transformer_additional_kwargs)
|
653 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
654 |
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
655 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
from safetensors.torch import load_file, safe_open
|
657 |
state_dict = load_file(model_file_safetensors)
|
658 |
else:
|
659 |
-
|
660 |
-
|
661 |
-
state_dict =
|
|
|
|
|
|
|
|
|
662 |
|
663 |
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
|
664 |
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
|
@@ -692,6 +764,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
692 |
params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
|
693 |
print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
|
694 |
|
|
|
695 |
return model
|
696 |
|
697 |
class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
|
@@ -769,6 +842,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
|
|
769 |
after_norm = False,
|
770 |
resize_inpaint_mask_directly: bool = False,
|
771 |
enable_clip_in_inpaint: bool = True,
|
|
|
772 |
enable_text_attention_mask: bool = True,
|
773 |
add_noise_in_inpaint_model: bool = False,
|
774 |
):
|
@@ -909,6 +983,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
|
|
909 |
control_latents: torch.Tensor = None,
|
910 |
clip_encoder_hidden_states: Optional[torch.Tensor]=None,
|
911 |
clip_attention_mask: Optional[torch.Tensor]=None,
|
|
|
912 |
return_dict=True,
|
913 |
):
|
914 |
"""
|
@@ -1085,7 +1160,10 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1085 |
return Transformer2DModelOutput(sample=output)
|
1086 |
|
1087 |
@classmethod
|
1088 |
-
def from_pretrained_2d(
|
|
|
|
|
|
|
1089 |
if subfolder is not None:
|
1090 |
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
1091 |
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
@@ -1097,16 +1175,73 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1097 |
config = json.load(f)
|
1098 |
|
1099 |
from diffusers.utils import WEIGHTS_NAME
|
1100 |
-
model = cls.from_config(config, **transformer_additional_kwargs)
|
1101 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1102 |
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
1103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1104 |
from safetensors.torch import load_file, safe_open
|
1105 |
state_dict = load_file(model_file_safetensors)
|
1106 |
else:
|
1107 |
-
|
1108 |
-
|
1109 |
-
state_dict =
|
|
|
|
|
|
|
|
|
1110 |
|
1111 |
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
|
1112 |
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
|
@@ -1156,6 +1291,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1156 |
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
1157 |
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
1158 |
|
|
|
1159 |
return model
|
1160 |
|
1161 |
class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
@@ -1178,8 +1314,11 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1178 |
timestep_activation_fn: str = "silu",
|
1179 |
freq_shift: int = 0,
|
1180 |
num_layers: int = 30,
|
|
|
|
|
1181 |
dropout: float = 0.0,
|
1182 |
time_embed_dim: int = 512,
|
|
|
1183 |
text_embed_dim: int = 4096,
|
1184 |
text_embed_dim_t5: int = 4096,
|
1185 |
norm_eps: float = 1e-5,
|
@@ -1191,8 +1330,10 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1191 |
after_norm = False,
|
1192 |
resize_inpaint_mask_directly: bool = False,
|
1193 |
enable_clip_in_inpaint: bool = True,
|
|
|
1194 |
enable_text_attention_mask: bool = True,
|
1195 |
add_noise_in_inpaint_model: bool = False,
|
|
|
1196 |
):
|
1197 |
super().__init__()
|
1198 |
self.num_heads = num_attention_heads
|
@@ -1211,8 +1352,20 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1211 |
self.proj = nn.Conv2d(
|
1212 |
in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
|
1213 |
)
|
1214 |
-
|
1215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1216 |
|
1217 |
if ref_channels is not None:
|
1218 |
self.ref_proj = nn.Conv2d(
|
@@ -1224,23 +1377,45 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1224 |
|
1225 |
if clip_channels is not None:
|
1226 |
self.clip_proj = nn.Linear(clip_channels, self.inner_dim)
|
1227 |
-
|
1228 |
-
self.
|
1229 |
-
|
1230 |
-
|
1231 |
-
|
1232 |
-
|
1233 |
-
|
1234 |
-
|
1235 |
-
|
1236 |
-
|
1237 |
-
|
1238 |
-
|
1239 |
-
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1244 |
self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine)
|
1245 |
|
1246 |
# 5. Output blocks
|
@@ -1275,6 +1450,7 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1275 |
ref_latents: Optional[torch.Tensor] = None,
|
1276 |
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
|
1277 |
clip_attention_mask: Optional[torch.Tensor] = None,
|
|
|
1278 |
return_dict=True,
|
1279 |
):
|
1280 |
batch_size, channels, video_length, height, width = hidden_states.size()
|
@@ -1343,6 +1519,9 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1343 |
encoder_hidden_states,
|
1344 |
temb,
|
1345 |
image_rotary_emb,
|
|
|
|
|
|
|
1346 |
**ckpt_kwargs,
|
1347 |
)
|
1348 |
else:
|
@@ -1351,6 +1530,9 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1351 |
encoder_hidden_states=encoder_hidden_states,
|
1352 |
temb=temb,
|
1353 |
image_rotary_emb=image_rotary_emb,
|
|
|
|
|
|
|
1354 |
)
|
1355 |
|
1356 |
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
@@ -1371,7 +1553,10 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1371 |
return Transformer2DModelOutput(sample=output)
|
1372 |
|
1373 |
@classmethod
|
1374 |
-
def from_pretrained_2d(
|
|
|
|
|
|
|
1375 |
if subfolder is not None:
|
1376 |
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
1377 |
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
@@ -1383,9 +1568,60 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1383 |
config = json.load(f)
|
1384 |
|
1385 |
from diffusers.utils import WEIGHTS_NAME
|
1386 |
-
model = cls.from_config(config, **transformer_additional_kwargs)
|
1387 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1388 |
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1389 |
if os.path.exists(model_file):
|
1390 |
state_dict = torch.load(model_file, map_location="cpu")
|
1391 |
elif os.path.exists(model_file_safetensors):
|
@@ -1433,4 +1669,5 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
1433 |
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
1434 |
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
1435 |
|
|
|
1436 |
return model
|
|
|
39 |
from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock,
|
40 |
SelfAttentionTemporalTransformerBlock,
|
41 |
TemporalTransformerBlock, zero_module)
|
42 |
+
from .embeddings import (HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
43 |
+
TimePositionalEncoding)
|
44 |
+
from .norm import AdaLayerNormSingle, EasyAnimateRMSNorm
|
45 |
from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D,
|
46 |
TemporalUpsampler3D, UnPatch1D)
|
47 |
from .resampler import Resampler
|
|
|
143 |
norm_eps: float = 1e-5,
|
144 |
attention_type: str = "default",
|
145 |
caption_channels: int = None,
|
146 |
+
n_query=8,
|
147 |
# block type
|
148 |
basic_block_type: str = "motionmodule",
|
149 |
# enable_uvit
|
|
|
170 |
after_norm = False,
|
171 |
resize_inpaint_mask_directly: bool = False,
|
172 |
enable_clip_in_inpaint: bool = True,
|
173 |
+
position_of_clip_embedding: str = "head",
|
174 |
+
enable_zero_in_inpaint: bool = False,
|
175 |
enable_text_attention_mask: bool = True,
|
176 |
add_noise_in_inpaint_model: bool = False,
|
177 |
):
|
|
|
196 |
self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
|
197 |
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
198 |
interpolation_scale = max(interpolation_scale, 1)
|
199 |
+
self.n_query = n_query
|
200 |
|
201 |
if self.casual_3d:
|
202 |
self.pos_embed = CasualPatchEmbed3D(
|
|
|
402 |
def forward(
|
403 |
self,
|
404 |
hidden_states: torch.Tensor,
|
405 |
+
timestep: Optional[torch.LongTensor] = None,
|
406 |
+
timestep_cond = None,
|
407 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
408 |
+
text_embedding_mask: Optional[torch.Tensor] = None,
|
409 |
+
encoder_hidden_states_t5: Optional[torch.Tensor] = None,
|
410 |
+
text_embedding_mask_t5: Optional[torch.Tensor] = None,
|
411 |
+
image_meta_size = None,
|
412 |
+
style = None,
|
413 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
414 |
inpaint_latents: torch.Tensor = None,
|
415 |
control_latents: torch.Tensor = None,
|
|
|
|
|
|
|
416 |
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
417 |
class_labels: Optional[torch.LongTensor] = None,
|
418 |
cross_attention_kwargs: Dict[str, Any] = None,
|
419 |
attention_mask: Optional[torch.Tensor] = None,
|
420 |
+
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
|
421 |
clip_attention_mask: Optional[torch.Tensor] = None,
|
422 |
return_dict: bool = True,
|
423 |
):
|
|
|
443 |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
444 |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
445 |
negative values to the attention scores corresponding to "discard" tokens.
|
446 |
+
text_embedding_mask ( `torch.Tensor`, *optional*):
|
447 |
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
448 |
|
449 |
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
|
|
477 |
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
478 |
attention_mask = attention_mask.unsqueeze(1)
|
479 |
|
480 |
+
text_embedding_mask = text_embedding_mask.squeeze(1)
|
481 |
if clip_attention_mask is not None:
|
482 |
+
text_embedding_mask = torch.cat([text_embedding_mask, clip_attention_mask], dim=1)
|
483 |
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
484 |
+
if text_embedding_mask is not None and text_embedding_mask.ndim == 2:
|
485 |
+
encoder_attention_mask = (1 - text_embedding_mask.to(encoder_hidden_states.dtype)) * -10000.0
|
486 |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
487 |
|
488 |
if inpaint_latents is not None:
|
|
|
649 |
return Transformer3DModelOutput(sample=output)
|
650 |
|
651 |
@classmethod
|
652 |
+
def from_pretrained_2d(
|
653 |
+
cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={},
|
654 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
655 |
+
):
|
656 |
if subfolder is not None:
|
657 |
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
658 |
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
|
|
664 |
config = json.load(f)
|
665 |
|
666 |
from diffusers.utils import WEIGHTS_NAME
|
|
|
667 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
668 |
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
669 |
+
|
670 |
+
if low_cpu_mem_usage:
|
671 |
+
try:
|
672 |
+
import re
|
673 |
+
|
674 |
+
from diffusers.models.modeling_utils import \
|
675 |
+
load_model_dict_into_meta
|
676 |
+
from diffusers.utils import is_accelerate_available
|
677 |
+
if is_accelerate_available():
|
678 |
+
import accelerate
|
679 |
+
|
680 |
+
# Instantiate model with empty weights
|
681 |
+
with accelerate.init_empty_weights():
|
682 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
683 |
+
|
684 |
+
param_device = "cpu"
|
685 |
+
from safetensors.torch import load_file, safe_open
|
686 |
+
state_dict = load_file(model_file_safetensors)
|
687 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
688 |
+
# move the params from meta device to cpu
|
689 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
690 |
+
if len(missing_keys) > 0:
|
691 |
+
raise ValueError(
|
692 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
693 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
694 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
695 |
+
" those weights or else make sure your checkpoint file is correct."
|
696 |
+
)
|
697 |
+
|
698 |
+
unexpected_keys = load_model_dict_into_meta(
|
699 |
+
model,
|
700 |
+
state_dict,
|
701 |
+
device=param_device,
|
702 |
+
dtype=torch_dtype,
|
703 |
+
model_name_or_path=pretrained_model_path,
|
704 |
+
)
|
705 |
+
|
706 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
707 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
708 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
709 |
+
|
710 |
+
if len(unexpected_keys) > 0:
|
711 |
+
print(
|
712 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
713 |
+
)
|
714 |
+
return model
|
715 |
+
except Exception as e:
|
716 |
+
print(
|
717 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
718 |
+
)
|
719 |
+
|
720 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
721 |
+
if os.path.exists(model_file):
|
722 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
723 |
+
elif os.path.exists(model_file_safetensors):
|
724 |
from safetensors.torch import load_file, safe_open
|
725 |
state_dict = load_file(model_file_safetensors)
|
726 |
else:
|
727 |
+
from safetensors.torch import load_file, safe_open
|
728 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
729 |
+
state_dict = {}
|
730 |
+
for model_file_safetensors in model_files_safetensors:
|
731 |
+
_state_dict = load_file(model_file_safetensors)
|
732 |
+
for key in _state_dict:
|
733 |
+
state_dict[key] = _state_dict[key]
|
734 |
|
735 |
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
|
736 |
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
|
|
|
764 |
params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
|
765 |
print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
|
766 |
|
767 |
+
model = model.to(torch_dtype)
|
768 |
return model
|
769 |
|
770 |
class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
|
|
|
842 |
after_norm = False,
|
843 |
resize_inpaint_mask_directly: bool = False,
|
844 |
enable_clip_in_inpaint: bool = True,
|
845 |
+
position_of_clip_embedding: str = "full",
|
846 |
enable_text_attention_mask: bool = True,
|
847 |
add_noise_in_inpaint_model: bool = False,
|
848 |
):
|
|
|
983 |
control_latents: torch.Tensor = None,
|
984 |
clip_encoder_hidden_states: Optional[torch.Tensor]=None,
|
985 |
clip_attention_mask: Optional[torch.Tensor]=None,
|
986 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
987 |
return_dict=True,
|
988 |
):
|
989 |
"""
|
|
|
1160 |
return Transformer2DModelOutput(sample=output)
|
1161 |
|
1162 |
@classmethod
|
1163 |
+
def from_pretrained_2d(
|
1164 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
1165 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
1166 |
+
):
|
1167 |
if subfolder is not None:
|
1168 |
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
1169 |
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
|
|
1175 |
config = json.load(f)
|
1176 |
|
1177 |
from diffusers.utils import WEIGHTS_NAME
|
|
|
1178 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1179 |
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
1180 |
+
|
1181 |
+
if low_cpu_mem_usage:
|
1182 |
+
try:
|
1183 |
+
import re
|
1184 |
+
|
1185 |
+
from diffusers.models.modeling_utils import \
|
1186 |
+
load_model_dict_into_meta
|
1187 |
+
from diffusers.utils import is_accelerate_available
|
1188 |
+
if is_accelerate_available():
|
1189 |
+
import accelerate
|
1190 |
+
|
1191 |
+
# Instantiate model with empty weights
|
1192 |
+
with accelerate.init_empty_weights():
|
1193 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
1194 |
+
|
1195 |
+
param_device = "cpu"
|
1196 |
+
from safetensors.torch import load_file, safe_open
|
1197 |
+
state_dict = load_file(model_file_safetensors)
|
1198 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
1199 |
+
# move the params from meta device to cpu
|
1200 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
1201 |
+
if len(missing_keys) > 0:
|
1202 |
+
raise ValueError(
|
1203 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
1204 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
1205 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
1206 |
+
" those weights or else make sure your checkpoint file is correct."
|
1207 |
+
)
|
1208 |
+
|
1209 |
+
unexpected_keys = load_model_dict_into_meta(
|
1210 |
+
model,
|
1211 |
+
state_dict,
|
1212 |
+
device=param_device,
|
1213 |
+
dtype=torch_dtype,
|
1214 |
+
model_name_or_path=pretrained_model_path,
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
1218 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
1219 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1220 |
+
|
1221 |
+
if len(unexpected_keys) > 0:
|
1222 |
+
print(
|
1223 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1224 |
+
)
|
1225 |
+
return model
|
1226 |
+
except Exception as e:
|
1227 |
+
print(
|
1228 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
1229 |
+
)
|
1230 |
+
|
1231 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
1232 |
+
if os.path.exists(model_file):
|
1233 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
1234 |
+
elif os.path.exists(model_file_safetensors):
|
1235 |
from safetensors.torch import load_file, safe_open
|
1236 |
state_dict = load_file(model_file_safetensors)
|
1237 |
else:
|
1238 |
+
from safetensors.torch import load_file, safe_open
|
1239 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
1240 |
+
state_dict = {}
|
1241 |
+
for model_file_safetensors in model_files_safetensors:
|
1242 |
+
_state_dict = load_file(model_file_safetensors)
|
1243 |
+
for key in _state_dict:
|
1244 |
+
state_dict[key] = _state_dict[key]
|
1245 |
|
1246 |
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
|
1247 |
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
|
|
|
1291 |
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
1292 |
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
1293 |
|
1294 |
+
model = model.to(torch_dtype)
|
1295 |
return model
|
1296 |
|
1297 |
class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
|
|
1314 |
timestep_activation_fn: str = "silu",
|
1315 |
freq_shift: int = 0,
|
1316 |
num_layers: int = 30,
|
1317 |
+
mmdit_layers: int = 10000,
|
1318 |
+
swa_layers: list = None,
|
1319 |
dropout: float = 0.0,
|
1320 |
time_embed_dim: int = 512,
|
1321 |
+
add_norm_text_encoder: bool = False,
|
1322 |
text_embed_dim: int = 4096,
|
1323 |
text_embed_dim_t5: int = 4096,
|
1324 |
norm_eps: float = 1e-5,
|
|
|
1330 |
after_norm = False,
|
1331 |
resize_inpaint_mask_directly: bool = False,
|
1332 |
enable_clip_in_inpaint: bool = True,
|
1333 |
+
position_of_clip_embedding: str = "full",
|
1334 |
enable_text_attention_mask: bool = True,
|
1335 |
add_noise_in_inpaint_model: bool = False,
|
1336 |
+
add_ref_latent_in_control_model: bool = False,
|
1337 |
):
|
1338 |
super().__init__()
|
1339 |
self.num_heads = num_attention_heads
|
|
|
1352 |
self.proj = nn.Conv2d(
|
1353 |
in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
|
1354 |
)
|
1355 |
+
if not add_norm_text_encoder:
|
1356 |
+
self.text_proj = nn.Linear(text_embed_dim, self.inner_dim)
|
1357 |
+
if text_embed_dim_t5 is not None:
|
1358 |
+
self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim)
|
1359 |
+
else:
|
1360 |
+
self.text_proj = nn.Sequential(
|
1361 |
+
EasyAnimateRMSNorm(text_embed_dim),
|
1362 |
+
nn.Linear(text_embed_dim, self.inner_dim)
|
1363 |
+
)
|
1364 |
+
if text_embed_dim_t5 is not None:
|
1365 |
+
self.text_proj_t5 = nn.Sequential(
|
1366 |
+
EasyAnimateRMSNorm(text_embed_dim),
|
1367 |
+
nn.Linear(text_embed_dim_t5, self.inner_dim)
|
1368 |
+
)
|
1369 |
|
1370 |
if ref_channels is not None:
|
1371 |
self.ref_proj = nn.Conv2d(
|
|
|
1377 |
|
1378 |
if clip_channels is not None:
|
1379 |
self.clip_proj = nn.Linear(clip_channels, self.inner_dim)
|
1380 |
+
|
1381 |
+
self.swa_layers = swa_layers
|
1382 |
+
if swa_layers is not None:
|
1383 |
+
self.transformer_blocks = nn.ModuleList(
|
1384 |
+
[
|
1385 |
+
EasyAnimateDiTBlock(
|
1386 |
+
dim=self.inner_dim,
|
1387 |
+
num_attention_heads=num_attention_heads,
|
1388 |
+
attention_head_dim=attention_head_dim,
|
1389 |
+
time_embed_dim=time_embed_dim,
|
1390 |
+
dropout=dropout,
|
1391 |
+
activation_fn=activation_fn,
|
1392 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
1393 |
+
norm_eps=norm_eps,
|
1394 |
+
after_norm=after_norm,
|
1395 |
+
is_mmdit_block=True if index < mmdit_layers else False,
|
1396 |
+
is_swa=True if index in swa_layers else False,
|
1397 |
+
)
|
1398 |
+
for index in range(num_layers)
|
1399 |
+
]
|
1400 |
+
)
|
1401 |
+
else:
|
1402 |
+
self.transformer_blocks = nn.ModuleList(
|
1403 |
+
[
|
1404 |
+
EasyAnimateDiTBlock(
|
1405 |
+
dim=self.inner_dim,
|
1406 |
+
num_attention_heads=num_attention_heads,
|
1407 |
+
attention_head_dim=attention_head_dim,
|
1408 |
+
time_embed_dim=time_embed_dim,
|
1409 |
+
dropout=dropout,
|
1410 |
+
activation_fn=activation_fn,
|
1411 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
1412 |
+
norm_eps=norm_eps,
|
1413 |
+
after_norm=after_norm,
|
1414 |
+
is_mmdit_block=True if _ < mmdit_layers else False,
|
1415 |
+
)
|
1416 |
+
for _ in range(num_layers)
|
1417 |
+
]
|
1418 |
+
)
|
1419 |
self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine)
|
1420 |
|
1421 |
# 5. Output blocks
|
|
|
1450 |
ref_latents: Optional[torch.Tensor] = None,
|
1451 |
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
|
1452 |
clip_attention_mask: Optional[torch.Tensor] = None,
|
1453 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
1454 |
return_dict=True,
|
1455 |
):
|
1456 |
batch_size, channels, video_length, height, width = hidden_states.size()
|
|
|
1519 |
encoder_hidden_states,
|
1520 |
temb,
|
1521 |
image_rotary_emb,
|
1522 |
+
video_length,
|
1523 |
+
height // self.patch_size,
|
1524 |
+
width // self.patch_size,
|
1525 |
**ckpt_kwargs,
|
1526 |
)
|
1527 |
else:
|
|
|
1530 |
encoder_hidden_states=encoder_hidden_states,
|
1531 |
temb=temb,
|
1532 |
image_rotary_emb=image_rotary_emb,
|
1533 |
+
num_frames=video_length,
|
1534 |
+
height=height // self.patch_size,
|
1535 |
+
width=width // self.patch_size
|
1536 |
)
|
1537 |
|
1538 |
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
|
1553 |
return Transformer2DModelOutput(sample=output)
|
1554 |
|
1555 |
@classmethod
|
1556 |
+
def from_pretrained_2d(
|
1557 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
1558 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
1559 |
+
):
|
1560 |
if subfolder is not None:
|
1561 |
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
1562 |
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
|
|
1568 |
config = json.load(f)
|
1569 |
|
1570 |
from diffusers.utils import WEIGHTS_NAME
|
|
|
1571 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1572 |
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
1573 |
+
|
1574 |
+
if low_cpu_mem_usage:
|
1575 |
+
try:
|
1576 |
+
import re
|
1577 |
+
|
1578 |
+
from diffusers.models.modeling_utils import \
|
1579 |
+
load_model_dict_into_meta
|
1580 |
+
from diffusers.utils import is_accelerate_available
|
1581 |
+
if is_accelerate_available():
|
1582 |
+
import accelerate
|
1583 |
+
|
1584 |
+
# Instantiate model with empty weights
|
1585 |
+
with accelerate.init_empty_weights():
|
1586 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
1587 |
+
|
1588 |
+
param_device = "cpu"
|
1589 |
+
from safetensors.torch import load_file, safe_open
|
1590 |
+
state_dict = load_file(model_file_safetensors)
|
1591 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
1592 |
+
# move the params from meta device to cpu
|
1593 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
1594 |
+
if len(missing_keys) > 0:
|
1595 |
+
raise ValueError(
|
1596 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
1597 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
1598 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
1599 |
+
" those weights or else make sure your checkpoint file is correct."
|
1600 |
+
)
|
1601 |
+
|
1602 |
+
unexpected_keys = load_model_dict_into_meta(
|
1603 |
+
model,
|
1604 |
+
state_dict,
|
1605 |
+
device=param_device,
|
1606 |
+
dtype=torch_dtype,
|
1607 |
+
model_name_or_path=pretrained_model_path,
|
1608 |
+
)
|
1609 |
+
|
1610 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
1611 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
1612 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1613 |
+
|
1614 |
+
if len(unexpected_keys) > 0:
|
1615 |
+
print(
|
1616 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1617 |
+
)
|
1618 |
+
return model
|
1619 |
+
except Exception as e:
|
1620 |
+
print(
|
1621 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
1622 |
+
)
|
1623 |
+
|
1624 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
1625 |
if os.path.exists(model_file):
|
1626 |
state_dict = torch.load(model_file, map_location="cpu")
|
1627 |
elif os.path.exists(model_file_safetensors):
|
|
|
1669 |
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
1670 |
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
1671 |
|
1672 |
+
model = model.to(torch_dtype)
|
1673 |
return model
|
easyanimate/pipeline/pipeline_easyanimate.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Copyright
|
2 |
#
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# you may not use this file except in compliance with the License.
|
@@ -12,61 +12,113 @@
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
15 |
-
import copy
|
16 |
-
import html
|
17 |
import inspect
|
18 |
-
import re
|
19 |
-
import urllib.parse as ul
|
20 |
from dataclasses import dataclass
|
21 |
-
from typing import Callable, List, Optional, Tuple, Union
|
22 |
|
23 |
import numpy as np
|
24 |
import torch
|
25 |
-
|
|
|
|
|
26 |
from diffusers.image_processor import VaeImageProcessor
|
27 |
-
from diffusers.models import AutoencoderKL
|
28 |
-
from diffusers.
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
|
30 |
-
is_bs4_available, is_ftfy_available,
|
|
|
31 |
replace_example_docstring)
|
32 |
from diffusers.utils.torch_utils import randn_tensor
|
33 |
from einops import rearrange
|
|
|
34 |
from tqdm import tqdm
|
35 |
-
from transformers import
|
|
|
|
|
36 |
|
37 |
-
from ..models
|
|
|
38 |
|
39 |
-
|
|
|
40 |
|
41 |
-
|
42 |
-
|
|
|
43 |
|
44 |
-
if is_ftfy_available():
|
45 |
-
import ftfy
|
46 |
|
|
|
47 |
|
48 |
EXAMPLE_DOC_STRING = """
|
49 |
Examples:
|
50 |
-
```
|
51 |
>>> import torch
|
52 |
>>> from diffusers import EasyAnimatePipeline
|
53 |
-
|
54 |
-
|
55 |
-
>>>
|
56 |
-
>>>
|
57 |
-
>>>
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
```
|
62 |
"""
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
65 |
def retrieve_timesteps(
|
66 |
scheduler,
|
67 |
num_inference_steps: Optional[int] = None,
|
68 |
device: Optional[Union[str, torch.device]] = None,
|
69 |
timesteps: Optional[List[int]] = None,
|
|
|
70 |
**kwargs,
|
71 |
):
|
72 |
"""
|
@@ -77,19 +129,23 @@ def retrieve_timesteps(
|
|
77 |
scheduler (`SchedulerMixin`):
|
78 |
The scheduler to get timesteps from.
|
79 |
num_inference_steps (`int`):
|
80 |
-
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
81 |
-
|
82 |
device (`str` or `torch.device`, *optional*):
|
83 |
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
84 |
timesteps (`List[int]`, *optional*):
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
88 |
|
89 |
Returns:
|
90 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
91 |
second element is the number of inference steps.
|
92 |
"""
|
|
|
|
|
93 |
if timesteps is not None:
|
94 |
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
95 |
if not accepts_timesteps:
|
@@ -100,86 +156,113 @@ def retrieve_timesteps(
|
|
100 |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
101 |
timesteps = scheduler.timesteps
|
102 |
num_inference_steps = len(timesteps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
else:
|
104 |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
105 |
timesteps = scheduler.timesteps
|
106 |
return timesteps, num_inference_steps
|
107 |
|
108 |
-
@dataclass
|
109 |
-
class EasyAnimatePipelineOutput(BaseOutput):
|
110 |
-
videos: Union[torch.Tensor, np.ndarray]
|
111 |
|
112 |
class EasyAnimatePipeline(DiffusionPipeline):
|
113 |
r"""
|
114 |
-
Pipeline for text-to-
|
115 |
|
116 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
117 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
118 |
|
|
|
|
|
|
|
|
|
119 |
Args:
|
120 |
-
vae ([`
|
121 |
-
Variational Auto-Encoder (VAE) Model to encode and decode
|
122 |
-
text_encoder ([`
|
123 |
-
|
124 |
-
[
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
133 |
"""
|
134 |
-
bad_punct_regex = re.compile(
|
135 |
-
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
136 |
-
) # noqa
|
137 |
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
def __init__(
|
142 |
self,
|
143 |
-
|
144 |
-
text_encoder:
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
148 |
):
|
149 |
super().__init__()
|
150 |
|
151 |
self.register_modules(
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
)
|
154 |
|
155 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
156 |
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
157 |
-
self.enable_autocast_float8_transformer_flag = False
|
158 |
-
|
159 |
-
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
160 |
-
def mask_text_embeddings(self, emb, mask):
|
161 |
-
if emb.shape[0] == 1:
|
162 |
-
keep_index = mask.sum().item()
|
163 |
-
return emb[:, :, :keep_index, :], keep_index
|
164 |
-
else:
|
165 |
-
masked_feature = emb * mask[:, None, :, None]
|
166 |
-
return masked_feature, emb.shape[2]
|
167 |
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
def encode_prompt(
|
170 |
self,
|
171 |
-
prompt:
|
172 |
-
|
173 |
-
|
174 |
num_images_per_prompt: int = 1,
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
max_sequence_length: int =
|
182 |
-
|
|
|
183 |
):
|
184 |
r"""
|
185 |
Encodes the prompt into text encoder hidden states.
|
@@ -187,33 +270,46 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
187 |
Args:
|
188 |
prompt (`str` or `List[str]`, *optional*):
|
189 |
prompt to be encoded
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
whether to use classifier free guidance or not
|
196 |
-
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
197 |
number of images that should be generated per prompt
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
201 |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
202 |
provided, text embeddings will be generated from `prompt` input argument.
|
203 |
-
negative_prompt_embeds (`torch.
|
204 |
-
Pre-generated negative text embeddings.
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
209 |
"""
|
|
|
|
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
214 |
|
215 |
-
if
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
if prompt is not None and isinstance(prompt, str):
|
219 |
batch_size = 1
|
@@ -222,74 +318,199 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
222 |
else:
|
223 |
batch_size = prompt_embeds.shape[0]
|
224 |
|
225 |
-
# See Section 3.1. of the paper.
|
226 |
-
max_length = max_sequence_length
|
227 |
-
|
228 |
if prompt_embeds is None:
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
)
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
264 |
|
265 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
266 |
-
# duplicate text embeddings
|
267 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
268 |
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
269 |
-
prompt_attention_mask = prompt_attention_mask.
|
270 |
-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
271 |
|
272 |
# get unconditional embeddings for classifier free guidance
|
273 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
if do_classifier_free_guidance:
|
295 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
@@ -299,14 +520,9 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
299 |
|
300 |
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
301 |
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
else:
|
306 |
-
negative_prompt_embeds = None
|
307 |
-
negative_prompt_attention_mask = None
|
308 |
-
|
309 |
-
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
310 |
|
311 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
312 |
def prepare_extra_step_kwargs(self, generator, eta):
|
@@ -331,20 +547,25 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
331 |
prompt,
|
332 |
height,
|
333 |
width,
|
334 |
-
negative_prompt,
|
335 |
-
callback_steps,
|
336 |
prompt_embeds=None,
|
337 |
negative_prompt_embeds=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
):
|
339 |
-
if height %
|
340 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
341 |
|
342 |
-
if
|
343 |
-
|
344 |
):
|
345 |
raise ValueError(
|
346 |
-
f"`
|
347 |
-
f" {type(callback_steps)}."
|
348 |
)
|
349 |
|
350 |
if prompt is not None and prompt_embeds is not None:
|
@@ -356,14 +577,18 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
356 |
raise ValueError(
|
357 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
358 |
)
|
|
|
|
|
|
|
|
|
359 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
360 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
361 |
|
362 |
-
if
|
363 |
-
raise ValueError(
|
364 |
-
|
365 |
-
|
366 |
-
)
|
367 |
|
368 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
369 |
raise ValueError(
|
@@ -371,6 +596,13 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
371 |
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
372 |
)
|
373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
375 |
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
376 |
raise ValueError(
|
@@ -378,153 +610,25 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
378 |
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
379 |
f" {negative_prompt_embeds.shape}."
|
380 |
)
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
if clean_caption and not is_ftfy_available():
|
390 |
-
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
391 |
-
logger.warn("Setting `clean_caption` to False...")
|
392 |
-
clean_caption = False
|
393 |
-
|
394 |
-
if not isinstance(text, (tuple, list)):
|
395 |
-
text = [text]
|
396 |
-
|
397 |
-
def process(text: str):
|
398 |
-
if clean_caption:
|
399 |
-
text = self._clean_caption(text)
|
400 |
-
text = self._clean_caption(text)
|
401 |
-
else:
|
402 |
-
text = text.lower().strip()
|
403 |
-
return text
|
404 |
-
|
405 |
-
return [process(t) for t in text]
|
406 |
-
|
407 |
-
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
408 |
-
def _clean_caption(self, caption):
|
409 |
-
caption = str(caption)
|
410 |
-
caption = ul.unquote_plus(caption)
|
411 |
-
caption = caption.strip().lower()
|
412 |
-
caption = re.sub("<person>", "person", caption)
|
413 |
-
# urls:
|
414 |
-
caption = re.sub(
|
415 |
-
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
416 |
-
"",
|
417 |
-
caption,
|
418 |
-
) # regex for urls
|
419 |
-
caption = re.sub(
|
420 |
-
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
421 |
-
"",
|
422 |
-
caption,
|
423 |
-
) # regex for urls
|
424 |
-
# html:
|
425 |
-
caption = BeautifulSoup(caption, features="html.parser").text
|
426 |
-
|
427 |
-
# @<nickname>
|
428 |
-
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
429 |
-
|
430 |
-
# 31C0—31EF CJK Strokes
|
431 |
-
# 31F0—31FF Katakana Phonetic Extensions
|
432 |
-
# 3200—32FF Enclosed CJK Letters and Months
|
433 |
-
# 3300—33FF CJK Compatibility
|
434 |
-
# 3400—4DBF CJK Unified Ideographs Extension A
|
435 |
-
# 4DC0—4DFF Yijing Hexagram Symbols
|
436 |
-
# 4E00—9FFF CJK Unified Ideographs
|
437 |
-
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
438 |
-
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
439 |
-
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
440 |
-
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
441 |
-
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
442 |
-
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
443 |
-
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
444 |
-
#######################################################
|
445 |
-
|
446 |
-
# все виды тире / all types of dash --> "-"
|
447 |
-
caption = re.sub(
|
448 |
-
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
449 |
-
"-",
|
450 |
-
caption,
|
451 |
-
)
|
452 |
-
|
453 |
-
# кавычки к одному стандарту
|
454 |
-
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
455 |
-
caption = re.sub(r"[‘’]", "'", caption)
|
456 |
-
|
457 |
-
# "
|
458 |
-
caption = re.sub(r""?", "", caption)
|
459 |
-
# &
|
460 |
-
caption = re.sub(r"&", "", caption)
|
461 |
-
|
462 |
-
# ip adresses:
|
463 |
-
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
464 |
-
|
465 |
-
# article ids:
|
466 |
-
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
467 |
-
|
468 |
-
# \n
|
469 |
-
caption = re.sub(r"\\n", " ", caption)
|
470 |
-
|
471 |
-
# "#123"
|
472 |
-
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
473 |
-
# "#12345.."
|
474 |
-
caption = re.sub(r"#\d{5,}\b", "", caption)
|
475 |
-
# "123456.."
|
476 |
-
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
477 |
-
# filenames:
|
478 |
-
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
479 |
-
|
480 |
-
#
|
481 |
-
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
482 |
-
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
483 |
-
|
484 |
-
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
485 |
-
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
486 |
-
|
487 |
-
# this-is-my-cute-cat / this_is_my_cute_cat
|
488 |
-
regex2 = re.compile(r"(?:\-|\_)")
|
489 |
-
if len(re.findall(regex2, caption)) > 3:
|
490 |
-
caption = re.sub(regex2, " ", caption)
|
491 |
-
|
492 |
-
caption = ftfy.fix_text(caption)
|
493 |
-
caption = html.unescape(html.unescape(caption))
|
494 |
-
|
495 |
-
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
496 |
-
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
497 |
-
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
498 |
-
|
499 |
-
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
500 |
-
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
501 |
-
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
502 |
-
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
503 |
-
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
504 |
-
|
505 |
-
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
506 |
-
|
507 |
-
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
508 |
-
|
509 |
-
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
510 |
-
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
511 |
-
caption = re.sub(r"\s+", " ", caption)
|
512 |
-
|
513 |
-
caption.strip()
|
514 |
-
|
515 |
-
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
516 |
-
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
517 |
-
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
518 |
-
caption = re.sub(r"^\.\S+$", "", caption)
|
519 |
-
|
520 |
-
return caption.strip()
|
521 |
|
522 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
523 |
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
524 |
-
if self.vae.quant_conv.weight.ndim==5:
|
525 |
-
|
526 |
-
|
527 |
-
|
|
|
|
|
|
|
|
|
|
|
528 |
else:
|
529 |
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
530 |
|
@@ -538,11 +642,12 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
538 |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
539 |
else:
|
540 |
latents = latents.to(device)
|
541 |
-
|
542 |
# scale the initial noise by the standard deviation required by the scheduler
|
543 |
-
|
|
|
544 |
return latents
|
545 |
-
|
546 |
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
547 |
if video.size()[2] <= mini_batch_encoder:
|
548 |
return video
|
@@ -558,16 +663,17 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
558 |
|
559 |
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
560 |
return video
|
561 |
-
|
562 |
def decode_latents(self, latents):
|
563 |
video_length = latents.shape[2]
|
564 |
latents = 1 / self.vae.config.scaling_factor * latents
|
565 |
-
if self.vae.quant_conv.weight.ndim==5:
|
566 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
567 |
mini_batch_decoder = self.vae.mini_batch_decoder
|
568 |
video = self.vae.decode(latents)[0]
|
569 |
video = video.clamp(-1, 1)
|
570 |
-
|
|
|
571 |
else:
|
572 |
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
573 |
video = []
|
@@ -580,8 +686,28 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
580 |
video = video.cpu().float().numpy()
|
581 |
return video
|
582 |
|
583 |
-
|
584 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
|
586 |
@torch.no_grad()
|
587 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
@@ -589,103 +715,131 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
589 |
self,
|
590 |
prompt: Union[str, List[str]] = None,
|
591 |
video_length: Optional[int] = None,
|
592 |
-
negative_prompt: str = "",
|
593 |
-
num_inference_steps: int = 20,
|
594 |
-
timesteps: List[int] = None,
|
595 |
-
guidance_scale: float = 4.5,
|
596 |
-
num_images_per_prompt: Optional[int] = 1,
|
597 |
height: Optional[int] = None,
|
598 |
width: Optional[int] = None,
|
599 |
-
|
|
|
|
|
|
|
|
|
600 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
601 |
-
latents: Optional[torch.
|
602 |
-
prompt_embeds: Optional[torch.
|
603 |
-
|
604 |
-
negative_prompt_embeds: Optional[torch.
|
605 |
-
|
|
|
|
|
|
|
|
|
606 |
output_type: Optional[str] = "latent",
|
607 |
return_dict: bool = True,
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
|
|
|
|
|
|
|
|
612 |
comfyui_progressbar: bool = False,
|
613 |
-
|
614 |
-
)
|
615 |
-
"""
|
616 |
-
|
617 |
-
|
618 |
-
Args:
|
619 |
-
prompt (`str` or `List[str]`, *optional*):
|
620 |
-
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
621 |
-
instead.
|
622 |
-
negative_prompt (`str` or `List[str]`, *optional*):
|
623 |
-
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
624 |
-
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
625 |
-
less than `1`).
|
626 |
-
num_inference_steps (`int`, *optional*, defaults to 100):
|
627 |
-
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
628 |
-
expense of slower inference.
|
629 |
-
timesteps (`List[int]`, *optional*):
|
630 |
-
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
631 |
-
timesteps are used. Must be in descending order.
|
632 |
-
guidance_scale (`float`, *optional*, defaults to 7.0):
|
633 |
-
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
634 |
-
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
635 |
-
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
636 |
-
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
637 |
-
usually at the expense of lower image quality.
|
638 |
-
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
639 |
-
The number of images to generate per prompt.
|
640 |
-
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
641 |
-
The height in pixels of the generated image.
|
642 |
-
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
643 |
-
The width in pixels of the generated image.
|
644 |
-
eta (`float`, *optional*, defaults to 0.0):
|
645 |
-
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
646 |
-
[`schedulers.DDIMScheduler`], will be ignored for others.
|
647 |
-
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
648 |
-
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
649 |
-
to make generation deterministic.
|
650 |
-
latents (`torch.FloatTensor`, *optional*):
|
651 |
-
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
652 |
-
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
653 |
-
tensor will ge generated by sampling using the supplied random `generator`.
|
654 |
-
prompt_embeds (`torch.FloatTensor`, *optional*):
|
655 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
656 |
-
provided, text embeddings will be generated from `prompt` input argument.
|
657 |
-
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
658 |
-
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
659 |
-
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
660 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
661 |
-
The output format of the generate image. Choose between
|
662 |
-
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
663 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
664 |
-
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
665 |
-
callback (`Callable`, *optional*):
|
666 |
-
A function that will be called every `callback_steps` steps during inference. The function will be
|
667 |
-
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
668 |
-
callback_steps (`int`, *optional*, defaults to 1):
|
669 |
-
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
670 |
-
called at every step.
|
671 |
-
clean_caption (`bool`, *optional*, defaults to `True`):
|
672 |
-
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
673 |
-
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
674 |
-
prompt.
|
675 |
-
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
|
676 |
|
677 |
Examples:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
|
679 |
Returns:
|
680 |
-
[`~pipelines.
|
681 |
-
If `return_dict` is `True`, [`~pipelines.
|
682 |
-
returned where the first element is a list with the generated images
|
|
|
|
|
683 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
# 1. Check inputs. Raise error if not correct
|
685 |
-
|
686 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
|
688 |
-
# 2.
|
689 |
if prompt is not None and isinstance(prompt, str):
|
690 |
batch_size = 1
|
691 |
elif prompt is not None and isinstance(prompt, list):
|
@@ -694,136 +848,223 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
694 |
batch_size = prompt_embeds.shape[0]
|
695 |
|
696 |
device = self._execution_device
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
|
|
702 |
|
703 |
# 3. Encode input prompt
|
704 |
(
|
705 |
prompt_embeds,
|
706 |
-
prompt_attention_mask,
|
707 |
negative_prompt_embeds,
|
|
|
708 |
negative_prompt_attention_mask,
|
709 |
) = self.encode_prompt(
|
710 |
-
prompt,
|
711 |
-
do_classifier_free_guidance,
|
712 |
-
negative_prompt=negative_prompt,
|
713 |
-
num_images_per_prompt=num_images_per_prompt,
|
714 |
device=device,
|
|
|
|
|
|
|
|
|
715 |
prompt_embeds=prompt_embeds,
|
716 |
negative_prompt_embeds=negative_prompt_embeds,
|
717 |
prompt_attention_mask=prompt_attention_mask,
|
718 |
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
719 |
-
|
720 |
-
max_sequence_length=max_sequence_length,
|
721 |
)
|
722 |
-
if
|
723 |
-
|
724 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
|
726 |
# 4. Prepare timesteps
|
727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
728 |
|
729 |
-
# 5. Prepare
|
730 |
-
|
731 |
latents = self.prepare_latents(
|
732 |
batch_size * num_images_per_prompt,
|
733 |
-
|
734 |
video_length,
|
735 |
height,
|
736 |
width,
|
737 |
-
|
738 |
device,
|
739 |
generator,
|
740 |
latents,
|
741 |
)
|
|
|
|
|
742 |
|
743 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
744 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
745 |
|
746 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
747 |
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
748 |
-
if self.transformer.config.sample_size == 128:
|
749 |
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
750 |
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
751 |
-
resolution = resolution.to(dtype=
|
752 |
-
aspect_ratio = aspect_ratio.to(dtype=
|
753 |
|
754 |
-
if do_classifier_free_guidance:
|
755 |
resolution = torch.cat([resolution, resolution], dim=0)
|
756 |
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
757 |
|
758 |
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
759 |
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
770 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
771 |
for i, t in enumerate(timesteps):
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
current_timestep = current_timestep[None].to(latent_model_input.device)
|
787 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
788 |
-
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
789 |
-
|
790 |
-
# predict noise model_output
|
791 |
noise_pred = self.transformer(
|
792 |
latent_model_input,
|
|
|
793 |
encoder_hidden_states=prompt_embeds,
|
794 |
-
|
795 |
-
|
|
|
|
|
|
|
|
|
796 |
added_cond_kwargs=added_cond_kwargs,
|
797 |
return_dict=False,
|
798 |
)[0]
|
|
|
|
|
|
|
799 |
|
800 |
# perform guidance
|
801 |
-
if do_classifier_free_guidance:
|
802 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
803 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
804 |
|
805 |
-
|
806 |
-
|
807 |
-
noise_pred = noise_pred
|
808 |
-
else:
|
809 |
-
noise_pred = noise_pred
|
810 |
|
811 |
-
# compute previous
|
812 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
813 |
|
814 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
815 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
816 |
progress_bar.update()
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
|
821 |
if comfyui_progressbar:
|
822 |
pbar.update(1)
|
823 |
|
824 |
-
if self.enable_autocast_float8_transformer_flag:
|
825 |
-
self.transformer = self.transformer.to("cpu", origin_weight_dtype)
|
826 |
-
|
827 |
# Post-processing
|
828 |
video = self.decode_latents(latents)
|
829 |
|
@@ -831,7 +1072,10 @@ class EasyAnimatePipeline(DiffusionPipeline):
|
|
831 |
if output_type == "latent":
|
832 |
video = torch.from_numpy(video)
|
833 |
|
|
|
|
|
|
|
834 |
if not return_dict:
|
835 |
return video
|
836 |
|
837 |
-
return EasyAnimatePipelineOutput(
|
|
|
1 |
+
# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
|
2 |
#
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# you may not use this file except in compliance with the License.
|
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
|
|
|
|
15 |
import inspect
|
|
|
|
|
16 |
from dataclasses import dataclass
|
17 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
18 |
|
19 |
import numpy as np
|
20 |
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from diffusers import DiffusionPipeline
|
23 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
24 |
from diffusers.image_processor import VaeImageProcessor
|
25 |
+
from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
|
26 |
+
from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
|
27 |
+
get_3d_rotary_pos_embed)
|
28 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
29 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
30 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
31 |
+
StableDiffusionSafetyChecker
|
32 |
+
from diffusers.schedulers import DDIMScheduler, FlowMatchEulerDiscreteScheduler
|
33 |
from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
|
34 |
+
is_bs4_available, is_ftfy_available,
|
35 |
+
is_torch_xla_available, logging,
|
36 |
replace_example_docstring)
|
37 |
from diffusers.utils.torch_utils import randn_tensor
|
38 |
from einops import rearrange
|
39 |
+
from PIL import Image
|
40 |
from tqdm import tqdm
|
41 |
+
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
42 |
+
Qwen2Tokenizer, Qwen2VLForConditionalGeneration,
|
43 |
+
T5EncoderModel, T5Tokenizer)
|
44 |
|
45 |
+
from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
46 |
+
from .pipeline_easyanimate_inpaint import EasyAnimatePipelineOutput
|
47 |
|
48 |
+
if is_torch_xla_available():
|
49 |
+
import torch_xla.core.xla_model as xm
|
50 |
|
51 |
+
XLA_AVAILABLE = True
|
52 |
+
else:
|
53 |
+
XLA_AVAILABLE = False
|
54 |
|
|
|
|
|
55 |
|
56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
57 |
|
58 |
EXAMPLE_DOC_STRING = """
|
59 |
Examples:
|
60 |
+
```python
|
61 |
>>> import torch
|
62 |
>>> from diffusers import EasyAnimatePipeline
|
63 |
+
>>> from diffusers.utils import export_to_video
|
64 |
+
|
65 |
+
>>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh" or "alibaba-pai/EasyAnimateV5.1-7b-zh"
|
66 |
+
>>> pipe = EasyAnimatePipeline.from_pretrained("alibaba-pai/EasyAnimateV5.1-7b-zh", torch_dtype=torch.float16).to("cuda")
|
67 |
+
>>> prompt = (
|
68 |
+
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
69 |
+
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
70 |
+
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
71 |
+
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
72 |
+
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
73 |
+
... "atmosphere of this unique musical performance."
|
74 |
+
... )
|
75 |
+
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).sample[0]
|
76 |
+
>>> export_to_video(video, "output.mp4", fps=8)
|
77 |
```
|
78 |
"""
|
79 |
|
80 |
+
|
81 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
82 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
83 |
+
tw = tgt_width
|
84 |
+
th = tgt_height
|
85 |
+
h, w = src
|
86 |
+
r = h / w
|
87 |
+
if r > (th / tw):
|
88 |
+
resize_height = th
|
89 |
+
resize_width = int(round(th / h * w))
|
90 |
+
else:
|
91 |
+
resize_width = tw
|
92 |
+
resize_height = int(round(tw / w * h))
|
93 |
+
|
94 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
95 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
96 |
+
|
97 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
98 |
+
|
99 |
+
|
100 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
101 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
102 |
+
"""
|
103 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
104 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
105 |
+
"""
|
106 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
107 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
108 |
+
# rescale the results from guidance (fixes overexposure)
|
109 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
110 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
111 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
112 |
+
return noise_cfg
|
113 |
+
|
114 |
+
|
115 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
116 |
def retrieve_timesteps(
|
117 |
scheduler,
|
118 |
num_inference_steps: Optional[int] = None,
|
119 |
device: Optional[Union[str, torch.device]] = None,
|
120 |
timesteps: Optional[List[int]] = None,
|
121 |
+
sigmas: Optional[List[float]] = None,
|
122 |
**kwargs,
|
123 |
):
|
124 |
"""
|
|
|
129 |
scheduler (`SchedulerMixin`):
|
130 |
The scheduler to get timesteps from.
|
131 |
num_inference_steps (`int`):
|
132 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
133 |
+
must be `None`.
|
134 |
device (`str` or `torch.device`, *optional*):
|
135 |
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
136 |
timesteps (`List[int]`, *optional*):
|
137 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
138 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
139 |
+
sigmas (`List[float]`, *optional*):
|
140 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
141 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
142 |
|
143 |
Returns:
|
144 |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
145 |
second element is the number of inference steps.
|
146 |
"""
|
147 |
+
if timesteps is not None and sigmas is not None:
|
148 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
149 |
if timesteps is not None:
|
150 |
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
151 |
if not accepts_timesteps:
|
|
|
156 |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
157 |
timesteps = scheduler.timesteps
|
158 |
num_inference_steps = len(timesteps)
|
159 |
+
elif sigmas is not None:
|
160 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
161 |
+
if not accept_sigmas:
|
162 |
+
raise ValueError(
|
163 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
164 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
165 |
+
)
|
166 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
167 |
+
timesteps = scheduler.timesteps
|
168 |
+
num_inference_steps = len(timesteps)
|
169 |
else:
|
170 |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
171 |
timesteps = scheduler.timesteps
|
172 |
return timesteps, num_inference_steps
|
173 |
|
|
|
|
|
|
|
174 |
|
175 |
class EasyAnimatePipeline(DiffusionPipeline):
|
176 |
r"""
|
177 |
+
Pipeline for text-to-video generation using EasyAnimate.
|
178 |
|
179 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
180 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
181 |
|
182 |
+
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
183 |
+
EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
|
184 |
+
HunyuanDiT team) in V5.
|
185 |
+
|
186 |
Args:
|
187 |
+
vae ([`AutoencoderKLMagvit`]):
|
188 |
+
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
189 |
+
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
|
190 |
+
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
191 |
+
EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5.
|
192 |
+
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
|
193 |
+
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
|
194 |
+
transformer ([`EasyAnimateTransformer3DModel`]):
|
195 |
+
The EasyAnimate model designed by EasyAnimate Team.
|
196 |
+
text_encoder_2 (`T5EncoderModel`):
|
197 |
+
EasyAnimate does not use text_encoder_2 in V5.1.
|
198 |
+
EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5.
|
199 |
+
tokenizer_2 (`T5Tokenizer`):
|
200 |
+
The tokenizer for the mT5 embedder.
|
201 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
202 |
+
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
203 |
"""
|
|
|
|
|
|
|
204 |
|
205 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
206 |
+
_optional_components = [
|
207 |
+
"text_encoder_2",
|
208 |
+
"tokenizer_2",
|
209 |
+
"text_encoder",
|
210 |
+
"tokenizer",
|
211 |
+
]
|
212 |
+
_callback_tensor_inputs = [
|
213 |
+
"latents",
|
214 |
+
"prompt_embeds",
|
215 |
+
"negative_prompt_embeds",
|
216 |
+
"prompt_embeds_2",
|
217 |
+
"negative_prompt_embeds_2",
|
218 |
+
]
|
219 |
|
220 |
def __init__(
|
221 |
self,
|
222 |
+
vae: AutoencoderKLMagvit,
|
223 |
+
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
|
224 |
+
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
|
225 |
+
text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]],
|
226 |
+
tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]],
|
227 |
+
transformer: EasyAnimateTransformer3DModel,
|
228 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
229 |
):
|
230 |
super().__init__()
|
231 |
|
232 |
self.register_modules(
|
233 |
+
vae=vae,
|
234 |
+
text_encoder=text_encoder,
|
235 |
+
text_encoder_2=text_encoder_2,
|
236 |
+
tokenizer=tokenizer,
|
237 |
+
tokenizer_2=tokenizer_2,
|
238 |
+
transformer=transformer,
|
239 |
+
scheduler=scheduler,
|
240 |
)
|
241 |
|
242 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
+
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
245 |
+
super().enable_sequential_cpu_offload(*args, **kwargs)
|
246 |
+
if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
|
247 |
+
import accelerate
|
248 |
+
accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
|
249 |
+
self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
|
250 |
+
|
251 |
def encode_prompt(
|
252 |
self,
|
253 |
+
prompt: str,
|
254 |
+
device: torch.device,
|
255 |
+
dtype: torch.dtype,
|
256 |
num_images_per_prompt: int = 1,
|
257 |
+
do_classifier_free_guidance: bool = True,
|
258 |
+
negative_prompt: Optional[str] = None,
|
259 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
260 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
261 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
262 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
263 |
+
max_sequence_length: Optional[int] = None,
|
264 |
+
text_encoder_index: int = 0,
|
265 |
+
actual_max_sequence_length: int = 256
|
266 |
):
|
267 |
r"""
|
268 |
Encodes the prompt into text encoder hidden states.
|
|
|
270 |
Args:
|
271 |
prompt (`str` or `List[str]`, *optional*):
|
272 |
prompt to be encoded
|
273 |
+
device: (`torch.device`):
|
274 |
+
torch device
|
275 |
+
dtype (`torch.dtype`):
|
276 |
+
torch dtype
|
277 |
+
num_images_per_prompt (`int`):
|
|
|
|
|
278 |
number of images that should be generated per prompt
|
279 |
+
do_classifier_free_guidance (`bool`):
|
280 |
+
whether to use classifier free guidance or not
|
281 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
282 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
283 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
284 |
+
less than `1`).
|
285 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
286 |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
287 |
provided, text embeddings will be generated from `prompt` input argument.
|
288 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
289 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
290 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
291 |
+
argument.
|
292 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
293 |
+
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
294 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
295 |
+
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
296 |
+
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
297 |
+
text_encoder_index (`int`, *optional*):
|
298 |
+
Index of the text encoder to use. `0` for clip and `1` for T5.
|
299 |
"""
|
300 |
+
tokenizers = [self.tokenizer, self.tokenizer_2]
|
301 |
+
text_encoders = [self.text_encoder, self.text_encoder_2]
|
302 |
|
303 |
+
tokenizer = tokenizers[text_encoder_index]
|
304 |
+
text_encoder = text_encoders[text_encoder_index]
|
|
|
305 |
|
306 |
+
if max_sequence_length is None:
|
307 |
+
if text_encoder_index == 0:
|
308 |
+
max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
|
309 |
+
if text_encoder_index == 1:
|
310 |
+
max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
|
311 |
+
else:
|
312 |
+
max_length = max_sequence_length
|
313 |
|
314 |
if prompt is not None and isinstance(prompt, str):
|
315 |
batch_size = 1
|
|
|
318 |
else:
|
319 |
batch_size = prompt_embeds.shape[0]
|
320 |
|
|
|
|
|
|
|
321 |
if prompt_embeds is None:
|
322 |
+
if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
|
323 |
+
text_inputs = tokenizer(
|
324 |
+
prompt,
|
325 |
+
padding="max_length",
|
326 |
+
max_length=max_length,
|
327 |
+
truncation=True,
|
328 |
+
return_attention_mask=True,
|
329 |
+
return_tensors="pt",
|
330 |
+
)
|
331 |
+
text_input_ids = text_inputs.input_ids
|
332 |
+
if text_input_ids.shape[-1] > actual_max_sequence_length:
|
333 |
+
reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
334 |
+
text_inputs = tokenizer(
|
335 |
+
reprompt,
|
336 |
+
padding="max_length",
|
337 |
+
max_length=max_length,
|
338 |
+
truncation=True,
|
339 |
+
return_attention_mask=True,
|
340 |
+
return_tensors="pt",
|
341 |
+
)
|
342 |
+
text_input_ids = text_inputs.input_ids
|
343 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
344 |
+
|
345 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
346 |
+
text_input_ids, untruncated_ids
|
347 |
+
):
|
348 |
+
_actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
|
349 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
|
350 |
+
logger.warning(
|
351 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
352 |
+
f" {_actual_max_sequence_length} tokens: {removed_text}"
|
353 |
+
)
|
354 |
+
|
355 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
356 |
+
|
357 |
+
if self.transformer.config.enable_text_attention_mask:
|
358 |
+
prompt_embeds = text_encoder(
|
359 |
+
text_input_ids.to(device),
|
360 |
+
attention_mask=prompt_attention_mask,
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
prompt_embeds = text_encoder(
|
364 |
+
text_input_ids.to(device)
|
365 |
+
)
|
366 |
+
prompt_embeds = prompt_embeds[0]
|
367 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
368 |
+
else:
|
369 |
+
if prompt is not None and isinstance(prompt, str):
|
370 |
+
messages = [
|
371 |
+
{
|
372 |
+
"role": "user",
|
373 |
+
"content": [{"type": "text", "text": prompt}],
|
374 |
+
}
|
375 |
+
]
|
376 |
+
else:
|
377 |
+
messages = [
|
378 |
+
{
|
379 |
+
"role": "user",
|
380 |
+
"content": [{"type": "text", "text": _prompt}],
|
381 |
+
} for _prompt in prompt
|
382 |
+
]
|
383 |
+
text = tokenizer.apply_chat_template(
|
384 |
+
messages, tokenize=False, add_generation_prompt=True
|
385 |
)
|
386 |
|
387 |
+
text_inputs = tokenizer(
|
388 |
+
text=[text],
|
389 |
+
padding="max_length",
|
390 |
+
max_length=max_length,
|
391 |
+
truncation=True,
|
392 |
+
return_attention_mask=True,
|
393 |
+
padding_side="right",
|
394 |
+
return_tensors="pt",
|
395 |
+
)
|
396 |
+
text_inputs = text_inputs.to(text_encoder.device)
|
397 |
+
|
398 |
+
text_input_ids = text_inputs.input_ids
|
399 |
+
prompt_attention_mask = text_inputs.attention_mask
|
400 |
+
if self.transformer.config.enable_text_attention_mask:
|
401 |
+
# Inference: Generation of the output
|
402 |
+
prompt_embeds = text_encoder(
|
403 |
+
input_ids=text_input_ids,
|
404 |
+
attention_mask=prompt_attention_mask,
|
405 |
+
output_hidden_states=True).hidden_states[-2]
|
406 |
+
else:
|
407 |
+
raise ValueError("LLM needs attention_mask")
|
408 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
409 |
+
|
410 |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
411 |
|
412 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
413 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
414 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
415 |
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
416 |
+
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
|
|
417 |
|
418 |
# get unconditional embeddings for classifier free guidance
|
419 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
420 |
+
if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
|
421 |
+
uncond_tokens: List[str]
|
422 |
+
if negative_prompt is None:
|
423 |
+
uncond_tokens = [""] * batch_size
|
424 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
425 |
+
raise TypeError(
|
426 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
427 |
+
f" {type(prompt)}."
|
428 |
+
)
|
429 |
+
elif isinstance(negative_prompt, str):
|
430 |
+
uncond_tokens = [negative_prompt]
|
431 |
+
elif batch_size != len(negative_prompt):
|
432 |
+
raise ValueError(
|
433 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
434 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
435 |
+
" the batch size of `prompt`."
|
436 |
+
)
|
437 |
+
else:
|
438 |
+
uncond_tokens = negative_prompt
|
439 |
+
|
440 |
+
max_length = prompt_embeds.shape[1]
|
441 |
+
uncond_input = tokenizer(
|
442 |
+
uncond_tokens,
|
443 |
+
padding="max_length",
|
444 |
+
max_length=max_length,
|
445 |
+
truncation=True,
|
446 |
+
return_tensors="pt",
|
447 |
+
)
|
448 |
+
uncond_input_ids = uncond_input.input_ids
|
449 |
+
if uncond_input_ids.shape[-1] > actual_max_sequence_length:
|
450 |
+
reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
451 |
+
uncond_input = tokenizer(
|
452 |
+
reuncond_tokens,
|
453 |
+
padding="max_length",
|
454 |
+
max_length=max_length,
|
455 |
+
truncation=True,
|
456 |
+
return_attention_mask=True,
|
457 |
+
return_tensors="pt",
|
458 |
+
)
|
459 |
+
uncond_input_ids = uncond_input.input_ids
|
460 |
+
|
461 |
+
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
|
462 |
+
if self.transformer.config.enable_text_attention_mask:
|
463 |
+
negative_prompt_embeds = text_encoder(
|
464 |
+
uncond_input.input_ids.to(device),
|
465 |
+
attention_mask=negative_prompt_attention_mask,
|
466 |
+
)
|
467 |
+
else:
|
468 |
+
negative_prompt_embeds = text_encoder(
|
469 |
+
uncond_input.input_ids.to(device)
|
470 |
+
)
|
471 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
472 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
473 |
+
else:
|
474 |
+
if negative_prompt is not None and isinstance(negative_prompt, str):
|
475 |
+
messages = [
|
476 |
+
{
|
477 |
+
"role": "user",
|
478 |
+
"content": [{"type": "text", "text": negative_prompt}],
|
479 |
+
}
|
480 |
+
]
|
481 |
+
else:
|
482 |
+
messages = [
|
483 |
+
{
|
484 |
+
"role": "user",
|
485 |
+
"content": [{"type": "text", "text": _negative_prompt}],
|
486 |
+
} for _negative_prompt in negative_prompt
|
487 |
+
]
|
488 |
+
text = tokenizer.apply_chat_template(
|
489 |
+
messages, tokenize=False, add_generation_prompt=True
|
490 |
+
)
|
491 |
|
492 |
+
text_inputs = tokenizer(
|
493 |
+
text=[text],
|
494 |
+
padding="max_length",
|
495 |
+
max_length=max_length,
|
496 |
+
truncation=True,
|
497 |
+
return_attention_mask=True,
|
498 |
+
padding_side="right",
|
499 |
+
return_tensors="pt",
|
500 |
+
)
|
501 |
+
text_inputs = text_inputs.to(text_encoder.device)
|
502 |
+
|
503 |
+
text_input_ids = text_inputs.input_ids
|
504 |
+
negative_prompt_attention_mask = text_inputs.attention_mask
|
505 |
+
if self.transformer.config.enable_text_attention_mask:
|
506 |
+
# Inference: Generation of the output
|
507 |
+
negative_prompt_embeds = text_encoder(
|
508 |
+
input_ids=text_input_ids,
|
509 |
+
attention_mask=negative_prompt_attention_mask,
|
510 |
+
output_hidden_states=True).hidden_states[-2]
|
511 |
+
else:
|
512 |
+
raise ValueError("LLM needs attention_mask")
|
513 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
514 |
|
515 |
if do_classifier_free_guidance:
|
516 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
|
520 |
|
521 |
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
522 |
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
523 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
|
524 |
+
|
525 |
+
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
|
|
|
|
|
|
|
|
|
|
526 |
|
527 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
528 |
def prepare_extra_step_kwargs(self, generator, eta):
|
|
|
547 |
prompt,
|
548 |
height,
|
549 |
width,
|
550 |
+
negative_prompt=None,
|
|
|
551 |
prompt_embeds=None,
|
552 |
negative_prompt_embeds=None,
|
553 |
+
prompt_attention_mask=None,
|
554 |
+
negative_prompt_attention_mask=None,
|
555 |
+
prompt_embeds_2=None,
|
556 |
+
negative_prompt_embeds_2=None,
|
557 |
+
prompt_attention_mask_2=None,
|
558 |
+
negative_prompt_attention_mask_2=None,
|
559 |
+
callback_on_step_end_tensor_inputs=None,
|
560 |
):
|
561 |
+
if height % 16 != 0 or width % 16 != 0:
|
562 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
563 |
|
564 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
565 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
566 |
):
|
567 |
raise ValueError(
|
568 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
|
|
569 |
)
|
570 |
|
571 |
if prompt is not None and prompt_embeds is not None:
|
|
|
577 |
raise ValueError(
|
578 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
579 |
)
|
580 |
+
elif prompt is None and prompt_embeds_2 is None:
|
581 |
+
raise ValueError(
|
582 |
+
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
583 |
+
)
|
584 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
585 |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
586 |
|
587 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
588 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
589 |
+
|
590 |
+
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
|
591 |
+
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
|
592 |
|
593 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
594 |
raise ValueError(
|
|
|
596 |
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
597 |
)
|
598 |
|
599 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
600 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
601 |
+
|
602 |
+
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
|
603 |
+
raise ValueError(
|
604 |
+
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
|
605 |
+
)
|
606 |
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
607 |
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
608 |
raise ValueError(
|
|
|
610 |
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
611 |
f" {negative_prompt_embeds.shape}."
|
612 |
)
|
613 |
+
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
|
614 |
+
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
|
615 |
+
raise ValueError(
|
616 |
+
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
|
617 |
+
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
|
618 |
+
f" {negative_prompt_embeds_2.shape}."
|
619 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
|
621 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
622 |
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
623 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
624 |
+
if self.vae.cache_mag_vae:
|
625 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
626 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
627 |
+
shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
628 |
+
else:
|
629 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
630 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
631 |
+
shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
632 |
else:
|
633 |
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
634 |
|
|
|
642 |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
643 |
else:
|
644 |
latents = latents.to(device)
|
645 |
+
|
646 |
# scale the initial noise by the standard deviation required by the scheduler
|
647 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
648 |
+
latents = latents * self.scheduler.init_noise_sigma
|
649 |
return latents
|
650 |
+
|
651 |
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
652 |
if video.size()[2] <= mini_batch_encoder:
|
653 |
return video
|
|
|
663 |
|
664 |
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
665 |
return video
|
666 |
+
|
667 |
def decode_latents(self, latents):
|
668 |
video_length = latents.shape[2]
|
669 |
latents = 1 / self.vae.config.scaling_factor * latents
|
670 |
+
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
671 |
mini_batch_encoder = self.vae.mini_batch_encoder
|
672 |
mini_batch_decoder = self.vae.mini_batch_decoder
|
673 |
video = self.vae.decode(latents)[0]
|
674 |
video = video.clamp(-1, 1)
|
675 |
+
if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
|
676 |
+
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
677 |
else:
|
678 |
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
679 |
video = []
|
|
|
686 |
video = video.cpu().float().numpy()
|
687 |
return video
|
688 |
|
689 |
+
@property
|
690 |
+
def guidance_scale(self):
|
691 |
+
return self._guidance_scale
|
692 |
+
|
693 |
+
@property
|
694 |
+
def guidance_rescale(self):
|
695 |
+
return self._guidance_rescale
|
696 |
+
|
697 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
698 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
699 |
+
# corresponds to doing no classifier free guidance.
|
700 |
+
@property
|
701 |
+
def do_classifier_free_guidance(self):
|
702 |
+
return self._guidance_scale > 1
|
703 |
+
|
704 |
+
@property
|
705 |
+
def num_timesteps(self):
|
706 |
+
return self._num_timesteps
|
707 |
+
|
708 |
+
@property
|
709 |
+
def interrupt(self):
|
710 |
+
return self._interrupt
|
711 |
|
712 |
@torch.no_grad()
|
713 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
|
715 |
self,
|
716 |
prompt: Union[str, List[str]] = None,
|
717 |
video_length: Optional[int] = None,
|
|
|
|
|
|
|
|
|
|
|
718 |
height: Optional[int] = None,
|
719 |
width: Optional[int] = None,
|
720 |
+
num_inference_steps: Optional[int] = 50,
|
721 |
+
guidance_scale: Optional[float] = 5.0,
|
722 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
723 |
+
num_images_per_prompt: Optional[int] = 1,
|
724 |
+
eta: Optional[float] = 0.0,
|
725 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
726 |
+
latents: Optional[torch.Tensor] = None,
|
727 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
728 |
+
prompt_embeds_2: Optional[torch.Tensor] = None,
|
729 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
730 |
+
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
731 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
732 |
+
prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
733 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
734 |
+
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
735 |
output_type: Optional[str] = "latent",
|
736 |
return_dict: bool = True,
|
737 |
+
callback_on_step_end: Optional[
|
738 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
739 |
+
] = None,
|
740 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
741 |
+
guidance_rescale: float = 0.0,
|
742 |
+
original_size: Optional[Tuple[int, int]] = (1024, 1024),
|
743 |
+
target_size: Optional[Tuple[int, int]] = None,
|
744 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
745 |
comfyui_progressbar: bool = False,
|
746 |
+
timesteps: Optional[List[int]] = None,
|
747 |
+
):
|
748 |
+
r"""
|
749 |
+
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
|
751 |
Examples:
|
752 |
+
prompt (`str` or `List[str]`, *optional*):
|
753 |
+
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
|
754 |
+
video_length (`int`, *optional*):
|
755 |
+
Length of the generated video (in frames).
|
756 |
+
height (`int`, *optional*):
|
757 |
+
Height of the generated image in pixels.
|
758 |
+
width (`int`, *optional*):
|
759 |
+
Width of the generated image in pixels.
|
760 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
761 |
+
Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
|
762 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
763 |
+
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
|
764 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
765 |
+
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
|
766 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
767 |
+
Number of images to generate for each prompt.
|
768 |
+
eta (`float`, *optional*, defaults to 0.0):
|
769 |
+
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
|
770 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
771 |
+
A generator to ensure reproducibility in image generation.
|
772 |
+
latents (`torch.Tensor`, *optional*):
|
773 |
+
Predefined latent tensors to condition generation.
|
774 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
775 |
+
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
|
776 |
+
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
777 |
+
Secondary text embeddings to supplement or replace the initial prompt embeddings.
|
778 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
779 |
+
Embeddings for negative prompts. Overrides string inputs if defined.
|
780 |
+
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
781 |
+
Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
|
782 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
783 |
+
Attention mask for the primary prompt embeddings.
|
784 |
+
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
785 |
+
Attention mask for the secondary prompt embeddings.
|
786 |
+
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
787 |
+
Attention mask for negative prompt embeddings.
|
788 |
+
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
789 |
+
Attention mask for secondary negative prompt embeddings.
|
790 |
+
output_type (`str`, *optional*, defaults to "latent"):
|
791 |
+
Format of the generated output, either as a PIL image or as a NumPy array.
|
792 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
793 |
+
If `True`, returns a structured output. Otherwise returns a simple tuple.
|
794 |
+
callback_on_step_end (`Callable`, *optional*):
|
795 |
+
Functions called at the end of each denoising step.
|
796 |
+
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
797 |
+
Tensor names to be included in callback function calls.
|
798 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
799 |
+
Adjusts noise levels based on guidance scale.
|
800 |
+
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
|
801 |
+
Original dimensions of the output.
|
802 |
+
target_size (`Tuple[int, int]`, *optional*):
|
803 |
+
Desired output dimensions for calculations.
|
804 |
+
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
|
805 |
+
Coordinates for cropping.
|
806 |
|
807 |
Returns:
|
808 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
809 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
810 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
811 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
812 |
+
"not-safe-for-work" (nsfw) content.
|
813 |
"""
|
814 |
+
|
815 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
816 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
817 |
+
|
818 |
+
# 0. default height and width
|
819 |
+
height = int((height // 16) * 16)
|
820 |
+
width = int((width // 16) * 16)
|
821 |
+
|
822 |
# 1. Check inputs. Raise error if not correct
|
823 |
+
self.check_inputs(
|
824 |
+
prompt,
|
825 |
+
height,
|
826 |
+
width,
|
827 |
+
negative_prompt,
|
828 |
+
prompt_embeds,
|
829 |
+
negative_prompt_embeds,
|
830 |
+
prompt_attention_mask,
|
831 |
+
negative_prompt_attention_mask,
|
832 |
+
prompt_embeds_2,
|
833 |
+
negative_prompt_embeds_2,
|
834 |
+
prompt_attention_mask_2,
|
835 |
+
negative_prompt_attention_mask_2,
|
836 |
+
callback_on_step_end_tensor_inputs,
|
837 |
+
)
|
838 |
+
self._guidance_scale = guidance_scale
|
839 |
+
self._guidance_rescale = guidance_rescale
|
840 |
+
self._interrupt = False
|
841 |
|
842 |
+
# 2. Define call parameters
|
843 |
if prompt is not None and isinstance(prompt, str):
|
844 |
batch_size = 1
|
845 |
elif prompt is not None and isinstance(prompt, list):
|
|
|
848 |
batch_size = prompt_embeds.shape[0]
|
849 |
|
850 |
device = self._execution_device
|
851 |
+
if self.text_encoder is not None:
|
852 |
+
dtype = self.text_encoder.dtype
|
853 |
+
elif self.text_encoder_2 is not None:
|
854 |
+
dtype = self.text_encoder_2.dtype
|
855 |
+
else:
|
856 |
+
dtype = self.transformer.dtype
|
857 |
|
858 |
# 3. Encode input prompt
|
859 |
(
|
860 |
prompt_embeds,
|
|
|
861 |
negative_prompt_embeds,
|
862 |
+
prompt_attention_mask,
|
863 |
negative_prompt_attention_mask,
|
864 |
) = self.encode_prompt(
|
865 |
+
prompt=prompt,
|
|
|
|
|
|
|
866 |
device=device,
|
867 |
+
dtype=dtype,
|
868 |
+
num_images_per_prompt=num_images_per_prompt,
|
869 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
870 |
+
negative_prompt=negative_prompt,
|
871 |
prompt_embeds=prompt_embeds,
|
872 |
negative_prompt_embeds=negative_prompt_embeds,
|
873 |
prompt_attention_mask=prompt_attention_mask,
|
874 |
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
875 |
+
text_encoder_index=0,
|
|
|
876 |
)
|
877 |
+
if self.tokenizer_2 is not None:
|
878 |
+
(
|
879 |
+
prompt_embeds_2,
|
880 |
+
negative_prompt_embeds_2,
|
881 |
+
prompt_attention_mask_2,
|
882 |
+
negative_prompt_attention_mask_2,
|
883 |
+
) = self.encode_prompt(
|
884 |
+
prompt=prompt,
|
885 |
+
device=device,
|
886 |
+
dtype=dtype,
|
887 |
+
num_images_per_prompt=num_images_per_prompt,
|
888 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
889 |
+
negative_prompt=negative_prompt,
|
890 |
+
prompt_embeds=prompt_embeds_2,
|
891 |
+
negative_prompt_embeds=negative_prompt_embeds_2,
|
892 |
+
prompt_attention_mask=prompt_attention_mask_2,
|
893 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
|
894 |
+
text_encoder_index=1,
|
895 |
+
)
|
896 |
+
else:
|
897 |
+
prompt_embeds_2 = None
|
898 |
+
negative_prompt_embeds_2 = None
|
899 |
+
prompt_attention_mask_2 = None
|
900 |
+
negative_prompt_attention_mask_2 = None
|
901 |
|
902 |
# 4. Prepare timesteps
|
903 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
904 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
905 |
+
else:
|
906 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
907 |
+
if comfyui_progressbar:
|
908 |
+
from comfy.utils import ProgressBar
|
909 |
+
pbar = ProgressBar(num_inference_steps + 1)
|
910 |
|
911 |
+
# 5. Prepare latent variables
|
912 |
+
num_channels_latents = self.transformer.config.in_channels
|
913 |
latents = self.prepare_latents(
|
914 |
batch_size * num_images_per_prompt,
|
915 |
+
num_channels_latents,
|
916 |
video_length,
|
917 |
height,
|
918 |
width,
|
919 |
+
dtype,
|
920 |
device,
|
921 |
generator,
|
922 |
latents,
|
923 |
)
|
924 |
+
if comfyui_progressbar:
|
925 |
+
pbar.update(1)
|
926 |
|
927 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
928 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
929 |
|
930 |
+
# 7 create image_rotary_emb, style embedding & time ids
|
931 |
+
grid_height = height // 8 // self.transformer.config.patch_size
|
932 |
+
grid_width = width // 8 // self.transformer.config.patch_size
|
933 |
+
if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
|
934 |
+
base_size_width = 720 // 8 // self.transformer.config.patch_size
|
935 |
+
base_size_height = 480 // 8 // self.transformer.config.patch_size
|
936 |
+
|
937 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
938 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
939 |
+
)
|
940 |
+
image_rotary_emb = get_3d_rotary_pos_embed(
|
941 |
+
self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
|
942 |
+
temporal_size=latents.size(2), use_real=True,
|
943 |
+
)
|
944 |
+
else:
|
945 |
+
base_size = 512 // 8 // self.transformer.config.patch_size
|
946 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
947 |
+
(grid_height, grid_width), base_size, base_size
|
948 |
+
)
|
949 |
+
image_rotary_emb = get_2d_rotary_pos_embed(
|
950 |
+
self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
|
951 |
+
)
|
952 |
+
|
953 |
+
# Get other hunyuan params
|
954 |
+
target_size = target_size or (height, width)
|
955 |
+
add_time_ids = list(original_size + target_size + crops_coords_top_left)
|
956 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
957 |
+
style = torch.tensor([0], device=device)
|
958 |
+
|
959 |
+
if self.do_classifier_free_guidance:
|
960 |
+
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
|
961 |
+
style = torch.cat([style] * 2, dim=0)
|
962 |
+
|
963 |
+
# To latents.device
|
964 |
+
add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat(
|
965 |
+
batch_size * num_images_per_prompt, 1
|
966 |
+
)
|
967 |
+
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
|
968 |
+
|
969 |
+
# Get other pixart params
|
970 |
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
971 |
+
if self.transformer.config.get("sample_size", 64) == 128:
|
972 |
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
973 |
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
974 |
+
resolution = resolution.to(dtype=dtype, device=device)
|
975 |
+
aspect_ratio = aspect_ratio.to(dtype=dtype, device=device)
|
976 |
|
977 |
+
if self.do_classifier_free_guidance:
|
978 |
resolution = torch.cat([resolution, resolution], dim=0)
|
979 |
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
980 |
|
981 |
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
982 |
|
983 |
+
if self.do_classifier_free_guidance:
|
984 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
985 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
986 |
+
if prompt_embeds_2 is not None:
|
987 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
988 |
+
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
|
989 |
+
|
990 |
+
# To latents.device
|
991 |
+
prompt_embeds = prompt_embeds.to(device=device)
|
992 |
+
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
993 |
+
if prompt_embeds_2 is not None:
|
994 |
+
prompt_embeds_2 = prompt_embeds_2.to(device=device)
|
995 |
+
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
|
996 |
+
|
997 |
+
# 8. Denoising loop
|
998 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
999 |
+
self._num_timesteps = len(timesteps)
|
1000 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1001 |
for i, t in enumerate(timesteps):
|
1002 |
+
if self.interrupt:
|
1003 |
+
continue
|
1004 |
+
|
1005 |
+
# expand the latents if we are doing classifier free guidance
|
1006 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1007 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
1008 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1009 |
+
|
1010 |
+
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
1011 |
+
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
1012 |
+
dtype=latent_model_input.dtype
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
# predict the noise residual
|
|
|
|
|
|
|
|
|
|
|
1016 |
noise_pred = self.transformer(
|
1017 |
latent_model_input,
|
1018 |
+
t_expand,
|
1019 |
encoder_hidden_states=prompt_embeds,
|
1020 |
+
text_embedding_mask=prompt_attention_mask,
|
1021 |
+
encoder_hidden_states_t5=prompt_embeds_2,
|
1022 |
+
text_embedding_mask_t5=prompt_attention_mask_2,
|
1023 |
+
image_meta_size=add_time_ids,
|
1024 |
+
style=style,
|
1025 |
+
image_rotary_emb=image_rotary_emb,
|
1026 |
added_cond_kwargs=added_cond_kwargs,
|
1027 |
return_dict=False,
|
1028 |
)[0]
|
1029 |
+
|
1030 |
+
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
1031 |
+
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
1032 |
|
1033 |
# perform guidance
|
1034 |
+
if self.do_classifier_free_guidance:
|
1035 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1036 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1037 |
|
1038 |
+
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
1039 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1040 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
|
|
|
|
1041 |
|
1042 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1043 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1044 |
|
1045 |
+
if callback_on_step_end is not None:
|
1046 |
+
callback_kwargs = {}
|
1047 |
+
for k in callback_on_step_end_tensor_inputs:
|
1048 |
+
callback_kwargs[k] = locals()[k]
|
1049 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1050 |
+
|
1051 |
+
latents = callback_outputs.pop("latents", latents)
|
1052 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1053 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1054 |
+
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
|
1055 |
+
negative_prompt_embeds_2 = callback_outputs.pop(
|
1056 |
+
"negative_prompt_embeds_2", negative_prompt_embeds_2
|
1057 |
+
)
|
1058 |
+
|
1059 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1060 |
progress_bar.update()
|
1061 |
+
|
1062 |
+
if XLA_AVAILABLE:
|
1063 |
+
xm.mark_step()
|
1064 |
|
1065 |
if comfyui_progressbar:
|
1066 |
pbar.update(1)
|
1067 |
|
|
|
|
|
|
|
1068 |
# Post-processing
|
1069 |
video = self.decode_latents(latents)
|
1070 |
|
|
|
1072 |
if output_type == "latent":
|
1073 |
video = torch.from_numpy(video)
|
1074 |
|
1075 |
+
# Offload all models
|
1076 |
+
self.maybe_free_model_hooks()
|
1077 |
+
|
1078 |
if not return_dict:
|
1079 |
return video
|
1080 |
|
1081 |
+
return EasyAnimatePipelineOutput(frames=video)
|
easyanimate/pipeline/{pipeline_easyanimate_multi_text_encoder_control.py → pipeline_easyanimate_control.py}
RENAMED
@@ -31,7 +31,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
31 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
32 |
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
33 |
StableDiffusionSafetyChecker
|
34 |
-
from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler
|
|
|
35 |
from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
|
36 |
is_bs4_available, is_ftfy_available,
|
37 |
is_torch_xla_available, logging,
|
@@ -41,11 +42,12 @@ from einops import rearrange
|
|
41 |
from PIL import Image
|
42 |
from tqdm import tqdm
|
43 |
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
44 |
-
CLIPVisionModelWithProjection,
|
45 |
-
T5EncoderModel,
|
|
|
46 |
|
47 |
from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
48 |
-
from .
|
49 |
|
50 |
if is_torch_xla_available():
|
51 |
import torch_xla.core.xla_model as xm
|
@@ -64,6 +66,7 @@ EXAMPLE_DOC_STRING = """
|
|
64 |
```
|
65 |
"""
|
66 |
|
|
|
67 |
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
68 |
tw = tgt_width
|
69 |
th = tgt_height
|
@@ -97,44 +100,140 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
97 |
return noise_cfg
|
98 |
|
99 |
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
r"""
|
102 |
Pipeline for text-to-video generation using EasyAnimate.
|
103 |
|
104 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
105 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
106 |
|
|
|
107 |
EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
|
108 |
-
HunyuanDiT team)
|
109 |
|
110 |
Args:
|
111 |
vae ([`AutoencoderKLMagvit`]):
|
112 |
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
113 |
-
text_encoder (Optional[`~transformers.
|
114 |
-
|
115 |
-
EasyAnimate uses
|
116 |
-
tokenizer (Optional[`~transformers.
|
117 |
-
A `
|
118 |
transformer ([`EasyAnimateTransformer3DModel`]):
|
119 |
-
The EasyAnimate model designed by
|
120 |
text_encoder_2 (`T5EncoderModel`):
|
121 |
-
|
|
|
122 |
tokenizer_2 (`T5Tokenizer`):
|
123 |
The tokenizer for the mT5 embedder.
|
124 |
-
scheduler ([`
|
125 |
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
126 |
"""
|
127 |
|
128 |
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
129 |
_optional_components = [
|
130 |
-
"safety_checker",
|
131 |
-
"feature_extractor",
|
132 |
"text_encoder_2",
|
133 |
"tokenizer_2",
|
134 |
"text_encoder",
|
135 |
"tokenizer",
|
136 |
]
|
137 |
-
_exclude_from_cpu_offload = ["safety_checker"]
|
138 |
_callback_tensor_inputs = [
|
139 |
"latents",
|
140 |
"prompt_embeds",
|
@@ -146,53 +245,30 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
146 |
def __init__(
|
147 |
self,
|
148 |
vae: AutoencoderKLMagvit,
|
149 |
-
text_encoder: BertModel,
|
150 |
-
tokenizer: BertTokenizer,
|
151 |
-
text_encoder_2: T5EncoderModel,
|
152 |
-
tokenizer_2: T5Tokenizer,
|
153 |
transformer: EasyAnimateTransformer3DModel,
|
154 |
-
scheduler:
|
155 |
-
safety_checker: StableDiffusionSafetyChecker,
|
156 |
-
feature_extractor: CLIPImageProcessor,
|
157 |
-
requires_safety_checker: bool = True
|
158 |
):
|
159 |
super().__init__()
|
160 |
|
161 |
self.register_modules(
|
162 |
vae=vae,
|
163 |
text_encoder=text_encoder,
|
|
|
164 |
tokenizer=tokenizer,
|
165 |
tokenizer_2=tokenizer_2,
|
166 |
transformer=transformer,
|
167 |
scheduler=scheduler,
|
168 |
-
safety_checker=safety_checker,
|
169 |
-
feature_extractor=feature_extractor,
|
170 |
-
text_encoder_2=text_encoder_2
|
171 |
)
|
172 |
|
173 |
-
if safety_checker is None and requires_safety_checker:
|
174 |
-
logger.warning(
|
175 |
-
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
176 |
-
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
177 |
-
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
178 |
-
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
179 |
-
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
180 |
-
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
181 |
-
)
|
182 |
-
|
183 |
-
if safety_checker is not None and feature_extractor is None:
|
184 |
-
raise ValueError(
|
185 |
-
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
186 |
-
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
187 |
-
)
|
188 |
-
|
189 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
190 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
191 |
self.mask_processor = VaeImageProcessor(
|
192 |
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
193 |
)
|
194 |
-
self.enable_autocast_float8_transformer_flag = False
|
195 |
-
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
196 |
|
197 |
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
198 |
super().enable_sequential_cpu_offload(*args, **kwargs)
|
@@ -272,19 +348,9 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
272 |
batch_size = prompt_embeds.shape[0]
|
273 |
|
274 |
if prompt_embeds is None:
|
275 |
-
|
276 |
-
prompt,
|
277 |
-
padding="max_length",
|
278 |
-
max_length=max_length,
|
279 |
-
truncation=True,
|
280 |
-
return_attention_mask=True,
|
281 |
-
return_tensors="pt",
|
282 |
-
)
|
283 |
-
text_input_ids = text_inputs.input_ids
|
284 |
-
if text_input_ids.shape[-1] > actual_max_sequence_length:
|
285 |
-
reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
286 |
text_inputs = tokenizer(
|
287 |
-
|
288 |
padding="max_length",
|
289 |
max_length=max_length,
|
290 |
truncation=True,
|
@@ -292,91 +358,188 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
292 |
return_tensors="pt",
|
293 |
)
|
294 |
text_input_ids = text_inputs.input_ids
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
text_input_ids
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
else:
|
313 |
-
|
314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
)
|
316 |
-
prompt_embeds = prompt_embeds[0]
|
317 |
-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
320 |
|
321 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
322 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
323 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
324 |
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
325 |
|
326 |
# get unconditional embeddings for classifier free guidance
|
327 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
uncond_tokens,
|
350 |
-
padding="max_length",
|
351 |
-
max_length=max_length,
|
352 |
-
truncation=True,
|
353 |
-
return_tensors="pt",
|
354 |
-
)
|
355 |
-
uncond_input_ids = uncond_input.input_ids
|
356 |
-
if uncond_input_ids.shape[-1] > actual_max_sequence_length:
|
357 |
-
reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
358 |
uncond_input = tokenizer(
|
359 |
-
|
360 |
padding="max_length",
|
361 |
max_length=max_length,
|
362 |
truncation=True,
|
363 |
-
return_attention_mask=True,
|
364 |
return_tensors="pt",
|
365 |
)
|
366 |
uncond_input_ids = uncond_input.input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
else:
|
375 |
-
|
376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
)
|
378 |
-
|
379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
381 |
if do_classifier_free_guidance:
|
382 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
@@ -386,24 +549,10 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
386 |
|
387 |
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
388 |
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
389 |
-
|
|
|
390 |
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
391 |
|
392 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
393 |
-
def run_safety_checker(self, image, device, dtype):
|
394 |
-
if self.safety_checker is None:
|
395 |
-
has_nsfw_concept = None
|
396 |
-
else:
|
397 |
-
if torch.is_tensor(image):
|
398 |
-
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
399 |
-
else:
|
400 |
-
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
401 |
-
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
402 |
-
image, has_nsfw_concept = self.safety_checker(
|
403 |
-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
404 |
-
)
|
405 |
-
return image, has_nsfw_concept
|
406 |
-
|
407 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
408 |
def prepare_extra_step_kwargs(self, generator, eta):
|
409 |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
@@ -438,8 +587,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
438 |
negative_prompt_attention_mask_2=None,
|
439 |
callback_on_step_end_tensor_inputs=None,
|
440 |
):
|
441 |
-
if height %
|
442 |
-
raise ValueError(f"`height` and `width` have to be divisible by
|
443 |
|
444 |
if callback_on_step_end_tensor_inputs is not None and not all(
|
445 |
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
@@ -524,43 +673,44 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
524 |
latents = latents.to(device)
|
525 |
|
526 |
# scale the initial noise by the standard deviation required by the scheduler
|
527 |
-
|
|
|
528 |
return latents
|
529 |
|
530 |
def prepare_control_latents(
|
531 |
-
self,
|
532 |
):
|
533 |
-
# resize the
|
534 |
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
535 |
# and half precision
|
536 |
|
537 |
-
if
|
538 |
-
|
539 |
bs = 1
|
540 |
-
|
541 |
-
for i in range(0,
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
if
|
550 |
-
|
551 |
bs = 1
|
552 |
-
|
553 |
-
for i in range(0,
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
else:
|
561 |
-
|
562 |
|
563 |
-
return
|
564 |
|
565 |
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
566 |
if video.size()[2] <= mini_batch_encoder:
|
@@ -623,9 +773,6 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
623 |
def interrupt(self):
|
624 |
return self._interrupt
|
625 |
|
626 |
-
def enable_autocast_float8_transformer(self):
|
627 |
-
self.enable_autocast_float8_transformer_flag = True
|
628 |
-
|
629 |
@torch.no_grad()
|
630 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
631 |
def __call__(
|
@@ -635,6 +782,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
635 |
height: Optional[int] = None,
|
636 |
width: Optional[int] = None,
|
637 |
control_video: Union[torch.FloatTensor] = None,
|
|
|
|
|
638 |
num_inference_steps: Optional[int] = 50,
|
639 |
guidance_scale: Optional[float] = 5.0,
|
640 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
@@ -661,6 +810,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
661 |
target_size: Optional[Tuple[int, int]] = None,
|
662 |
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
663 |
comfyui_progressbar: bool = False,
|
|
|
664 |
):
|
665 |
r"""
|
666 |
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
@@ -765,6 +915,12 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
765 |
batch_size = prompt_embeds.shape[0]
|
766 |
|
767 |
device = self._execution_device
|
|
|
|
|
|
|
|
|
|
|
|
|
768 |
|
769 |
# 3. Encode input prompt
|
770 |
(
|
@@ -775,7 +931,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
775 |
) = self.encode_prompt(
|
776 |
prompt=prompt,
|
777 |
device=device,
|
778 |
-
dtype=
|
779 |
num_images_per_prompt=num_images_per_prompt,
|
780 |
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
781 |
negative_prompt=negative_prompt,
|
@@ -785,28 +941,36 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
785 |
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
786 |
text_encoder_index=0,
|
787 |
)
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
|
|
|
|
|
|
|
|
|
|
807 |
|
808 |
# 4. Prepare timesteps
|
809 |
-
self.scheduler
|
|
|
|
|
|
|
810 |
timesteps = self.scheduler.timesteps
|
811 |
if comfyui_progressbar:
|
812 |
from comfy.utils import ProgressBar
|
@@ -820,7 +984,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
820 |
video_length,
|
821 |
height,
|
822 |
width,
|
823 |
-
|
824 |
device,
|
825 |
generator,
|
826 |
latents,
|
@@ -828,27 +992,69 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
828 |
if comfyui_progressbar:
|
829 |
pbar.update(1)
|
830 |
|
831 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
832 |
video_length = control_video.shape[2]
|
833 |
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
834 |
control_video = control_video.to(dtype=torch.float32)
|
835 |
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
else:
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
852 |
|
853 |
if comfyui_progressbar:
|
854 |
pbar.update(1)
|
@@ -880,34 +1086,49 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
880 |
)
|
881 |
|
882 |
# Get other hunyuan params
|
883 |
-
style = torch.tensor([0], device=device)
|
884 |
-
|
885 |
target_size = target_size or (height, width)
|
886 |
add_time_ids = list(original_size + target_size + crops_coords_top_left)
|
887 |
-
add_time_ids = torch.tensor([add_time_ids], dtype=
|
|
|
888 |
|
889 |
if self.do_classifier_free_guidance:
|
890 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
891 |
-
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
892 |
-
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
893 |
-
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
|
894 |
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
|
895 |
style = torch.cat([style] * 2, dim=0)
|
896 |
|
897 |
# To latents.device
|
898 |
-
|
899 |
-
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
900 |
-
prompt_embeds_2 = prompt_embeds_2.to(device=device)
|
901 |
-
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
|
902 |
-
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
|
903 |
batch_size * num_images_per_prompt, 1
|
904 |
)
|
905 |
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
|
906 |
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
911 |
# 8. Denoising loop
|
912 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
913 |
self._num_timesteps = len(timesteps)
|
@@ -918,7 +1139,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
918 |
|
919 |
# expand the latents if we are doing classifier free guidance
|
920 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
921 |
-
|
|
|
922 |
|
923 |
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
924 |
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
@@ -935,8 +1157,9 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
935 |
image_meta_size=add_time_ids,
|
936 |
style=style,
|
937 |
image_rotary_emb=image_rotary_emb,
|
938 |
-
|
939 |
control_latents=control_latents,
|
|
|
940 |
)[0]
|
941 |
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
942 |
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
@@ -976,10 +1199,6 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
976 |
if comfyui_progressbar:
|
977 |
pbar.update(1)
|
978 |
|
979 |
-
if self.enable_autocast_float8_transformer_flag:
|
980 |
-
self.transformer = self.transformer.to("cpu", origin_weight_dtype)
|
981 |
-
|
982 |
-
torch.cuda.empty_cache()
|
983 |
# Post-processing
|
984 |
video = self.decode_latents(latents)
|
985 |
|
@@ -993,4 +1212,4 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
|
|
993 |
if not return_dict:
|
994 |
return video
|
995 |
|
996 |
-
return EasyAnimatePipelineOutput(
|
|
|
31 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
32 |
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
33 |
StableDiffusionSafetyChecker
|
34 |
+
from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
|
35 |
+
FlowMatchEulerDiscreteScheduler)
|
36 |
from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
|
37 |
is_bs4_available, is_ftfy_available,
|
38 |
is_torch_xla_available, logging,
|
|
|
42 |
from PIL import Image
|
43 |
from tqdm import tqdm
|
44 |
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
45 |
+
CLIPVisionModelWithProjection, Qwen2Tokenizer,
|
46 |
+
Qwen2VLForConditionalGeneration, T5EncoderModel,
|
47 |
+
T5Tokenizer)
|
48 |
|
49 |
from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
50 |
+
from .pipeline_easyanimate_inpaint import EasyAnimatePipelineOutput
|
51 |
|
52 |
if is_torch_xla_available():
|
53 |
import torch_xla.core.xla_model as xm
|
|
|
66 |
```
|
67 |
"""
|
68 |
|
69 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
70 |
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
71 |
tw = tgt_width
|
72 |
th = tgt_height
|
|
|
100 |
return noise_cfg
|
101 |
|
102 |
|
103 |
+
# Resize mask information in magvit
|
104 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
105 |
+
latent_size = latent.size()
|
106 |
+
|
107 |
+
if process_first_frame_only:
|
108 |
+
target_size = list(latent_size[2:])
|
109 |
+
target_size[0] = 1
|
110 |
+
first_frame_resized = F.interpolate(
|
111 |
+
mask[:, :, 0:1, :, :],
|
112 |
+
size=target_size,
|
113 |
+
mode='trilinear',
|
114 |
+
align_corners=False
|
115 |
+
)
|
116 |
+
|
117 |
+
target_size = list(latent_size[2:])
|
118 |
+
target_size[0] = target_size[0] - 1
|
119 |
+
if target_size[0] != 0:
|
120 |
+
remaining_frames_resized = F.interpolate(
|
121 |
+
mask[:, :, 1:, :, :],
|
122 |
+
size=target_size,
|
123 |
+
mode='trilinear',
|
124 |
+
align_corners=False
|
125 |
+
)
|
126 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
127 |
+
else:
|
128 |
+
resized_mask = first_frame_resized
|
129 |
+
else:
|
130 |
+
target_size = list(latent_size[2:])
|
131 |
+
resized_mask = F.interpolate(
|
132 |
+
mask,
|
133 |
+
size=target_size,
|
134 |
+
mode='trilinear',
|
135 |
+
align_corners=False
|
136 |
+
)
|
137 |
+
return resized_mask
|
138 |
+
|
139 |
+
|
140 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
141 |
+
def retrieve_timesteps(
|
142 |
+
scheduler,
|
143 |
+
num_inference_steps: Optional[int] = None,
|
144 |
+
device: Optional[Union[str, torch.device]] = None,
|
145 |
+
timesteps: Optional[List[int]] = None,
|
146 |
+
sigmas: Optional[List[float]] = None,
|
147 |
+
**kwargs,
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
151 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
scheduler (`SchedulerMixin`):
|
155 |
+
The scheduler to get timesteps from.
|
156 |
+
num_inference_steps (`int`):
|
157 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
158 |
+
must be `None`.
|
159 |
+
device (`str` or `torch.device`, *optional*):
|
160 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
161 |
+
timesteps (`List[int]`, *optional*):
|
162 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
163 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
164 |
+
sigmas (`List[float]`, *optional*):
|
165 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
166 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
170 |
+
second element is the number of inference steps.
|
171 |
+
"""
|
172 |
+
if timesteps is not None and sigmas is not None:
|
173 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
174 |
+
if timesteps is not None:
|
175 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
176 |
+
if not accepts_timesteps:
|
177 |
+
raise ValueError(
|
178 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
179 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
180 |
+
)
|
181 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
182 |
+
timesteps = scheduler.timesteps
|
183 |
+
num_inference_steps = len(timesteps)
|
184 |
+
elif sigmas is not None:
|
185 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
186 |
+
if not accept_sigmas:
|
187 |
+
raise ValueError(
|
188 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
189 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
190 |
+
)
|
191 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
192 |
+
timesteps = scheduler.timesteps
|
193 |
+
num_inference_steps = len(timesteps)
|
194 |
+
else:
|
195 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
196 |
+
timesteps = scheduler.timesteps
|
197 |
+
return timesteps, num_inference_steps
|
198 |
+
|
199 |
+
|
200 |
+
class EasyAnimateControlPipeline(DiffusionPipeline):
|
201 |
r"""
|
202 |
Pipeline for text-to-video generation using EasyAnimate.
|
203 |
|
204 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
205 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
206 |
|
207 |
+
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
208 |
EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
|
209 |
+
HunyuanDiT team) in V5.
|
210 |
|
211 |
Args:
|
212 |
vae ([`AutoencoderKLMagvit`]):
|
213 |
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
214 |
+
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
|
215 |
+
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
216 |
+
EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5.
|
217 |
+
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
|
218 |
+
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
|
219 |
transformer ([`EasyAnimateTransformer3DModel`]):
|
220 |
+
The EasyAnimate model designed by EasyAnimate Team.
|
221 |
text_encoder_2 (`T5EncoderModel`):
|
222 |
+
EasyAnimate does not use text_encoder_2 in V5.1.
|
223 |
+
EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5.
|
224 |
tokenizer_2 (`T5Tokenizer`):
|
225 |
The tokenizer for the mT5 embedder.
|
226 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
227 |
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
228 |
"""
|
229 |
|
230 |
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
231 |
_optional_components = [
|
|
|
|
|
232 |
"text_encoder_2",
|
233 |
"tokenizer_2",
|
234 |
"text_encoder",
|
235 |
"tokenizer",
|
236 |
]
|
|
|
237 |
_callback_tensor_inputs = [
|
238 |
"latents",
|
239 |
"prompt_embeds",
|
|
|
245 |
def __init__(
|
246 |
self,
|
247 |
vae: AutoencoderKLMagvit,
|
248 |
+
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
|
249 |
+
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
|
250 |
+
text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]],
|
251 |
+
tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]],
|
252 |
transformer: EasyAnimateTransformer3DModel,
|
253 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
|
|
|
|
|
|
254 |
):
|
255 |
super().__init__()
|
256 |
|
257 |
self.register_modules(
|
258 |
vae=vae,
|
259 |
text_encoder=text_encoder,
|
260 |
+
text_encoder_2=text_encoder_2,
|
261 |
tokenizer=tokenizer,
|
262 |
tokenizer_2=tokenizer_2,
|
263 |
transformer=transformer,
|
264 |
scheduler=scheduler,
|
|
|
|
|
|
|
265 |
)
|
266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
268 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
269 |
self.mask_processor = VaeImageProcessor(
|
270 |
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
271 |
)
|
|
|
|
|
272 |
|
273 |
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
274 |
super().enable_sequential_cpu_offload(*args, **kwargs)
|
|
|
348 |
batch_size = prompt_embeds.shape[0]
|
349 |
|
350 |
if prompt_embeds is None:
|
351 |
+
if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
text_inputs = tokenizer(
|
353 |
+
prompt,
|
354 |
padding="max_length",
|
355 |
max_length=max_length,
|
356 |
truncation=True,
|
|
|
358 |
return_tensors="pt",
|
359 |
)
|
360 |
text_input_ids = text_inputs.input_ids
|
361 |
+
if text_input_ids.shape[-1] > actual_max_sequence_length:
|
362 |
+
reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
363 |
+
text_inputs = tokenizer(
|
364 |
+
reprompt,
|
365 |
+
padding="max_length",
|
366 |
+
max_length=max_length,
|
367 |
+
truncation=True,
|
368 |
+
return_attention_mask=True,
|
369 |
+
return_tensors="pt",
|
370 |
+
)
|
371 |
+
text_input_ids = text_inputs.input_ids
|
372 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
373 |
+
|
374 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
375 |
+
text_input_ids, untruncated_ids
|
376 |
+
):
|
377 |
+
_actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
|
378 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
|
379 |
+
logger.warning(
|
380 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
381 |
+
f" {_actual_max_sequence_length} tokens: {removed_text}"
|
382 |
+
)
|
383 |
+
|
384 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
385 |
+
|
386 |
+
if self.transformer.config.enable_text_attention_mask:
|
387 |
+
prompt_embeds = text_encoder(
|
388 |
+
text_input_ids.to(device),
|
389 |
+
attention_mask=prompt_attention_mask,
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
prompt_embeds = text_encoder(
|
393 |
+
text_input_ids.to(device)
|
394 |
+
)
|
395 |
+
prompt_embeds = prompt_embeds[0]
|
396 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
397 |
else:
|
398 |
+
if prompt is not None and isinstance(prompt, str):
|
399 |
+
messages = [
|
400 |
+
{
|
401 |
+
"role": "user",
|
402 |
+
"content": [{"type": "text", "text": prompt}],
|
403 |
+
}
|
404 |
+
]
|
405 |
+
else:
|
406 |
+
messages = [
|
407 |
+
{
|
408 |
+
"role": "user",
|
409 |
+
"content": [{"type": "text", "text": _prompt}],
|
410 |
+
} for _prompt in prompt
|
411 |
+
]
|
412 |
+
text = tokenizer.apply_chat_template(
|
413 |
+
messages, tokenize=False, add_generation_prompt=True
|
414 |
)
|
|
|
|
|
415 |
|
416 |
+
text_inputs = tokenizer(
|
417 |
+
text=[text],
|
418 |
+
padding="max_length",
|
419 |
+
max_length=max_length,
|
420 |
+
truncation=True,
|
421 |
+
return_attention_mask=True,
|
422 |
+
padding_side="right",
|
423 |
+
return_tensors="pt",
|
424 |
+
)
|
425 |
+
text_inputs = text_inputs.to(text_encoder.device)
|
426 |
+
|
427 |
+
text_input_ids = text_inputs.input_ids
|
428 |
+
prompt_attention_mask = text_inputs.attention_mask
|
429 |
+
if self.transformer.config.enable_text_attention_mask:
|
430 |
+
# Inference: Generation of the output
|
431 |
+
prompt_embeds = text_encoder(
|
432 |
+
input_ids=text_input_ids,
|
433 |
+
attention_mask=prompt_attention_mask,
|
434 |
+
output_hidden_states=True).hidden_states[-2]
|
435 |
+
else:
|
436 |
+
raise ValueError("LLM needs attention_mask")
|
437 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
438 |
+
|
439 |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
440 |
|
441 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
442 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
443 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
444 |
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
445 |
+
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
446 |
|
447 |
# get unconditional embeddings for classifier free guidance
|
448 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
449 |
+
if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
|
450 |
+
uncond_tokens: List[str]
|
451 |
+
if negative_prompt is None:
|
452 |
+
uncond_tokens = [""] * batch_size
|
453 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
454 |
+
raise TypeError(
|
455 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
456 |
+
f" {type(prompt)}."
|
457 |
+
)
|
458 |
+
elif isinstance(negative_prompt, str):
|
459 |
+
uncond_tokens = [negative_prompt]
|
460 |
+
elif batch_size != len(negative_prompt):
|
461 |
+
raise ValueError(
|
462 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
463 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
464 |
+
" the batch size of `prompt`."
|
465 |
+
)
|
466 |
+
else:
|
467 |
+
uncond_tokens = negative_prompt
|
468 |
+
|
469 |
+
max_length = prompt_embeds.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
uncond_input = tokenizer(
|
471 |
+
uncond_tokens,
|
472 |
padding="max_length",
|
473 |
max_length=max_length,
|
474 |
truncation=True,
|
|
|
475 |
return_tensors="pt",
|
476 |
)
|
477 |
uncond_input_ids = uncond_input.input_ids
|
478 |
+
if uncond_input_ids.shape[-1] > actual_max_sequence_length:
|
479 |
+
reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
480 |
+
uncond_input = tokenizer(
|
481 |
+
reuncond_tokens,
|
482 |
+
padding="max_length",
|
483 |
+
max_length=max_length,
|
484 |
+
truncation=True,
|
485 |
+
return_attention_mask=True,
|
486 |
+
return_tensors="pt",
|
487 |
+
)
|
488 |
+
uncond_input_ids = uncond_input.input_ids
|
489 |
|
490 |
+
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
|
491 |
+
if self.transformer.config.enable_text_attention_mask:
|
492 |
+
negative_prompt_embeds = text_encoder(
|
493 |
+
uncond_input.input_ids.to(device),
|
494 |
+
attention_mask=negative_prompt_attention_mask,
|
495 |
+
)
|
496 |
+
else:
|
497 |
+
negative_prompt_embeds = text_encoder(
|
498 |
+
uncond_input.input_ids.to(device)
|
499 |
+
)
|
500 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
501 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
502 |
else:
|
503 |
+
if negative_prompt is not None and isinstance(negative_prompt, str):
|
504 |
+
messages = [
|
505 |
+
{
|
506 |
+
"role": "user",
|
507 |
+
"content": [{"type": "text", "text": negative_prompt}],
|
508 |
+
}
|
509 |
+
]
|
510 |
+
else:
|
511 |
+
messages = [
|
512 |
+
{
|
513 |
+
"role": "user",
|
514 |
+
"content": [{"type": "text", "text": _negative_prompt}],
|
515 |
+
} for _negative_prompt in negative_prompt
|
516 |
+
]
|
517 |
+
text = tokenizer.apply_chat_template(
|
518 |
+
messages, tokenize=False, add_generation_prompt=True
|
519 |
)
|
520 |
+
|
521 |
+
text_inputs = tokenizer(
|
522 |
+
text=[text],
|
523 |
+
padding="max_length",
|
524 |
+
max_length=max_length,
|
525 |
+
truncation=True,
|
526 |
+
return_attention_mask=True,
|
527 |
+
padding_side="right",
|
528 |
+
return_tensors="pt",
|
529 |
+
)
|
530 |
+
text_inputs = text_inputs.to(text_encoder.device)
|
531 |
+
|
532 |
+
text_input_ids = text_inputs.input_ids
|
533 |
+
negative_prompt_attention_mask = text_inputs.attention_mask
|
534 |
+
if self.transformer.config.enable_text_attention_mask:
|
535 |
+
# Inference: Generation of the output
|
536 |
+
negative_prompt_embeds = text_encoder(
|
537 |
+
input_ids=text_input_ids,
|
538 |
+
attention_mask=negative_prompt_attention_mask,
|
539 |
+
output_hidden_states=True).hidden_states[-2]
|
540 |
+
else:
|
541 |
+
raise ValueError("LLM needs attention_mask")
|
542 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
543 |
|
544 |
if do_classifier_free_guidance:
|
545 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
|
549 |
|
550 |
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
551 |
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
552 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
|
553 |
+
|
554 |
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
557 |
def prepare_extra_step_kwargs(self, generator, eta):
|
558 |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
|
|
587 |
negative_prompt_attention_mask_2=None,
|
588 |
callback_on_step_end_tensor_inputs=None,
|
589 |
):
|
590 |
+
if height % 16 != 0 or width % 16 != 0:
|
591 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
592 |
|
593 |
if callback_on_step_end_tensor_inputs is not None and not all(
|
594 |
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
|
|
673 |
latents = latents.to(device)
|
674 |
|
675 |
# scale the initial noise by the standard deviation required by the scheduler
|
676 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
677 |
+
latents = latents * self.scheduler.init_noise_sigma
|
678 |
return latents
|
679 |
|
680 |
def prepare_control_latents(
|
681 |
+
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
682 |
):
|
683 |
+
# resize the control to latents shape as we concatenate the control to the latents
|
684 |
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
685 |
# and half precision
|
686 |
|
687 |
+
if control is not None:
|
688 |
+
control = control.to(device=device, dtype=dtype)
|
689 |
bs = 1
|
690 |
+
new_control = []
|
691 |
+
for i in range(0, control.shape[0], bs):
|
692 |
+
control_bs = control[i : i + bs]
|
693 |
+
control_bs = self.vae.encode(control_bs)[0]
|
694 |
+
control_bs = control_bs.mode()
|
695 |
+
new_control.append(control_bs)
|
696 |
+
control = torch.cat(new_control, dim = 0)
|
697 |
+
control = control * self.vae.config.scaling_factor
|
698 |
+
|
699 |
+
if control_image is not None:
|
700 |
+
control_image = control_image.to(device=device, dtype=dtype)
|
701 |
bs = 1
|
702 |
+
new_control_pixel_values = []
|
703 |
+
for i in range(0, control_image.shape[0], bs):
|
704 |
+
control_pixel_values_bs = control_image[i : i + bs]
|
705 |
+
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
706 |
+
control_pixel_values_bs = control_pixel_values_bs.mode()
|
707 |
+
new_control_pixel_values.append(control_pixel_values_bs)
|
708 |
+
control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
|
709 |
+
control_image_latents = control_image_latents * self.vae.config.scaling_factor
|
710 |
else:
|
711 |
+
control_image_latents = None
|
712 |
|
713 |
+
return control, control_image_latents
|
714 |
|
715 |
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
716 |
if video.size()[2] <= mini_batch_encoder:
|
|
|
773 |
def interrupt(self):
|
774 |
return self._interrupt
|
775 |
|
|
|
|
|
|
|
776 |
@torch.no_grad()
|
777 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
778 |
def __call__(
|
|
|
782 |
height: Optional[int] = None,
|
783 |
width: Optional[int] = None,
|
784 |
control_video: Union[torch.FloatTensor] = None,
|
785 |
+
control_camera_video: Union[torch.FloatTensor] = None,
|
786 |
+
ref_image: Union[torch.FloatTensor] = None,
|
787 |
num_inference_steps: Optional[int] = 50,
|
788 |
guidance_scale: Optional[float] = 5.0,
|
789 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
|
810 |
target_size: Optional[Tuple[int, int]] = None,
|
811 |
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
812 |
comfyui_progressbar: bool = False,
|
813 |
+
timesteps: Optional[List[int]] = None,
|
814 |
):
|
815 |
r"""
|
816 |
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
|
|
915 |
batch_size = prompt_embeds.shape[0]
|
916 |
|
917 |
device = self._execution_device
|
918 |
+
if self.text_encoder is not None:
|
919 |
+
dtype = self.text_encoder.dtype
|
920 |
+
elif self.text_encoder_2 is not None:
|
921 |
+
dtype = self.text_encoder_2.dtype
|
922 |
+
else:
|
923 |
+
dtype = self.transformer.dtype
|
924 |
|
925 |
# 3. Encode input prompt
|
926 |
(
|
|
|
931 |
) = self.encode_prompt(
|
932 |
prompt=prompt,
|
933 |
device=device,
|
934 |
+
dtype=dtype,
|
935 |
num_images_per_prompt=num_images_per_prompt,
|
936 |
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
937 |
negative_prompt=negative_prompt,
|
|
|
941 |
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
942 |
text_encoder_index=0,
|
943 |
)
|
944 |
+
if self.tokenizer_2 is not None:
|
945 |
+
(
|
946 |
+
prompt_embeds_2,
|
947 |
+
negative_prompt_embeds_2,
|
948 |
+
prompt_attention_mask_2,
|
949 |
+
negative_prompt_attention_mask_2,
|
950 |
+
) = self.encode_prompt(
|
951 |
+
prompt=prompt,
|
952 |
+
device=device,
|
953 |
+
dtype=dtype,
|
954 |
+
num_images_per_prompt=num_images_per_prompt,
|
955 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
956 |
+
negative_prompt=negative_prompt,
|
957 |
+
prompt_embeds=prompt_embeds_2,
|
958 |
+
negative_prompt_embeds=negative_prompt_embeds_2,
|
959 |
+
prompt_attention_mask=prompt_attention_mask_2,
|
960 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
|
961 |
+
text_encoder_index=1,
|
962 |
+
)
|
963 |
+
else:
|
964 |
+
prompt_embeds_2 = None
|
965 |
+
negative_prompt_embeds_2 = None
|
966 |
+
prompt_attention_mask_2 = None
|
967 |
+
negative_prompt_attention_mask_2 = None
|
968 |
|
969 |
# 4. Prepare timesteps
|
970 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
971 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
972 |
+
else:
|
973 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
974 |
timesteps = self.scheduler.timesteps
|
975 |
if comfyui_progressbar:
|
976 |
from comfy.utils import ProgressBar
|
|
|
984 |
video_length,
|
985 |
height,
|
986 |
width,
|
987 |
+
dtype,
|
988 |
device,
|
989 |
generator,
|
990 |
latents,
|
|
|
992 |
if comfyui_progressbar:
|
993 |
pbar.update(1)
|
994 |
|
995 |
+
if control_camera_video is not None:
|
996 |
+
control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True)
|
997 |
+
control_video_latents = control_video_latents * 6
|
998 |
+
control_latents = (
|
999 |
+
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
1000 |
+
).to(device, dtype)
|
1001 |
+
elif control_video is not None:
|
1002 |
video_length = control_video.shape[2]
|
1003 |
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
1004 |
control_video = control_video.to(dtype=torch.float32)
|
1005 |
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
1006 |
+
control_video_latents = self.prepare_control_latents(
|
1007 |
+
None,
|
1008 |
+
control_video,
|
1009 |
+
batch_size,
|
1010 |
+
height,
|
1011 |
+
width,
|
1012 |
+
dtype,
|
1013 |
+
device,
|
1014 |
+
generator,
|
1015 |
+
self.do_classifier_free_guidance
|
1016 |
+
)[1]
|
1017 |
+
control_latents = (
|
1018 |
+
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
1019 |
+
).to(device, dtype)
|
1020 |
else:
|
1021 |
+
control_video_latents = torch.zeros_like(latents).to(device, dtype)
|
1022 |
+
control_latents = (
|
1023 |
+
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
1024 |
+
).to(device, dtype)
|
1025 |
+
|
1026 |
+
if ref_image is not None:
|
1027 |
+
video_length = ref_image.shape[2]
|
1028 |
+
ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
|
1029 |
+
ref_image = ref_image.to(dtype=torch.float32)
|
1030 |
+
ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
|
1031 |
+
|
1032 |
+
ref_image_latentes = self.prepare_control_latents(
|
1033 |
+
None,
|
1034 |
+
ref_image,
|
1035 |
+
batch_size,
|
1036 |
+
height,
|
1037 |
+
width,
|
1038 |
+
prompt_embeds.dtype,
|
1039 |
+
device,
|
1040 |
+
generator,
|
1041 |
+
self.do_classifier_free_guidance
|
1042 |
+
)[1]
|
1043 |
+
|
1044 |
+
ref_image_latentes_conv_in = torch.zeros_like(latents)
|
1045 |
+
if latents.size()[2] != 1:
|
1046 |
+
ref_image_latentes_conv_in[:, :, :1] = ref_image_latentes
|
1047 |
+
ref_image_latentes_conv_in = (
|
1048 |
+
torch.cat([ref_image_latentes_conv_in] * 2) if self.do_classifier_free_guidance else ref_image_latentes_conv_in
|
1049 |
+
).to(device, dtype)
|
1050 |
+
control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
|
1051 |
+
else:
|
1052 |
+
if self.transformer.config.get("add_ref_latent_in_control_model", False):
|
1053 |
+
ref_image_latentes_conv_in = torch.zeros_like(latents)
|
1054 |
+
ref_image_latentes_conv_in = (
|
1055 |
+
torch.cat([ref_image_latentes_conv_in] * 2) if self.do_classifier_free_guidance else ref_image_latentes_conv_in
|
1056 |
+
).to(device, dtype)
|
1057 |
+
control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
|
1058 |
|
1059 |
if comfyui_progressbar:
|
1060 |
pbar.update(1)
|
|
|
1086 |
)
|
1087 |
|
1088 |
# Get other hunyuan params
|
|
|
|
|
1089 |
target_size = target_size or (height, width)
|
1090 |
add_time_ids = list(original_size + target_size + crops_coords_top_left)
|
1091 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
1092 |
+
style = torch.tensor([0], device=device)
|
1093 |
|
1094 |
if self.do_classifier_free_guidance:
|
|
|
|
|
|
|
|
|
1095 |
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
|
1096 |
style = torch.cat([style] * 2, dim=0)
|
1097 |
|
1098 |
# To latents.device
|
1099 |
+
add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat(
|
|
|
|
|
|
|
|
|
1100 |
batch_size * num_images_per_prompt, 1
|
1101 |
)
|
1102 |
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
|
1103 |
|
1104 |
+
# Get other pixart params
|
1105 |
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
1106 |
+
if self.transformer.config.get("sample_size", 64) == 128:
|
1107 |
+
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
1108 |
+
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
1109 |
+
resolution = resolution.to(dtype=dtype, device=device)
|
1110 |
+
aspect_ratio = aspect_ratio.to(dtype=dtype, device=device)
|
1111 |
+
|
1112 |
+
if self.do_classifier_free_guidance:
|
1113 |
+
resolution = torch.cat([resolution, resolution], dim=0)
|
1114 |
+
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
1115 |
+
|
1116 |
+
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
1117 |
+
|
1118 |
+
if self.do_classifier_free_guidance:
|
1119 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
1120 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
1121 |
+
if prompt_embeds_2 is not None:
|
1122 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
1123 |
+
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
|
1124 |
+
|
1125 |
+
# To latents.device
|
1126 |
+
prompt_embeds = prompt_embeds.to(device=device)
|
1127 |
+
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
1128 |
+
if prompt_embeds_2 is not None:
|
1129 |
+
prompt_embeds_2 = prompt_embeds_2.to(device=device)
|
1130 |
+
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
|
1131 |
+
|
1132 |
# 8. Denoising loop
|
1133 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1134 |
self._num_timesteps = len(timesteps)
|
|
|
1139 |
|
1140 |
# expand the latents if we are doing classifier free guidance
|
1141 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1142 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
1143 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1144 |
|
1145 |
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
1146 |
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
|
|
1157 |
image_meta_size=add_time_ids,
|
1158 |
style=style,
|
1159 |
image_rotary_emb=image_rotary_emb,
|
1160 |
+
added_cond_kwargs=added_cond_kwargs,
|
1161 |
control_latents=control_latents,
|
1162 |
+
return_dict=False,
|
1163 |
)[0]
|
1164 |
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
1165 |
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
|
|
1199 |
if comfyui_progressbar:
|
1200 |
pbar.update(1)
|
1201 |
|
|
|
|
|
|
|
|
|
1202 |
# Post-processing
|
1203 |
video = self.decode_latents(latents)
|
1204 |
|
|
|
1212 |
if not return_dict:
|
1213 |
return video
|
1214 |
|
1215 |
+
return EasyAnimatePipelineOutput(frames=video)
|
easyanimate/pipeline/pipeline_easyanimate_inpaint.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py
DELETED
@@ -1,925 +0,0 @@
|
|
1 |
-
# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import inspect
|
16 |
-
from typing import Callable, Dict, List, Optional, Tuple, Union
|
17 |
-
|
18 |
-
import numpy as np
|
19 |
-
import torch
|
20 |
-
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
21 |
-
from diffusers.image_processor import VaeImageProcessor
|
22 |
-
from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
|
23 |
-
get_3d_rotary_pos_embed)
|
24 |
-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
25 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
26 |
-
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
27 |
-
StableDiffusionSafetyChecker
|
28 |
-
from diffusers.schedulers import DDIMScheduler
|
29 |
-
from diffusers.utils import (is_torch_xla_available, logging,
|
30 |
-
replace_example_docstring)
|
31 |
-
from diffusers.utils.torch_utils import randn_tensor
|
32 |
-
from einops import rearrange
|
33 |
-
from tqdm import tqdm
|
34 |
-
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
35 |
-
T5Tokenizer, T5EncoderModel)
|
36 |
-
|
37 |
-
from .pipeline_easyanimate import EasyAnimatePipelineOutput
|
38 |
-
from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
39 |
-
|
40 |
-
if is_torch_xla_available():
|
41 |
-
import torch_xla.core.xla_model as xm
|
42 |
-
|
43 |
-
XLA_AVAILABLE = True
|
44 |
-
else:
|
45 |
-
XLA_AVAILABLE = False
|
46 |
-
|
47 |
-
|
48 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49 |
-
|
50 |
-
EXAMPLE_DOC_STRING = """
|
51 |
-
Examples:
|
52 |
-
```py
|
53 |
-
>>> pass
|
54 |
-
```
|
55 |
-
"""
|
56 |
-
|
57 |
-
|
58 |
-
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
59 |
-
tw = tgt_width
|
60 |
-
th = tgt_height
|
61 |
-
h, w = src
|
62 |
-
r = h / w
|
63 |
-
if r > (th / tw):
|
64 |
-
resize_height = th
|
65 |
-
resize_width = int(round(th / h * w))
|
66 |
-
else:
|
67 |
-
resize_width = tw
|
68 |
-
resize_height = int(round(tw / w * h))
|
69 |
-
|
70 |
-
crop_top = int(round((th - resize_height) / 2.0))
|
71 |
-
crop_left = int(round((tw - resize_width) / 2.0))
|
72 |
-
|
73 |
-
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
74 |
-
|
75 |
-
|
76 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
77 |
-
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
78 |
-
"""
|
79 |
-
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
80 |
-
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
81 |
-
"""
|
82 |
-
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
83 |
-
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
84 |
-
# rescale the results from guidance (fixes overexposure)
|
85 |
-
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
86 |
-
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
87 |
-
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
88 |
-
return noise_cfg
|
89 |
-
|
90 |
-
|
91 |
-
class EasyAnimatePipeline_Multi_Text_Encoder(DiffusionPipeline):
|
92 |
-
r"""
|
93 |
-
Pipeline for text-to-video generation using EasyAnimate.
|
94 |
-
|
95 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
96 |
-
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
97 |
-
|
98 |
-
EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
|
99 |
-
HunyuanDiT team)
|
100 |
-
|
101 |
-
Args:
|
102 |
-
vae ([`AutoencoderKLMagvit`]):
|
103 |
-
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
104 |
-
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
|
105 |
-
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
106 |
-
EasyAnimate uses a fine-tuned [bilingual CLIP].
|
107 |
-
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
|
108 |
-
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
|
109 |
-
transformer ([`EasyAnimateTransformer3DModel`]):
|
110 |
-
The EasyAnimate model designed by Tencent Hunyuan.
|
111 |
-
text_encoder_2 (`T5EncoderModel`):
|
112 |
-
The mT5 embedder.
|
113 |
-
tokenizer_2 (`T5Tokenizer`):
|
114 |
-
The tokenizer for the mT5 embedder.
|
115 |
-
scheduler ([`DDIMScheduler`]):
|
116 |
-
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
117 |
-
"""
|
118 |
-
|
119 |
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
120 |
-
_optional_components = [
|
121 |
-
"safety_checker",
|
122 |
-
"feature_extractor",
|
123 |
-
"text_encoder_2",
|
124 |
-
"tokenizer_2",
|
125 |
-
"text_encoder",
|
126 |
-
"tokenizer",
|
127 |
-
]
|
128 |
-
_exclude_from_cpu_offload = ["safety_checker"]
|
129 |
-
_callback_tensor_inputs = [
|
130 |
-
"latents",
|
131 |
-
"prompt_embeds",
|
132 |
-
"negative_prompt_embeds",
|
133 |
-
"prompt_embeds_2",
|
134 |
-
"negative_prompt_embeds_2",
|
135 |
-
]
|
136 |
-
|
137 |
-
def __init__(
|
138 |
-
self,
|
139 |
-
vae: AutoencoderKLMagvit,
|
140 |
-
text_encoder: BertModel,
|
141 |
-
tokenizer: BertTokenizer,
|
142 |
-
text_encoder_2: T5EncoderModel,
|
143 |
-
tokenizer_2: T5Tokenizer,
|
144 |
-
transformer: EasyAnimateTransformer3DModel,
|
145 |
-
scheduler: DDIMScheduler,
|
146 |
-
safety_checker: StableDiffusionSafetyChecker,
|
147 |
-
feature_extractor: CLIPImageProcessor,
|
148 |
-
requires_safety_checker: bool = True,
|
149 |
-
):
|
150 |
-
super().__init__()
|
151 |
-
|
152 |
-
self.register_modules(
|
153 |
-
vae=vae,
|
154 |
-
text_encoder=text_encoder,
|
155 |
-
tokenizer=tokenizer,
|
156 |
-
tokenizer_2=tokenizer_2,
|
157 |
-
transformer=transformer,
|
158 |
-
scheduler=scheduler,
|
159 |
-
safety_checker=safety_checker,
|
160 |
-
feature_extractor=feature_extractor,
|
161 |
-
text_encoder_2=text_encoder_2,
|
162 |
-
)
|
163 |
-
|
164 |
-
if safety_checker is None and requires_safety_checker:
|
165 |
-
logger.warning(
|
166 |
-
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
167 |
-
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
168 |
-
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
169 |
-
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
170 |
-
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
171 |
-
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
172 |
-
)
|
173 |
-
|
174 |
-
if safety_checker is not None and feature_extractor is None:
|
175 |
-
raise ValueError(
|
176 |
-
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
177 |
-
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
178 |
-
)
|
179 |
-
|
180 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
181 |
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
182 |
-
self.enable_autocast_float8_transformer_flag = False
|
183 |
-
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
184 |
-
|
185 |
-
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
186 |
-
super().enable_sequential_cpu_offload(*args, **kwargs)
|
187 |
-
if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
|
188 |
-
import accelerate
|
189 |
-
accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
|
190 |
-
self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
|
191 |
-
|
192 |
-
def encode_prompt(
|
193 |
-
self,
|
194 |
-
prompt: str,
|
195 |
-
device: torch.device,
|
196 |
-
dtype: torch.dtype,
|
197 |
-
num_images_per_prompt: int = 1,
|
198 |
-
do_classifier_free_guidance: bool = True,
|
199 |
-
negative_prompt: Optional[str] = None,
|
200 |
-
prompt_embeds: Optional[torch.Tensor] = None,
|
201 |
-
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
202 |
-
prompt_attention_mask: Optional[torch.Tensor] = None,
|
203 |
-
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
204 |
-
max_sequence_length: Optional[int] = None,
|
205 |
-
text_encoder_index: int = 0,
|
206 |
-
actual_max_sequence_length: int = 256
|
207 |
-
):
|
208 |
-
r"""
|
209 |
-
Encodes the prompt into text encoder hidden states.
|
210 |
-
|
211 |
-
Args:
|
212 |
-
prompt (`str` or `List[str]`, *optional*):
|
213 |
-
prompt to be encoded
|
214 |
-
device: (`torch.device`):
|
215 |
-
torch device
|
216 |
-
dtype (`torch.dtype`):
|
217 |
-
torch dtype
|
218 |
-
num_images_per_prompt (`int`):
|
219 |
-
number of images that should be generated per prompt
|
220 |
-
do_classifier_free_guidance (`bool`):
|
221 |
-
whether to use classifier free guidance or not
|
222 |
-
negative_prompt (`str` or `List[str]`, *optional*):
|
223 |
-
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
224 |
-
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
225 |
-
less than `1`).
|
226 |
-
prompt_embeds (`torch.Tensor`, *optional*):
|
227 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
228 |
-
provided, text embeddings will be generated from `prompt` input argument.
|
229 |
-
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
230 |
-
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
231 |
-
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
232 |
-
argument.
|
233 |
-
prompt_attention_mask (`torch.Tensor`, *optional*):
|
234 |
-
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
235 |
-
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
236 |
-
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
237 |
-
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
238 |
-
text_encoder_index (`int`, *optional*):
|
239 |
-
Index of the text encoder to use. `0` for clip and `1` for T5.
|
240 |
-
"""
|
241 |
-
tokenizers = [self.tokenizer, self.tokenizer_2]
|
242 |
-
text_encoders = [self.text_encoder, self.text_encoder_2]
|
243 |
-
|
244 |
-
tokenizer = tokenizers[text_encoder_index]
|
245 |
-
text_encoder = text_encoders[text_encoder_index]
|
246 |
-
|
247 |
-
if max_sequence_length is None:
|
248 |
-
if text_encoder_index == 0:
|
249 |
-
max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
|
250 |
-
if text_encoder_index == 1:
|
251 |
-
max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
|
252 |
-
else:
|
253 |
-
max_length = max_sequence_length
|
254 |
-
|
255 |
-
if prompt is not None and isinstance(prompt, str):
|
256 |
-
batch_size = 1
|
257 |
-
elif prompt is not None and isinstance(prompt, list):
|
258 |
-
batch_size = len(prompt)
|
259 |
-
else:
|
260 |
-
batch_size = prompt_embeds.shape[0]
|
261 |
-
|
262 |
-
if prompt_embeds is None:
|
263 |
-
text_inputs = tokenizer(
|
264 |
-
prompt,
|
265 |
-
padding="max_length",
|
266 |
-
max_length=max_length,
|
267 |
-
truncation=True,
|
268 |
-
return_attention_mask=True,
|
269 |
-
return_tensors="pt",
|
270 |
-
)
|
271 |
-
text_input_ids = text_inputs.input_ids
|
272 |
-
if text_input_ids.shape[-1] > actual_max_sequence_length:
|
273 |
-
reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
274 |
-
text_inputs = tokenizer(
|
275 |
-
reprompt,
|
276 |
-
padding="max_length",
|
277 |
-
max_length=max_length,
|
278 |
-
truncation=True,
|
279 |
-
return_attention_mask=True,
|
280 |
-
return_tensors="pt",
|
281 |
-
)
|
282 |
-
text_input_ids = text_inputs.input_ids
|
283 |
-
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
284 |
-
|
285 |
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
286 |
-
text_input_ids, untruncated_ids
|
287 |
-
):
|
288 |
-
_actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
|
289 |
-
removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
|
290 |
-
logger.warning(
|
291 |
-
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
292 |
-
f" {_actual_max_sequence_length} tokens: {removed_text}"
|
293 |
-
)
|
294 |
-
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
295 |
-
|
296 |
-
if self.transformer.config.enable_text_attention_mask:
|
297 |
-
prompt_embeds = text_encoder(
|
298 |
-
text_input_ids.to(device),
|
299 |
-
attention_mask=prompt_attention_mask,
|
300 |
-
)
|
301 |
-
else:
|
302 |
-
prompt_embeds = text_encoder(
|
303 |
-
text_input_ids.to(device)
|
304 |
-
)
|
305 |
-
prompt_embeds = prompt_embeds[0]
|
306 |
-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
307 |
-
|
308 |
-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
309 |
-
|
310 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
311 |
-
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
312 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
313 |
-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
314 |
-
|
315 |
-
# get unconditional embeddings for classifier free guidance
|
316 |
-
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
317 |
-
uncond_tokens: List[str]
|
318 |
-
if negative_prompt is None:
|
319 |
-
uncond_tokens = [""] * batch_size
|
320 |
-
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
321 |
-
raise TypeError(
|
322 |
-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
323 |
-
f" {type(prompt)}."
|
324 |
-
)
|
325 |
-
elif isinstance(negative_prompt, str):
|
326 |
-
uncond_tokens = [negative_prompt]
|
327 |
-
elif batch_size != len(negative_prompt):
|
328 |
-
raise ValueError(
|
329 |
-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
330 |
-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
331 |
-
" the batch size of `prompt`."
|
332 |
-
)
|
333 |
-
else:
|
334 |
-
uncond_tokens = negative_prompt
|
335 |
-
|
336 |
-
max_length = prompt_embeds.shape[1]
|
337 |
-
uncond_input = tokenizer(
|
338 |
-
uncond_tokens,
|
339 |
-
padding="max_length",
|
340 |
-
max_length=max_length,
|
341 |
-
truncation=True,
|
342 |
-
return_tensors="pt",
|
343 |
-
)
|
344 |
-
uncond_input_ids = uncond_input.input_ids
|
345 |
-
if uncond_input_ids.shape[-1] > actual_max_sequence_length:
|
346 |
-
reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
347 |
-
uncond_input = tokenizer(
|
348 |
-
reuncond_tokens,
|
349 |
-
padding="max_length",
|
350 |
-
max_length=max_length,
|
351 |
-
truncation=True,
|
352 |
-
return_attention_mask=True,
|
353 |
-
return_tensors="pt",
|
354 |
-
)
|
355 |
-
uncond_input_ids = uncond_input.input_ids
|
356 |
-
|
357 |
-
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
|
358 |
-
if self.transformer.config.enable_text_attention_mask:
|
359 |
-
negative_prompt_embeds = text_encoder(
|
360 |
-
uncond_input.input_ids.to(device),
|
361 |
-
attention_mask=negative_prompt_attention_mask,
|
362 |
-
)
|
363 |
-
else:
|
364 |
-
negative_prompt_embeds = text_encoder(
|
365 |
-
uncond_input.input_ids.to(device)
|
366 |
-
)
|
367 |
-
negative_prompt_embeds = negative_prompt_embeds[0]
|
368 |
-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
369 |
-
|
370 |
-
if do_classifier_free_guidance:
|
371 |
-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
372 |
-
seq_len = negative_prompt_embeds.shape[1]
|
373 |
-
|
374 |
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
375 |
-
|
376 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
377 |
-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
378 |
-
|
379 |
-
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
380 |
-
|
381 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
382 |
-
def run_safety_checker(self, image, device, dtype):
|
383 |
-
if self.safety_checker is None:
|
384 |
-
has_nsfw_concept = None
|
385 |
-
else:
|
386 |
-
if torch.is_tensor(image):
|
387 |
-
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
388 |
-
else:
|
389 |
-
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
390 |
-
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
391 |
-
image, has_nsfw_concept = self.safety_checker(
|
392 |
-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
393 |
-
)
|
394 |
-
return image, has_nsfw_concept
|
395 |
-
|
396 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
397 |
-
def prepare_extra_step_kwargs(self, generator, eta):
|
398 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
399 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
400 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
401 |
-
# and should be between [0, 1]
|
402 |
-
|
403 |
-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
404 |
-
extra_step_kwargs = {}
|
405 |
-
if accepts_eta:
|
406 |
-
extra_step_kwargs["eta"] = eta
|
407 |
-
|
408 |
-
# check if the scheduler accepts generator
|
409 |
-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
410 |
-
if accepts_generator:
|
411 |
-
extra_step_kwargs["generator"] = generator
|
412 |
-
return extra_step_kwargs
|
413 |
-
|
414 |
-
def check_inputs(
|
415 |
-
self,
|
416 |
-
prompt,
|
417 |
-
height,
|
418 |
-
width,
|
419 |
-
negative_prompt=None,
|
420 |
-
prompt_embeds=None,
|
421 |
-
negative_prompt_embeds=None,
|
422 |
-
prompt_attention_mask=None,
|
423 |
-
negative_prompt_attention_mask=None,
|
424 |
-
prompt_embeds_2=None,
|
425 |
-
negative_prompt_embeds_2=None,
|
426 |
-
prompt_attention_mask_2=None,
|
427 |
-
negative_prompt_attention_mask_2=None,
|
428 |
-
callback_on_step_end_tensor_inputs=None,
|
429 |
-
):
|
430 |
-
if height % 8 != 0 or width % 8 != 0:
|
431 |
-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
432 |
-
|
433 |
-
if callback_on_step_end_tensor_inputs is not None and not all(
|
434 |
-
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
435 |
-
):
|
436 |
-
raise ValueError(
|
437 |
-
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
438 |
-
)
|
439 |
-
|
440 |
-
if prompt is not None and prompt_embeds is not None:
|
441 |
-
raise ValueError(
|
442 |
-
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
443 |
-
" only forward one of the two."
|
444 |
-
)
|
445 |
-
elif prompt is None and prompt_embeds is None:
|
446 |
-
raise ValueError(
|
447 |
-
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
448 |
-
)
|
449 |
-
elif prompt is None and prompt_embeds_2 is None:
|
450 |
-
raise ValueError(
|
451 |
-
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
452 |
-
)
|
453 |
-
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
454 |
-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
455 |
-
|
456 |
-
if prompt_embeds is not None and prompt_attention_mask is None:
|
457 |
-
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
458 |
-
|
459 |
-
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
|
460 |
-
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
|
461 |
-
|
462 |
-
if negative_prompt is not None and negative_prompt_embeds is not None:
|
463 |
-
raise ValueError(
|
464 |
-
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
465 |
-
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
466 |
-
)
|
467 |
-
|
468 |
-
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
469 |
-
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
470 |
-
|
471 |
-
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
|
472 |
-
raise ValueError(
|
473 |
-
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
|
474 |
-
)
|
475 |
-
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
476 |
-
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
477 |
-
raise ValueError(
|
478 |
-
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
479 |
-
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
480 |
-
f" {negative_prompt_embeds.shape}."
|
481 |
-
)
|
482 |
-
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
|
483 |
-
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
|
484 |
-
raise ValueError(
|
485 |
-
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
|
486 |
-
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
|
487 |
-
f" {negative_prompt_embeds_2.shape}."
|
488 |
-
)
|
489 |
-
|
490 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
491 |
-
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
492 |
-
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
493 |
-
if self.vae.cache_mag_vae:
|
494 |
-
mini_batch_encoder = self.vae.mini_batch_encoder
|
495 |
-
mini_batch_decoder = self.vae.mini_batch_decoder
|
496 |
-
shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
497 |
-
else:
|
498 |
-
mini_batch_encoder = self.vae.mini_batch_encoder
|
499 |
-
mini_batch_decoder = self.vae.mini_batch_decoder
|
500 |
-
shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
501 |
-
else:
|
502 |
-
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
503 |
-
|
504 |
-
if isinstance(generator, list) and len(generator) != batch_size:
|
505 |
-
raise ValueError(
|
506 |
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
507 |
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
508 |
-
)
|
509 |
-
|
510 |
-
if latents is None:
|
511 |
-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
512 |
-
else:
|
513 |
-
latents = latents.to(device)
|
514 |
-
|
515 |
-
# scale the initial noise by the standard deviation required by the scheduler
|
516 |
-
latents = latents * self.scheduler.init_noise_sigma
|
517 |
-
return latents
|
518 |
-
|
519 |
-
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
520 |
-
if video.size()[2] <= mini_batch_encoder:
|
521 |
-
return video
|
522 |
-
prefix_index_before = mini_batch_encoder // 2
|
523 |
-
prefix_index_after = mini_batch_encoder - prefix_index_before
|
524 |
-
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
525 |
-
|
526 |
-
# Encode middle videos
|
527 |
-
latents = self.vae.encode(pixel_values)[0]
|
528 |
-
latents = latents.mode()
|
529 |
-
# Decode middle videos
|
530 |
-
middle_video = self.vae.decode(latents)[0]
|
531 |
-
|
532 |
-
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
533 |
-
return video
|
534 |
-
|
535 |
-
def decode_latents(self, latents):
|
536 |
-
video_length = latents.shape[2]
|
537 |
-
latents = 1 / self.vae.config.scaling_factor * latents
|
538 |
-
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
539 |
-
mini_batch_encoder = self.vae.mini_batch_encoder
|
540 |
-
mini_batch_decoder = self.vae.mini_batch_decoder
|
541 |
-
video = self.vae.decode(latents)[0]
|
542 |
-
video = video.clamp(-1, 1)
|
543 |
-
if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
|
544 |
-
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
545 |
-
else:
|
546 |
-
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
547 |
-
video = []
|
548 |
-
for frame_idx in tqdm(range(latents.shape[0])):
|
549 |
-
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
550 |
-
video = torch.cat(video)
|
551 |
-
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
552 |
-
video = (video / 2 + 0.5).clamp(0, 1)
|
553 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
554 |
-
video = video.cpu().float().numpy()
|
555 |
-
return video
|
556 |
-
|
557 |
-
@property
|
558 |
-
def guidance_scale(self):
|
559 |
-
return self._guidance_scale
|
560 |
-
|
561 |
-
@property
|
562 |
-
def guidance_rescale(self):
|
563 |
-
return self._guidance_rescale
|
564 |
-
|
565 |
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
566 |
-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
567 |
-
# corresponds to doing no classifier free guidance.
|
568 |
-
@property
|
569 |
-
def do_classifier_free_guidance(self):
|
570 |
-
return self._guidance_scale > 1
|
571 |
-
|
572 |
-
@property
|
573 |
-
def num_timesteps(self):
|
574 |
-
return self._num_timesteps
|
575 |
-
|
576 |
-
@property
|
577 |
-
def interrupt(self):
|
578 |
-
return self._interrupt
|
579 |
-
|
580 |
-
def enable_autocast_float8_transformer(self):
|
581 |
-
self.enable_autocast_float8_transformer_flag = True
|
582 |
-
|
583 |
-
@torch.no_grad()
|
584 |
-
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
585 |
-
def __call__(
|
586 |
-
self,
|
587 |
-
prompt: Union[str, List[str]] = None,
|
588 |
-
video_length: Optional[int] = None,
|
589 |
-
height: Optional[int] = None,
|
590 |
-
width: Optional[int] = None,
|
591 |
-
num_inference_steps: Optional[int] = 50,
|
592 |
-
guidance_scale: Optional[float] = 5.0,
|
593 |
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
594 |
-
num_images_per_prompt: Optional[int] = 1,
|
595 |
-
eta: Optional[float] = 0.0,
|
596 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
597 |
-
latents: Optional[torch.Tensor] = None,
|
598 |
-
prompt_embeds: Optional[torch.Tensor] = None,
|
599 |
-
prompt_embeds_2: Optional[torch.Tensor] = None,
|
600 |
-
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
601 |
-
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
602 |
-
prompt_attention_mask: Optional[torch.Tensor] = None,
|
603 |
-
prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
604 |
-
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
605 |
-
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
606 |
-
output_type: Optional[str] = "latent",
|
607 |
-
return_dict: bool = True,
|
608 |
-
callback_on_step_end: Optional[
|
609 |
-
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
610 |
-
] = None,
|
611 |
-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
612 |
-
guidance_rescale: float = 0.0,
|
613 |
-
original_size: Optional[Tuple[int, int]] = (1024, 1024),
|
614 |
-
target_size: Optional[Tuple[int, int]] = None,
|
615 |
-
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
616 |
-
comfyui_progressbar: bool = False,
|
617 |
-
):
|
618 |
-
r"""
|
619 |
-
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
620 |
-
|
621 |
-
Examples:
|
622 |
-
prompt (`str` or `List[str]`, *optional*):
|
623 |
-
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
|
624 |
-
video_length (`int`, *optional*):
|
625 |
-
Length of the generated video (in frames).
|
626 |
-
height (`int`, *optional*):
|
627 |
-
Height of the generated image in pixels.
|
628 |
-
width (`int`, *optional*):
|
629 |
-
Width of the generated image in pixels.
|
630 |
-
num_inference_steps (`int`, *optional*, defaults to 50):
|
631 |
-
Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
|
632 |
-
guidance_scale (`float`, *optional*, defaults to 5.0):
|
633 |
-
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
|
634 |
-
negative_prompt (`str` or `List[str]`, *optional*):
|
635 |
-
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
|
636 |
-
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
637 |
-
Number of images to generate for each prompt.
|
638 |
-
eta (`float`, *optional*, defaults to 0.0):
|
639 |
-
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
|
640 |
-
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
641 |
-
A generator to ensure reproducibility in image generation.
|
642 |
-
latents (`torch.Tensor`, *optional*):
|
643 |
-
Predefined latent tensors to condition generation.
|
644 |
-
prompt_embeds (`torch.Tensor`, *optional*):
|
645 |
-
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
|
646 |
-
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
647 |
-
Secondary text embeddings to supplement or replace the initial prompt embeddings.
|
648 |
-
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
649 |
-
Embeddings for negative prompts. Overrides string inputs if defined.
|
650 |
-
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
651 |
-
Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
|
652 |
-
prompt_attention_mask (`torch.Tensor`, *optional*):
|
653 |
-
Attention mask for the primary prompt embeddings.
|
654 |
-
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
655 |
-
Attention mask for the secondary prompt embeddings.
|
656 |
-
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
657 |
-
Attention mask for negative prompt embeddings.
|
658 |
-
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
659 |
-
Attention mask for secondary negative prompt embeddings.
|
660 |
-
output_type (`str`, *optional*, defaults to "latent"):
|
661 |
-
Format of the generated output, either as a PIL image or as a NumPy array.
|
662 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
663 |
-
If `True`, returns a structured output. Otherwise returns a simple tuple.
|
664 |
-
callback_on_step_end (`Callable`, *optional*):
|
665 |
-
Functions called at the end of each denoising step.
|
666 |
-
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
667 |
-
Tensor names to be included in callback function calls.
|
668 |
-
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
669 |
-
Adjusts noise levels based on guidance scale.
|
670 |
-
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
|
671 |
-
Original dimensions of the output.
|
672 |
-
target_size (`Tuple[int, int]`, *optional*):
|
673 |
-
Desired output dimensions for calculations.
|
674 |
-
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
|
675 |
-
Coordinates for cropping.
|
676 |
-
|
677 |
-
Returns:
|
678 |
-
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
679 |
-
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
680 |
-
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
681 |
-
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
682 |
-
"not-safe-for-work" (nsfw) content.
|
683 |
-
"""
|
684 |
-
|
685 |
-
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
686 |
-
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
687 |
-
|
688 |
-
# 0. default height and width
|
689 |
-
height = int((height // 16) * 16)
|
690 |
-
width = int((width // 16) * 16)
|
691 |
-
|
692 |
-
# 1. Check inputs. Raise error if not correct
|
693 |
-
self.check_inputs(
|
694 |
-
prompt,
|
695 |
-
height,
|
696 |
-
width,
|
697 |
-
negative_prompt,
|
698 |
-
prompt_embeds,
|
699 |
-
negative_prompt_embeds,
|
700 |
-
prompt_attention_mask,
|
701 |
-
negative_prompt_attention_mask,
|
702 |
-
prompt_embeds_2,
|
703 |
-
negative_prompt_embeds_2,
|
704 |
-
prompt_attention_mask_2,
|
705 |
-
negative_prompt_attention_mask_2,
|
706 |
-
callback_on_step_end_tensor_inputs,
|
707 |
-
)
|
708 |
-
self._guidance_scale = guidance_scale
|
709 |
-
self._guidance_rescale = guidance_rescale
|
710 |
-
self._interrupt = False
|
711 |
-
|
712 |
-
# 2. Define call parameters
|
713 |
-
if prompt is not None and isinstance(prompt, str):
|
714 |
-
batch_size = 1
|
715 |
-
elif prompt is not None and isinstance(prompt, list):
|
716 |
-
batch_size = len(prompt)
|
717 |
-
else:
|
718 |
-
batch_size = prompt_embeds.shape[0]
|
719 |
-
|
720 |
-
device = self._execution_device
|
721 |
-
|
722 |
-
# 3. Encode input prompt
|
723 |
-
(
|
724 |
-
prompt_embeds,
|
725 |
-
negative_prompt_embeds,
|
726 |
-
prompt_attention_mask,
|
727 |
-
negative_prompt_attention_mask,
|
728 |
-
) = self.encode_prompt(
|
729 |
-
prompt=prompt,
|
730 |
-
device=device,
|
731 |
-
dtype=self.transformer.dtype,
|
732 |
-
num_images_per_prompt=num_images_per_prompt,
|
733 |
-
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
734 |
-
negative_prompt=negative_prompt,
|
735 |
-
prompt_embeds=prompt_embeds,
|
736 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
737 |
-
prompt_attention_mask=prompt_attention_mask,
|
738 |
-
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
739 |
-
text_encoder_index=0,
|
740 |
-
)
|
741 |
-
(
|
742 |
-
prompt_embeds_2,
|
743 |
-
negative_prompt_embeds_2,
|
744 |
-
prompt_attention_mask_2,
|
745 |
-
negative_prompt_attention_mask_2,
|
746 |
-
) = self.encode_prompt(
|
747 |
-
prompt=prompt,
|
748 |
-
device=device,
|
749 |
-
dtype=self.transformer.dtype,
|
750 |
-
num_images_per_prompt=num_images_per_prompt,
|
751 |
-
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
752 |
-
negative_prompt=negative_prompt,
|
753 |
-
prompt_embeds=prompt_embeds_2,
|
754 |
-
negative_prompt_embeds=negative_prompt_embeds_2,
|
755 |
-
prompt_attention_mask=prompt_attention_mask_2,
|
756 |
-
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
|
757 |
-
text_encoder_index=1,
|
758 |
-
)
|
759 |
-
torch.cuda.empty_cache()
|
760 |
-
|
761 |
-
# 4. Prepare timesteps
|
762 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
763 |
-
timesteps = self.scheduler.timesteps
|
764 |
-
if comfyui_progressbar:
|
765 |
-
from comfy.utils import ProgressBar
|
766 |
-
pbar = ProgressBar(num_inference_steps + 1)
|
767 |
-
|
768 |
-
# 5. Prepare latent variables
|
769 |
-
num_channels_latents = self.transformer.config.in_channels
|
770 |
-
latents = self.prepare_latents(
|
771 |
-
batch_size * num_images_per_prompt,
|
772 |
-
num_channels_latents,
|
773 |
-
video_length,
|
774 |
-
height,
|
775 |
-
width,
|
776 |
-
prompt_embeds.dtype,
|
777 |
-
device,
|
778 |
-
generator,
|
779 |
-
latents,
|
780 |
-
)
|
781 |
-
if comfyui_progressbar:
|
782 |
-
pbar.update(1)
|
783 |
-
|
784 |
-
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
785 |
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
786 |
-
|
787 |
-
# 7 create image_rotary_emb, style embedding & time ids
|
788 |
-
grid_height = height // 8 // self.transformer.config.patch_size
|
789 |
-
grid_width = width // 8 // self.transformer.config.patch_size
|
790 |
-
if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
|
791 |
-
base_size_width = 720 // 8 // self.transformer.config.patch_size
|
792 |
-
base_size_height = 480 // 8 // self.transformer.config.patch_size
|
793 |
-
|
794 |
-
grid_crops_coords = get_resize_crop_region_for_grid(
|
795 |
-
(grid_height, grid_width), base_size_width, base_size_height
|
796 |
-
)
|
797 |
-
image_rotary_emb = get_3d_rotary_pos_embed(
|
798 |
-
self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
|
799 |
-
temporal_size=latents.size(2), use_real=True,
|
800 |
-
)
|
801 |
-
else:
|
802 |
-
base_size = 512 // 8 // self.transformer.config.patch_size
|
803 |
-
grid_crops_coords = get_resize_crop_region_for_grid(
|
804 |
-
(grid_height, grid_width), base_size, base_size
|
805 |
-
)
|
806 |
-
image_rotary_emb = get_2d_rotary_pos_embed(
|
807 |
-
self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
|
808 |
-
)
|
809 |
-
|
810 |
-
# Get other hunyuan params
|
811 |
-
style = torch.tensor([0], device=device)
|
812 |
-
|
813 |
-
target_size = target_size or (height, width)
|
814 |
-
add_time_ids = list(original_size + target_size + crops_coords_top_left)
|
815 |
-
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
816 |
-
|
817 |
-
if self.do_classifier_free_guidance:
|
818 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
819 |
-
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
820 |
-
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
821 |
-
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
|
822 |
-
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
|
823 |
-
style = torch.cat([style] * 2, dim=0)
|
824 |
-
|
825 |
-
# To latents.device
|
826 |
-
prompt_embeds = prompt_embeds.to(device=device)
|
827 |
-
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
828 |
-
prompt_embeds_2 = prompt_embeds_2.to(device=device)
|
829 |
-
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
|
830 |
-
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
|
831 |
-
batch_size * num_images_per_prompt, 1
|
832 |
-
)
|
833 |
-
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
|
834 |
-
|
835 |
-
torch.cuda.empty_cache()
|
836 |
-
if self.enable_autocast_float8_transformer_flag:
|
837 |
-
origin_weight_dtype = self.transformer.dtype
|
838 |
-
self.transformer = self.transformer.to(torch.float8_e4m3fn)
|
839 |
-
# 8. Denoising loop
|
840 |
-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
841 |
-
self._num_timesteps = len(timesteps)
|
842 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
843 |
-
for i, t in enumerate(timesteps):
|
844 |
-
if self.interrupt:
|
845 |
-
continue
|
846 |
-
|
847 |
-
# expand the latents if we are doing classifier free guidance
|
848 |
-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
849 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
850 |
-
|
851 |
-
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
852 |
-
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
853 |
-
dtype=latent_model_input.dtype
|
854 |
-
)
|
855 |
-
|
856 |
-
# predict the noise residual
|
857 |
-
noise_pred = self.transformer(
|
858 |
-
latent_model_input,
|
859 |
-
t_expand,
|
860 |
-
encoder_hidden_states=prompt_embeds,
|
861 |
-
text_embedding_mask=prompt_attention_mask,
|
862 |
-
encoder_hidden_states_t5=prompt_embeds_2,
|
863 |
-
text_embedding_mask_t5=prompt_attention_mask_2,
|
864 |
-
image_meta_size=add_time_ids,
|
865 |
-
style=style,
|
866 |
-
image_rotary_emb=image_rotary_emb,
|
867 |
-
return_dict=False,
|
868 |
-
)[0]
|
869 |
-
|
870 |
-
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
871 |
-
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
872 |
-
|
873 |
-
# perform guidance
|
874 |
-
if self.do_classifier_free_guidance:
|
875 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
876 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
877 |
-
|
878 |
-
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
879 |
-
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
880 |
-
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
881 |
-
|
882 |
-
# compute the previous noisy sample x_t -> x_t-1
|
883 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
884 |
-
|
885 |
-
if callback_on_step_end is not None:
|
886 |
-
callback_kwargs = {}
|
887 |
-
for k in callback_on_step_end_tensor_inputs:
|
888 |
-
callback_kwargs[k] = locals()[k]
|
889 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
890 |
-
|
891 |
-
latents = callback_outputs.pop("latents", latents)
|
892 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
893 |
-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
894 |
-
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
|
895 |
-
negative_prompt_embeds_2 = callback_outputs.pop(
|
896 |
-
"negative_prompt_embeds_2", negative_prompt_embeds_2
|
897 |
-
)
|
898 |
-
|
899 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
900 |
-
progress_bar.update()
|
901 |
-
|
902 |
-
if XLA_AVAILABLE:
|
903 |
-
xm.mark_step()
|
904 |
-
|
905 |
-
if comfyui_progressbar:
|
906 |
-
pbar.update(1)
|
907 |
-
|
908 |
-
if self.enable_autocast_float8_transformer_flag:
|
909 |
-
self.transformer = self.transformer.to("cpu", origin_weight_dtype)
|
910 |
-
|
911 |
-
torch.cuda.empty_cache()
|
912 |
-
# Post-processing
|
913 |
-
video = self.decode_latents(latents)
|
914 |
-
|
915 |
-
# Convert to tensor
|
916 |
-
if output_type == "latent":
|
917 |
-
video = torch.from_numpy(video)
|
918 |
-
|
919 |
-
# Offload all models
|
920 |
-
self.maybe_free_model_hooks()
|
921 |
-
|
922 |
-
if not return_dict:
|
923 |
-
return video
|
924 |
-
|
925 |
-
return EasyAnimatePipelineOutput(videos=video)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py
DELETED
@@ -1,1334 +0,0 @@
|
|
1 |
-
# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import inspect
|
16 |
-
from typing import Callable, Dict, List, Optional, Tuple, Union
|
17 |
-
|
18 |
-
import torch
|
19 |
-
import torch.nn.functional as F
|
20 |
-
from diffusers import DiffusionPipeline
|
21 |
-
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
22 |
-
from diffusers.image_processor import VaeImageProcessor
|
23 |
-
from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
|
24 |
-
from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
|
25 |
-
get_3d_rotary_pos_embed)
|
26 |
-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
27 |
-
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
28 |
-
StableDiffusionSafetyChecker
|
29 |
-
from diffusers.schedulers import DDIMScheduler
|
30 |
-
from diffusers.utils import (is_torch_xla_available, logging,
|
31 |
-
replace_example_docstring)
|
32 |
-
from diffusers.utils.torch_utils import randn_tensor
|
33 |
-
from einops import rearrange
|
34 |
-
from PIL import Image
|
35 |
-
from tqdm import tqdm
|
36 |
-
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
37 |
-
CLIPVisionModelWithProjection, T5Tokenizer,
|
38 |
-
T5EncoderModel)
|
39 |
-
|
40 |
-
from .pipeline_easyanimate import EasyAnimatePipelineOutput
|
41 |
-
from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
42 |
-
|
43 |
-
if is_torch_xla_available():
|
44 |
-
import torch_xla.core.xla_model as xm
|
45 |
-
|
46 |
-
XLA_AVAILABLE = True
|
47 |
-
else:
|
48 |
-
XLA_AVAILABLE = False
|
49 |
-
|
50 |
-
|
51 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52 |
-
|
53 |
-
EXAMPLE_DOC_STRING = """
|
54 |
-
Examples:
|
55 |
-
```py
|
56 |
-
>>> pass
|
57 |
-
```
|
58 |
-
"""
|
59 |
-
|
60 |
-
|
61 |
-
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
62 |
-
tw = tgt_width
|
63 |
-
th = tgt_height
|
64 |
-
h, w = src
|
65 |
-
r = h / w
|
66 |
-
if r > (th / tw):
|
67 |
-
resize_height = th
|
68 |
-
resize_width = int(round(th / h * w))
|
69 |
-
else:
|
70 |
-
resize_width = tw
|
71 |
-
resize_height = int(round(tw / w * h))
|
72 |
-
|
73 |
-
crop_top = int(round((th - resize_height) / 2.0))
|
74 |
-
crop_left = int(round((tw - resize_width) / 2.0))
|
75 |
-
|
76 |
-
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
77 |
-
|
78 |
-
|
79 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
80 |
-
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
81 |
-
"""
|
82 |
-
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
83 |
-
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
84 |
-
"""
|
85 |
-
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
86 |
-
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
87 |
-
# rescale the results from guidance (fixes overexposure)
|
88 |
-
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
89 |
-
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
90 |
-
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
91 |
-
return noise_cfg
|
92 |
-
|
93 |
-
|
94 |
-
def resize_mask(mask, latent, process_first_frame_only=True):
|
95 |
-
latent_size = latent.size()
|
96 |
-
|
97 |
-
if process_first_frame_only:
|
98 |
-
target_size = list(latent_size[2:])
|
99 |
-
target_size[0] = 1
|
100 |
-
first_frame_resized = F.interpolate(
|
101 |
-
mask[:, :, 0:1, :, :],
|
102 |
-
size=target_size,
|
103 |
-
mode='trilinear',
|
104 |
-
align_corners=False
|
105 |
-
)
|
106 |
-
|
107 |
-
target_size = list(latent_size[2:])
|
108 |
-
target_size[0] = target_size[0] - 1
|
109 |
-
if target_size[0] != 0:
|
110 |
-
remaining_frames_resized = F.interpolate(
|
111 |
-
mask[:, :, 1:, :, :],
|
112 |
-
size=target_size,
|
113 |
-
mode='trilinear',
|
114 |
-
align_corners=False
|
115 |
-
)
|
116 |
-
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
117 |
-
else:
|
118 |
-
resized_mask = first_frame_resized
|
119 |
-
else:
|
120 |
-
target_size = list(latent_size[2:])
|
121 |
-
resized_mask = F.interpolate(
|
122 |
-
mask,
|
123 |
-
size=target_size,
|
124 |
-
mode='trilinear',
|
125 |
-
align_corners=False
|
126 |
-
)
|
127 |
-
return resized_mask
|
128 |
-
|
129 |
-
|
130 |
-
def add_noise_to_reference_video(image, ratio=None):
|
131 |
-
if ratio is None:
|
132 |
-
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
|
133 |
-
sigma = torch.exp(sigma).to(image.dtype)
|
134 |
-
else:
|
135 |
-
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
|
136 |
-
|
137 |
-
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
138 |
-
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
|
139 |
-
image = image + image_noise
|
140 |
-
return image
|
141 |
-
|
142 |
-
|
143 |
-
class EasyAnimatePipeline_Multi_Text_Encoder_Inpaint(DiffusionPipeline):
|
144 |
-
r"""
|
145 |
-
Pipeline for text-to-video generation using EasyAnimate.
|
146 |
-
|
147 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
148 |
-
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
149 |
-
|
150 |
-
EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
|
151 |
-
HunyuanDiT team)
|
152 |
-
|
153 |
-
Args:
|
154 |
-
vae ([`AutoencoderKLMagvit`]):
|
155 |
-
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
156 |
-
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
|
157 |
-
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
158 |
-
EasyAnimate uses a fine-tuned [bilingual CLIP].
|
159 |
-
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
|
160 |
-
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
|
161 |
-
transformer ([`EasyAnimateTransformer3DModel`]):
|
162 |
-
The EasyAnimate model designed by Tencent Hunyuan.
|
163 |
-
text_encoder_2 (`T5EncoderModel`):
|
164 |
-
The mT5 embedder.
|
165 |
-
tokenizer_2 (`T5Tokenizer`):
|
166 |
-
The tokenizer for the mT5 embedder.
|
167 |
-
scheduler ([`DDIMScheduler`]):
|
168 |
-
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
169 |
-
clip_image_processor (`CLIPImageProcessor`):
|
170 |
-
The CLIP image embedder.
|
171 |
-
clip_image_encoder (`CLIPVisionModelWithProjection`):
|
172 |
-
The image processor for the CLIP image embedder.
|
173 |
-
"""
|
174 |
-
|
175 |
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->clip_image_encoder->transformer->vae"
|
176 |
-
_optional_components = [
|
177 |
-
"safety_checker",
|
178 |
-
"feature_extractor",
|
179 |
-
"text_encoder_2",
|
180 |
-
"tokenizer_2",
|
181 |
-
"text_encoder",
|
182 |
-
"tokenizer",
|
183 |
-
"clip_image_encoder",
|
184 |
-
]
|
185 |
-
_exclude_from_cpu_offload = ["safety_checker"]
|
186 |
-
_callback_tensor_inputs = [
|
187 |
-
"latents",
|
188 |
-
"prompt_embeds",
|
189 |
-
"negative_prompt_embeds",
|
190 |
-
"prompt_embeds_2",
|
191 |
-
"negative_prompt_embeds_2",
|
192 |
-
]
|
193 |
-
|
194 |
-
def __init__(
|
195 |
-
self,
|
196 |
-
vae: AutoencoderKLMagvit,
|
197 |
-
text_encoder: BertModel,
|
198 |
-
tokenizer: BertTokenizer,
|
199 |
-
text_encoder_2: T5EncoderModel,
|
200 |
-
tokenizer_2: T5Tokenizer,
|
201 |
-
transformer: EasyAnimateTransformer3DModel,
|
202 |
-
scheduler: DDIMScheduler,
|
203 |
-
safety_checker: StableDiffusionSafetyChecker,
|
204 |
-
feature_extractor: CLIPImageProcessor,
|
205 |
-
requires_safety_checker: bool = True,
|
206 |
-
clip_image_processor: CLIPImageProcessor = None,
|
207 |
-
clip_image_encoder: CLIPVisionModelWithProjection = None,
|
208 |
-
):
|
209 |
-
super().__init__()
|
210 |
-
|
211 |
-
self.register_modules(
|
212 |
-
vae=vae,
|
213 |
-
text_encoder=text_encoder,
|
214 |
-
tokenizer=tokenizer,
|
215 |
-
tokenizer_2=tokenizer_2,
|
216 |
-
transformer=transformer,
|
217 |
-
scheduler=scheduler,
|
218 |
-
safety_checker=safety_checker,
|
219 |
-
feature_extractor=feature_extractor,
|
220 |
-
text_encoder_2=text_encoder_2,
|
221 |
-
clip_image_processor=clip_image_processor,
|
222 |
-
clip_image_encoder=clip_image_encoder,
|
223 |
-
)
|
224 |
-
|
225 |
-
if safety_checker is None and requires_safety_checker:
|
226 |
-
logger.warning(
|
227 |
-
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
228 |
-
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
229 |
-
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
230 |
-
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
231 |
-
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
232 |
-
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
233 |
-
)
|
234 |
-
|
235 |
-
if safety_checker is not None and feature_extractor is None:
|
236 |
-
raise ValueError(
|
237 |
-
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
238 |
-
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
239 |
-
)
|
240 |
-
|
241 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
242 |
-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
243 |
-
self.mask_processor = VaeImageProcessor(
|
244 |
-
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
245 |
-
)
|
246 |
-
self.enable_autocast_float8_transformer_flag = False
|
247 |
-
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
248 |
-
|
249 |
-
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
250 |
-
super().enable_sequential_cpu_offload(*args, **kwargs)
|
251 |
-
if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
|
252 |
-
import accelerate
|
253 |
-
accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
|
254 |
-
self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
|
255 |
-
|
256 |
-
def encode_prompt(
|
257 |
-
self,
|
258 |
-
prompt: str,
|
259 |
-
device: torch.device,
|
260 |
-
dtype: torch.dtype,
|
261 |
-
num_images_per_prompt: int = 1,
|
262 |
-
do_classifier_free_guidance: bool = True,
|
263 |
-
negative_prompt: Optional[str] = None,
|
264 |
-
prompt_embeds: Optional[torch.Tensor] = None,
|
265 |
-
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
266 |
-
prompt_attention_mask: Optional[torch.Tensor] = None,
|
267 |
-
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
268 |
-
max_sequence_length: Optional[int] = None,
|
269 |
-
text_encoder_index: int = 0,
|
270 |
-
actual_max_sequence_length: int = 256
|
271 |
-
):
|
272 |
-
r"""
|
273 |
-
Encodes the prompt into text encoder hidden states.
|
274 |
-
|
275 |
-
Args:
|
276 |
-
prompt (`str` or `List[str]`, *optional*):
|
277 |
-
prompt to be encoded
|
278 |
-
device: (`torch.device`):
|
279 |
-
torch device
|
280 |
-
dtype (`torch.dtype`):
|
281 |
-
torch dtype
|
282 |
-
num_images_per_prompt (`int`):
|
283 |
-
number of images that should be generated per prompt
|
284 |
-
do_classifier_free_guidance (`bool`):
|
285 |
-
whether to use classifier free guidance or not
|
286 |
-
negative_prompt (`str` or `List[str]`, *optional*):
|
287 |
-
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
288 |
-
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
289 |
-
less than `1`).
|
290 |
-
prompt_embeds (`torch.Tensor`, *optional*):
|
291 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
292 |
-
provided, text embeddings will be generated from `prompt` input argument.
|
293 |
-
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
294 |
-
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
295 |
-
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
296 |
-
argument.
|
297 |
-
prompt_attention_mask (`torch.Tensor`, *optional*):
|
298 |
-
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
299 |
-
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
300 |
-
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
301 |
-
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
302 |
-
text_encoder_index (`int`, *optional*):
|
303 |
-
Index of the text encoder to use. `0` for clip and `1` for T5.
|
304 |
-
"""
|
305 |
-
tokenizers = [self.tokenizer, self.tokenizer_2]
|
306 |
-
text_encoders = [self.text_encoder, self.text_encoder_2]
|
307 |
-
|
308 |
-
tokenizer = tokenizers[text_encoder_index]
|
309 |
-
text_encoder = text_encoders[text_encoder_index]
|
310 |
-
|
311 |
-
if max_sequence_length is None:
|
312 |
-
if text_encoder_index == 0:
|
313 |
-
max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
|
314 |
-
if text_encoder_index == 1:
|
315 |
-
max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
|
316 |
-
else:
|
317 |
-
max_length = max_sequence_length
|
318 |
-
|
319 |
-
if prompt is not None and isinstance(prompt, str):
|
320 |
-
batch_size = 1
|
321 |
-
elif prompt is not None and isinstance(prompt, list):
|
322 |
-
batch_size = len(prompt)
|
323 |
-
else:
|
324 |
-
batch_size = prompt_embeds.shape[0]
|
325 |
-
|
326 |
-
if prompt_embeds is None:
|
327 |
-
text_inputs = tokenizer(
|
328 |
-
prompt,
|
329 |
-
padding="max_length",
|
330 |
-
max_length=max_length,
|
331 |
-
truncation=True,
|
332 |
-
return_attention_mask=True,
|
333 |
-
return_tensors="pt",
|
334 |
-
)
|
335 |
-
text_input_ids = text_inputs.input_ids
|
336 |
-
if text_input_ids.shape[-1] > actual_max_sequence_length:
|
337 |
-
reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
338 |
-
text_inputs = tokenizer(
|
339 |
-
reprompt,
|
340 |
-
padding="max_length",
|
341 |
-
max_length=max_length,
|
342 |
-
truncation=True,
|
343 |
-
return_attention_mask=True,
|
344 |
-
return_tensors="pt",
|
345 |
-
)
|
346 |
-
text_input_ids = text_inputs.input_ids
|
347 |
-
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
348 |
-
|
349 |
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
350 |
-
text_input_ids, untruncated_ids
|
351 |
-
):
|
352 |
-
_actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
|
353 |
-
removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
|
354 |
-
logger.warning(
|
355 |
-
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
356 |
-
f" {_actual_max_sequence_length} tokens: {removed_text}"
|
357 |
-
)
|
358 |
-
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
359 |
-
if self.transformer.config.enable_text_attention_mask:
|
360 |
-
prompt_embeds = text_encoder(
|
361 |
-
text_input_ids.to(device),
|
362 |
-
attention_mask=prompt_attention_mask,
|
363 |
-
)
|
364 |
-
else:
|
365 |
-
prompt_embeds = text_encoder(
|
366 |
-
text_input_ids.to(device)
|
367 |
-
)
|
368 |
-
prompt_embeds = prompt_embeds[0]
|
369 |
-
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
370 |
-
|
371 |
-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
372 |
-
|
373 |
-
bs_embed, seq_len, _ = prompt_embeds.shape
|
374 |
-
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
375 |
-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
376 |
-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
377 |
-
|
378 |
-
# get unconditional embeddings for classifier free guidance
|
379 |
-
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
380 |
-
uncond_tokens: List[str]
|
381 |
-
if negative_prompt is None:
|
382 |
-
uncond_tokens = [""] * batch_size
|
383 |
-
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
384 |
-
raise TypeError(
|
385 |
-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
386 |
-
f" {type(prompt)}."
|
387 |
-
)
|
388 |
-
elif isinstance(negative_prompt, str):
|
389 |
-
uncond_tokens = [negative_prompt]
|
390 |
-
elif batch_size != len(negative_prompt):
|
391 |
-
raise ValueError(
|
392 |
-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
393 |
-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
394 |
-
" the batch size of `prompt`."
|
395 |
-
)
|
396 |
-
else:
|
397 |
-
uncond_tokens = negative_prompt
|
398 |
-
|
399 |
-
max_length = prompt_embeds.shape[1]
|
400 |
-
uncond_input = tokenizer(
|
401 |
-
uncond_tokens,
|
402 |
-
padding="max_length",
|
403 |
-
max_length=max_length,
|
404 |
-
truncation=True,
|
405 |
-
return_tensors="pt",
|
406 |
-
)
|
407 |
-
uncond_input_ids = uncond_input.input_ids
|
408 |
-
if uncond_input_ids.shape[-1] > actual_max_sequence_length:
|
409 |
-
reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
|
410 |
-
uncond_input = tokenizer(
|
411 |
-
reuncond_tokens,
|
412 |
-
padding="max_length",
|
413 |
-
max_length=max_length,
|
414 |
-
truncation=True,
|
415 |
-
return_attention_mask=True,
|
416 |
-
return_tensors="pt",
|
417 |
-
)
|
418 |
-
uncond_input_ids = uncond_input.input_ids
|
419 |
-
|
420 |
-
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
|
421 |
-
if self.transformer.config.enable_text_attention_mask:
|
422 |
-
negative_prompt_embeds = text_encoder(
|
423 |
-
uncond_input.input_ids.to(device),
|
424 |
-
attention_mask=negative_prompt_attention_mask,
|
425 |
-
)
|
426 |
-
else:
|
427 |
-
negative_prompt_embeds = text_encoder(
|
428 |
-
uncond_input.input_ids.to(device)
|
429 |
-
)
|
430 |
-
negative_prompt_embeds = negative_prompt_embeds[0]
|
431 |
-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
432 |
-
|
433 |
-
if do_classifier_free_guidance:
|
434 |
-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
435 |
-
seq_len = negative_prompt_embeds.shape[1]
|
436 |
-
|
437 |
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
438 |
-
|
439 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
440 |
-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
441 |
-
|
442 |
-
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
443 |
-
|
444 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
445 |
-
def run_safety_checker(self, image, device, dtype):
|
446 |
-
if self.safety_checker is None:
|
447 |
-
has_nsfw_concept = None
|
448 |
-
else:
|
449 |
-
if torch.is_tensor(image):
|
450 |
-
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
451 |
-
else:
|
452 |
-
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
453 |
-
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
454 |
-
image, has_nsfw_concept = self.safety_checker(
|
455 |
-
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
456 |
-
)
|
457 |
-
return image, has_nsfw_concept
|
458 |
-
|
459 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
460 |
-
def prepare_extra_step_kwargs(self, generator, eta):
|
461 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
462 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
463 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
464 |
-
# and should be between [0, 1]
|
465 |
-
|
466 |
-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
467 |
-
extra_step_kwargs = {}
|
468 |
-
if accepts_eta:
|
469 |
-
extra_step_kwargs["eta"] = eta
|
470 |
-
|
471 |
-
# check if the scheduler accepts generator
|
472 |
-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
473 |
-
if accepts_generator:
|
474 |
-
extra_step_kwargs["generator"] = generator
|
475 |
-
return extra_step_kwargs
|
476 |
-
|
477 |
-
def check_inputs(
|
478 |
-
self,
|
479 |
-
prompt,
|
480 |
-
height,
|
481 |
-
width,
|
482 |
-
negative_prompt=None,
|
483 |
-
prompt_embeds=None,
|
484 |
-
negative_prompt_embeds=None,
|
485 |
-
prompt_attention_mask=None,
|
486 |
-
negative_prompt_attention_mask=None,
|
487 |
-
prompt_embeds_2=None,
|
488 |
-
negative_prompt_embeds_2=None,
|
489 |
-
prompt_attention_mask_2=None,
|
490 |
-
negative_prompt_attention_mask_2=None,
|
491 |
-
callback_on_step_end_tensor_inputs=None,
|
492 |
-
):
|
493 |
-
if height % 8 != 0 or width % 8 != 0:
|
494 |
-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
495 |
-
|
496 |
-
if callback_on_step_end_tensor_inputs is not None and not all(
|
497 |
-
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
498 |
-
):
|
499 |
-
raise ValueError(
|
500 |
-
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
501 |
-
)
|
502 |
-
|
503 |
-
if prompt is not None and prompt_embeds is not None:
|
504 |
-
raise ValueError(
|
505 |
-
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
506 |
-
" only forward one of the two."
|
507 |
-
)
|
508 |
-
elif prompt is None and prompt_embeds is None:
|
509 |
-
raise ValueError(
|
510 |
-
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
511 |
-
)
|
512 |
-
elif prompt is None and prompt_embeds_2 is None:
|
513 |
-
raise ValueError(
|
514 |
-
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
|
515 |
-
)
|
516 |
-
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
517 |
-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
518 |
-
|
519 |
-
if prompt_embeds is not None and prompt_attention_mask is None:
|
520 |
-
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
521 |
-
|
522 |
-
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
|
523 |
-
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
|
524 |
-
|
525 |
-
if negative_prompt is not None and negative_prompt_embeds is not None:
|
526 |
-
raise ValueError(
|
527 |
-
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
528 |
-
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
529 |
-
)
|
530 |
-
|
531 |
-
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
532 |
-
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
533 |
-
|
534 |
-
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
|
535 |
-
raise ValueError(
|
536 |
-
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
|
537 |
-
)
|
538 |
-
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
539 |
-
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
540 |
-
raise ValueError(
|
541 |
-
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
542 |
-
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
543 |
-
f" {negative_prompt_embeds.shape}."
|
544 |
-
)
|
545 |
-
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
|
546 |
-
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
|
547 |
-
raise ValueError(
|
548 |
-
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
|
549 |
-
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
|
550 |
-
f" {negative_prompt_embeds_2.shape}."
|
551 |
-
)
|
552 |
-
|
553 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
554 |
-
def get_timesteps(self, num_inference_steps, strength, device):
|
555 |
-
# get the original timestep using init_timestep
|
556 |
-
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
557 |
-
|
558 |
-
t_start = max(num_inference_steps - init_timestep, 0)
|
559 |
-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
560 |
-
|
561 |
-
return timesteps, num_inference_steps - t_start
|
562 |
-
|
563 |
-
def prepare_mask_latents(
|
564 |
-
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
565 |
-
):
|
566 |
-
# resize the mask to latents shape as we concatenate the mask to the latents
|
567 |
-
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
568 |
-
# and half precision
|
569 |
-
if mask is not None:
|
570 |
-
mask = mask.to(device=device, dtype=self.vae.dtype)
|
571 |
-
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
572 |
-
bs = 1
|
573 |
-
new_mask = []
|
574 |
-
for i in range(0, mask.shape[0], bs):
|
575 |
-
mask_bs = mask[i : i + bs]
|
576 |
-
mask_bs = self.vae.encode(mask_bs)[0]
|
577 |
-
mask_bs = mask_bs.mode()
|
578 |
-
new_mask.append(mask_bs)
|
579 |
-
mask = torch.cat(new_mask, dim = 0)
|
580 |
-
mask = mask * self.vae.config.scaling_factor
|
581 |
-
|
582 |
-
else:
|
583 |
-
if mask.shape[1] == 4:
|
584 |
-
mask = mask
|
585 |
-
else:
|
586 |
-
video_length = mask.shape[2]
|
587 |
-
mask = rearrange(mask, "b c f h w -> (b f) c h w")
|
588 |
-
mask = self._encode_vae_image(mask, generator=generator)
|
589 |
-
mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
|
590 |
-
|
591 |
-
if masked_image is not None:
|
592 |
-
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
593 |
-
if self.transformer.config.add_noise_in_inpaint_model:
|
594 |
-
masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
|
595 |
-
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
596 |
-
bs = 1
|
597 |
-
new_mask_pixel_values = []
|
598 |
-
for i in range(0, masked_image.shape[0], bs):
|
599 |
-
mask_pixel_values_bs = masked_image[i : i + bs]
|
600 |
-
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
601 |
-
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
602 |
-
new_mask_pixel_values.append(mask_pixel_values_bs)
|
603 |
-
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
604 |
-
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
605 |
-
|
606 |
-
else:
|
607 |
-
if masked_image.shape[1] == 4:
|
608 |
-
masked_image_latents = masked_image
|
609 |
-
else:
|
610 |
-
video_length = masked_image.shape[2]
|
611 |
-
masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
|
612 |
-
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
|
613 |
-
masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
|
614 |
-
|
615 |
-
# aligning device to prevent device errors when concating it with the latent model input
|
616 |
-
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
617 |
-
else:
|
618 |
-
masked_image_latents = None
|
619 |
-
|
620 |
-
return mask, masked_image_latents
|
621 |
-
|
622 |
-
def prepare_latents(
|
623 |
-
self,
|
624 |
-
batch_size,
|
625 |
-
num_channels_latents,
|
626 |
-
height,
|
627 |
-
width,
|
628 |
-
video_length,
|
629 |
-
dtype,
|
630 |
-
device,
|
631 |
-
generator,
|
632 |
-
latents=None,
|
633 |
-
video=None,
|
634 |
-
timestep=None,
|
635 |
-
is_strength_max=True,
|
636 |
-
return_noise=False,
|
637 |
-
return_video_latents=False,
|
638 |
-
):
|
639 |
-
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
640 |
-
if self.vae.cache_mag_vae:
|
641 |
-
mini_batch_encoder = self.vae.mini_batch_encoder
|
642 |
-
mini_batch_decoder = self.vae.mini_batch_decoder
|
643 |
-
shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
644 |
-
else:
|
645 |
-
mini_batch_encoder = self.vae.mini_batch_encoder
|
646 |
-
mini_batch_decoder = self.vae.mini_batch_decoder
|
647 |
-
shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
648 |
-
else:
|
649 |
-
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
650 |
-
|
651 |
-
if isinstance(generator, list) and len(generator) != batch_size:
|
652 |
-
raise ValueError(
|
653 |
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
654 |
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
655 |
-
)
|
656 |
-
|
657 |
-
if return_video_latents or (latents is None and not is_strength_max):
|
658 |
-
video = video.to(device=device, dtype=self.vae.dtype)
|
659 |
-
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
660 |
-
bs = 1
|
661 |
-
new_video = []
|
662 |
-
for i in range(0, video.shape[0], bs):
|
663 |
-
video_bs = video[i : i + bs]
|
664 |
-
video_bs = self.vae.encode(video_bs)[0]
|
665 |
-
video_bs = video_bs.sample()
|
666 |
-
new_video.append(video_bs)
|
667 |
-
video = torch.cat(new_video, dim = 0)
|
668 |
-
video = video * self.vae.config.scaling_factor
|
669 |
-
|
670 |
-
else:
|
671 |
-
if video.shape[1] == 4:
|
672 |
-
video = video
|
673 |
-
else:
|
674 |
-
video_length = video.shape[2]
|
675 |
-
video = rearrange(video, "b c f h w -> (b f) c h w")
|
676 |
-
video = self._encode_vae_image(video, generator=generator)
|
677 |
-
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
678 |
-
video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
|
679 |
-
video_latents = video_latents.to(device=device, dtype=dtype)
|
680 |
-
|
681 |
-
if latents is None:
|
682 |
-
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
683 |
-
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
684 |
-
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
|
685 |
-
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
686 |
-
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
687 |
-
else:
|
688 |
-
noise = latents.to(device)
|
689 |
-
latents = noise * self.scheduler.init_noise_sigma
|
690 |
-
|
691 |
-
# scale the initial noise by the standard deviation required by the scheduler
|
692 |
-
outputs = (latents,)
|
693 |
-
|
694 |
-
if return_noise:
|
695 |
-
outputs += (noise,)
|
696 |
-
|
697 |
-
if return_video_latents:
|
698 |
-
outputs += (video_latents,)
|
699 |
-
|
700 |
-
return outputs
|
701 |
-
|
702 |
-
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
703 |
-
if video.size()[2] <= mini_batch_encoder:
|
704 |
-
return video
|
705 |
-
prefix_index_before = mini_batch_encoder // 2
|
706 |
-
prefix_index_after = mini_batch_encoder - prefix_index_before
|
707 |
-
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
708 |
-
|
709 |
-
# Encode middle videos
|
710 |
-
latents = self.vae.encode(pixel_values)[0]
|
711 |
-
latents = latents.mode()
|
712 |
-
# Decode middle videos
|
713 |
-
middle_video = self.vae.decode(latents)[0]
|
714 |
-
|
715 |
-
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
716 |
-
return video
|
717 |
-
|
718 |
-
def decode_latents(self, latents):
|
719 |
-
video_length = latents.shape[2]
|
720 |
-
latents = 1 / self.vae.config.scaling_factor * latents
|
721 |
-
if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
|
722 |
-
mini_batch_encoder = self.vae.mini_batch_encoder
|
723 |
-
mini_batch_decoder = self.vae.mini_batch_decoder
|
724 |
-
video = self.vae.decode(latents)[0]
|
725 |
-
video = video.clamp(-1, 1)
|
726 |
-
if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
|
727 |
-
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
728 |
-
else:
|
729 |
-
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
730 |
-
video = []
|
731 |
-
for frame_idx in tqdm(range(latents.shape[0])):
|
732 |
-
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
733 |
-
video = torch.cat(video)
|
734 |
-
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
735 |
-
video = (video / 2 + 0.5).clamp(0, 1)
|
736 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
737 |
-
video = video.cpu().float().numpy()
|
738 |
-
return video
|
739 |
-
|
740 |
-
@property
|
741 |
-
def guidance_scale(self):
|
742 |
-
return self._guidance_scale
|
743 |
-
|
744 |
-
@property
|
745 |
-
def guidance_rescale(self):
|
746 |
-
return self._guidance_rescale
|
747 |
-
|
748 |
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
749 |
-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
750 |
-
# corresponds to doing no classifier free guidance.
|
751 |
-
@property
|
752 |
-
def do_classifier_free_guidance(self):
|
753 |
-
return self._guidance_scale > 1
|
754 |
-
|
755 |
-
@property
|
756 |
-
def num_timesteps(self):
|
757 |
-
return self._num_timesteps
|
758 |
-
|
759 |
-
@property
|
760 |
-
def interrupt(self):
|
761 |
-
return self._interrupt
|
762 |
-
|
763 |
-
def enable_autocast_float8_transformer(self):
|
764 |
-
self.enable_autocast_float8_transformer_flag = True
|
765 |
-
|
766 |
-
@torch.no_grad()
|
767 |
-
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
768 |
-
def __call__(
|
769 |
-
self,
|
770 |
-
prompt: Union[str, List[str]] = None,
|
771 |
-
video_length: Optional[int] = None,
|
772 |
-
video: Union[torch.FloatTensor] = None,
|
773 |
-
mask_video: Union[torch.FloatTensor] = None,
|
774 |
-
masked_video_latents: Union[torch.FloatTensor] = None,
|
775 |
-
height: Optional[int] = None,
|
776 |
-
width: Optional[int] = None,
|
777 |
-
num_inference_steps: Optional[int] = 50,
|
778 |
-
guidance_scale: Optional[float] = 5.0,
|
779 |
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
780 |
-
num_images_per_prompt: Optional[int] = 1,
|
781 |
-
eta: Optional[float] = 0.0,
|
782 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
783 |
-
latents: Optional[torch.Tensor] = None,
|
784 |
-
prompt_embeds: Optional[torch.Tensor] = None,
|
785 |
-
prompt_embeds_2: Optional[torch.Tensor] = None,
|
786 |
-
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
787 |
-
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
|
788 |
-
prompt_attention_mask: Optional[torch.Tensor] = None,
|
789 |
-
prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
790 |
-
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
791 |
-
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
|
792 |
-
output_type: Optional[str] = "latent",
|
793 |
-
return_dict: bool = True,
|
794 |
-
callback_on_step_end: Optional[
|
795 |
-
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
796 |
-
] = None,
|
797 |
-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
798 |
-
guidance_rescale: float = 0.0,
|
799 |
-
original_size: Optional[Tuple[int, int]] = (1024, 1024),
|
800 |
-
target_size: Optional[Tuple[int, int]] = None,
|
801 |
-
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
802 |
-
clip_image: Image = None,
|
803 |
-
clip_apply_ratio: float = 0.40,
|
804 |
-
strength: float = 1.0,
|
805 |
-
noise_aug_strength: float = 0.0563,
|
806 |
-
comfyui_progressbar: bool = False,
|
807 |
-
):
|
808 |
-
r"""
|
809 |
-
The call function to the pipeline for generation with HunyuanDiT.
|
810 |
-
|
811 |
-
Examples:
|
812 |
-
prompt (`str` or `List[str]`, *optional*):
|
813 |
-
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
814 |
-
video_length (`int`, *optional*):
|
815 |
-
Length of the video to be generated in seconds. This parameter influences the number of frames and
|
816 |
-
continuity of generated content.
|
817 |
-
video (`torch.FloatTensor`, *optional*):
|
818 |
-
A tensor representing an input video, which can be modified depending on the prompts provided.
|
819 |
-
mask_video (`torch.FloatTensor`, *optional*):
|
820 |
-
A tensor to specify areas of the video to be masked (omitted from generation).
|
821 |
-
masked_video_latents (`torch.FloatTensor`, *optional*):
|
822 |
-
Latents from masked portions of the video, utilized during image generation.
|
823 |
-
height (`int`, *optional*):
|
824 |
-
The height in pixels of the generated image or video frames.
|
825 |
-
width (`int`, *optional*):
|
826 |
-
The width in pixels of the generated image or video frames.
|
827 |
-
num_inference_steps (`int`, *optional*, defaults to 50):
|
828 |
-
The number of denoising steps. More denoising steps usually lead to a higher quality image but slower
|
829 |
-
inference time. This parameter is modulated by `strength`.
|
830 |
-
guidance_scale (`float`, *optional*, defaults to 5.0):
|
831 |
-
A higher guidance scale value encourages the model to generate images closely linked to the text
|
832 |
-
`prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`.
|
833 |
-
negative_prompt (`str` or `List[str]`, *optional*):
|
834 |
-
The prompt or prompts to guide what to exclude in image generation. If not defined, you need to
|
835 |
-
provide `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`).
|
836 |
-
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
837 |
-
The number of images to generate per prompt.
|
838 |
-
eta (`float`, *optional*, defaults to 0.0):
|
839 |
-
A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the
|
840 |
-
[`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the
|
841 |
-
inference process.
|
842 |
-
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
843 |
-
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting
|
844 |
-
random seeds which helps in making generation deterministic.
|
845 |
-
latents (`torch.Tensor`, *optional*):
|
846 |
-
A pre-computed latent representation which can be used to guide the generation process.
|
847 |
-
prompt_embeds (`torch.Tensor`, *optional*):
|
848 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
849 |
-
provided, embeddings are generated from the `prompt` input argument.
|
850 |
-
prompt_embeds_2 (`torch.Tensor`, *optional*):
|
851 |
-
Secondary set of pre-generated text embeddings, useful for advanced prompt weighting.
|
852 |
-
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
853 |
-
Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the outputs.
|
854 |
-
If not provided, embeddings are generated from the `negative_prompt` argument.
|
855 |
-
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
|
856 |
-
Secondary set of pre-generated negative text embeddings for further control.
|
857 |
-
prompt_attention_mask (`torch.Tensor`, *optional*):
|
858 |
-
Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using
|
859 |
-
`prompt_embeds`.
|
860 |
-
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
861 |
-
Attention mask for the secondary prompt embedding.
|
862 |
-
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
863 |
-
Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used.
|
864 |
-
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
|
865 |
-
Attention mask for the secondary negative prompt embedding.
|
866 |
-
output_type (`str`, *optional*, defaults to `"latent"`):
|
867 |
-
The output format of the generated image. Choose between `PIL.Image` and `np.array` to define
|
868 |
-
how you want the results to be formatted.
|
869 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
870 |
-
If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned;
|
871 |
-
otherwise, a tuple containing the generated images and safety flags will be returned.
|
872 |
-
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
873 |
-
A callback function (or a list of them) that will be executed at the end of each denoising step,
|
874 |
-
allowing for custom processing during generation.
|
875 |
-
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
876 |
-
Specifies which tensor inputs should be included in the callback function. If not defined, all tensor
|
877 |
-
inputs will be passed, facilitating enhanced logging or monitoring of the generation process.
|
878 |
-
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
879 |
-
Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from
|
880 |
-
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
881 |
-
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
|
882 |
-
The original dimensions of the image. Used to compute time ids during the generation process.
|
883 |
-
target_size (`Tuple[int, int]`, *optional*):
|
884 |
-
The targeted dimensions of the generated image, also utilized in the time id calculations.
|
885 |
-
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
|
886 |
-
Coordinates defining the top left corner of any cropping, utilized while calculating the time ids.
|
887 |
-
clip_image (`Image`, *optional*):
|
888 |
-
An optional image to assist in the generation process. It may be used as an additional visual cue.
|
889 |
-
clip_apply_ratio (`float`, *optional*, defaults to 0.40):
|
890 |
-
Ratio indicating how much influence the clip image should exert over the generated content.
|
891 |
-
strength (`float`, *optional*, defaults to 1.0):
|
892 |
-
Affects the overall styling or quality of the generated output. Values closer to 1 usually provide direct
|
893 |
-
adherence to prompts.
|
894 |
-
comfyui_progressbar (`bool`, *optional*, defaults to `False`):
|
895 |
-
Enables a progress bar in ComfyUI, providing visual feedback during the generation process.
|
896 |
-
|
897 |
-
Examples:
|
898 |
-
# Example usage of the function for generating images based on prompts.
|
899 |
-
|
900 |
-
Returns:
|
901 |
-
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
902 |
-
Returns either a structured output containing generated images and their metadata when `return_dict` is
|
903 |
-
`True`, or a simpler tuple, where the first element is a list of generated images and the second
|
904 |
-
element indicates if any of them contain "not-safe-for-work" (NSFW) content.
|
905 |
-
"""
|
906 |
-
|
907 |
-
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
908 |
-
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
909 |
-
|
910 |
-
# 0. default height and width
|
911 |
-
height = int(height // 16 * 16)
|
912 |
-
width = int(width // 16 * 16)
|
913 |
-
|
914 |
-
# 1. Check inputs. Raise error if not correct
|
915 |
-
self.check_inputs(
|
916 |
-
prompt,
|
917 |
-
height,
|
918 |
-
width,
|
919 |
-
negative_prompt,
|
920 |
-
prompt_embeds,
|
921 |
-
negative_prompt_embeds,
|
922 |
-
prompt_attention_mask,
|
923 |
-
negative_prompt_attention_mask,
|
924 |
-
prompt_embeds_2,
|
925 |
-
negative_prompt_embeds_2,
|
926 |
-
prompt_attention_mask_2,
|
927 |
-
negative_prompt_attention_mask_2,
|
928 |
-
callback_on_step_end_tensor_inputs,
|
929 |
-
)
|
930 |
-
self._guidance_scale = guidance_scale
|
931 |
-
self._guidance_rescale = guidance_rescale
|
932 |
-
self._interrupt = False
|
933 |
-
|
934 |
-
# 2. Define call parameters
|
935 |
-
if prompt is not None and isinstance(prompt, str):
|
936 |
-
batch_size = 1
|
937 |
-
elif prompt is not None and isinstance(prompt, list):
|
938 |
-
batch_size = len(prompt)
|
939 |
-
else:
|
940 |
-
batch_size = prompt_embeds.shape[0]
|
941 |
-
|
942 |
-
device = self._execution_device
|
943 |
-
|
944 |
-
# 3. Encode input prompt
|
945 |
-
(
|
946 |
-
prompt_embeds,
|
947 |
-
negative_prompt_embeds,
|
948 |
-
prompt_attention_mask,
|
949 |
-
negative_prompt_attention_mask,
|
950 |
-
) = self.encode_prompt(
|
951 |
-
prompt=prompt,
|
952 |
-
device=device,
|
953 |
-
dtype=self.transformer.dtype,
|
954 |
-
num_images_per_prompt=num_images_per_prompt,
|
955 |
-
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
956 |
-
negative_prompt=negative_prompt,
|
957 |
-
prompt_embeds=prompt_embeds,
|
958 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
959 |
-
prompt_attention_mask=prompt_attention_mask,
|
960 |
-
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
961 |
-
text_encoder_index=0,
|
962 |
-
)
|
963 |
-
(
|
964 |
-
prompt_embeds_2,
|
965 |
-
negative_prompt_embeds_2,
|
966 |
-
prompt_attention_mask_2,
|
967 |
-
negative_prompt_attention_mask_2,
|
968 |
-
) = self.encode_prompt(
|
969 |
-
prompt=prompt,
|
970 |
-
device=device,
|
971 |
-
dtype=self.transformer.dtype,
|
972 |
-
num_images_per_prompt=num_images_per_prompt,
|
973 |
-
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
974 |
-
negative_prompt=negative_prompt,
|
975 |
-
prompt_embeds=prompt_embeds_2,
|
976 |
-
negative_prompt_embeds=negative_prompt_embeds_2,
|
977 |
-
prompt_attention_mask=prompt_attention_mask_2,
|
978 |
-
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
|
979 |
-
text_encoder_index=1,
|
980 |
-
)
|
981 |
-
torch.cuda.empty_cache()
|
982 |
-
|
983 |
-
# 4. set timesteps
|
984 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
985 |
-
timesteps, num_inference_steps = self.get_timesteps(
|
986 |
-
num_inference_steps=num_inference_steps, strength=strength, device=device
|
987 |
-
)
|
988 |
-
if comfyui_progressbar:
|
989 |
-
from comfy.utils import ProgressBar
|
990 |
-
pbar = ProgressBar(num_inference_steps + 3)
|
991 |
-
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
992 |
-
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
993 |
-
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
994 |
-
is_strength_max = strength == 1.0
|
995 |
-
|
996 |
-
if video is not None:
|
997 |
-
video_length = video.shape[2]
|
998 |
-
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
999 |
-
init_video = init_video.to(dtype=torch.float32)
|
1000 |
-
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
1001 |
-
else:
|
1002 |
-
init_video = None
|
1003 |
-
|
1004 |
-
# Prepare latent variables
|
1005 |
-
num_channels_latents = self.vae.config.latent_channels
|
1006 |
-
num_channels_transformer = self.transformer.config.in_channels
|
1007 |
-
return_image_latents = num_channels_transformer == num_channels_latents
|
1008 |
-
|
1009 |
-
# 5. Prepare latents.
|
1010 |
-
latents_outputs = self.prepare_latents(
|
1011 |
-
batch_size * num_images_per_prompt,
|
1012 |
-
num_channels_latents,
|
1013 |
-
height,
|
1014 |
-
width,
|
1015 |
-
video_length,
|
1016 |
-
prompt_embeds.dtype,
|
1017 |
-
device,
|
1018 |
-
generator,
|
1019 |
-
latents,
|
1020 |
-
video=init_video,
|
1021 |
-
timestep=latent_timestep,
|
1022 |
-
is_strength_max=is_strength_max,
|
1023 |
-
return_noise=True,
|
1024 |
-
return_video_latents=return_image_latents,
|
1025 |
-
)
|
1026 |
-
if return_image_latents:
|
1027 |
-
latents, noise, image_latents = latents_outputs
|
1028 |
-
else:
|
1029 |
-
latents, noise = latents_outputs
|
1030 |
-
|
1031 |
-
if comfyui_progressbar:
|
1032 |
-
pbar.update(1)
|
1033 |
-
|
1034 |
-
# 6. Prepare clip latents if it needs.
|
1035 |
-
if clip_image is not None and self.transformer.enable_clip_in_inpaint:
|
1036 |
-
inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
|
1037 |
-
inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
|
1038 |
-
clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:]
|
1039 |
-
clip_encoder_hidden_states_neg = torch.zeros(
|
1040 |
-
[
|
1041 |
-
batch_size,
|
1042 |
-
int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
|
1043 |
-
int(self.clip_image_encoder.config.hidden_size)
|
1044 |
-
]
|
1045 |
-
).to(latents.device, dtype=latents.dtype)
|
1046 |
-
|
1047 |
-
clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
|
1048 |
-
clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
|
1049 |
-
|
1050 |
-
clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states
|
1051 |
-
clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask
|
1052 |
-
|
1053 |
-
elif clip_image is None and num_channels_transformer != num_channels_latents and self.transformer.enable_clip_in_inpaint:
|
1054 |
-
clip_encoder_hidden_states = torch.zeros(
|
1055 |
-
[
|
1056 |
-
batch_size,
|
1057 |
-
int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
|
1058 |
-
int(self.clip_image_encoder.config.hidden_size)
|
1059 |
-
]
|
1060 |
-
).to(latents.device, dtype=latents.dtype)
|
1061 |
-
|
1062 |
-
clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query])
|
1063 |
-
clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
|
1064 |
-
|
1065 |
-
clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states
|
1066 |
-
clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask
|
1067 |
-
|
1068 |
-
else:
|
1069 |
-
clip_encoder_hidden_states_input = None
|
1070 |
-
clip_attention_mask_input = None
|
1071 |
-
if comfyui_progressbar:
|
1072 |
-
pbar.update(1)
|
1073 |
-
|
1074 |
-
# 7. Prepare inpaint latents if it needs.
|
1075 |
-
if mask_video is not None:
|
1076 |
-
if (mask_video == 255).all():
|
1077 |
-
# Use zero latents if we want to t2v.
|
1078 |
-
if self.transformer.resize_inpaint_mask_directly:
|
1079 |
-
mask_latents = torch.zeros_like(latents)[:, :1].to(latents.device, latents.dtype)
|
1080 |
-
else:
|
1081 |
-
mask_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
1082 |
-
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
1083 |
-
|
1084 |
-
mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
|
1085 |
-
masked_video_latents_input = (
|
1086 |
-
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
|
1087 |
-
)
|
1088 |
-
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
1089 |
-
else:
|
1090 |
-
# Prepare mask latent variables
|
1091 |
-
video_length = video.shape[2]
|
1092 |
-
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
1093 |
-
mask_condition = mask_condition.to(dtype=torch.float32)
|
1094 |
-
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
1095 |
-
|
1096 |
-
if num_channels_transformer != num_channels_latents:
|
1097 |
-
mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
|
1098 |
-
if masked_video_latents is None:
|
1099 |
-
masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
|
1100 |
-
else:
|
1101 |
-
masked_video = masked_video_latents
|
1102 |
-
|
1103 |
-
if self.transformer.resize_inpaint_mask_directly:
|
1104 |
-
_, masked_video_latents = self.prepare_mask_latents(
|
1105 |
-
None,
|
1106 |
-
masked_video,
|
1107 |
-
batch_size,
|
1108 |
-
height,
|
1109 |
-
width,
|
1110 |
-
prompt_embeds.dtype,
|
1111 |
-
device,
|
1112 |
-
generator,
|
1113 |
-
self.do_classifier_free_guidance,
|
1114 |
-
noise_aug_strength=noise_aug_strength,
|
1115 |
-
)
|
1116 |
-
mask_latents = resize_mask(1 - mask_condition, masked_video_latents, self.vae.cache_mag_vae)
|
1117 |
-
mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
|
1118 |
-
else:
|
1119 |
-
mask_latents, masked_video_latents = self.prepare_mask_latents(
|
1120 |
-
mask_condition_tile,
|
1121 |
-
masked_video,
|
1122 |
-
batch_size,
|
1123 |
-
height,
|
1124 |
-
width,
|
1125 |
-
prompt_embeds.dtype,
|
1126 |
-
device,
|
1127 |
-
generator,
|
1128 |
-
self.do_classifier_free_guidance,
|
1129 |
-
noise_aug_strength=noise_aug_strength,
|
1130 |
-
)
|
1131 |
-
|
1132 |
-
mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
|
1133 |
-
masked_video_latents_input = (
|
1134 |
-
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
|
1135 |
-
)
|
1136 |
-
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
1137 |
-
else:
|
1138 |
-
inpaint_latents = None
|
1139 |
-
|
1140 |
-
mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
|
1141 |
-
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
1142 |
-
else:
|
1143 |
-
if num_channels_transformer != num_channels_latents:
|
1144 |
-
mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
1145 |
-
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
1146 |
-
|
1147 |
-
mask_input = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask
|
1148 |
-
masked_video_latents_input = (
|
1149 |
-
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
|
1150 |
-
)
|
1151 |
-
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
1152 |
-
else:
|
1153 |
-
mask = torch.zeros_like(init_video[:, :1])
|
1154 |
-
mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
|
1155 |
-
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
1156 |
-
|
1157 |
-
inpaint_latents = None
|
1158 |
-
if comfyui_progressbar:
|
1159 |
-
pbar.update(1)
|
1160 |
-
|
1161 |
-
# Check that sizes of mask, masked image and latents match
|
1162 |
-
if num_channels_transformer != num_channels_latents:
|
1163 |
-
num_channels_mask = mask_latents.shape[1]
|
1164 |
-
num_channels_masked_image = masked_video_latents.shape[1]
|
1165 |
-
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
|
1166 |
-
raise ValueError(
|
1167 |
-
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
|
1168 |
-
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
1169 |
-
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
1170 |
-
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
1171 |
-
" `pipeline.transformer` or your `mask_image` or `image` input."
|
1172 |
-
)
|
1173 |
-
|
1174 |
-
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1175 |
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1176 |
-
|
1177 |
-
# 9 create image_rotary_emb, style embedding & time ids
|
1178 |
-
grid_height = height // 8 // self.transformer.config.patch_size
|
1179 |
-
grid_width = width // 8 // self.transformer.config.patch_size
|
1180 |
-
if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
|
1181 |
-
base_size_width = 720 // 8 // self.transformer.config.patch_size
|
1182 |
-
base_size_height = 480 // 8 // self.transformer.config.patch_size
|
1183 |
-
|
1184 |
-
grid_crops_coords = get_resize_crop_region_for_grid(
|
1185 |
-
(grid_height, grid_width), base_size_width, base_size_height
|
1186 |
-
)
|
1187 |
-
image_rotary_emb = get_3d_rotary_pos_embed(
|
1188 |
-
self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
|
1189 |
-
temporal_size=latents.size(2), use_real=True,
|
1190 |
-
)
|
1191 |
-
else:
|
1192 |
-
base_size = 512 // 8 // self.transformer.config.patch_size
|
1193 |
-
grid_crops_coords = get_resize_crop_region_for_grid(
|
1194 |
-
(grid_height, grid_width), base_size, base_size
|
1195 |
-
)
|
1196 |
-
image_rotary_emb = get_2d_rotary_pos_embed(
|
1197 |
-
self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
|
1198 |
-
)
|
1199 |
-
|
1200 |
-
# Get other hunyuan params
|
1201 |
-
style = torch.tensor([0], device=device)
|
1202 |
-
|
1203 |
-
target_size = target_size or (height, width)
|
1204 |
-
add_time_ids = list(original_size + target_size + crops_coords_top_left)
|
1205 |
-
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
1206 |
-
|
1207 |
-
if self.do_classifier_free_guidance:
|
1208 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
1209 |
-
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
1210 |
-
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
1211 |
-
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
|
1212 |
-
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
|
1213 |
-
style = torch.cat([style] * 2, dim=0)
|
1214 |
-
|
1215 |
-
prompt_embeds = prompt_embeds.to(device=device)
|
1216 |
-
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
1217 |
-
prompt_embeds_2 = prompt_embeds_2.to(device=device)
|
1218 |
-
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
|
1219 |
-
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
|
1220 |
-
batch_size * num_images_per_prompt, 1
|
1221 |
-
)
|
1222 |
-
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
|
1223 |
-
|
1224 |
-
torch.cuda.empty_cache()
|
1225 |
-
if self.enable_autocast_float8_transformer_flag:
|
1226 |
-
origin_weight_dtype = self.transformer.dtype
|
1227 |
-
self.transformer = self.transformer.to(torch.float8_e4m3fn)
|
1228 |
-
# 10. Denoising loop
|
1229 |
-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1230 |
-
self._num_timesteps = len(timesteps)
|
1231 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1232 |
-
for i, t in enumerate(timesteps):
|
1233 |
-
if self.interrupt:
|
1234 |
-
continue
|
1235 |
-
|
1236 |
-
# expand the latents if we are doing classifier free guidance
|
1237 |
-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1238 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1239 |
-
|
1240 |
-
if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
|
1241 |
-
clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
|
1242 |
-
clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input)
|
1243 |
-
else:
|
1244 |
-
clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
|
1245 |
-
clip_attention_mask_actual_input = clip_attention_mask_input
|
1246 |
-
|
1247 |
-
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
1248 |
-
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
1249 |
-
dtype=latent_model_input.dtype
|
1250 |
-
)
|
1251 |
-
|
1252 |
-
# predict the noise residual
|
1253 |
-
noise_pred = self.transformer(
|
1254 |
-
latent_model_input,
|
1255 |
-
t_expand,
|
1256 |
-
encoder_hidden_states=prompt_embeds,
|
1257 |
-
text_embedding_mask=prompt_attention_mask,
|
1258 |
-
encoder_hidden_states_t5=prompt_embeds_2,
|
1259 |
-
text_embedding_mask_t5=prompt_attention_mask_2,
|
1260 |
-
image_meta_size=add_time_ids,
|
1261 |
-
style=style,
|
1262 |
-
image_rotary_emb=image_rotary_emb,
|
1263 |
-
inpaint_latents=inpaint_latents,
|
1264 |
-
clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
|
1265 |
-
clip_attention_mask=clip_attention_mask_actual_input,
|
1266 |
-
return_dict=False,
|
1267 |
-
)[0]
|
1268 |
-
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
1269 |
-
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
1270 |
-
|
1271 |
-
# perform guidance
|
1272 |
-
if self.do_classifier_free_guidance:
|
1273 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1274 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1275 |
-
|
1276 |
-
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
1277 |
-
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1278 |
-
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
1279 |
-
|
1280 |
-
# compute the previous noisy sample x_t -> x_t-1
|
1281 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1282 |
-
|
1283 |
-
if num_channels_transformer == 4:
|
1284 |
-
init_latents_proper = image_latents
|
1285 |
-
init_mask = mask
|
1286 |
-
if i < len(timesteps) - 1:
|
1287 |
-
noise_timestep = timesteps[i + 1]
|
1288 |
-
init_latents_proper = self.scheduler.add_noise(
|
1289 |
-
init_latents_proper, noise, torch.tensor([noise_timestep])
|
1290 |
-
)
|
1291 |
-
|
1292 |
-
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
1293 |
-
|
1294 |
-
if callback_on_step_end is not None:
|
1295 |
-
callback_kwargs = {}
|
1296 |
-
for k in callback_on_step_end_tensor_inputs:
|
1297 |
-
callback_kwargs[k] = locals()[k]
|
1298 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1299 |
-
|
1300 |
-
latents = callback_outputs.pop("latents", latents)
|
1301 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1302 |
-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1303 |
-
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
|
1304 |
-
negative_prompt_embeds_2 = callback_outputs.pop(
|
1305 |
-
"negative_prompt_embeds_2", negative_prompt_embeds_2
|
1306 |
-
)
|
1307 |
-
|
1308 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1309 |
-
progress_bar.update()
|
1310 |
-
|
1311 |
-
if XLA_AVAILABLE:
|
1312 |
-
xm.mark_step()
|
1313 |
-
|
1314 |
-
if comfyui_progressbar:
|
1315 |
-
pbar.update(1)
|
1316 |
-
|
1317 |
-
if self.enable_autocast_float8_transformer_flag:
|
1318 |
-
self.transformer = self.transformer.to("cpu", origin_weight_dtype)
|
1319 |
-
|
1320 |
-
torch.cuda.empty_cache()
|
1321 |
-
# Post-processing
|
1322 |
-
video = self.decode_latents(latents)
|
1323 |
-
|
1324 |
-
# Convert to tensor
|
1325 |
-
if output_type == "latent":
|
1326 |
-
video = torch.from_numpy(video)
|
1327 |
-
|
1328 |
-
# Offload all models
|
1329 |
-
self.maybe_free_model_hooks()
|
1330 |
-
|
1331 |
-
if not return_dict:
|
1332 |
-
return video
|
1333 |
-
|
1334 |
-
return EasyAnimatePipelineOutput(videos=video)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
easyanimate/ui/ui.py
CHANGED
@@ -17,41 +17,42 @@ import torch
|
|
17 |
from diffusers import (AutoencoderKL, DDIMScheduler,
|
18 |
DPMSolverMultistepScheduler,
|
19 |
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
20 |
-
PNDMScheduler)
|
21 |
from diffusers.utils.import_utils import is_xformers_available
|
22 |
from omegaconf import OmegaConf
|
23 |
from PIL import Image
|
24 |
from safetensors import safe_open
|
25 |
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
26 |
-
CLIPVisionModelWithProjection,
|
27 |
-
T5EncoderModel,
|
|
|
28 |
|
29 |
-
from
|
30 |
-
from
|
31 |
name_to_transformer3d)
|
32 |
-
from
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
from
|
37 |
EasyAnimateInpaintPipeline
|
38 |
-
from
|
39 |
-
|
40 |
-
from
|
41 |
-
EasyAnimatePipeline_Multi_Text_Encoder_Inpaint
|
42 |
-
from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
|
43 |
-
from easyanimate.utils.utils import (
|
44 |
get_image_to_video_latent, get_video_to_video_latent,
|
45 |
get_width_and_height_from_image_and_base_resolution, save_videos_grid)
|
46 |
-
from easyanimate.utils.fp8_optimization import convert_weight_dtype_wrapper
|
47 |
|
48 |
-
|
49 |
"Euler": EulerDiscreteScheduler,
|
50 |
"Euler A": EulerAncestralDiscreteScheduler,
|
51 |
"DPM++": DPMSolverMultistepScheduler,
|
52 |
"PNDM": PNDMScheduler,
|
53 |
"DDIM": DDIMScheduler,
|
54 |
}
|
|
|
|
|
|
|
|
|
55 |
|
56 |
gradio_version = pkg_resources.get_distribution("gradio").version
|
57 |
gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
|
@@ -98,8 +99,8 @@ class EasyAnimateController:
|
|
98 |
self.GPU_memory_mode = GPU_memory_mode
|
99 |
|
100 |
self.weight_dtype = weight_dtype
|
101 |
-
self.edition = "v5"
|
102 |
-
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "
|
103 |
|
104 |
def refresh_diffusion_transformer(self):
|
105 |
self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
|
@@ -121,26 +122,37 @@ class EasyAnimateController:
|
|
121 |
if edition == "v1":
|
122 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v1_motion_module.yaml"))
|
123 |
return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
|
|
|
124 |
gr.update(value=512, minimum=384, maximum=704, step=32), \
|
125 |
gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
|
126 |
elif edition == "v2":
|
127 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v2_magvit_motion_module.yaml"))
|
128 |
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
|
|
129 |
gr.update(value=672, minimum=128, maximum=1344, step=16), \
|
130 |
gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
|
131 |
elif edition == "v3":
|
132 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v3_slicevae_motion_module.yaml"))
|
133 |
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
|
|
134 |
gr.update(value=672, minimum=128, maximum=1344, step=16), \
|
135 |
gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
|
136 |
elif edition == "v4":
|
137 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v4_slicevae_multi_text_encoder.yaml"))
|
138 |
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
|
|
139 |
gr.update(value=672, minimum=128, maximum=1344, step=16), \
|
140 |
gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
|
141 |
elif edition == "v5":
|
142 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml"))
|
143 |
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
gr.update(value=672, minimum=128, maximum=1344, step=16), \
|
145 |
gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4)
|
146 |
|
@@ -170,33 +182,55 @@ class EasyAnimateController:
|
|
170 |
self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
|
171 |
diffusion_transformer_dropdown,
|
172 |
subfolder="transformer",
|
173 |
-
transformer_additional_kwargs=transformer_additional_kwargs
|
174 |
-
|
|
|
|
|
175 |
|
176 |
if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
|
177 |
tokenizer = BertTokenizer.from_pretrained(
|
178 |
diffusion_transformer_dropdown, subfolder="tokenizer"
|
179 |
)
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
183 |
else:
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
187 |
tokenizer_2 = None
|
188 |
|
189 |
if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
|
190 |
text_encoder = BertModel.from_pretrained(
|
191 |
diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
|
192 |
)
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
text_encoder_2 = None
|
201 |
|
202 |
# Get pipeline
|
@@ -212,23 +246,18 @@ class EasyAnimateController:
|
|
212 |
clip_image_processor = None
|
213 |
|
214 |
# Get Scheduler
|
215 |
-
|
216 |
-
"
|
217 |
-
|
218 |
-
"
|
219 |
-
"PNDM": PNDMScheduler,
|
220 |
-
"DDIM": DDIMScheduler,
|
221 |
-
}["Euler"]
|
222 |
-
|
223 |
scheduler = Choosen_Scheduler.from_pretrained(
|
224 |
diffusion_transformer_dropdown,
|
225 |
subfolder="scheduler"
|
226 |
)
|
227 |
|
228 |
-
if self.
|
229 |
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
230 |
-
self.pipeline =
|
231 |
-
diffusion_transformer_dropdown,
|
232 |
text_encoder=text_encoder,
|
233 |
text_encoder_2=text_encoder_2,
|
234 |
tokenizer=tokenizer,
|
@@ -236,13 +265,11 @@ class EasyAnimateController:
|
|
236 |
vae=self.vae,
|
237 |
transformer=self.transformer,
|
238 |
scheduler=scheduler,
|
239 |
-
torch_dtype=self.weight_dtype,
|
240 |
clip_image_encoder=clip_image_encoder,
|
241 |
clip_image_processor=clip_image_processor,
|
242 |
-
)
|
243 |
else:
|
244 |
-
self.pipeline =
|
245 |
-
diffusion_transformer_dropdown,
|
246 |
text_encoder=text_encoder,
|
247 |
text_encoder_2=text_encoder_2,
|
248 |
tokenizer=tokenizer,
|
@@ -250,40 +277,25 @@ class EasyAnimateController:
|
|
250 |
vae=self.vae,
|
251 |
transformer=self.transformer,
|
252 |
scheduler=scheduler,
|
253 |
-
|
254 |
-
)
|
255 |
else:
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
clip_image_encoder=clip_image_encoder,
|
266 |
-
clip_image_processor=clip_image_processor,
|
267 |
-
)
|
268 |
-
else:
|
269 |
-
self.pipeline = EasyAnimatePipeline(
|
270 |
-
diffusion_transformer_dropdown,
|
271 |
-
text_encoder=text_encoder,
|
272 |
-
tokenizer=tokenizer,
|
273 |
-
vae=self.vae,
|
274 |
-
transformer=self.transformer,
|
275 |
-
scheduler=scheduler,
|
276 |
-
torch_dtype=self.weight_dtype
|
277 |
-
)
|
278 |
|
279 |
if self.GPU_memory_mode == "sequential_cpu_offload":
|
280 |
self.pipeline.enable_sequential_cpu_offload()
|
281 |
elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
282 |
self.pipeline.enable_model_cpu_offload()
|
283 |
-
self.pipeline.enable_autocast_float8_transformer()
|
284 |
convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
|
285 |
else:
|
286 |
-
self.
|
287 |
print("Update diffusion transformer done")
|
288 |
return gr.update()
|
289 |
|
@@ -374,8 +386,10 @@ class EasyAnimateController:
|
|
374 |
if self.base_model_path != base_model_dropdown:
|
375 |
self.update_base_model(base_model_dropdown)
|
376 |
|
|
|
|
|
|
|
377 |
if self.lora_model_path != lora_model_dropdown:
|
378 |
-
print("Update lora model")
|
379 |
self.update_lora_model(lora_model_dropdown)
|
380 |
|
381 |
if control_video is not None and self.model_type == "Inpaint":
|
@@ -426,19 +440,21 @@ class EasyAnimateController:
|
|
426 |
else:
|
427 |
raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
|
428 |
|
429 |
-
fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition]
|
430 |
is_image = True if generation_method == "Image Generation" else False
|
431 |
|
432 |
-
if
|
|
|
|
|
433 |
|
434 |
-
|
|
|
|
|
|
|
|
|
435 |
if self.lora_model_path != "none":
|
436 |
# lora part
|
437 |
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
438 |
-
|
439 |
-
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
440 |
-
else: seed_textbox = np.random.randint(0, 1e10)
|
441 |
-
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
442 |
|
443 |
try:
|
444 |
if self.model_type == "Inpaint":
|
@@ -480,7 +496,7 @@ class EasyAnimateController:
|
|
480 |
video = input_video,
|
481 |
mask_video = input_video_mask,
|
482 |
strength = 1,
|
483 |
-
).
|
484 |
|
485 |
if init_frames != 0:
|
486 |
mix_ratio = torch.from_numpy(
|
@@ -531,7 +547,7 @@ class EasyAnimateController:
|
|
531 |
video = input_video,
|
532 |
mask_video = input_video_mask,
|
533 |
strength = strength,
|
534 |
-
).
|
535 |
else:
|
536 |
if self.vae.cache_mag_vae:
|
537 |
length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
|
@@ -547,7 +563,7 @@ class EasyAnimateController:
|
|
547 |
height = height_slider,
|
548 |
video_length = length_slider if not is_image else 1,
|
549 |
generator = generator
|
550 |
-
).
|
551 |
else:
|
552 |
if self.vae.cache_mag_vae:
|
553 |
length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
|
@@ -566,7 +582,7 @@ class EasyAnimateController:
|
|
566 |
generator = generator,
|
567 |
|
568 |
control_video = input_video,
|
569 |
-
).
|
570 |
except Exception as e:
|
571 |
gc.collect()
|
572 |
torch.cuda.empty_cache()
|
@@ -676,8 +692,8 @@ def ui(GPU_memory_mode, weight_dtype):
|
|
676 |
with gr.Row():
|
677 |
easyanimate_edition_dropdown = gr.Dropdown(
|
678 |
label="The config of EasyAnimate Edition (EasyAnimate版本配置)",
|
679 |
-
choices=["v1", "v2", "v3", "v4", "v5"],
|
680 |
-
value="v5",
|
681 |
interactive=True,
|
682 |
)
|
683 |
gr.Markdown(
|
@@ -751,13 +767,22 @@ def ui(GPU_memory_mode, weight_dtype):
|
|
751 |
"""
|
752 |
)
|
753 |
|
754 |
-
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful
|
755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
756 |
|
757 |
with gr.Row():
|
758 |
with gr.Column():
|
759 |
with gr.Row():
|
760 |
-
sampler_dropdown = gr.Dropdown(
|
|
|
|
|
|
|
761 |
sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1)
|
762 |
|
763 |
resize_method = gr.Radio(
|
@@ -794,11 +819,11 @@ def ui(GPU_memory_mode, weight_dtype):
|
|
794 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
795 |
def select_template(evt: gr.SelectData):
|
796 |
text = {
|
797 |
-
"asset/1.png": "
|
798 |
-
"asset/2.png": "
|
799 |
-
"asset/3.png": "
|
800 |
-
"asset/4.png": "
|
801 |
-
"asset/5.png": "
|
802 |
}[template_gallery_path[evt.index]]
|
803 |
return template_gallery_path[evt.index], text
|
804 |
|
@@ -838,6 +863,7 @@ def ui(GPU_memory_mode, weight_dtype):
|
|
838 |
gr.Markdown(
|
839 |
"""
|
840 |
Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
|
|
|
841 |
"""
|
842 |
)
|
843 |
control_video = gr.Video(
|
@@ -927,6 +953,7 @@ def ui(GPU_memory_mode, weight_dtype):
|
|
927 |
diffusion_transformer_dropdown,
|
928 |
motion_module_dropdown,
|
929 |
motion_module_refresh_button,
|
|
|
930 |
width_slider,
|
931 |
height_slider,
|
932 |
length_slider,
|
@@ -1003,33 +1030,55 @@ class EasyAnimateController_Modelscope:
|
|
1003 |
self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
|
1004 |
model_name,
|
1005 |
subfolder="transformer",
|
1006 |
-
transformer_additional_kwargs=transformer_additional_kwargs
|
1007 |
-
|
|
|
|
|
1008 |
|
1009 |
if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
|
1010 |
tokenizer = BertTokenizer.from_pretrained(
|
1011 |
model_name, subfolder="tokenizer"
|
1012 |
)
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
|
|
|
|
|
|
|
|
|
|
1016 |
else:
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
|
|
|
|
|
|
|
|
|
|
1020 |
tokenizer_2 = None
|
1021 |
|
1022 |
if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
|
1023 |
text_encoder = BertModel.from_pretrained(
|
1024 |
model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
|
1025 |
)
|
1026 |
-
|
1027 |
-
|
1028 |
-
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1033 |
text_encoder_2 = None
|
1034 |
|
1035 |
# Get pipeline
|
@@ -1045,23 +1094,18 @@ class EasyAnimateController_Modelscope:
|
|
1045 |
clip_image_processor = None
|
1046 |
|
1047 |
# Get Scheduler
|
1048 |
-
|
1049 |
-
"
|
1050 |
-
|
1051 |
-
"
|
1052 |
-
"PNDM": PNDMScheduler,
|
1053 |
-
"DDIM": DDIMScheduler,
|
1054 |
-
}["Euler"]
|
1055 |
-
|
1056 |
scheduler = Choosen_Scheduler.from_pretrained(
|
1057 |
model_name,
|
1058 |
subfolder="scheduler"
|
1059 |
)
|
1060 |
|
1061 |
-
if
|
1062 |
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
1063 |
-
self.pipeline =
|
1064 |
-
model_name,
|
1065 |
text_encoder=text_encoder,
|
1066 |
text_encoder_2=text_encoder_2,
|
1067 |
tokenizer=tokenizer,
|
@@ -1069,51 +1113,34 @@ class EasyAnimateController_Modelscope:
|
|
1069 |
vae=self.vae,
|
1070 |
transformer=self.transformer,
|
1071 |
scheduler=scheduler,
|
1072 |
-
torch_dtype=self.weight_dtype,
|
1073 |
clip_image_encoder=clip_image_encoder,
|
1074 |
clip_image_processor=clip_image_processor,
|
1075 |
-
)
|
1076 |
else:
|
1077 |
-
self.pipeline =
|
1078 |
-
model_name,
|
1079 |
text_encoder=text_encoder,
|
1080 |
text_encoder_2=text_encoder_2,
|
1081 |
tokenizer=tokenizer,
|
1082 |
tokenizer_2=tokenizer_2,
|
1083 |
vae=self.vae,
|
1084 |
transformer=self.transformer,
|
1085 |
-
scheduler=scheduler
|
1086 |
-
|
1087 |
-
)
|
1088 |
else:
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
clip_image_encoder=clip_image_encoder,
|
1099 |
-
clip_image_processor=clip_image_processor,
|
1100 |
-
)
|
1101 |
-
else:
|
1102 |
-
self.pipeline = EasyAnimatePipeline(
|
1103 |
-
model_name,
|
1104 |
-
text_encoder=text_encoder,
|
1105 |
-
tokenizer=tokenizer,
|
1106 |
-
vae=self.vae,
|
1107 |
-
transformer=self.transformer,
|
1108 |
-
scheduler=scheduler,
|
1109 |
-
torch_dtype=self.weight_dtype
|
1110 |
-
)
|
1111 |
|
1112 |
if GPU_memory_mode == "sequential_cpu_offload":
|
1113 |
self.pipeline.enable_sequential_cpu_offload()
|
1114 |
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
1115 |
self.pipeline.enable_model_cpu_offload()
|
1116 |
-
self.pipeline.enable_autocast_float8_transformer()
|
1117 |
convert_weight_dtype_wrapper(self.pipeline.transformer, weight_dtype)
|
1118 |
else:
|
1119 |
GPU_memory_mode.enable_model_cpu_offload()
|
@@ -1214,17 +1241,17 @@ class EasyAnimateController_Modelscope:
|
|
1214 |
else:
|
1215 |
raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
|
1216 |
|
1217 |
-
fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition]
|
1218 |
is_image = True if generation_method == "Image Generation" else False
|
1219 |
|
1220 |
-
self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
|
1221 |
-
if self.lora_model_path != "none":
|
1222 |
-
# lora part
|
1223 |
-
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
1224 |
-
|
1225 |
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
1226 |
else: seed_textbox = np.random.randint(0, 1e10)
|
1227 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
|
|
|
|
|
|
|
|
|
|
1228 |
|
1229 |
try:
|
1230 |
if self.model_type == "Inpaint":
|
@@ -1254,7 +1281,7 @@ class EasyAnimateController_Modelscope:
|
|
1254 |
video = input_video,
|
1255 |
mask_video = input_video_mask,
|
1256 |
strength = strength,
|
1257 |
-
).
|
1258 |
else:
|
1259 |
sample = self.pipeline(
|
1260 |
prompt_textbox,
|
@@ -1265,7 +1292,7 @@ class EasyAnimateController_Modelscope:
|
|
1265 |
height = height_slider,
|
1266 |
video_length = length_slider if not is_image else 1,
|
1267 |
generator = generator
|
1268 |
-
).
|
1269 |
else:
|
1270 |
if self.vae.cache_mag_vae:
|
1271 |
length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
|
@@ -1285,7 +1312,7 @@ class EasyAnimateController_Modelscope:
|
|
1285 |
generator = generator,
|
1286 |
|
1287 |
control_video = input_video,
|
1288 |
-
).
|
1289 |
except Exception as e:
|
1290 |
gc.collect()
|
1291 |
torch.cuda.empty_cache()
|
@@ -1406,13 +1433,28 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample,
|
|
1406 |
"""
|
1407 |
)
|
1408 |
|
1409 |
-
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful
|
1410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1411 |
|
1412 |
with gr.Row():
|
1413 |
with gr.Column():
|
1414 |
with gr.Row():
|
1415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1416 |
sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False)
|
1417 |
|
1418 |
if edition == "v1":
|
@@ -1466,11 +1508,11 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample,
|
|
1466 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
1467 |
def select_template(evt: gr.SelectData):
|
1468 |
text = {
|
1469 |
-
"asset/1.png": "
|
1470 |
-
"asset/2.png": "
|
1471 |
-
"asset/3.png": "
|
1472 |
-
"asset/4.png": "
|
1473 |
-
"asset/5.png": "
|
1474 |
}[template_gallery_path[evt.index]]
|
1475 |
return template_gallery_path[evt.index], text
|
1476 |
|
@@ -1510,6 +1552,7 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample,
|
|
1510 |
gr.Markdown(
|
1511 |
"""
|
1512 |
Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
|
|
|
1513 |
"""
|
1514 |
)
|
1515 |
control_video = gr.Video(
|
@@ -1820,13 +1863,28 @@ def ui_eas(edition, config_path, model_name, savedir_sample):
|
|
1820 |
"""
|
1821 |
)
|
1822 |
|
1823 |
-
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful
|
1824 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1825 |
|
1826 |
with gr.Row():
|
1827 |
with gr.Column():
|
1828 |
with gr.Row():
|
1829 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1830 |
sample_step_slider = gr.Slider(label="Sampling steps", value=40, minimum=10, maximum=40, step=1, interactive=False)
|
1831 |
|
1832 |
if edition == "v1":
|
@@ -1875,11 +1933,11 @@ def ui_eas(edition, config_path, model_name, savedir_sample):
|
|
1875 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
1876 |
def select_template(evt: gr.SelectData):
|
1877 |
text = {
|
1878 |
-
"asset/1.png": "
|
1879 |
-
"asset/2.png": "
|
1880 |
-
"asset/3.png": "
|
1881 |
-
"asset/4.png": "
|
1882 |
-
"asset/5.png": "
|
1883 |
}[template_gallery_path[evt.index]]
|
1884 |
return template_gallery_path[evt.index], text
|
1885 |
|
|
|
17 |
from diffusers import (AutoencoderKL, DDIMScheduler,
|
18 |
DPMSolverMultistepScheduler,
|
19 |
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
20 |
+
FlowMatchEulerDiscreteScheduler, PNDMScheduler)
|
21 |
from diffusers.utils.import_utils import is_xformers_available
|
22 |
from omegaconf import OmegaConf
|
23 |
from PIL import Image
|
24 |
from safetensors import safe_open
|
25 |
from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
|
26 |
+
CLIPVisionModelWithProjection, Qwen2Tokenizer,
|
27 |
+
Qwen2VLForConditionalGeneration, T5EncoderModel,
|
28 |
+
T5Tokenizer)
|
29 |
|
30 |
+
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
31 |
+
from ..models import (name_to_autoencoder_magvit,
|
32 |
name_to_transformer3d)
|
33 |
+
from ..pipeline.pipeline_easyanimate import \
|
34 |
+
EasyAnimatePipeline
|
35 |
+
from ..pipeline.pipeline_easyanimate_control import \
|
36 |
+
EasyAnimateControlPipeline
|
37 |
+
from ..pipeline.pipeline_easyanimate_inpaint import \
|
38 |
EasyAnimateInpaintPipeline
|
39 |
+
from ..utils.fp8_optimization import convert_weight_dtype_wrapper
|
40 |
+
from ..utils.lora_utils import merge_lora, unmerge_lora
|
41 |
+
from ..utils.utils import (
|
|
|
|
|
|
|
42 |
get_image_to_video_latent, get_video_to_video_latent,
|
43 |
get_width_and_height_from_image_and_base_resolution, save_videos_grid)
|
|
|
44 |
|
45 |
+
ddpm_scheduler_dict = {
|
46 |
"Euler": EulerDiscreteScheduler,
|
47 |
"Euler A": EulerAncestralDiscreteScheduler,
|
48 |
"DPM++": DPMSolverMultistepScheduler,
|
49 |
"PNDM": PNDMScheduler,
|
50 |
"DDIM": DDIMScheduler,
|
51 |
}
|
52 |
+
flow_scheduler_dict = {
|
53 |
+
"Flow": FlowMatchEulerDiscreteScheduler,
|
54 |
+
}
|
55 |
+
all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
|
56 |
|
57 |
gradio_version = pkg_resources.get_distribution("gradio").version
|
58 |
gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
|
|
|
99 |
self.GPU_memory_mode = GPU_memory_mode
|
100 |
|
101 |
self.weight_dtype = weight_dtype
|
102 |
+
self.edition = "v5.1"
|
103 |
+
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5.1_magvit_qwen.yaml"))
|
104 |
|
105 |
def refresh_diffusion_transformer(self):
|
106 |
self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
|
|
|
122 |
if edition == "v1":
|
123 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v1_motion_module.yaml"))
|
124 |
return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
|
125 |
+
gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
|
126 |
gr.update(value=512, minimum=384, maximum=704, step=32), \
|
127 |
gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
|
128 |
elif edition == "v2":
|
129 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v2_magvit_motion_module.yaml"))
|
130 |
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
131 |
+
gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
|
132 |
gr.update(value=672, minimum=128, maximum=1344, step=16), \
|
133 |
gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
|
134 |
elif edition == "v3":
|
135 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v3_slicevae_motion_module.yaml"))
|
136 |
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
137 |
+
gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
|
138 |
gr.update(value=672, minimum=128, maximum=1344, step=16), \
|
139 |
gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
|
140 |
elif edition == "v4":
|
141 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v4_slicevae_multi_text_encoder.yaml"))
|
142 |
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
143 |
+
gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
|
144 |
gr.update(value=672, minimum=128, maximum=1344, step=16), \
|
145 |
gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
|
146 |
elif edition == "v5":
|
147 |
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml"))
|
148 |
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
149 |
+
gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
|
150 |
+
gr.update(value=672, minimum=128, maximum=1344, step=16), \
|
151 |
+
gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4)
|
152 |
+
elif edition == "v5.1":
|
153 |
+
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5.1_magvit_qwen.yaml"))
|
154 |
+
return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
155 |
+
gr.update(choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]), \
|
156 |
gr.update(value=672, minimum=128, maximum=1344, step=16), \
|
157 |
gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4)
|
158 |
|
|
|
182 |
self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
|
183 |
diffusion_transformer_dropdown,
|
184 |
subfolder="transformer",
|
185 |
+
transformer_additional_kwargs=transformer_additional_kwargs,
|
186 |
+
torch_dtype=torch.float8_e4m3fn if self.GPU_memory_mode == "model_cpu_offload_and_qfloat8" else self.weight_dtype,
|
187 |
+
low_cpu_mem_usage=True,
|
188 |
+
)
|
189 |
|
190 |
if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
|
191 |
tokenizer = BertTokenizer.from_pretrained(
|
192 |
diffusion_transformer_dropdown, subfolder="tokenizer"
|
193 |
)
|
194 |
+
if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
|
195 |
+
tokenizer_2 = Qwen2Tokenizer.from_pretrained(
|
196 |
+
os.path.join(diffusion_transformer_dropdown, "tokenizer_2")
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
tokenizer_2 = T5Tokenizer.from_pretrained(
|
200 |
+
diffusion_transformer_dropdown, subfolder="tokenizer_2"
|
201 |
+
)
|
202 |
else:
|
203 |
+
if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
|
204 |
+
tokenizer = Qwen2Tokenizer.from_pretrained(
|
205 |
+
os.path.join(diffusion_transformer_dropdown, "tokenizer")
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
209 |
+
diffusion_transformer_dropdown, subfolder="tokenizer"
|
210 |
+
)
|
211 |
tokenizer_2 = None
|
212 |
|
213 |
if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
|
214 |
text_encoder = BertModel.from_pretrained(
|
215 |
diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
|
216 |
)
|
217 |
+
if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
|
218 |
+
text_encoder_2 = Qwen2VLForConditionalGeneration.from_pretrained(
|
219 |
+
os.path.join(diffusion_transformer_dropdown, "text_encoder_2")
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
text_encoder_2 = T5EncoderModel.from_pretrained(
|
223 |
+
diffusion_transformer_dropdown, subfolder="text_encoder_2", torch_dtype=self.weight_dtype
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
|
227 |
+
text_encoder = Qwen2VLForConditionalGeneration.from_pretrained(
|
228 |
+
os.path.join(diffusion_transformer_dropdown, "text_encoder")
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
232 |
+
diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
|
233 |
+
)
|
234 |
text_encoder_2 = None
|
235 |
|
236 |
# Get pipeline
|
|
|
246 |
clip_image_processor = None
|
247 |
|
248 |
# Get Scheduler
|
249 |
+
if self.edition in ["v5.1"]:
|
250 |
+
Choosen_Scheduler = all_cheduler_dict["Flow"]
|
251 |
+
else:
|
252 |
+
Choosen_Scheduler = all_cheduler_dict["Euler"]
|
|
|
|
|
|
|
|
|
253 |
scheduler = Choosen_Scheduler.from_pretrained(
|
254 |
diffusion_transformer_dropdown,
|
255 |
subfolder="scheduler"
|
256 |
)
|
257 |
|
258 |
+
if self.model_type == "Inpaint":
|
259 |
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
260 |
+
self.pipeline = EasyAnimateInpaintPipeline(
|
|
|
261 |
text_encoder=text_encoder,
|
262 |
text_encoder_2=text_encoder_2,
|
263 |
tokenizer=tokenizer,
|
|
|
265 |
vae=self.vae,
|
266 |
transformer=self.transformer,
|
267 |
scheduler=scheduler,
|
|
|
268 |
clip_image_encoder=clip_image_encoder,
|
269 |
clip_image_processor=clip_image_processor,
|
270 |
+
).to(self.weight_dtype)
|
271 |
else:
|
272 |
+
self.pipeline = EasyAnimatePipeline(
|
|
|
273 |
text_encoder=text_encoder,
|
274 |
text_encoder_2=text_encoder_2,
|
275 |
tokenizer=tokenizer,
|
|
|
277 |
vae=self.vae,
|
278 |
transformer=self.transformer,
|
279 |
scheduler=scheduler,
|
280 |
+
).to(self.weight_dtype)
|
|
|
281 |
else:
|
282 |
+
self.pipeline = EasyAnimateControlPipeline(
|
283 |
+
text_encoder=text_encoder,
|
284 |
+
text_encoder_2=text_encoder_2,
|
285 |
+
tokenizer=tokenizer,
|
286 |
+
tokenizer_2=tokenizer_2,
|
287 |
+
vae=self.vae,
|
288 |
+
transformer=self.transformer,
|
289 |
+
scheduler=scheduler,
|
290 |
+
).to(self.weight_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
if self.GPU_memory_mode == "sequential_cpu_offload":
|
293 |
self.pipeline.enable_sequential_cpu_offload()
|
294 |
elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
295 |
self.pipeline.enable_model_cpu_offload()
|
|
|
296 |
convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
|
297 |
else:
|
298 |
+
self.pipeline.enable_model_cpu_offload()
|
299 |
print("Update diffusion transformer done")
|
300 |
return gr.update()
|
301 |
|
|
|
386 |
if self.base_model_path != base_model_dropdown:
|
387 |
self.update_base_model(base_model_dropdown)
|
388 |
|
389 |
+
if self.motion_module_path != motion_module_dropdown:
|
390 |
+
self.update_motion_module(motion_module_dropdown)
|
391 |
+
|
392 |
if self.lora_model_path != lora_model_dropdown:
|
|
|
393 |
self.update_lora_model(lora_model_dropdown)
|
394 |
|
395 |
if control_video is not None and self.model_type == "Inpaint":
|
|
|
440 |
else:
|
441 |
raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
|
442 |
|
443 |
+
fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8, "v5.1": 8}[self.edition]
|
444 |
is_image = True if generation_method == "Image Generation" else False
|
445 |
|
446 |
+
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
447 |
+
else: seed_textbox = np.random.randint(0, 1e10)
|
448 |
+
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
449 |
|
450 |
+
if is_xformers_available() \
|
451 |
+
and self.inference_config['transformer_additional_kwargs'].get('transformer_type', 'Transformer3DModel') == 'Transformer3DModel':
|
452 |
+
self.transformer.enable_xformers_memory_efficient_attention()
|
453 |
+
|
454 |
+
self.pipeline.scheduler = all_cheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
|
455 |
if self.lora_model_path != "none":
|
456 |
# lora part
|
457 |
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
|
|
|
|
|
|
|
|
458 |
|
459 |
try:
|
460 |
if self.model_type == "Inpaint":
|
|
|
496 |
video = input_video,
|
497 |
mask_video = input_video_mask,
|
498 |
strength = 1,
|
499 |
+
).frames
|
500 |
|
501 |
if init_frames != 0:
|
502 |
mix_ratio = torch.from_numpy(
|
|
|
547 |
video = input_video,
|
548 |
mask_video = input_video_mask,
|
549 |
strength = strength,
|
550 |
+
).frames
|
551 |
else:
|
552 |
if self.vae.cache_mag_vae:
|
553 |
length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
|
|
|
563 |
height = height_slider,
|
564 |
video_length = length_slider if not is_image else 1,
|
565 |
generator = generator
|
566 |
+
).frames
|
567 |
else:
|
568 |
if self.vae.cache_mag_vae:
|
569 |
length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
|
|
|
582 |
generator = generator,
|
583 |
|
584 |
control_video = input_video,
|
585 |
+
).frames
|
586 |
except Exception as e:
|
587 |
gc.collect()
|
588 |
torch.cuda.empty_cache()
|
|
|
692 |
with gr.Row():
|
693 |
easyanimate_edition_dropdown = gr.Dropdown(
|
694 |
label="The config of EasyAnimate Edition (EasyAnimate版本配置)",
|
695 |
+
choices=["v1", "v2", "v3", "v4", "v5", "v5.1"],
|
696 |
+
value="v5.1",
|
697 |
interactive=True,
|
698 |
)
|
699 |
gr.Markdown(
|
|
|
767 |
"""
|
768 |
)
|
769 |
|
770 |
+
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.")
|
771 |
+
gr.Markdown(
|
772 |
+
"""
|
773 |
+
Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism.
|
774 |
+
使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。
|
775 |
+
"""
|
776 |
+
)
|
777 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." )
|
778 |
|
779 |
with gr.Row():
|
780 |
with gr.Column():
|
781 |
with gr.Row():
|
782 |
+
sampler_dropdown = gr.Dropdown(
|
783 |
+
label="Sampling method (采样器种类)",
|
784 |
+
choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]
|
785 |
+
)
|
786 |
sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1)
|
787 |
|
788 |
resize_method = gr.Radio(
|
|
|
819 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
820 |
def select_template(evt: gr.SelectData):
|
821 |
text = {
|
822 |
+
"asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
|
823 |
+
"asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
|
824 |
+
"asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
|
825 |
+
"asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
|
826 |
+
"asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
|
827 |
}[template_gallery_path[evt.index]]
|
828 |
return template_gallery_path[evt.index], text
|
829 |
|
|
|
863 |
gr.Markdown(
|
864 |
"""
|
865 |
Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
|
866 |
+
Only normal controls are supported in app.py; trajectory control and camera control need ComfyUI, as shown in https://github.com/aigc-apps/EasyAnimate/tree/main/comfyui.
|
867 |
"""
|
868 |
)
|
869 |
control_video = gr.Video(
|
|
|
953 |
diffusion_transformer_dropdown,
|
954 |
motion_module_dropdown,
|
955 |
motion_module_refresh_button,
|
956 |
+
sampler_dropdown,
|
957 |
width_slider,
|
958 |
height_slider,
|
959 |
length_slider,
|
|
|
1030 |
self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
|
1031 |
model_name,
|
1032 |
subfolder="transformer",
|
1033 |
+
transformer_additional_kwargs=transformer_additional_kwargs,
|
1034 |
+
torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype,
|
1035 |
+
low_cpu_mem_usage=True,
|
1036 |
+
)
|
1037 |
|
1038 |
if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
|
1039 |
tokenizer = BertTokenizer.from_pretrained(
|
1040 |
model_name, subfolder="tokenizer"
|
1041 |
)
|
1042 |
+
if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
|
1043 |
+
tokenizer_2 = Qwen2Tokenizer.from_pretrained(
|
1044 |
+
os.path.join(model_name, "tokenizer_2")
|
1045 |
+
)
|
1046 |
+
else:
|
1047 |
+
tokenizer_2 = T5Tokenizer.from_pretrained(
|
1048 |
+
model_name, subfolder="tokenizer_2"
|
1049 |
+
)
|
1050 |
else:
|
1051 |
+
if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
|
1052 |
+
tokenizer = Qwen2Tokenizer.from_pretrained(
|
1053 |
+
os.path.join(model_name, "tokenizer")
|
1054 |
+
)
|
1055 |
+
else:
|
1056 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
1057 |
+
model_name, subfolder="tokenizer"
|
1058 |
+
)
|
1059 |
tokenizer_2 = None
|
1060 |
|
1061 |
if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
|
1062 |
text_encoder = BertModel.from_pretrained(
|
1063 |
model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
|
1064 |
)
|
1065 |
+
if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
|
1066 |
+
text_encoder_2 = Qwen2VLForConditionalGeneration.from_pretrained(
|
1067 |
+
os.path.join(model_name, "text_encoder_2"), torch_dtype=self.weight_dtype
|
1068 |
+
)
|
1069 |
+
else:
|
1070 |
+
text_encoder_2 = T5EncoderModel.from_pretrained(
|
1071 |
+
model_name, subfolder="text_encoder_2", torch_dtype=self.weight_dtype
|
1072 |
+
)
|
1073 |
+
else:
|
1074 |
+
if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
|
1075 |
+
text_encoder = Qwen2VLForConditionalGeneration.from_pretrained(
|
1076 |
+
os.path.join(model_name, "text_encoder"), torch_dtype=self.weight_dtype
|
1077 |
+
)
|
1078 |
+
else:
|
1079 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
1080 |
+
model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
|
1081 |
+
)
|
1082 |
text_encoder_2 = None
|
1083 |
|
1084 |
# Get pipeline
|
|
|
1094 |
clip_image_processor = None
|
1095 |
|
1096 |
# Get Scheduler
|
1097 |
+
if self.edition in ["v5.1"]:
|
1098 |
+
Choosen_Scheduler = all_cheduler_dict["Flow"]
|
1099 |
+
else:
|
1100 |
+
Choosen_Scheduler = all_cheduler_dict["Euler"]
|
|
|
|
|
|
|
|
|
1101 |
scheduler = Choosen_Scheduler.from_pretrained(
|
1102 |
model_name,
|
1103 |
subfolder="scheduler"
|
1104 |
)
|
1105 |
|
1106 |
+
if model_type == "Inpaint":
|
1107 |
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
1108 |
+
self.pipeline = EasyAnimateInpaintPipeline(
|
|
|
1109 |
text_encoder=text_encoder,
|
1110 |
text_encoder_2=text_encoder_2,
|
1111 |
tokenizer=tokenizer,
|
|
|
1113 |
vae=self.vae,
|
1114 |
transformer=self.transformer,
|
1115 |
scheduler=scheduler,
|
|
|
1116 |
clip_image_encoder=clip_image_encoder,
|
1117 |
clip_image_processor=clip_image_processor,
|
1118 |
+
).to(weight_dtype)
|
1119 |
else:
|
1120 |
+
self.pipeline = EasyAnimatePipeline(
|
|
|
1121 |
text_encoder=text_encoder,
|
1122 |
text_encoder_2=text_encoder_2,
|
1123 |
tokenizer=tokenizer,
|
1124 |
tokenizer_2=tokenizer_2,
|
1125 |
vae=self.vae,
|
1126 |
transformer=self.transformer,
|
1127 |
+
scheduler=scheduler
|
1128 |
+
).to(weight_dtype)
|
|
|
1129 |
else:
|
1130 |
+
self.pipeline = EasyAnimateControlPipeline(
|
1131 |
+
text_encoder=text_encoder,
|
1132 |
+
text_encoder_2=text_encoder_2,
|
1133 |
+
tokenizer=tokenizer,
|
1134 |
+
tokenizer_2=tokenizer_2,
|
1135 |
+
vae=self.vae,
|
1136 |
+
transformer=self.transformer,
|
1137 |
+
scheduler=scheduler,
|
1138 |
+
).to(weight_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1139 |
|
1140 |
if GPU_memory_mode == "sequential_cpu_offload":
|
1141 |
self.pipeline.enable_sequential_cpu_offload()
|
1142 |
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
1143 |
self.pipeline.enable_model_cpu_offload()
|
|
|
1144 |
convert_weight_dtype_wrapper(self.pipeline.transformer, weight_dtype)
|
1145 |
else:
|
1146 |
GPU_memory_mode.enable_model_cpu_offload()
|
|
|
1241 |
else:
|
1242 |
raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
|
1243 |
|
1244 |
+
fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8, "v5.1": 8}[self.edition]
|
1245 |
is_image = True if generation_method == "Image Generation" else False
|
1246 |
|
|
|
|
|
|
|
|
|
|
|
1247 |
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
1248 |
else: seed_textbox = np.random.randint(0, 1e10)
|
1249 |
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
1250 |
+
|
1251 |
+
self.pipeline.scheduler = all_cheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
|
1252 |
+
if self.lora_model_path != "none":
|
1253 |
+
# lora part
|
1254 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
1255 |
|
1256 |
try:
|
1257 |
if self.model_type == "Inpaint":
|
|
|
1281 |
video = input_video,
|
1282 |
mask_video = input_video_mask,
|
1283 |
strength = strength,
|
1284 |
+
).frames
|
1285 |
else:
|
1286 |
sample = self.pipeline(
|
1287 |
prompt_textbox,
|
|
|
1292 |
height = height_slider,
|
1293 |
video_length = length_slider if not is_image else 1,
|
1294 |
generator = generator
|
1295 |
+
).frames
|
1296 |
else:
|
1297 |
if self.vae.cache_mag_vae:
|
1298 |
length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
|
|
|
1312 |
generator = generator,
|
1313 |
|
1314 |
control_video = input_video,
|
1315 |
+
).frames
|
1316 |
except Exception as e:
|
1317 |
gc.collect()
|
1318 |
torch.cuda.empty_cache()
|
|
|
1433 |
"""
|
1434 |
)
|
1435 |
|
1436 |
+
prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.")
|
1437 |
+
gr.Markdown(
|
1438 |
+
"""
|
1439 |
+
Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism.
|
1440 |
+
使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。
|
1441 |
+
"""
|
1442 |
+
)
|
1443 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." )
|
1444 |
|
1445 |
with gr.Row():
|
1446 |
with gr.Column():
|
1447 |
with gr.Row():
|
1448 |
+
if edition in ["v5.1"]:
|
1449 |
+
sampler_dropdown = gr.Dropdown(
|
1450 |
+
label="Sampling method (采样器种类)",
|
1451 |
+
choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]
|
1452 |
+
)
|
1453 |
+
else:
|
1454 |
+
sampler_dropdown = gr.Dropdown(
|
1455 |
+
label="Sampling method (采样器种类)",
|
1456 |
+
choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]
|
1457 |
+
)
|
1458 |
sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False)
|
1459 |
|
1460 |
if edition == "v1":
|
|
|
1508 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
1509 |
def select_template(evt: gr.SelectData):
|
1510 |
text = {
|
1511 |
+
"asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
|
1512 |
+
"asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
|
1513 |
+
"asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
|
1514 |
+
"asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
|
1515 |
+
"asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
|
1516 |
}[template_gallery_path[evt.index]]
|
1517 |
return template_gallery_path[evt.index], text
|
1518 |
|
|
|
1552 |
gr.Markdown(
|
1553 |
"""
|
1554 |
Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
|
1555 |
+
Only normal controls are supported in app.py; trajectory control and camera control need ComfyUI, as shown in https://github.com/aigc-apps/EasyAnimate/tree/main/comfyui.
|
1556 |
"""
|
1557 |
)
|
1558 |
control_video = gr.Video(
|
|
|
1863 |
"""
|
1864 |
)
|
1865 |
|
1866 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.")
|
1867 |
+
gr.Markdown(
|
1868 |
+
"""
|
1869 |
+
Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism.
|
1870 |
+
使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。
|
1871 |
+
"""
|
1872 |
+
)
|
1873 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." )
|
1874 |
|
1875 |
with gr.Row():
|
1876 |
with gr.Column():
|
1877 |
with gr.Row():
|
1878 |
+
if edition in ["v5.1"]:
|
1879 |
+
sampler_dropdown = gr.Dropdown(
|
1880 |
+
label="Sampling method (采样器种类)",
|
1881 |
+
choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]
|
1882 |
+
)
|
1883 |
+
else:
|
1884 |
+
sampler_dropdown = gr.Dropdown(
|
1885 |
+
label="Sampling method (采样器种类)",
|
1886 |
+
choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]
|
1887 |
+
)
|
1888 |
sample_step_slider = gr.Slider(label="Sampling steps", value=40, minimum=10, maximum=40, step=1, interactive=False)
|
1889 |
|
1890 |
if edition == "v1":
|
|
|
1933 |
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
1934 |
def select_template(evt: gr.SelectData):
|
1935 |
text = {
|
1936 |
+
"asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
|
1937 |
+
"asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
|
1938 |
+
"asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
|
1939 |
+
"asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
|
1940 |
+
"asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
|
1941 |
}[template_gallery_path[evt.index]]
|
1942 |
return template_gallery_path[evt.index], text
|
1943 |
|
easyanimate/utils/lora_utils.py
CHANGED
@@ -369,7 +369,6 @@ def create_network(
|
|
369 |
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
|
370 |
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
371 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
372 |
-
SPECIAL_LAYER_NAME = ["text_proj_t5"]
|
373 |
if state_dict is None:
|
374 |
state_dict = load_file(lora_path, device=device)
|
375 |
else:
|
@@ -410,20 +409,25 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
|
|
410 |
else:
|
411 |
temp_name = layer_infos.pop(0)
|
412 |
|
413 |
-
|
414 |
-
|
|
|
|
|
|
|
|
|
|
|
415 |
if 'alpha' in elems.keys():
|
416 |
alpha = elems['alpha'].item() / weight_up.shape[1]
|
417 |
else:
|
418 |
alpha = 1.0
|
419 |
|
420 |
-
curr_layer.weight.data = curr_layer.weight.data.to(device)
|
421 |
if len(weight_up.shape) == 4:
|
422 |
-
curr_layer.weight.data += multiplier * alpha * torch.mm(
|
423 |
-
|
424 |
-
|
425 |
else:
|
426 |
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
|
|
|
427 |
|
428 |
return pipeline
|
429 |
|
@@ -448,35 +452,43 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
|
|
448 |
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
449 |
curr_layer = pipeline.transformer
|
450 |
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
if 'alpha' in elems.keys():
|
471 |
alpha = elems['alpha'].item() / weight_up.shape[1]
|
472 |
else:
|
473 |
alpha = 1.0
|
474 |
|
475 |
-
curr_layer.weight.data = curr_layer.weight.data.to(device)
|
476 |
if len(weight_up.shape) == 4:
|
477 |
-
curr_layer.weight.data -= multiplier * alpha * torch.mm(
|
478 |
-
|
|
|
479 |
else:
|
480 |
curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
|
|
|
481 |
|
482 |
-
return pipeline
|
|
|
369 |
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
|
370 |
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
371 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
|
|
372 |
if state_dict is None:
|
373 |
state_dict = load_file(lora_path, device=device)
|
374 |
else:
|
|
|
409 |
else:
|
410 |
temp_name = layer_infos.pop(0)
|
411 |
|
412 |
+
origin_dtype = curr_layer.weight.data.dtype
|
413 |
+
origin_device = curr_layer.weight.data.device
|
414 |
+
|
415 |
+
curr_layer = curr_layer.to(device, dtype)
|
416 |
+
weight_up = elems['lora_up.weight'].to(device, dtype)
|
417 |
+
weight_down = elems['lora_down.weight'].to(device, dtype)
|
418 |
+
|
419 |
if 'alpha' in elems.keys():
|
420 |
alpha = elems['alpha'].item() / weight_up.shape[1]
|
421 |
else:
|
422 |
alpha = 1.0
|
423 |
|
|
|
424 |
if len(weight_up.shape) == 4:
|
425 |
+
curr_layer.weight.data += multiplier * alpha * torch.mm(
|
426 |
+
weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
|
427 |
+
).unsqueeze(2).unsqueeze(3)
|
428 |
else:
|
429 |
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
|
430 |
+
curr_layer = curr_layer.to(origin_device, origin_dtype)
|
431 |
|
432 |
return pipeline
|
433 |
|
|
|
452 |
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
453 |
curr_layer = pipeline.transformer
|
454 |
|
455 |
+
try:
|
456 |
+
curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
|
457 |
+
except Exception:
|
458 |
+
temp_name = layer_infos.pop(0)
|
459 |
+
while len(layer_infos) > -1:
|
460 |
+
try:
|
461 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
462 |
+
if len(layer_infos) > 0:
|
463 |
+
temp_name = layer_infos.pop(0)
|
464 |
+
elif len(layer_infos) == 0:
|
465 |
+
break
|
466 |
+
except Exception:
|
467 |
+
if len(layer_infos) == 0:
|
468 |
+
print('Error loading layer')
|
469 |
+
if len(temp_name) > 0:
|
470 |
+
temp_name += "_" + layer_infos.pop(0)
|
471 |
+
else:
|
472 |
+
temp_name = layer_infos.pop(0)
|
473 |
+
|
474 |
+
origin_dtype = curr_layer.weight.data.dtype
|
475 |
+
origin_device = curr_layer.weight.data.device
|
476 |
+
|
477 |
+
curr_layer = curr_layer.to(device, dtype)
|
478 |
+
weight_up = elems['lora_up.weight'].to(device, dtype)
|
479 |
+
weight_down = elems['lora_down.weight'].to(device, dtype)
|
480 |
+
|
481 |
if 'alpha' in elems.keys():
|
482 |
alpha = elems['alpha'].item() / weight_up.shape[1]
|
483 |
else:
|
484 |
alpha = 1.0
|
485 |
|
|
|
486 |
if len(weight_up.shape) == 4:
|
487 |
+
curr_layer.weight.data -= multiplier * alpha * torch.mm(
|
488 |
+
weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
|
489 |
+
).unsqueeze(2).unsqueeze(3)
|
490 |
else:
|
491 |
curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
|
492 |
+
curr_layer = curr_layer.to(origin_device, origin_dtype)
|
493 |
|
494 |
+
return pipeline
|
easyanimate/utils/utils.py
CHANGED
@@ -169,47 +169,67 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide
|
|
169 |
return input_video, input_video_mask, clip_image
|
170 |
|
171 |
def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
|
172 |
-
if
|
173 |
-
|
174 |
-
|
|
|
175 |
|
176 |
-
|
177 |
-
|
178 |
|
179 |
-
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
|
190 |
-
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
195 |
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
if ref_image is not None:
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
|
215 |
-
return
|
|
|
169 |
return input_video, input_video_mask, clip_image
|
170 |
|
171 |
def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
|
172 |
+
if input_video_path is not None:
|
173 |
+
if isinstance(input_video_path, str):
|
174 |
+
cap = cv2.VideoCapture(input_video_path)
|
175 |
+
input_video = []
|
176 |
|
177 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
178 |
+
frame_skip = 1 if fps is None else int(original_fps // fps)
|
179 |
|
180 |
+
frame_count = 0
|
181 |
|
182 |
+
while True:
|
183 |
+
ret, frame = cap.read()
|
184 |
+
if not ret:
|
185 |
+
break
|
186 |
|
187 |
+
if frame_count % frame_skip == 0:
|
188 |
+
frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
|
189 |
+
input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
190 |
|
191 |
+
frame_count += 1
|
192 |
|
193 |
+
cap.release()
|
194 |
+
else:
|
195 |
+
input_video = input_video_path
|
196 |
+
|
197 |
+
input_video = torch.from_numpy(np.array(input_video))[:video_length]
|
198 |
+
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
199 |
|
200 |
+
if validation_video_mask is not None:
|
201 |
+
validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
|
202 |
+
input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
|
203 |
+
|
204 |
+
input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
|
205 |
+
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
|
206 |
+
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
|
207 |
+
else:
|
208 |
+
input_video_mask = torch.zeros_like(input_video[:, :1])
|
209 |
+
input_video_mask[:, :, :] = 255
|
210 |
+
else:
|
211 |
+
input_video, input_video_mask = None, None
|
212 |
|
213 |
if ref_image is not None:
|
214 |
+
if isinstance(ref_image, str):
|
215 |
+
ref_image = Image.open(ref_image).convert("RGB")
|
216 |
+
ref_image = ref_image.resize((sample_size[1], sample_size[0]))
|
217 |
+
ref_image = torch.from_numpy(np.array(ref_image))
|
218 |
+
ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
219 |
+
else:
|
220 |
+
ref_image = torch.from_numpy(np.array(ref_image))
|
221 |
+
ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
222 |
+
return input_video, input_video_mask, ref_image
|
223 |
|
224 |
+
def get_image_latent(ref_image=None, sample_size=None):
|
225 |
+
if ref_image is not None:
|
226 |
+
if isinstance(ref_image, str):
|
227 |
+
ref_image = Image.open(ref_image).convert("RGB")
|
228 |
+
ref_image = ref_image.resize((sample_size[1], sample_size[0]))
|
229 |
+
ref_image = torch.from_numpy(np.array(ref_image))
|
230 |
+
ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
231 |
+
else:
|
232 |
+
ref_image = torch.from_numpy(np.array(ref_image))
|
233 |
+
ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
|
234 |
|
235 |
+
return ref_image
|
easyanimate/vae/ldm/models/autoencoder.py
CHANGED
@@ -126,13 +126,13 @@ class AutoencoderKLMagvit(pl.LightningModule):
|
|
126 |
|
127 |
def configure_optimizers(self):
|
128 |
lr = self.learning_rate
|
129 |
-
opt_ae = torch.optim.
|
130 |
list(self.decoder.parameters())+
|
131 |
list(self.quant_conv.parameters())+
|
132 |
list(self.post_quant_conv.parameters()),
|
133 |
-
lr=lr, betas=(0.
|
134 |
-
opt_disc = torch.optim.
|
135 |
-
lr=lr, betas=(0.
|
136 |
return [opt_ae, opt_disc], []
|
137 |
|
138 |
def get_last_layer(self):
|
|
|
126 |
|
127 |
def configure_optimizers(self):
|
128 |
lr = self.learning_rate
|
129 |
+
opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+
|
130 |
list(self.decoder.parameters())+
|
131 |
list(self.quant_conv.parameters())+
|
132 |
list(self.post_quant_conv.parameters()),
|
133 |
+
lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
134 |
+
opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(),
|
135 |
+
lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
136 |
return [opt_ae, opt_disc], []
|
137 |
|
138 |
def get_last_layer(self):
|
easyanimate/vae/ldm/models/casual3dcnn.py
CHANGED
@@ -279,13 +279,13 @@ class AutoencoderKL(pl.LightningModule):
|
|
279 |
|
280 |
def configure_optimizers(self):
|
281 |
lr = self.learning_rate
|
282 |
-
opt_ae = torch.optim.
|
283 |
list(self.decoder.parameters())+
|
284 |
list(self.quant_conv.parameters())+
|
285 |
-
list(self.post_quant_conv.parameters()),
|
286 |
-
lr=lr, betas=(0.
|
287 |
-
opt_disc = torch.optim.
|
288 |
-
lr=lr, betas=(0.
|
289 |
return [opt_ae, opt_disc], []
|
290 |
|
291 |
def get_last_layer(self):
|
|
|
279 |
|
280 |
def configure_optimizers(self):
|
281 |
lr = self.learning_rate
|
282 |
+
opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+
|
283 |
list(self.decoder.parameters())+
|
284 |
list(self.quant_conv.parameters())+
|
285 |
+
list(self.post_quant_conv.parameters()), \
|
286 |
+
lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
287 |
+
opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(),
|
288 |
+
lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
289 |
return [opt_ae, opt_disc], []
|
290 |
|
291 |
def get_last_layer(self):
|
easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py
CHANGED
@@ -277,23 +277,23 @@ class AutoencoderKLMagvit_CogVideoX(pl.LightningModule):
|
|
277 |
training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
|
278 |
else:
|
279 |
training_list = list(self.decoder.parameters())
|
280 |
-
opt_ae = torch.optim.
|
281 |
elif self.train_encoder_only:
|
282 |
if self.quant_conv is not None:
|
283 |
training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
|
284 |
else:
|
285 |
training_list = list(self.encoder.parameters())
|
286 |
-
opt_ae = torch.optim.
|
287 |
else:
|
288 |
training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
289 |
if self.quant_conv is not None:
|
290 |
training_list = training_list + list(self.quant_conv.parameters())
|
291 |
if self.post_quant_conv is not None:
|
292 |
training_list = training_list + list(self.post_quant_conv.parameters())
|
293 |
-
opt_ae = torch.optim.
|
294 |
-
opt_disc = torch.optim.
|
295 |
list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
|
296 |
-
lr=lr, betas=(0.
|
297 |
)
|
298 |
return [opt_ae, opt_disc], []
|
299 |
|
|
|
277 |
training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
|
278 |
else:
|
279 |
training_list = list(self.decoder.parameters())
|
280 |
+
opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
281 |
elif self.train_encoder_only:
|
282 |
if self.quant_conv is not None:
|
283 |
training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
|
284 |
else:
|
285 |
training_list = list(self.encoder.parameters())
|
286 |
+
opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
287 |
else:
|
288 |
training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
289 |
if self.quant_conv is not None:
|
290 |
training_list = training_list + list(self.quant_conv.parameters())
|
291 |
if self.post_quant_conv is not None:
|
292 |
training_list = training_list + list(self.post_quant_conv.parameters())
|
293 |
+
opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
294 |
+
opt_disc = torch.optim.AdamW(
|
295 |
list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
|
296 |
+
lr=lr, betas=(0.9, 0.999), weight_decay=5e-2
|
297 |
)
|
298 |
return [opt_ae, opt_disc], []
|
299 |
|
easyanimate/vae/ldm/models/omnigen_casual3dcnn.py
CHANGED
@@ -95,6 +95,7 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
95 |
out_channels: int = 3,
|
96 |
ch = 128,
|
97 |
ch_mult = [ 1,2,4,4 ],
|
|
|
98 |
use_gc_blocks = None,
|
99 |
down_block_types: tuple = None,
|
100 |
up_block_types: tuple = None,
|
@@ -129,8 +130,9 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
129 |
in_channels=in_channels,
|
130 |
out_channels=latent_channels,
|
131 |
down_block_types=down_block_types,
|
132 |
-
ch
|
133 |
-
ch_mult
|
|
|
134 |
use_gc_blocks=use_gc_blocks,
|
135 |
mid_block_type=mid_block_type,
|
136 |
mid_block_use_attention=mid_block_use_attention,
|
@@ -144,6 +146,7 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
144 |
slice_mag_vae=slice_mag_vae,
|
145 |
slice_compression_vae=slice_compression_vae,
|
146 |
cache_compression_vae=cache_compression_vae,
|
|
|
147 |
spatial_group_norm=spatial_group_norm,
|
148 |
mini_batch_encoder=mini_batch_encoder,
|
149 |
)
|
@@ -152,8 +155,9 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
152 |
in_channels=latent_channels,
|
153 |
out_channels=out_channels,
|
154 |
up_block_types=up_block_types,
|
155 |
-
ch
|
156 |
-
ch_mult
|
|
|
157 |
use_gc_blocks=use_gc_blocks,
|
158 |
mid_block_type=mid_block_type,
|
159 |
mid_block_use_attention=mid_block_use_attention,
|
@@ -292,23 +296,23 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
|
|
292 |
training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
|
293 |
else:
|
294 |
training_list = list(self.decoder.parameters())
|
295 |
-
opt_ae = torch.optim.
|
296 |
elif self.train_encoder_only:
|
297 |
if self.quant_conv is not None:
|
298 |
training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
|
299 |
else:
|
300 |
training_list = list(self.encoder.parameters())
|
301 |
-
opt_ae = torch.optim.
|
302 |
else:
|
303 |
training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
304 |
if self.quant_conv is not None:
|
305 |
training_list = training_list + list(self.quant_conv.parameters())
|
306 |
if self.post_quant_conv is not None:
|
307 |
training_list = training_list + list(self.post_quant_conv.parameters())
|
308 |
-
opt_ae = torch.optim.
|
309 |
-
opt_disc = torch.optim.
|
310 |
list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
|
311 |
-
lr=lr, betas=(0.
|
312 |
)
|
313 |
return [opt_ae, opt_disc], []
|
314 |
|
|
|
95 |
out_channels: int = 3,
|
96 |
ch = 128,
|
97 |
ch_mult = [ 1,2,4,4 ],
|
98 |
+
block_out_channels = [128, 256, 512, 512],
|
99 |
use_gc_blocks = None,
|
100 |
down_block_types: tuple = None,
|
101 |
up_block_types: tuple = None,
|
|
|
130 |
in_channels=in_channels,
|
131 |
out_channels=latent_channels,
|
132 |
down_block_types=down_block_types,
|
133 |
+
ch=ch,
|
134 |
+
ch_mult=ch_mult,
|
135 |
+
block_out_channels=block_out_channels,
|
136 |
use_gc_blocks=use_gc_blocks,
|
137 |
mid_block_type=mid_block_type,
|
138 |
mid_block_use_attention=mid_block_use_attention,
|
|
|
146 |
slice_mag_vae=slice_mag_vae,
|
147 |
slice_compression_vae=slice_compression_vae,
|
148 |
cache_compression_vae=cache_compression_vae,
|
149 |
+
cache_mag_vae=cache_mag_vae,
|
150 |
spatial_group_norm=spatial_group_norm,
|
151 |
mini_batch_encoder=mini_batch_encoder,
|
152 |
)
|
|
|
155 |
in_channels=latent_channels,
|
156 |
out_channels=out_channels,
|
157 |
up_block_types=up_block_types,
|
158 |
+
ch=ch,
|
159 |
+
ch_mult=ch_mult,
|
160 |
+
block_out_channels=block_out_channels,
|
161 |
use_gc_blocks=use_gc_blocks,
|
162 |
mid_block_type=mid_block_type,
|
163 |
mid_block_use_attention=mid_block_use_attention,
|
|
|
296 |
training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
|
297 |
else:
|
298 |
training_list = list(self.decoder.parameters())
|
299 |
+
opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
300 |
elif self.train_encoder_only:
|
301 |
if self.quant_conv is not None:
|
302 |
training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
|
303 |
else:
|
304 |
training_list = list(self.encoder.parameters())
|
305 |
+
opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
306 |
else:
|
307 |
training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
308 |
if self.quant_conv is not None:
|
309 |
training_list = training_list + list(self.quant_conv.parameters())
|
310 |
if self.post_quant_conv is not None:
|
311 |
training_list = training_list + list(self.post_quant_conv.parameters())
|
312 |
+
opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
|
313 |
+
opt_disc = torch.optim.AdamW(
|
314 |
list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
|
315 |
+
lr=lr, betas=(0.9, 0.999), weight_decay=5e-2
|
316 |
)
|
317 |
return [opt_ae, opt_disc], []
|
318 |
|
easyanimate/vae/ldm/models/omnigen_enc_dec.py
CHANGED
@@ -58,6 +58,7 @@ class Encoder(nn.Module):
|
|
58 |
down_block_types = ("SpatialDownBlock3D",),
|
59 |
ch = 128,
|
60 |
ch_mult = [1,2,4,4,],
|
|
|
61 |
use_gc_blocks = None,
|
62 |
mid_block_type: str = "MidBlock3D",
|
63 |
mid_block_use_attention: bool = True,
|
@@ -77,7 +78,8 @@ class Encoder(nn.Module):
|
|
77 |
verbose = False,
|
78 |
):
|
79 |
super().__init__()
|
80 |
-
block_out_channels
|
|
|
81 |
assert len(down_block_types) == len(block_out_channels), (
|
82 |
"Number of down block types must match number of block output channels."
|
83 |
)
|
@@ -364,6 +366,7 @@ class Decoder(nn.Module):
|
|
364 |
up_block_types = ("SpatialUpBlock3D",),
|
365 |
ch = 128,
|
366 |
ch_mult = [1,2,4,4,],
|
|
|
367 |
use_gc_blocks = None,
|
368 |
mid_block_type: str = "MidBlock3D",
|
369 |
mid_block_use_attention: bool = True,
|
@@ -382,7 +385,8 @@ class Decoder(nn.Module):
|
|
382 |
verbose = False,
|
383 |
):
|
384 |
super().__init__()
|
385 |
-
block_out_channels
|
|
|
386 |
assert len(up_block_types) == len(block_out_channels), (
|
387 |
"Number of up block types must match number of block output channels."
|
388 |
)
|
|
|
58 |
down_block_types = ("SpatialDownBlock3D",),
|
59 |
ch = 128,
|
60 |
ch_mult = [1,2,4,4,],
|
61 |
+
block_out_channels = [128, 256, 512, 512],
|
62 |
use_gc_blocks = None,
|
63 |
mid_block_type: str = "MidBlock3D",
|
64 |
mid_block_use_attention: bool = True,
|
|
|
78 |
verbose = False,
|
79 |
):
|
80 |
super().__init__()
|
81 |
+
if block_out_channels is None:
|
82 |
+
block_out_channels = [ch * i for i in ch_mult]
|
83 |
assert len(down_block_types) == len(block_out_channels), (
|
84 |
"Number of down block types must match number of block output channels."
|
85 |
)
|
|
|
366 |
up_block_types = ("SpatialUpBlock3D",),
|
367 |
ch = 128,
|
368 |
ch_mult = [1,2,4,4,],
|
369 |
+
block_out_channels = [128, 256, 512, 512],
|
370 |
use_gc_blocks = None,
|
371 |
mid_block_type: str = "MidBlock3D",
|
372 |
mid_block_use_attention: bool = True,
|
|
|
385 |
verbose = False,
|
386 |
):
|
387 |
super().__init__()
|
388 |
+
if block_out_channels is None:
|
389 |
+
block_out_channels = [ch * i for i in ch_mult]
|
390 |
assert len(up_block_types) == len(block_out_channels), (
|
391 |
"Number of up block types must match number of block output channels."
|
392 |
)
|
easyanimate/vae/ldm/modules/losses/contperceptual.py
CHANGED
@@ -9,7 +9,8 @@ from ..vaemodules.discriminator import Discriminator3D
|
|
9 |
class LPIPSWithDiscriminator(nn.Module):
|
10 |
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
11 |
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
12 |
-
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
|
|
13 |
disc_loss="hinge", l2_loss_weight=0.0, l1_loss_weight=1.0):
|
14 |
|
15 |
super().__init__()
|
@@ -34,6 +35,8 @@ class LPIPSWithDiscriminator(nn.Module):
|
|
34 |
self.disc_factor = disc_factor
|
35 |
self.discriminator_weight = disc_weight
|
36 |
self.disc_conditional = disc_conditional
|
|
|
|
|
37 |
self.l1_loss_weight = l1_loss_weight
|
38 |
self.l2_loss_weight = l2_loss_weight
|
39 |
|
@@ -50,6 +53,18 @@ class LPIPSWithDiscriminator(nn.Module):
|
|
50 |
d_weight = d_weight * self.discriminator_weight
|
51 |
return d_weight
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
|
54 |
global_step, last_layer=None, cond=None, split="train",
|
55 |
weights=None):
|
@@ -86,6 +101,8 @@ class LPIPSWithDiscriminator(nn.Module):
|
|
86 |
kl_loss = posteriors.kl()
|
87 |
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
88 |
|
|
|
|
|
89 |
# now the GAN part
|
90 |
if optimizer_idx == 0:
|
91 |
# generator update
|
@@ -102,13 +119,13 @@ class LPIPSWithDiscriminator(nn.Module):
|
|
102 |
try:
|
103 |
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
104 |
except RuntimeError:
|
105 |
-
assert not self.training
|
106 |
d_weight = torch.tensor(0.0)
|
107 |
else:
|
108 |
d_weight = torch.tensor(0.0)
|
109 |
|
110 |
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
111 |
-
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
|
112 |
|
113 |
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
114 |
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
|
|
9 |
class LPIPSWithDiscriminator(nn.Module):
|
10 |
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
11 |
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
12 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
13 |
+
outlier_penalty_loss_r=3.0, outlier_penalty_loss_weight=1e5,
|
14 |
disc_loss="hinge", l2_loss_weight=0.0, l1_loss_weight=1.0):
|
15 |
|
16 |
super().__init__()
|
|
|
35 |
self.disc_factor = disc_factor
|
36 |
self.discriminator_weight = disc_weight
|
37 |
self.disc_conditional = disc_conditional
|
38 |
+
self.outlier_penalty_loss_r = outlier_penalty_loss_r
|
39 |
+
self.outlier_penalty_loss_weight = outlier_penalty_loss_weight
|
40 |
self.l1_loss_weight = l1_loss_weight
|
41 |
self.l2_loss_weight = l2_loss_weight
|
42 |
|
|
|
53 |
d_weight = d_weight * self.discriminator_weight
|
54 |
return d_weight
|
55 |
|
56 |
+
def outlier_penalty_loss(self, posteriors, r):
|
57 |
+
batch_size, channels, frames, height, width = posteriors.shape
|
58 |
+
mean_X = posteriors.mean(dim=(3, 4), keepdim=True)
|
59 |
+
std_X = posteriors.std(dim=(3, 4), keepdim=True)
|
60 |
+
|
61 |
+
diff = torch.abs(posteriors - mean_X)
|
62 |
+
penalty = torch.maximum(diff - r * std_X, torch.zeros_like(diff))
|
63 |
+
|
64 |
+
opl = penalty.sum(dim=(3, 4)) / (height * width)
|
65 |
+
opl_final = opl.mean(dim=(0, 1, 2))
|
66 |
+
return opl_final
|
67 |
+
|
68 |
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
|
69 |
global_step, last_layer=None, cond=None, split="train",
|
70 |
weights=None):
|
|
|
101 |
kl_loss = posteriors.kl()
|
102 |
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
103 |
|
104 |
+
outlier_penalty_loss = self.outlier_penalty_loss(posteriors.mode(), self.outlier_penalty_loss_r) * self.outlier_penalty_loss_weight
|
105 |
+
|
106 |
# now the GAN part
|
107 |
if optimizer_idx == 0:
|
108 |
# generator update
|
|
|
119 |
try:
|
120 |
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
121 |
except RuntimeError:
|
122 |
+
# assert not self.training
|
123 |
d_weight = torch.tensor(0.0)
|
124 |
else:
|
125 |
d_weight = torch.tensor(0.0)
|
126 |
|
127 |
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
128 |
+
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + outlier_penalty_loss
|
129 |
|
130 |
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
131 |
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
easyanimate/vae/ldm/modules/vaemodules/__init__.py
CHANGED
File without changes
|
easyanimate/vae/ldm/modules/vaemodules/activations.py
CHANGED
File without changes
|
easyanimate/vae/ldm/modules/vaemodules/common.py
CHANGED
@@ -8,6 +8,17 @@ from einops import rearrange, repeat
|
|
8 |
from .activations import get_activation
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def cast_tuple(t, length = 1):
|
12 |
return t if isinstance(t, tuple) else ((t,) * length)
|
13 |
|
@@ -66,10 +77,15 @@ class CausalConv3d(nn.Conv3d):
|
|
66 |
**kwargs,
|
67 |
)
|
68 |
|
|
|
|
|
|
|
|
|
69 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
70 |
# x: (B, C, T, H, W)
|
71 |
dtype = x.dtype
|
72 |
-
|
|
|
73 |
if self.padding_flag == 0:
|
74 |
x = F.pad(
|
75 |
x,
|
@@ -85,7 +101,11 @@ class CausalConv3d(nn.Conv3d):
|
|
85 |
mode="replicate", # TODO: check if this is necessary
|
86 |
)
|
87 |
x = x.to(dtype=dtype)
|
88 |
-
|
|
|
|
|
|
|
|
|
89 |
|
90 |
b, c, f, h, w = x.size()
|
91 |
outputs = []
|
@@ -105,7 +125,11 @@ class CausalConv3d(nn.Conv3d):
|
|
105 |
[self.prev_features, x], dim = 2
|
106 |
)
|
107 |
x = x.to(dtype=dtype)
|
108 |
-
|
|
|
|
|
|
|
|
|
109 |
|
110 |
b, c, f, h, w = x.size()
|
111 |
outputs = []
|
@@ -122,7 +146,12 @@ class CausalConv3d(nn.Conv3d):
|
|
122 |
mode="replicate", # TODO: check if this is necessary
|
123 |
)
|
124 |
x = x.to(dtype=dtype)
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
126 |
return super().forward(x)
|
127 |
elif self.padding_flag == 6:
|
128 |
if self.t_stride == 2:
|
@@ -133,7 +162,12 @@ class CausalConv3d(nn.Conv3d):
|
|
133 |
x = torch.concat(
|
134 |
[self.prev_features, x], dim = 2
|
135 |
)
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
137 |
x = x.to(dtype=dtype)
|
138 |
return super().forward(x)
|
139 |
else:
|
|
|
8 |
from .activations import get_activation
|
9 |
|
10 |
|
11 |
+
try:
|
12 |
+
current_version = torch.__version__
|
13 |
+
version_numbers = [int(x) for x in current_version.split('.')[:2]]
|
14 |
+
if version_numbers[0] < 2 or (version_numbers[0] == 2 and version_numbers[1] < 2):
|
15 |
+
need_to_float = True
|
16 |
+
else:
|
17 |
+
need_to_float = False
|
18 |
+
except Exception as e:
|
19 |
+
print("Encountered an error with Torch version. Set the data type to float in the VAE. ")
|
20 |
+
need_to_float = False
|
21 |
+
|
22 |
def cast_tuple(t, length = 1):
|
23 |
return t if isinstance(t, tuple) else ((t,) * length)
|
24 |
|
|
|
77 |
**kwargs,
|
78 |
)
|
79 |
|
80 |
+
def _clear_conv_cache(self):
|
81 |
+
del self.prev_features
|
82 |
+
self.prev_features = None
|
83 |
+
|
84 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
85 |
# x: (B, C, T, H, W)
|
86 |
dtype = x.dtype
|
87 |
+
if need_to_float:
|
88 |
+
x = x.float()
|
89 |
if self.padding_flag == 0:
|
90 |
x = F.pad(
|
91 |
x,
|
|
|
101 |
mode="replicate", # TODO: check if this is necessary
|
102 |
)
|
103 |
x = x.to(dtype=dtype)
|
104 |
+
|
105 |
+
# Clear cache before
|
106 |
+
self._clear_conv_cache()
|
107 |
+
# We could move these to the cpu for a lower VRAM
|
108 |
+
self.prev_features = x[:, :, -self.temporal_padding:].clone()
|
109 |
|
110 |
b, c, f, h, w = x.size()
|
111 |
outputs = []
|
|
|
125 |
[self.prev_features, x], dim = 2
|
126 |
)
|
127 |
x = x.to(dtype=dtype)
|
128 |
+
|
129 |
+
# Clear cache before
|
130 |
+
self._clear_conv_cache()
|
131 |
+
# We could move these to the cpu for a lower VRAM
|
132 |
+
self.prev_features = x[:, :, -self.temporal_padding:].clone()
|
133 |
|
134 |
b, c, f, h, w = x.size()
|
135 |
outputs = []
|
|
|
146 |
mode="replicate", # TODO: check if this is necessary
|
147 |
)
|
148 |
x = x.to(dtype=dtype)
|
149 |
+
|
150 |
+
# Clear cache before
|
151 |
+
self._clear_conv_cache()
|
152 |
+
# We could move these to the cpu for a lower VRAM
|
153 |
+
self.prev_features = x[:, :, -self.temporal_padding:].clone()
|
154 |
+
|
155 |
return super().forward(x)
|
156 |
elif self.padding_flag == 6:
|
157 |
if self.t_stride == 2:
|
|
|
162 |
x = torch.concat(
|
163 |
[self.prev_features, x], dim = 2
|
164 |
)
|
165 |
+
|
166 |
+
# Clear cache before
|
167 |
+
self._clear_conv_cache()
|
168 |
+
# We could move these to the cpu for a lower VRAM
|
169 |
+
self.prev_features = x[:, :, -self.temporal_padding:].clone()
|
170 |
+
|
171 |
x = x.to(dtype=dtype)
|
172 |
return super().forward(x)
|
173 |
else:
|
easyanimate/vae/ldm/modules/vaemodules/down_blocks.py
CHANGED
File without changes
|
easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py
CHANGED
File without changes
|
easyanimate/vae/ldm/modules/vaemodules/up_blocks.py
CHANGED
File without changes
|
requirements.txt
CHANGED
@@ -6,7 +6,6 @@ tomesd
|
|
6 |
torch>=2.1.2
|
7 |
torchdiffeq
|
8 |
torchsde
|
9 |
-
xformers
|
10 |
decord
|
11 |
datasets
|
12 |
numpy
|
@@ -21,8 +20,6 @@ tensorboard
|
|
21 |
beautifulsoup4
|
22 |
ftfy
|
23 |
func_timeout
|
24 |
-
deepspeed
|
25 |
accelerate>=0.25.0
|
26 |
-
|
27 |
-
|
28 |
-
transformers>=4.37.2
|
|
|
6 |
torch>=2.1.2
|
7 |
torchdiffeq
|
8 |
torchsde
|
|
|
9 |
decord
|
10 |
datasets
|
11 |
numpy
|
|
|
20 |
beautifulsoup4
|
21 |
ftfy
|
22 |
func_timeout
|
|
|
23 |
accelerate>=0.25.0
|
24 |
+
diffusers==0.30.1
|
25 |
+
transformers==4.46.2
|
|