bubbliiiing commited on
Commit
c2a6cd2
·
1 Parent(s): b0f1243

Update V5.1

Browse files
Files changed (33) hide show
  1. app.py +9 -18
  2. config/easyanimate_video_v5.1_magvit_qwen.yaml +21 -0
  3. easyanimate/api/api.py +1 -1
  4. easyanimate/api/post_infer.py +2 -2
  5. easyanimate/data/dataset_image_video.py +220 -32
  6. easyanimate/models/__init__.py +3 -4
  7. easyanimate/models/attention.py +60 -31
  8. easyanimate/models/autoencoder_magvit.py +15 -117
  9. easyanimate/models/embeddings.py +3 -2
  10. easyanimate/models/norm.py +16 -0
  11. easyanimate/models/processor.py +146 -0
  12. easyanimate/models/transformer3d.py +280 -43
  13. easyanimate/pipeline/pipeline_easyanimate.py +730 -486
  14. easyanimate/pipeline/{pipeline_easyanimate_multi_text_encoder_control.py → pipeline_easyanimate_control.py} +448 -229
  15. easyanimate/pipeline/pipeline_easyanimate_inpaint.py +0 -0
  16. easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py +0 -925
  17. easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py +0 -1334
  18. easyanimate/ui/ui.py +237 -179
  19. easyanimate/utils/lora_utils.py +42 -30
  20. easyanimate/utils/utils.py +53 -33
  21. easyanimate/vae/ldm/models/autoencoder.py +4 -4
  22. easyanimate/vae/ldm/models/casual3dcnn.py +5 -5
  23. easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py +5 -5
  24. easyanimate/vae/ldm/models/omnigen_casual3dcnn.py +13 -9
  25. easyanimate/vae/ldm/models/omnigen_enc_dec.py +6 -2
  26. easyanimate/vae/ldm/modules/losses/contperceptual.py +20 -3
  27. easyanimate/vae/ldm/modules/vaemodules/__init__.py +0 -0
  28. easyanimate/vae/ldm/modules/vaemodules/activations.py +0 -0
  29. easyanimate/vae/ldm/modules/vaemodules/common.py +39 -5
  30. easyanimate/vae/ldm/modules/vaemodules/down_blocks.py +0 -0
  31. easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py +0 -0
  32. easyanimate/vae/ldm/modules/vaemodules/up_blocks.py +0 -0
  33. requirements.txt +2 -5
app.py CHANGED
@@ -19,6 +19,9 @@ if __name__ == "__main__":
19
  #
20
  # "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use,
21
  # resulting in slower speeds but saving a large amount of GPU memory.
 
 
 
22
  GPU_memory_mode = "model_cpu_offload_and_qfloat8"
23
  # Use torch.float16 if GPU does not support torch.bfloat16
24
  # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
@@ -29,11 +32,11 @@ if __name__ == "__main__":
29
  server_port = 7860
30
 
31
  # Params below is used when ui_mode = "modelscope"
32
- edition = "v5"
33
  # Config
34
- config_path = "config/easyanimate_video_v5_magvit_multi_text_encoder.yaml"
35
  # Model path of the pretrained model
36
- model_name = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP"
37
  # "Inpaint" or "Control"
38
  model_type = "Inpaint"
39
  # Save dir
@@ -46,18 +49,6 @@ if __name__ == "__main__":
46
  else:
47
  demo, controller = ui(GPU_memory_mode, weight_dtype)
48
 
49
- # launch gradio
50
- app, _, _ = demo.queue(status_update_rate=1).launch(
51
- server_name=server_name,
52
- server_port=server_port,
53
- prevent_thread_lock=True
54
- )
55
-
56
- # launch api
57
- infer_forward_api(None, app, controller)
58
- update_diffusion_transformer_api(None, app, controller)
59
- update_edition_api(None, app, controller)
60
-
61
- # not close the python
62
- while True:
63
- time.sleep(5)
 
19
  #
20
  # "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use,
21
  # resulting in slower speeds but saving a large amount of GPU memory.
22
+ #
23
+ # EasyAnimateV1, V2 and V3 support "model_cpu_offload" "sequential_cpu_offload"
24
+ # EasyAnimateV4, V5 and V5.1 support "model_cpu_offload" "model_cpu_offload_and_qfloat8" "sequential_cpu_offload"
25
  GPU_memory_mode = "model_cpu_offload_and_qfloat8"
26
  # Use torch.float16 if GPU does not support torch.bfloat16
27
  # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
 
32
  server_port = 7860
33
 
34
  # Params below is used when ui_mode = "modelscope"
35
+ edition = "v5.1"
36
  # Config
37
+ config_path = "config/easyanimate_video_v5.1_magvit_qwen.yaml"
38
  # Model path of the pretrained model
39
+ model_name = "models/Diffusion_Transformer/EasyAnimateV5.1-12b-zh-InP"
40
  # "Inpaint" or "Control"
41
  model_type = "Inpaint"
42
  # Save dir
 
49
  else:
50
  demo, controller = ui(GPU_memory_mode, weight_dtype)
51
 
52
+ demo.launch(
53
+ server_name=server_name, server_port=server_port
54
+ )
 
 
 
 
 
 
 
 
 
 
 
 
config/easyanimate_video_v5.1_magvit_qwen.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_additional_kwargs:
2
+ transformer_type: "EasyAnimateTransformer3DModel"
3
+ after_norm: false
4
+ time_position_encoding_type: "3d_rope"
5
+ resize_inpaint_mask_directly: true
6
+ enable_text_attention_mask: true
7
+ enable_clip_in_inpaint: false
8
+ add_ref_latent_in_control_model: true
9
+
10
+ vae_kwargs:
11
+ vae_type: "AutoencoderKLMagvit"
12
+ mini_batch_encoder: 4
13
+ mini_batch_decoder: 1
14
+ slice_mag_vae: false
15
+ slice_compression_vae: false
16
+ cache_compression_vae: false
17
+ cache_mag_vae: true
18
+
19
+ text_encoder_kwargs:
20
+ enable_multi_text_encoder: false
21
+ replace_t5_to_llm: true
easyanimate/api/api.py CHANGED
@@ -93,7 +93,7 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
93
  lora_model_path = datas.get('lora_model_path', 'none')
94
  lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
95
  prompt_textbox = datas.get('prompt_textbox', None)
96
- negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Blurring, mutation, deformation, distortion, dark and solid, comics.')
97
  sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
98
  sample_step_slider = datas.get('sample_step_slider', 30)
99
  resize_method = datas.get('resize_method', "Generate by")
 
93
  lora_model_path = datas.get('lora_model_path', 'none')
94
  lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
95
  prompt_textbox = datas.get('prompt_textbox', None)
96
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art.')
97
  sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
98
  sample_step_slider = datas.get('sample_step_slider', 30)
99
  resize_method = datas.get('resize_method', "Generate by")
easyanimate/api/post_infer.py CHANGED
@@ -54,14 +54,14 @@ if __name__ == '__main__':
54
  # -------------------------- #
55
  # Step 1: update edition
56
  # -------------------------- #
57
- edition = "v5"
58
  outputs = post_update_edition(edition)
59
  print('Output update edition: ', outputs)
60
 
61
  # -------------------------- #
62
  # Step 2: update edition
63
  # -------------------------- #
64
- diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP"
65
  outputs = post_diffusion_transformer(diffusion_transformer_path)
66
  print('Output update edition: ', outputs)
67
 
 
54
  # -------------------------- #
55
  # Step 1: update edition
56
  # -------------------------- #
57
+ edition = "v5.1"
58
  outputs = post_update_edition(edition)
59
  print('Output update edition: ', outputs)
60
 
61
  # -------------------------- #
62
  # Step 2: update edition
63
  # -------------------------- #
64
+ diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV5.1-12b-zh-InP"
65
  outputs = post_diffusion_transformer(diffusion_transformer_path)
66
  print('Output update edition: ', outputs)
67
 
easyanimate/data/dataset_image_video.py CHANGED
@@ -12,9 +12,12 @@ import albumentations
12
  import cv2
13
  import numpy as np
14
  import torch
 
15
  import torchvision.transforms as transforms
16
  from decord import VideoReader
 
17
  from func_timeout import FunctionTimedOut, func_timeout
 
18
  from PIL import Image
19
  from torch.utils.data import BatchSampler, Sampler
20
  from torch.utils.data.dataset import Dataset
@@ -100,6 +103,152 @@ def get_random_mask(shape):
100
  else:
101
  raise ValueError(f"The mask_index {mask_index} is not define")
102
  return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  class ImageVideoSampler(BatchSampler):
105
  """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
@@ -184,7 +333,7 @@ class ImageVideoDataset(Dataset):
184
  video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
185
  image_sample_size=512,
186
  video_repeat=0,
187
- text_drop_ratio=-1,
188
  enable_bucket=False,
189
  video_length_drop_start=0.1,
190
  video_length_drop_end=0.9,
@@ -355,7 +504,6 @@ class ImageVideoDataset(Dataset):
355
 
356
  return sample
357
 
358
-
359
  class ImageVideoControlDataset(Dataset):
360
  def __init__(
361
  self,
@@ -363,11 +511,12 @@ class ImageVideoControlDataset(Dataset):
363
  video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
364
  image_sample_size=512,
365
  video_repeat=0,
366
- text_drop_ratio=-1,
367
  enable_bucket=False,
368
  video_length_drop_start=0.1,
369
  video_length_drop_end=0.9,
370
  enable_inpaint=False,
 
371
  ):
372
  # Loading annotations from files
373
  print(f"loading annotations from {ann_path} ...")
@@ -397,6 +546,7 @@ class ImageVideoControlDataset(Dataset):
397
  self.enable_bucket = enable_bucket
398
  self.text_drop_ratio = text_drop_ratio
399
  self.enable_inpaint = enable_inpaint
 
400
 
401
  self.video_length_drop_start = video_length_drop_start
402
  self.video_length_drop_end = video_length_drop_end
@@ -412,6 +562,13 @@ class ImageVideoControlDataset(Dataset):
412
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
413
  ]
414
  )
 
 
 
 
 
 
 
415
 
416
  # Image params
417
  self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
@@ -484,33 +641,59 @@ class ImageVideoControlDataset(Dataset):
484
  else:
485
  control_video_id = os.path.join(self.data_root, control_video_id)
486
 
487
- with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
488
- try:
489
- sample_args = (control_video_reader, batch_index)
490
- control_pixel_values = func_timeout(
491
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
492
- )
493
- resized_frames = []
494
- for i in range(len(control_pixel_values)):
495
- frame = control_pixel_values[i]
496
- resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
497
- resized_frames.append(resized_frame)
498
- control_pixel_values = np.array(resized_frames)
499
- except FunctionTimedOut:
500
- raise ValueError(f"Read {idx} timeout.")
501
- except Exception as e:
502
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
503
-
504
- if not self.enable_bucket:
505
- control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
506
- control_pixel_values = control_pixel_values / 255.
507
- del control_video_reader
508
  else:
509
- control_pixel_values = control_pixel_values
510
-
511
- if not self.enable_bucket:
512
- control_pixel_values = self.video_transforms(control_pixel_values)
513
- return pixel_values, control_pixel_values, text, "video"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  else:
515
  image_path, text = data_info['file_path'], data_info['text']
516
  if self.data_root is not None:
@@ -536,7 +719,8 @@ class ImageVideoControlDataset(Dataset):
536
  control_image = self.image_transforms(control_image).unsqueeze(0)
537
  else:
538
  control_image = np.expand_dims(np.array(control_image), 0)
539
- return image, control_image, text, 'image'
 
540
 
541
  def __len__(self):
542
  return self.length
@@ -552,13 +736,17 @@ class ImageVideoControlDataset(Dataset):
552
  if data_type_local != data_type:
553
  raise ValueError("data_type_local != data_type")
554
 
555
- pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
 
556
  sample["pixel_values"] = pixel_values
557
  sample["control_pixel_values"] = control_pixel_values
558
  sample["text"] = name
559
  sample["data_type"] = data_type
560
  sample["idx"] = idx
561
-
 
 
 
562
  if len(sample) > 0:
563
  break
564
  except Exception as e:
 
12
  import cv2
13
  import numpy as np
14
  import torch
15
+ import torch.nn.functional as F
16
  import torchvision.transforms as transforms
17
  from decord import VideoReader
18
+ from einops import rearrange
19
  from func_timeout import FunctionTimedOut, func_timeout
20
+ from packaging import version as pver
21
  from PIL import Image
22
  from torch.utils.data import BatchSampler, Sampler
23
  from torch.utils.data.dataset import Dataset
 
103
  else:
104
  raise ValueError(f"The mask_index {mask_index} is not define")
105
  return mask
106
+
107
+ class Camera(object):
108
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
109
+ """
110
+ def __init__(self, entry):
111
+ fx, fy, cx, cy = entry[1:5]
112
+ self.fx = fx
113
+ self.fy = fy
114
+ self.cx = cx
115
+ self.cy = cy
116
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
117
+ w2c_mat_4x4 = np.eye(4)
118
+ w2c_mat_4x4[:3, :] = w2c_mat
119
+ self.w2c_mat = w2c_mat_4x4
120
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
121
+
122
+ def custom_meshgrid(*args):
123
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
124
+ """
125
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
126
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
127
+ return torch.meshgrid(*args)
128
+ else:
129
+ return torch.meshgrid(*args, indexing='ij')
130
+
131
+ def get_relative_pose(cam_params):
132
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
133
+ """
134
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
135
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
136
+ cam_to_origin = 0
137
+ target_cam_c2w = np.array([
138
+ [1, 0, 0, 0],
139
+ [0, 1, 0, -cam_to_origin],
140
+ [0, 0, 1, 0],
141
+ [0, 0, 0, 1]
142
+ ])
143
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
144
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
145
+ ret_poses = np.array(ret_poses, dtype=np.float32)
146
+ return ret_poses
147
+
148
+ def ray_condition(K, c2w, H, W, device):
149
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
150
+ """
151
+ # c2w: B, V, 4, 4
152
+ # K: B, V, 4
153
+
154
+ B = K.shape[0]
155
+
156
+ j, i = custom_meshgrid(
157
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
158
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
159
+ )
160
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
161
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
162
+
163
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
164
+
165
+ zs = torch.ones_like(i) # [B, HxW]
166
+ xs = (i - cx) / fx * zs
167
+ ys = (j - cy) / fy * zs
168
+ zs = zs.expand_as(ys)
169
+
170
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
171
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
172
+
173
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
174
+ rays_o = c2w[..., :3, 3] # B, V, 3
175
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
176
+ # c2w @ dirctions
177
+ rays_dxo = torch.cross(rays_o, rays_d)
178
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
179
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
180
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
181
+ return plucker
182
+
183
+ def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
184
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
185
+ """
186
+ with open(pose_file_path, 'r') as f:
187
+ poses = f.readlines()
188
+
189
+ poses = [pose.strip().split(' ') for pose in poses[1:]]
190
+ cam_params = [[float(x) for x in pose] for pose in poses]
191
+ if return_poses:
192
+ return cam_params
193
+ else:
194
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
195
+
196
+ sample_wh_ratio = width / height
197
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
198
+
199
+ if pose_wh_ratio > sample_wh_ratio:
200
+ resized_ori_w = height * pose_wh_ratio
201
+ for cam_param in cam_params:
202
+ cam_param.fx = resized_ori_w * cam_param.fx / width
203
+ else:
204
+ resized_ori_h = width / pose_wh_ratio
205
+ for cam_param in cam_params:
206
+ cam_param.fy = resized_ori_h * cam_param.fy / height
207
+
208
+ intrinsic = np.asarray([[cam_param.fx * width,
209
+ cam_param.fy * height,
210
+ cam_param.cx * width,
211
+ cam_param.cy * height]
212
+ for cam_param in cam_params], dtype=np.float32)
213
+
214
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
215
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
216
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
217
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
218
+ plucker_embedding = plucker_embedding[None]
219
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
220
+ return plucker_embedding
221
+
222
+ def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
223
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
224
+ """
225
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
226
+
227
+ sample_wh_ratio = width / height
228
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
229
+
230
+ if pose_wh_ratio > sample_wh_ratio:
231
+ resized_ori_w = height * pose_wh_ratio
232
+ for cam_param in cam_params:
233
+ cam_param.fx = resized_ori_w * cam_param.fx / width
234
+ else:
235
+ resized_ori_h = width / pose_wh_ratio
236
+ for cam_param in cam_params:
237
+ cam_param.fy = resized_ori_h * cam_param.fy / height
238
+
239
+ intrinsic = np.asarray([[cam_param.fx * width,
240
+ cam_param.fy * height,
241
+ cam_param.cx * width,
242
+ cam_param.cy * height]
243
+ for cam_param in cam_params], dtype=np.float32)
244
+
245
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
246
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
247
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
248
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
249
+ plucker_embedding = plucker_embedding[None]
250
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
251
+ return plucker_embedding
252
 
253
  class ImageVideoSampler(BatchSampler):
254
  """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
 
333
  video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
334
  image_sample_size=512,
335
  video_repeat=0,
336
+ text_drop_ratio=0.1,
337
  enable_bucket=False,
338
  video_length_drop_start=0.1,
339
  video_length_drop_end=0.9,
 
504
 
505
  return sample
506
 
 
507
  class ImageVideoControlDataset(Dataset):
508
  def __init__(
509
  self,
 
511
  video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
512
  image_sample_size=512,
513
  video_repeat=0,
514
+ text_drop_ratio=0.1,
515
  enable_bucket=False,
516
  video_length_drop_start=0.1,
517
  video_length_drop_end=0.9,
518
  enable_inpaint=False,
519
+ enable_camera_info=False,
520
  ):
521
  # Loading annotations from files
522
  print(f"loading annotations from {ann_path} ...")
 
546
  self.enable_bucket = enable_bucket
547
  self.text_drop_ratio = text_drop_ratio
548
  self.enable_inpaint = enable_inpaint
549
+ self.enable_camera_info = enable_camera_info
550
 
551
  self.video_length_drop_start = video_length_drop_start
552
  self.video_length_drop_end = video_length_drop_end
 
562
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
563
  ]
564
  )
565
+ if self.enable_camera_info:
566
+ self.video_transforms_camera = transforms.Compose(
567
+ [
568
+ transforms.Resize(min(self.video_sample_size)),
569
+ transforms.CenterCrop(self.video_sample_size)
570
+ ]
571
+ )
572
 
573
  # Image params
574
  self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
 
641
  else:
642
  control_video_id = os.path.join(self.data_root, control_video_id)
643
 
644
+ if self.enable_camera_info:
645
+ if control_video_id.lower().endswith('.txt'):
646
+ if not self.enable_bucket:
647
+ control_pixel_values = torch.zeros_like(pixel_values)
648
+
649
+ control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
650
+ control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
651
+ control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
652
+ control_camera_values = self.video_transforms_camera(control_camera_values)
653
+ else:
654
+ control_pixel_values = np.zeros_like(pixel_values)
655
+
656
+ control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
657
+ control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
658
+ control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
659
+ control_camera_values = np.array([control_camera_values[index] for index in batch_index])
 
 
 
 
 
660
  else:
661
+ if not self.enable_bucket:
662
+ control_pixel_values = torch.zeros_like(pixel_values)
663
+ control_camera_values = None
664
+ else:
665
+ control_pixel_values = np.zeros_like(pixel_values)
666
+ control_camera_values = None
667
+ else:
668
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
669
+ try:
670
+ sample_args = (control_video_reader, batch_index)
671
+ control_pixel_values = func_timeout(
672
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
673
+ )
674
+ resized_frames = []
675
+ for i in range(len(control_pixel_values)):
676
+ frame = control_pixel_values[i]
677
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
678
+ resized_frames.append(resized_frame)
679
+ control_pixel_values = np.array(resized_frames)
680
+ except FunctionTimedOut:
681
+ raise ValueError(f"Read {idx} timeout.")
682
+ except Exception as e:
683
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
684
+
685
+ if not self.enable_bucket:
686
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
687
+ control_pixel_values = control_pixel_values / 255.
688
+ del control_video_reader
689
+ else:
690
+ control_pixel_values = control_pixel_values
691
+
692
+ if not self.enable_bucket:
693
+ control_pixel_values = self.video_transforms(control_pixel_values)
694
+ control_camera_values = None
695
+
696
+ return pixel_values, control_pixel_values, control_camera_values, text, "video"
697
  else:
698
  image_path, text = data_info['file_path'], data_info['text']
699
  if self.data_root is not None:
 
719
  control_image = self.image_transforms(control_image).unsqueeze(0)
720
  else:
721
  control_image = np.expand_dims(np.array(control_image), 0)
722
+
723
+ return image, control_image, None, text, 'image'
724
 
725
  def __len__(self):
726
  return self.length
 
736
  if data_type_local != data_type:
737
  raise ValueError("data_type_local != data_type")
738
 
739
+ pixel_values, control_pixel_values, control_camera_values, name, data_type = self.get_batch(idx)
740
+
741
  sample["pixel_values"] = pixel_values
742
  sample["control_pixel_values"] = control_pixel_values
743
  sample["text"] = name
744
  sample["data_type"] = data_type
745
  sample["idx"] = idx
746
+
747
+ if self.enable_camera_info:
748
+ sample["control_camera_values"] = control_camera_values
749
+
750
  if len(sample) > 0:
751
  break
752
  except Exception as e:
easyanimate/models/__init__.py CHANGED
@@ -1,8 +1,7 @@
1
- from .autoencoder_magvit import (AutoencoderKLCogVideoX, AutoencoderKLMagvit, AutoencoderKL)
 
2
  from .transformer3d import (EasyAnimateTransformer3DModel,
3
- HunyuanTransformer3DModel,
4
- Transformer3DModel)
5
-
6
 
7
  name_to_transformer3d = {
8
  "Transformer3DModel": Transformer3DModel,
 
1
+ from .autoencoder_magvit import (AutoencoderKL, AutoencoderKLCogVideoX,
2
+ AutoencoderKLMagvit)
3
  from .transformer3d import (EasyAnimateTransformer3DModel,
4
+ HunyuanTransformer3DModel, Transformer3DModel)
 
 
5
 
6
  name_to_transformer3d = {
7
  "Transformer3DModel": Transformer3DModel,
easyanimate/models/attention.py CHANGED
@@ -29,7 +29,7 @@ from diffusers.models.embeddings import (SinusoidalPositionalEmbedding,
29
  get_3d_sincos_pos_embed)
30
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
31
  from diffusers.models.modeling_utils import ModelMixin
32
- from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero,
33
  CogVideoXLayerNormZero)
34
  from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging
35
  from diffusers.utils.import_utils import is_xformers_available
@@ -38,12 +38,11 @@ from einops import rearrange, repeat
38
  from torch import nn
39
 
40
  from .motion_module import PositionalEncoding, get_motion_module
41
- from .norm import AdaLayerNormShift, FP32LayerNorm, EasyAnimateLayerNormZero
42
  from .processor import (EasyAnimateAttnProcessor2_0,
 
43
  LazyKVCompressionProcessor2_0)
44
 
45
-
46
-
47
  if is_xformers_available():
48
  import xformers
49
  import xformers.ops
@@ -1042,7 +1041,9 @@ class EasyAnimateDiTBlock(nn.Module):
1042
  ff_bias: bool = True,
1043
  qk_norm: bool = True,
1044
  after_norm: bool = False,
1045
- norm_type: str="fp32_layer_norm"
 
 
1046
  ):
1047
  super().__init__()
1048
 
@@ -1051,6 +1052,7 @@ class EasyAnimateDiTBlock(nn.Module):
1051
  time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
1052
  )
1053
 
 
1054
  self.attn1 = Attention(
1055
  query_dim=dim,
1056
  dim_head=attention_head_dim,
@@ -1058,17 +1060,20 @@ class EasyAnimateDiTBlock(nn.Module):
1058
  qk_norm="layer_norm" if qk_norm else None,
1059
  eps=1e-6,
1060
  bias=True,
1061
- processor=EasyAnimateAttnProcessor2_0(),
1062
- )
1063
- self.attn2 = Attention(
1064
- query_dim=dim,
1065
- dim_head=attention_head_dim,
1066
- heads=num_attention_heads,
1067
- qk_norm="layer_norm" if qk_norm else None,
1068
- eps=1e-6,
1069
- bias=True,
1070
- processor=EasyAnimateAttnProcessor2_0(),
1071
  )
 
 
 
 
 
 
 
 
 
 
 
 
1072
 
1073
  # FFN Part
1074
  self.norm2 = EasyAnimateLayerNormZero(
@@ -1082,14 +1087,18 @@ class EasyAnimateDiTBlock(nn.Module):
1082
  inner_dim=ff_inner_dim,
1083
  bias=ff_bias,
1084
  )
1085
- self.txt_ff = FeedForward(
1086
- dim,
1087
- dropout=dropout,
1088
- activation_fn=activation_fn,
1089
- final_dropout=final_dropout,
1090
- inner_dim=ff_inner_dim,
1091
- bias=ff_bias,
1092
- )
 
 
 
 
1093
  if after_norm:
1094
  self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1095
  else:
@@ -1101,6 +1110,9 @@ class EasyAnimateDiTBlock(nn.Module):
1101
  encoder_hidden_states: torch.Tensor,
1102
  temb: torch.Tensor,
1103
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 
 
 
1104
  ) -> torch.Tensor:
1105
  # Norm
1106
  norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
@@ -1108,12 +1120,23 @@ class EasyAnimateDiTBlock(nn.Module):
1108
  )
1109
 
1110
  # Attn
1111
- attn_hidden_states, attn_encoder_hidden_states = self.attn1(
1112
- hidden_states=norm_hidden_states,
1113
- encoder_hidden_states=norm_encoder_hidden_states,
1114
- image_rotary_emb=image_rotary_emb,
1115
- attn2=self.attn2,
1116
- )
 
 
 
 
 
 
 
 
 
 
 
1117
  hidden_states = hidden_states + gate_msa * attn_hidden_states
1118
  encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
1119
 
@@ -1125,10 +1148,16 @@ class EasyAnimateDiTBlock(nn.Module):
1125
  # FFN
1126
  if self.norm3 is not None:
1127
  norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
1128
- norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
 
 
 
1129
  else:
1130
  norm_hidden_states = self.ff(norm_hidden_states)
1131
- norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
 
 
 
1132
  hidden_states = hidden_states + gate_ff * norm_hidden_states
1133
  encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states
1134
  return hidden_states, encoder_hidden_states
 
29
  get_3d_sincos_pos_embed)
30
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
31
  from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero,
33
  CogVideoXLayerNormZero)
34
  from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging
35
  from diffusers.utils.import_utils import is_xformers_available
 
38
  from torch import nn
39
 
40
  from .motion_module import PositionalEncoding, get_motion_module
41
+ from .norm import AdaLayerNormShift, EasyAnimateLayerNormZero, FP32LayerNorm
42
  from .processor import (EasyAnimateAttnProcessor2_0,
43
+ EasyAnimateSWAttnProcessor2_0,
44
  LazyKVCompressionProcessor2_0)
45
 
 
 
46
  if is_xformers_available():
47
  import xformers
48
  import xformers.ops
 
1041
  ff_bias: bool = True,
1042
  qk_norm: bool = True,
1043
  after_norm: bool = False,
1044
+ norm_type: str="fp32_layer_norm",
1045
+ is_mmdit_block: bool = True,
1046
+ is_swa: bool = False,
1047
  ):
1048
  super().__init__()
1049
 
 
1052
  time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
1053
  )
1054
 
1055
+ self.is_swa = is_swa
1056
  self.attn1 = Attention(
1057
  query_dim=dim,
1058
  dim_head=attention_head_dim,
 
1060
  qk_norm="layer_norm" if qk_norm else None,
1061
  eps=1e-6,
1062
  bias=True,
1063
+ processor=EasyAnimateAttnProcessor2_0() if not is_swa else EasyAnimateSWAttnProcessor2_0(),
 
 
 
 
 
 
 
 
 
1064
  )
1065
+ if is_mmdit_block:
1066
+ self.attn2 = Attention(
1067
+ query_dim=dim,
1068
+ dim_head=attention_head_dim,
1069
+ heads=num_attention_heads,
1070
+ qk_norm="layer_norm" if qk_norm else None,
1071
+ eps=1e-6,
1072
+ bias=True,
1073
+ processor=EasyAnimateAttnProcessor2_0() if not is_swa else EasyAnimateSWAttnProcessor2_0(),
1074
+ )
1075
+ else:
1076
+ self.attn2 = None
1077
 
1078
  # FFN Part
1079
  self.norm2 = EasyAnimateLayerNormZero(
 
1087
  inner_dim=ff_inner_dim,
1088
  bias=ff_bias,
1089
  )
1090
+ if is_mmdit_block:
1091
+ self.txt_ff = FeedForward(
1092
+ dim,
1093
+ dropout=dropout,
1094
+ activation_fn=activation_fn,
1095
+ final_dropout=final_dropout,
1096
+ inner_dim=ff_inner_dim,
1097
+ bias=ff_bias,
1098
+ )
1099
+ else:
1100
+ self.txt_ff = None
1101
+
1102
  if after_norm:
1103
  self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1104
  else:
 
1110
  encoder_hidden_states: torch.Tensor,
1111
  temb: torch.Tensor,
1112
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1113
+ num_frames = None,
1114
+ height = None,
1115
+ width = None
1116
  ) -> torch.Tensor:
1117
  # Norm
1118
  norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
 
1120
  )
1121
 
1122
  # Attn
1123
+ if self.is_swa:
1124
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
1125
+ hidden_states=norm_hidden_states,
1126
+ encoder_hidden_states=norm_encoder_hidden_states,
1127
+ image_rotary_emb=image_rotary_emb,
1128
+ attn2=self.attn2,
1129
+ num_frames=num_frames,
1130
+ height=height,
1131
+ width=width,
1132
+ )
1133
+ else:
1134
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
1135
+ hidden_states=norm_hidden_states,
1136
+ encoder_hidden_states=norm_encoder_hidden_states,
1137
+ image_rotary_emb=image_rotary_emb,
1138
+ attn2=self.attn2
1139
+ )
1140
  hidden_states = hidden_states + gate_msa * attn_hidden_states
1141
  encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
1142
 
 
1148
  # FFN
1149
  if self.norm3 is not None:
1150
  norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
1151
+ if self.txt_ff is not None:
1152
+ norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
1153
+ else:
1154
+ norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states))
1155
  else:
1156
  norm_hidden_states = self.ff(norm_hidden_states)
1157
+ if self.txt_ff is not None:
1158
+ norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
1159
+ else:
1160
+ norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
1161
  hidden_states = hidden_states + gate_ff * norm_hidden_states
1162
  encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states
1163
  return hidden_states, encoder_hidden_states
easyanimate/models/autoencoder_magvit.py CHANGED
@@ -44,6 +44,7 @@ from ..vae.ldm.models.cogvideox_enc_dec import (CogVideoXCausalConv3d,
44
  CogVideoXDecoder3D,
45
  CogVideoXEncoder3D,
46
  CogVideoXSafeConv3d)
 
47
  from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
48
  from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
49
 
@@ -96,6 +97,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
96
  out_channels: int = 3,
97
  ch = 128,
98
  ch_mult = [ 1,2,4,4 ],
 
99
  use_gc_blocks = None,
100
  down_block_types: tuple = None,
101
  up_block_types: tuple = None,
@@ -109,6 +111,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
109
  latent_channels: int = 4,
110
  norm_num_groups: int = 32,
111
  scaling_factor: float = 0.1825,
 
112
  slice_mag_vae=True,
113
  slice_compression_vae=False,
114
  cache_compression_vae=False,
@@ -130,8 +133,9 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
130
  in_channels=in_channels,
131
  out_channels=latent_channels,
132
  down_block_types=down_block_types,
133
- ch = ch,
134
- ch_mult = ch_mult,
 
135
  use_gc_blocks=use_gc_blocks,
136
  mid_block_type=mid_block_type,
137
  mid_block_use_attention=mid_block_use_attention,
@@ -154,8 +158,9 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
154
  in_channels=latent_channels,
155
  out_channels=out_channels,
156
  up_block_types=up_block_types,
157
- ch = ch,
158
- ch_mult = ch_mult,
 
159
  use_gc_blocks=use_gc_blocks,
160
  mid_block_type=mid_block_type,
161
  mid_block_use_attention=mid_block_use_attention,
@@ -196,81 +201,10 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
196
  if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
197
  module.gradient_checkpointing = value
198
 
199
- @property
200
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
201
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
202
- r"""
203
- Returns:
204
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
205
- indexed by its weight name.
206
- """
207
- # set recursively
208
- processors = {}
209
-
210
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
211
- if hasattr(module, "get_processor"):
212
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
213
-
214
- for sub_name, child in module.named_children():
215
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
216
-
217
- return processors
218
-
219
- for name, module in self.named_children():
220
- fn_recursive_add_processors(name, module, processors)
221
-
222
- return processors
223
-
224
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
225
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
226
- r"""
227
- Sets the attention processor to use to compute attention.
228
-
229
- Parameters:
230
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
231
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
232
- for **all** `Attention` layers.
233
-
234
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
235
- processor. This is strongly recommended when setting trainable attention processors.
236
-
237
- """
238
- count = len(self.attn_processors.keys())
239
-
240
- if isinstance(processor, dict) and len(processor) != count:
241
- raise ValueError(
242
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
243
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
244
- )
245
-
246
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
247
- if hasattr(module, "set_processor"):
248
- if not isinstance(processor, dict):
249
- module.set_processor(processor)
250
- else:
251
- module.set_processor(processor.pop(f"{name}.processor"))
252
-
253
- for sub_name, child in module.named_children():
254
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
255
-
256
- for name, module in self.named_children():
257
- fn_recursive_attn_processor(name, module, processor)
258
-
259
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
260
- def set_default_attn_processor(self):
261
- """
262
- Disables custom attention processors and sets the default attention implementation.
263
- """
264
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
265
- processor = AttnAddedKVProcessor()
266
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
267
- processor = AttnProcessor()
268
- else:
269
- raise ValueError(
270
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
271
- )
272
-
273
- self.set_attn_processor(processor)
274
 
275
  @apply_forward_hook
276
  def encode(
@@ -308,6 +242,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
308
  moments = self.quant_conv(h)
309
  posterior = DiagonalGaussianDistribution(moments)
310
 
 
311
  if not return_dict:
312
  return (posterior,)
313
 
@@ -355,6 +290,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
355
  else:
356
  decoded = self._decode(z).sample
357
 
 
358
  if not return_dict:
359
  return (decoded,)
360
 
@@ -519,44 +455,6 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
519
 
520
  return DecoderOutput(sample=dec)
521
 
522
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
523
- def fuse_qkv_projections(self):
524
- """
525
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
526
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
527
-
528
- <Tip warning={true}>
529
-
530
- This API is 🧪 experimental.
531
-
532
- </Tip>
533
- """
534
- self.original_attn_processors = None
535
-
536
- for _, attn_processor in self.attn_processors.items():
537
- if "Added" in str(attn_processor.__class__.__name__):
538
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
539
-
540
- self.original_attn_processors = self.attn_processors
541
-
542
- for module in self.modules():
543
- if isinstance(module, Attention):
544
- module.fuse_projections(fuse=True)
545
-
546
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
547
- def unfuse_qkv_projections(self):
548
- """Disables the fused QKV projection if enabled.
549
-
550
- <Tip warning={true}>
551
-
552
- This API is 🧪 experimental.
553
-
554
- </Tip>
555
-
556
- """
557
- if self.original_attn_processors is not None:
558
- self.set_attn_processor(self.original_attn_processors)
559
-
560
  @classmethod
561
  def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
562
  import json
 
44
  CogVideoXDecoder3D,
45
  CogVideoXEncoder3D,
46
  CogVideoXSafeConv3d)
47
+ from ..vae.ldm.models.omnigen_enc_dec import CausalConv3d
48
  from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
49
  from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
50
 
 
97
  out_channels: int = 3,
98
  ch = 128,
99
  ch_mult = [ 1,2,4,4 ],
100
+ block_out_channels = [128, 256, 512, 512],
101
  use_gc_blocks = None,
102
  down_block_types: tuple = None,
103
  up_block_types: tuple = None,
 
111
  latent_channels: int = 4,
112
  norm_num_groups: int = 32,
113
  scaling_factor: float = 0.1825,
114
+ force_upcast: float = True,
115
  slice_mag_vae=True,
116
  slice_compression_vae=False,
117
  cache_compression_vae=False,
 
133
  in_channels=in_channels,
134
  out_channels=latent_channels,
135
  down_block_types=down_block_types,
136
+ ch=ch,
137
+ ch_mult=ch_mult,
138
+ block_out_channels=block_out_channels,
139
  use_gc_blocks=use_gc_blocks,
140
  mid_block_type=mid_block_type,
141
  mid_block_use_attention=mid_block_use_attention,
 
158
  in_channels=latent_channels,
159
  out_channels=out_channels,
160
  up_block_types=up_block_types,
161
+ ch=ch,
162
+ ch_mult=ch_mult,
163
+ block_out_channels=block_out_channels,
164
  use_gc_blocks=use_gc_blocks,
165
  mid_block_type=mid_block_type,
166
  mid_block_use_attention=mid_block_use_attention,
 
201
  if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
202
  module.gradient_checkpointing = value
203
 
204
+ def _clear_conv_cache(self):
205
+ for name, module in self.named_modules():
206
+ if isinstance(module, CausalConv3d):
207
+ module._clear_conv_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  @apply_forward_hook
210
  def encode(
 
242
  moments = self.quant_conv(h)
243
  posterior = DiagonalGaussianDistribution(moments)
244
 
245
+ self._clear_conv_cache()
246
  if not return_dict:
247
  return (posterior,)
248
 
 
290
  else:
291
  decoded = self._decode(z).sample
292
 
293
+ self._clear_conv_cache()
294
  if not return_dict:
295
  return (decoded,)
296
 
 
455
 
456
  return DecoderOutput(sample=dec)
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  @classmethod
459
  def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
460
  import json
easyanimate/models/embeddings.py CHANGED
@@ -4,8 +4,9 @@ from typing import Optional
4
  import numpy as np
5
  import torch
6
  import torch.nn.functional as F
7
- from diffusers.models.embeddings import (PixArtAlphaTextProjection, get_timestep_embedding,
8
- TimestepEmbedding, Timesteps)
 
9
  from einops import rearrange
10
  from torch import nn
11
 
 
4
  import numpy as np
5
  import torch
6
  import torch.nn.functional as F
7
+ from diffusers.models.embeddings import (PixArtAlphaTextProjection,
8
+ TimestepEmbedding, Timesteps,
9
+ get_timestep_embedding)
10
  from einops import rearrange
11
  from torch import nn
12
 
easyanimate/models/norm.py CHANGED
@@ -25,6 +25,22 @@ class FP32LayerNorm(nn.LayerNorm):
25
  inputs.float(), self.normalized_shape, None, None, self.eps
26
  ).to(origin_dtype)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
29
  """
30
  For PixArt-Alpha.
 
25
  inputs.float(), self.normalized_shape, None, None, self.eps
26
  ).to(origin_dtype)
27
 
28
+ class EasyAnimateRMSNorm(nn.Module):
29
+ def __init__(self, hidden_size, eps=1e-6):
30
+ super().__init__()
31
+ self.weight = nn.Parameter(torch.ones(hidden_size))
32
+ self.variance_epsilon = eps
33
+
34
+ def forward(self, hidden_states):
35
+ input_dtype = hidden_states.dtype
36
+ hidden_states = hidden_states.to(torch.float32)
37
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
38
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
39
+ return self.weight * hidden_states.to(input_dtype)
40
+
41
+ def extra_repr(self):
42
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
43
+
44
  class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
45
  """
46
  For PixArt-Alpha.
easyanimate/models/processor.py CHANGED
@@ -310,3 +310,149 @@ class EasyAnimateAttnProcessor2_0:
310
  hidden_states = attn.to_out[1](hidden_states)
311
  encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
312
  return hidden_states, encoder_hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  hidden_states = attn.to_out[1](hidden_states)
311
  encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
312
  return hidden_states, encoder_hidden_states
313
+
314
+ try:
315
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
316
+ from flash_attn.bert_padding import pad_input, unpad_input
317
+ except:
318
+ print("Flash Attention is not installed. Please install with `pip install flash-attn`, if you want to use SWA.")
319
+
320
+ class EasyAnimateSWAttnProcessor2_0:
321
+ def __init__(self, window_size=1024):
322
+ self.window_size = window_size
323
+
324
+ def __call__(
325
+ self,
326
+ attn: Attention,
327
+ hidden_states: torch.Tensor,
328
+ encoder_hidden_states: torch.Tensor,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ image_rotary_emb: Optional[torch.Tensor] = None,
331
+ num_frames: int = None,
332
+ height: int = None,
333
+ width: int = None,
334
+ attn2: Attention = None,
335
+ ) -> torch.Tensor:
336
+ text_seq_length = encoder_hidden_states.size(1)
337
+
338
+ batch_size, sequence_length, _ = (
339
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
340
+ )
341
+
342
+ if attn2 is None:
343
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
344
+
345
+ query = attn.to_q(hidden_states)
346
+ key = attn.to_k(hidden_states)
347
+ value = attn.to_v(hidden_states)
348
+
349
+ inner_dim = key.shape[-1]
350
+ head_dim = inner_dim // attn.heads
351
+
352
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
353
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
354
+ value = value.view(batch_size, -1, attn.heads, head_dim)
355
+
356
+ if attn.norm_q is not None:
357
+ query = attn.norm_q(query)
358
+ if attn.norm_k is not None:
359
+ key = attn.norm_k(key)
360
+
361
+ if attn2 is not None:
362
+ query_txt = attn2.to_q(encoder_hidden_states)
363
+ key_txt = attn2.to_k(encoder_hidden_states)
364
+ value_txt = attn2.to_v(encoder_hidden_states)
365
+
366
+ inner_dim = key_txt.shape[-1]
367
+ head_dim = inner_dim // attn.heads
368
+
369
+ query_txt = query_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
370
+ key_txt = key_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
371
+ value_txt = value_txt.view(batch_size, -1, attn.heads, head_dim)
372
+
373
+ if attn2.norm_q is not None:
374
+ query_txt = attn2.norm_q(query_txt)
375
+ if attn2.norm_k is not None:
376
+ key_txt = attn2.norm_k(key_txt)
377
+
378
+ query = torch.cat([query_txt, query], dim=2)
379
+ key = torch.cat([key_txt, key], dim=2)
380
+ value = torch.cat([value_txt, value], dim=1)
381
+
382
+ # Apply RoPE if needed
383
+ if image_rotary_emb is not None:
384
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
385
+ if not attn.is_cross_attention:
386
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
387
+
388
+ query = query.transpose(1, 2).to(value)
389
+ key = key.transpose(1, 2).to(value)
390
+ interval = max((query.size(1) - text_seq_length) // (self.window_size - text_seq_length), 1)
391
+
392
+ cross_key = torch.cat([key[:, :text_seq_length], key[:, text_seq_length::interval]], dim=1)
393
+ cross_val = torch.cat([value[:, :text_seq_length], value[:, text_seq_length::interval]], dim=1)
394
+ cross_hidden_states = flash_attn_func(query, cross_key, cross_val, dropout_p=0.0, causal=False)
395
+
396
+ # Split and rearrange to six directions
397
+ querys = torch.tensor_split(query[:, text_seq_length:], 6, 2)
398
+ keys = torch.tensor_split(key[:, text_seq_length:], 6, 2)
399
+ values = torch.tensor_split(value[:, text_seq_length:], 6, 2)
400
+
401
+ new_querys = [querys[0]]
402
+ new_keys = [keys[0]]
403
+ new_values = [values[0]]
404
+ for index, mode in enumerate(
405
+ [
406
+ "bs (f h w) hn hd -> bs (f w h) hn hd",
407
+ "bs (f h w) hn hd -> bs (h f w) hn hd",
408
+ "bs (f h w) hn hd -> bs (h w f) hn hd",
409
+ "bs (f h w) hn hd -> bs (w f h) hn hd",
410
+ "bs (f h w) hn hd -> bs (w h f) hn hd"
411
+ ]
412
+ ):
413
+ new_querys.append(rearrange(querys[index + 1], mode, f=num_frames, h=height, w=width))
414
+ new_keys.append(rearrange(keys[index + 1], mode, f=num_frames, h=height, w=width))
415
+ new_values.append(rearrange(values[index + 1], mode, f=num_frames, h=height, w=width))
416
+ query = torch.cat(new_querys, dim=2)
417
+ key = torch.cat(new_keys, dim=2)
418
+ value = torch.cat(new_values, dim=2)
419
+
420
+ # apply attention
421
+ hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False, window_size=(self.window_size, self.window_size))
422
+
423
+ hidden_states = torch.tensor_split(hidden_states, 6, 2)
424
+ new_hidden_states = [hidden_states[0]]
425
+ for index, mode in enumerate(
426
+ [
427
+ "bs (f w h) hn hd -> bs (f h w) hn hd",
428
+ "bs (h f w) hn hd -> bs (f h w) hn hd",
429
+ "bs (h w f) hn hd -> bs (f h w) hn hd",
430
+ "bs (w f h) hn hd -> bs (f h w) hn hd",
431
+ "bs (w h f) hn hd -> bs (f h w) hn hd"
432
+ ]
433
+ ):
434
+ new_hidden_states.append(rearrange(hidden_states[index + 1], mode, f=num_frames, h=height, w=width))
435
+ hidden_states = torch.cat([cross_hidden_states[:, :text_seq_length], torch.cat(new_hidden_states, dim=2)], dim=1) + cross_hidden_states
436
+
437
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
438
+
439
+ if attn2 is None:
440
+ # linear proj
441
+ hidden_states = attn.to_out[0](hidden_states)
442
+ # dropout
443
+ hidden_states = attn.to_out[1](hidden_states)
444
+
445
+ encoder_hidden_states, hidden_states = hidden_states.split(
446
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
447
+ )
448
+ else:
449
+ encoder_hidden_states, hidden_states = hidden_states.split(
450
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
451
+ )
452
+ # linear proj
453
+ hidden_states = attn.to_out[0](hidden_states)
454
+ encoder_hidden_states = attn2.to_out[0](encoder_hidden_states)
455
+ # dropout
456
+ hidden_states = attn.to_out[1](hidden_states)
457
+ encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
458
+ return hidden_states, encoder_hidden_states
easyanimate/models/transformer3d.py CHANGED
@@ -39,8 +39,9 @@ from torch import nn
39
  from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock,
40
  SelfAttentionTemporalTransformerBlock,
41
  TemporalTransformerBlock, zero_module)
42
- from .embeddings import HunyuanCombinedTimestepTextSizeStyleEmbedding, TimePositionalEncoding
43
- from .norm import AdaLayerNormSingle
 
44
  from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D,
45
  TemporalUpsampler3D, UnPatch1D)
46
  from .resampler import Resampler
@@ -142,6 +143,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
142
  norm_eps: float = 1e-5,
143
  attention_type: str = "default",
144
  caption_channels: int = None,
 
145
  # block type
146
  basic_block_type: str = "motionmodule",
147
  # enable_uvit
@@ -168,6 +170,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
168
  after_norm = False,
169
  resize_inpaint_mask_directly: bool = False,
170
  enable_clip_in_inpaint: bool = True,
 
 
171
  enable_text_attention_mask: bool = True,
172
  add_noise_in_inpaint_model: bool = False,
173
  ):
@@ -192,6 +196,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
192
  self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
193
  interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
194
  interpolation_scale = max(interpolation_scale, 1)
 
195
 
196
  if self.casual_3d:
197
  self.pos_embed = CasualPatchEmbed3D(
@@ -397,16 +402,22 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
397
  def forward(
398
  self,
399
  hidden_states: torch.Tensor,
 
 
 
 
 
 
 
 
 
400
  inpaint_latents: torch.Tensor = None,
401
  control_latents: torch.Tensor = None,
402
- encoder_hidden_states: Optional[torch.Tensor] = None,
403
- clip_encoder_hidden_states: Optional[torch.Tensor] = None,
404
- timestep: Optional[torch.LongTensor] = None,
405
  added_cond_kwargs: Dict[str, torch.Tensor] = None,
406
  class_labels: Optional[torch.LongTensor] = None,
407
  cross_attention_kwargs: Dict[str, Any] = None,
408
  attention_mask: Optional[torch.Tensor] = None,
409
- encoder_attention_mask: Optional[torch.Tensor] = None,
410
  clip_attention_mask: Optional[torch.Tensor] = None,
411
  return_dict: bool = True,
412
  ):
@@ -432,7 +443,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
432
  An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
433
  is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
434
  negative values to the attention scores corresponding to "discard" tokens.
435
- encoder_attention_mask ( `torch.Tensor`, *optional*):
436
  Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
437
 
438
  * Mask `(batch, sequence_length)` True = keep, False = discard.
@@ -466,11 +477,12 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
466
  attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
467
  attention_mask = attention_mask.unsqueeze(1)
468
 
 
469
  if clip_attention_mask is not None:
470
- encoder_attention_mask = torch.cat([encoder_attention_mask, clip_attention_mask], dim=1)
471
  # convert encoder_attention_mask to a bias the same way we do for attention_mask
472
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
473
- encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
474
  encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
475
 
476
  if inpaint_latents is not None:
@@ -637,7 +649,10 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
637
  return Transformer3DModelOutput(sample=output)
638
 
639
  @classmethod
640
- def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
 
 
 
641
  if subfolder is not None:
642
  pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
643
  print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
@@ -649,16 +664,73 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
649
  config = json.load(f)
650
 
651
  from diffusers.utils import WEIGHTS_NAME
652
- model = cls.from_config(config, **transformer_additional_kwargs)
653
  model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
654
  model_file_safetensors = model_file.replace(".bin", ".safetensors")
655
- if os.path.exists(model_file_safetensors):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  from safetensors.torch import load_file, safe_open
657
  state_dict = load_file(model_file_safetensors)
658
  else:
659
- if not os.path.isfile(model_file):
660
- raise RuntimeError(f"{model_file} does not exist")
661
- state_dict = torch.load(model_file, map_location="cpu")
 
 
 
 
662
 
663
  if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
664
  new_shape = model.state_dict()['pos_embed.proj.weight'].size()
@@ -692,6 +764,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
692
  params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
693
  print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
694
 
 
695
  return model
696
 
697
  class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
@@ -769,6 +842,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
769
  after_norm = False,
770
  resize_inpaint_mask_directly: bool = False,
771
  enable_clip_in_inpaint: bool = True,
 
772
  enable_text_attention_mask: bool = True,
773
  add_noise_in_inpaint_model: bool = False,
774
  ):
@@ -909,6 +983,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
909
  control_latents: torch.Tensor = None,
910
  clip_encoder_hidden_states: Optional[torch.Tensor]=None,
911
  clip_attention_mask: Optional[torch.Tensor]=None,
 
912
  return_dict=True,
913
  ):
914
  """
@@ -1085,7 +1160,10 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
1085
  return Transformer2DModelOutput(sample=output)
1086
 
1087
  @classmethod
1088
- def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
 
 
 
1089
  if subfolder is not None:
1090
  pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1091
  print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
@@ -1097,16 +1175,73 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
1097
  config = json.load(f)
1098
 
1099
  from diffusers.utils import WEIGHTS_NAME
1100
- model = cls.from_config(config, **transformer_additional_kwargs)
1101
  model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1102
  model_file_safetensors = model_file.replace(".bin", ".safetensors")
1103
- if os.path.exists(model_file_safetensors):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1104
  from safetensors.torch import load_file, safe_open
1105
  state_dict = load_file(model_file_safetensors)
1106
  else:
1107
- if not os.path.isfile(model_file):
1108
- raise RuntimeError(f"{model_file} does not exist")
1109
- state_dict = torch.load(model_file, map_location="cpu")
 
 
 
 
1110
 
1111
  if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
1112
  new_shape = model.state_dict()['pos_embed.proj.weight'].size()
@@ -1156,6 +1291,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
1156
  params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1157
  print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1158
 
 
1159
  return model
1160
 
1161
  class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
@@ -1178,8 +1314,11 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1178
  timestep_activation_fn: str = "silu",
1179
  freq_shift: int = 0,
1180
  num_layers: int = 30,
 
 
1181
  dropout: float = 0.0,
1182
  time_embed_dim: int = 512,
 
1183
  text_embed_dim: int = 4096,
1184
  text_embed_dim_t5: int = 4096,
1185
  norm_eps: float = 1e-5,
@@ -1191,8 +1330,10 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1191
  after_norm = False,
1192
  resize_inpaint_mask_directly: bool = False,
1193
  enable_clip_in_inpaint: bool = True,
 
1194
  enable_text_attention_mask: bool = True,
1195
  add_noise_in_inpaint_model: bool = False,
 
1196
  ):
1197
  super().__init__()
1198
  self.num_heads = num_attention_heads
@@ -1211,8 +1352,20 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1211
  self.proj = nn.Conv2d(
1212
  in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
1213
  )
1214
- self.text_proj = nn.Linear(text_embed_dim, self.inner_dim)
1215
- self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim)
 
 
 
 
 
 
 
 
 
 
 
 
1216
 
1217
  if ref_channels is not None:
1218
  self.ref_proj = nn.Conv2d(
@@ -1224,23 +1377,45 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1224
 
1225
  if clip_channels is not None:
1226
  self.clip_proj = nn.Linear(clip_channels, self.inner_dim)
1227
-
1228
- self.transformer_blocks = nn.ModuleList(
1229
- [
1230
- EasyAnimateDiTBlock(
1231
- dim=self.inner_dim,
1232
- num_attention_heads=num_attention_heads,
1233
- attention_head_dim=attention_head_dim,
1234
- time_embed_dim=time_embed_dim,
1235
- dropout=dropout,
1236
- activation_fn=activation_fn,
1237
- norm_elementwise_affine=norm_elementwise_affine,
1238
- norm_eps=norm_eps,
1239
- after_norm=after_norm
1240
- )
1241
- for _ in range(num_layers)
1242
- ]
1243
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1244
  self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine)
1245
 
1246
  # 5. Output blocks
@@ -1275,6 +1450,7 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1275
  ref_latents: Optional[torch.Tensor] = None,
1276
  clip_encoder_hidden_states: Optional[torch.Tensor] = None,
1277
  clip_attention_mask: Optional[torch.Tensor] = None,
 
1278
  return_dict=True,
1279
  ):
1280
  batch_size, channels, video_length, height, width = hidden_states.size()
@@ -1343,6 +1519,9 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1343
  encoder_hidden_states,
1344
  temb,
1345
  image_rotary_emb,
 
 
 
1346
  **ckpt_kwargs,
1347
  )
1348
  else:
@@ -1351,6 +1530,9 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1351
  encoder_hidden_states=encoder_hidden_states,
1352
  temb=temb,
1353
  image_rotary_emb=image_rotary_emb,
 
 
 
1354
  )
1355
 
1356
  hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -1371,7 +1553,10 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1371
  return Transformer2DModelOutput(sample=output)
1372
 
1373
  @classmethod
1374
- def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
 
 
 
1375
  if subfolder is not None:
1376
  pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1377
  print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
@@ -1383,9 +1568,60 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1383
  config = json.load(f)
1384
 
1385
  from diffusers.utils import WEIGHTS_NAME
1386
- model = cls.from_config(config, **transformer_additional_kwargs)
1387
  model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1388
  model_file_safetensors = model_file.replace(".bin", ".safetensors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1389
  if os.path.exists(model_file):
1390
  state_dict = torch.load(model_file, map_location="cpu")
1391
  elif os.path.exists(model_file_safetensors):
@@ -1433,4 +1669,5 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
1433
  params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1434
  print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1435
 
 
1436
  return model
 
39
  from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock,
40
  SelfAttentionTemporalTransformerBlock,
41
  TemporalTransformerBlock, zero_module)
42
+ from .embeddings import (HunyuanCombinedTimestepTextSizeStyleEmbedding,
43
+ TimePositionalEncoding)
44
+ from .norm import AdaLayerNormSingle, EasyAnimateRMSNorm
45
  from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D,
46
  TemporalUpsampler3D, UnPatch1D)
47
  from .resampler import Resampler
 
143
  norm_eps: float = 1e-5,
144
  attention_type: str = "default",
145
  caption_channels: int = None,
146
+ n_query=8,
147
  # block type
148
  basic_block_type: str = "motionmodule",
149
  # enable_uvit
 
170
  after_norm = False,
171
  resize_inpaint_mask_directly: bool = False,
172
  enable_clip_in_inpaint: bool = True,
173
+ position_of_clip_embedding: str = "head",
174
+ enable_zero_in_inpaint: bool = False,
175
  enable_text_attention_mask: bool = True,
176
  add_noise_in_inpaint_model: bool = False,
177
  ):
 
196
  self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
197
  interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
198
  interpolation_scale = max(interpolation_scale, 1)
199
+ self.n_query = n_query
200
 
201
  if self.casual_3d:
202
  self.pos_embed = CasualPatchEmbed3D(
 
402
  def forward(
403
  self,
404
  hidden_states: torch.Tensor,
405
+ timestep: Optional[torch.LongTensor] = None,
406
+ timestep_cond = None,
407
+ encoder_hidden_states: Optional[torch.Tensor] = None,
408
+ text_embedding_mask: Optional[torch.Tensor] = None,
409
+ encoder_hidden_states_t5: Optional[torch.Tensor] = None,
410
+ text_embedding_mask_t5: Optional[torch.Tensor] = None,
411
+ image_meta_size = None,
412
+ style = None,
413
+ image_rotary_emb: Optional[torch.Tensor] = None,
414
  inpaint_latents: torch.Tensor = None,
415
  control_latents: torch.Tensor = None,
 
 
 
416
  added_cond_kwargs: Dict[str, torch.Tensor] = None,
417
  class_labels: Optional[torch.LongTensor] = None,
418
  cross_attention_kwargs: Dict[str, Any] = None,
419
  attention_mask: Optional[torch.Tensor] = None,
420
+ clip_encoder_hidden_states: Optional[torch.Tensor] = None,
421
  clip_attention_mask: Optional[torch.Tensor] = None,
422
  return_dict: bool = True,
423
  ):
 
443
  An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
444
  is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
445
  negative values to the attention scores corresponding to "discard" tokens.
446
+ text_embedding_mask ( `torch.Tensor`, *optional*):
447
  Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
448
 
449
  * Mask `(batch, sequence_length)` True = keep, False = discard.
 
477
  attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
478
  attention_mask = attention_mask.unsqueeze(1)
479
 
480
+ text_embedding_mask = text_embedding_mask.squeeze(1)
481
  if clip_attention_mask is not None:
482
+ text_embedding_mask = torch.cat([text_embedding_mask, clip_attention_mask], dim=1)
483
  # convert encoder_attention_mask to a bias the same way we do for attention_mask
484
+ if text_embedding_mask is not None and text_embedding_mask.ndim == 2:
485
+ encoder_attention_mask = (1 - text_embedding_mask.to(encoder_hidden_states.dtype)) * -10000.0
486
  encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
487
 
488
  if inpaint_latents is not None:
 
649
  return Transformer3DModelOutput(sample=output)
650
 
651
  @classmethod
652
+ def from_pretrained_2d(
653
+ cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={},
654
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
655
+ ):
656
  if subfolder is not None:
657
  pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
658
  print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
 
664
  config = json.load(f)
665
 
666
  from diffusers.utils import WEIGHTS_NAME
 
667
  model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
668
  model_file_safetensors = model_file.replace(".bin", ".safetensors")
669
+
670
+ if low_cpu_mem_usage:
671
+ try:
672
+ import re
673
+
674
+ from diffusers.models.modeling_utils import \
675
+ load_model_dict_into_meta
676
+ from diffusers.utils import is_accelerate_available
677
+ if is_accelerate_available():
678
+ import accelerate
679
+
680
+ # Instantiate model with empty weights
681
+ with accelerate.init_empty_weights():
682
+ model = cls.from_config(config, **transformer_additional_kwargs)
683
+
684
+ param_device = "cpu"
685
+ from safetensors.torch import load_file, safe_open
686
+ state_dict = load_file(model_file_safetensors)
687
+ model._convert_deprecated_attention_blocks(state_dict)
688
+ # move the params from meta device to cpu
689
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
690
+ if len(missing_keys) > 0:
691
+ raise ValueError(
692
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
693
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
694
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
695
+ " those weights or else make sure your checkpoint file is correct."
696
+ )
697
+
698
+ unexpected_keys = load_model_dict_into_meta(
699
+ model,
700
+ state_dict,
701
+ device=param_device,
702
+ dtype=torch_dtype,
703
+ model_name_or_path=pretrained_model_path,
704
+ )
705
+
706
+ if cls._keys_to_ignore_on_load_unexpected is not None:
707
+ for pat in cls._keys_to_ignore_on_load_unexpected:
708
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
709
+
710
+ if len(unexpected_keys) > 0:
711
+ print(
712
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
713
+ )
714
+ return model
715
+ except Exception as e:
716
+ print(
717
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
718
+ )
719
+
720
+ model = cls.from_config(config, **transformer_additional_kwargs)
721
+ if os.path.exists(model_file):
722
+ state_dict = torch.load(model_file, map_location="cpu")
723
+ elif os.path.exists(model_file_safetensors):
724
  from safetensors.torch import load_file, safe_open
725
  state_dict = load_file(model_file_safetensors)
726
  else:
727
+ from safetensors.torch import load_file, safe_open
728
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
729
+ state_dict = {}
730
+ for model_file_safetensors in model_files_safetensors:
731
+ _state_dict = load_file(model_file_safetensors)
732
+ for key in _state_dict:
733
+ state_dict[key] = _state_dict[key]
734
 
735
  if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
736
  new_shape = model.state_dict()['pos_embed.proj.weight'].size()
 
764
  params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
765
  print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
766
 
767
+ model = model.to(torch_dtype)
768
  return model
769
 
770
  class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
 
842
  after_norm = False,
843
  resize_inpaint_mask_directly: bool = False,
844
  enable_clip_in_inpaint: bool = True,
845
+ position_of_clip_embedding: str = "full",
846
  enable_text_attention_mask: bool = True,
847
  add_noise_in_inpaint_model: bool = False,
848
  ):
 
983
  control_latents: torch.Tensor = None,
984
  clip_encoder_hidden_states: Optional[torch.Tensor]=None,
985
  clip_attention_mask: Optional[torch.Tensor]=None,
986
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
987
  return_dict=True,
988
  ):
989
  """
 
1160
  return Transformer2DModelOutput(sample=output)
1161
 
1162
  @classmethod
1163
+ def from_pretrained_2d(
1164
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1165
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1166
+ ):
1167
  if subfolder is not None:
1168
  pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1169
  print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
 
1175
  config = json.load(f)
1176
 
1177
  from diffusers.utils import WEIGHTS_NAME
 
1178
  model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1179
  model_file_safetensors = model_file.replace(".bin", ".safetensors")
1180
+
1181
+ if low_cpu_mem_usage:
1182
+ try:
1183
+ import re
1184
+
1185
+ from diffusers.models.modeling_utils import \
1186
+ load_model_dict_into_meta
1187
+ from diffusers.utils import is_accelerate_available
1188
+ if is_accelerate_available():
1189
+ import accelerate
1190
+
1191
+ # Instantiate model with empty weights
1192
+ with accelerate.init_empty_weights():
1193
+ model = cls.from_config(config, **transformer_additional_kwargs)
1194
+
1195
+ param_device = "cpu"
1196
+ from safetensors.torch import load_file, safe_open
1197
+ state_dict = load_file(model_file_safetensors)
1198
+ model._convert_deprecated_attention_blocks(state_dict)
1199
+ # move the params from meta device to cpu
1200
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
1201
+ if len(missing_keys) > 0:
1202
+ raise ValueError(
1203
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
1204
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
1205
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
1206
+ " those weights or else make sure your checkpoint file is correct."
1207
+ )
1208
+
1209
+ unexpected_keys = load_model_dict_into_meta(
1210
+ model,
1211
+ state_dict,
1212
+ device=param_device,
1213
+ dtype=torch_dtype,
1214
+ model_name_or_path=pretrained_model_path,
1215
+ )
1216
+
1217
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1218
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1219
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1220
+
1221
+ if len(unexpected_keys) > 0:
1222
+ print(
1223
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1224
+ )
1225
+ return model
1226
+ except Exception as e:
1227
+ print(
1228
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1229
+ )
1230
+
1231
+ model = cls.from_config(config, **transformer_additional_kwargs)
1232
+ if os.path.exists(model_file):
1233
+ state_dict = torch.load(model_file, map_location="cpu")
1234
+ elif os.path.exists(model_file_safetensors):
1235
  from safetensors.torch import load_file, safe_open
1236
  state_dict = load_file(model_file_safetensors)
1237
  else:
1238
+ from safetensors.torch import load_file, safe_open
1239
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1240
+ state_dict = {}
1241
+ for model_file_safetensors in model_files_safetensors:
1242
+ _state_dict = load_file(model_file_safetensors)
1243
+ for key in _state_dict:
1244
+ state_dict[key] = _state_dict[key]
1245
 
1246
  if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
1247
  new_shape = model.state_dict()['pos_embed.proj.weight'].size()
 
1291
  params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1292
  print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1293
 
1294
+ model = model.to(torch_dtype)
1295
  return model
1296
 
1297
  class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
 
1314
  timestep_activation_fn: str = "silu",
1315
  freq_shift: int = 0,
1316
  num_layers: int = 30,
1317
+ mmdit_layers: int = 10000,
1318
+ swa_layers: list = None,
1319
  dropout: float = 0.0,
1320
  time_embed_dim: int = 512,
1321
+ add_norm_text_encoder: bool = False,
1322
  text_embed_dim: int = 4096,
1323
  text_embed_dim_t5: int = 4096,
1324
  norm_eps: float = 1e-5,
 
1330
  after_norm = False,
1331
  resize_inpaint_mask_directly: bool = False,
1332
  enable_clip_in_inpaint: bool = True,
1333
+ position_of_clip_embedding: str = "full",
1334
  enable_text_attention_mask: bool = True,
1335
  add_noise_in_inpaint_model: bool = False,
1336
+ add_ref_latent_in_control_model: bool = False,
1337
  ):
1338
  super().__init__()
1339
  self.num_heads = num_attention_heads
 
1352
  self.proj = nn.Conv2d(
1353
  in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
1354
  )
1355
+ if not add_norm_text_encoder:
1356
+ self.text_proj = nn.Linear(text_embed_dim, self.inner_dim)
1357
+ if text_embed_dim_t5 is not None:
1358
+ self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim)
1359
+ else:
1360
+ self.text_proj = nn.Sequential(
1361
+ EasyAnimateRMSNorm(text_embed_dim),
1362
+ nn.Linear(text_embed_dim, self.inner_dim)
1363
+ )
1364
+ if text_embed_dim_t5 is not None:
1365
+ self.text_proj_t5 = nn.Sequential(
1366
+ EasyAnimateRMSNorm(text_embed_dim),
1367
+ nn.Linear(text_embed_dim_t5, self.inner_dim)
1368
+ )
1369
 
1370
  if ref_channels is not None:
1371
  self.ref_proj = nn.Conv2d(
 
1377
 
1378
  if clip_channels is not None:
1379
  self.clip_proj = nn.Linear(clip_channels, self.inner_dim)
1380
+
1381
+ self.swa_layers = swa_layers
1382
+ if swa_layers is not None:
1383
+ self.transformer_blocks = nn.ModuleList(
1384
+ [
1385
+ EasyAnimateDiTBlock(
1386
+ dim=self.inner_dim,
1387
+ num_attention_heads=num_attention_heads,
1388
+ attention_head_dim=attention_head_dim,
1389
+ time_embed_dim=time_embed_dim,
1390
+ dropout=dropout,
1391
+ activation_fn=activation_fn,
1392
+ norm_elementwise_affine=norm_elementwise_affine,
1393
+ norm_eps=norm_eps,
1394
+ after_norm=after_norm,
1395
+ is_mmdit_block=True if index < mmdit_layers else False,
1396
+ is_swa=True if index in swa_layers else False,
1397
+ )
1398
+ for index in range(num_layers)
1399
+ ]
1400
+ )
1401
+ else:
1402
+ self.transformer_blocks = nn.ModuleList(
1403
+ [
1404
+ EasyAnimateDiTBlock(
1405
+ dim=self.inner_dim,
1406
+ num_attention_heads=num_attention_heads,
1407
+ attention_head_dim=attention_head_dim,
1408
+ time_embed_dim=time_embed_dim,
1409
+ dropout=dropout,
1410
+ activation_fn=activation_fn,
1411
+ norm_elementwise_affine=norm_elementwise_affine,
1412
+ norm_eps=norm_eps,
1413
+ after_norm=after_norm,
1414
+ is_mmdit_block=True if _ < mmdit_layers else False,
1415
+ )
1416
+ for _ in range(num_layers)
1417
+ ]
1418
+ )
1419
  self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine)
1420
 
1421
  # 5. Output blocks
 
1450
  ref_latents: Optional[torch.Tensor] = None,
1451
  clip_encoder_hidden_states: Optional[torch.Tensor] = None,
1452
  clip_attention_mask: Optional[torch.Tensor] = None,
1453
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
1454
  return_dict=True,
1455
  ):
1456
  batch_size, channels, video_length, height, width = hidden_states.size()
 
1519
  encoder_hidden_states,
1520
  temb,
1521
  image_rotary_emb,
1522
+ video_length,
1523
+ height // self.patch_size,
1524
+ width // self.patch_size,
1525
  **ckpt_kwargs,
1526
  )
1527
  else:
 
1530
  encoder_hidden_states=encoder_hidden_states,
1531
  temb=temb,
1532
  image_rotary_emb=image_rotary_emb,
1533
+ num_frames=video_length,
1534
+ height=height // self.patch_size,
1535
+ width=width // self.patch_size
1536
  )
1537
 
1538
  hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
 
1553
  return Transformer2DModelOutput(sample=output)
1554
 
1555
  @classmethod
1556
+ def from_pretrained_2d(
1557
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1558
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1559
+ ):
1560
  if subfolder is not None:
1561
  pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1562
  print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
 
1568
  config = json.load(f)
1569
 
1570
  from diffusers.utils import WEIGHTS_NAME
 
1571
  model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1572
  model_file_safetensors = model_file.replace(".bin", ".safetensors")
1573
+
1574
+ if low_cpu_mem_usage:
1575
+ try:
1576
+ import re
1577
+
1578
+ from diffusers.models.modeling_utils import \
1579
+ load_model_dict_into_meta
1580
+ from diffusers.utils import is_accelerate_available
1581
+ if is_accelerate_available():
1582
+ import accelerate
1583
+
1584
+ # Instantiate model with empty weights
1585
+ with accelerate.init_empty_weights():
1586
+ model = cls.from_config(config, **transformer_additional_kwargs)
1587
+
1588
+ param_device = "cpu"
1589
+ from safetensors.torch import load_file, safe_open
1590
+ state_dict = load_file(model_file_safetensors)
1591
+ model._convert_deprecated_attention_blocks(state_dict)
1592
+ # move the params from meta device to cpu
1593
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
1594
+ if len(missing_keys) > 0:
1595
+ raise ValueError(
1596
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
1597
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
1598
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
1599
+ " those weights or else make sure your checkpoint file is correct."
1600
+ )
1601
+
1602
+ unexpected_keys = load_model_dict_into_meta(
1603
+ model,
1604
+ state_dict,
1605
+ device=param_device,
1606
+ dtype=torch_dtype,
1607
+ model_name_or_path=pretrained_model_path,
1608
+ )
1609
+
1610
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1611
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1612
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1613
+
1614
+ if len(unexpected_keys) > 0:
1615
+ print(
1616
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1617
+ )
1618
+ return model
1619
+ except Exception as e:
1620
+ print(
1621
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1622
+ )
1623
+
1624
+ model = cls.from_config(config, **transformer_additional_kwargs)
1625
  if os.path.exists(model_file):
1626
  state_dict = torch.load(model_file, map_location="cpu")
1627
  elif os.path.exists(model_file_safetensors):
 
1669
  params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1670
  print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1671
 
1672
+ model = model.to(torch_dtype)
1673
  return model
easyanimate/pipeline/pipeline_easyanimate.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -12,61 +12,113 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- import copy
16
- import html
17
  import inspect
18
- import re
19
- import urllib.parse as ul
20
  from dataclasses import dataclass
21
- from typing import Callable, List, Optional, Tuple, Union
22
 
23
  import numpy as np
24
  import torch
25
- from diffusers import DiffusionPipeline, ImagePipelineOutput
 
 
26
  from diffusers.image_processor import VaeImageProcessor
27
- from diffusers.models import AutoencoderKL
28
- from diffusers.schedulers import DPMSolverMultistepScheduler
 
 
 
 
 
 
29
  from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
30
- is_bs4_available, is_ftfy_available, logging,
 
31
  replace_example_docstring)
32
  from diffusers.utils.torch_utils import randn_tensor
33
  from einops import rearrange
 
34
  from tqdm import tqdm
35
- from transformers import T5EncoderModel, T5Tokenizer
 
 
36
 
37
- from ..models.transformer3d import Transformer3DModel
 
38
 
39
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
40
 
41
- if is_bs4_available():
42
- from bs4 import BeautifulSoup
 
43
 
44
- if is_ftfy_available():
45
- import ftfy
46
 
 
47
 
48
  EXAMPLE_DOC_STRING = """
49
  Examples:
50
- ```py
51
  >>> import torch
52
  >>> from diffusers import EasyAnimatePipeline
53
-
54
- >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
55
- >>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
56
- >>> # Enable memory optimizations.
57
- >>> pipe.enable_model_cpu_offload()
58
-
59
- >>> prompt = "A small cactus with a happy face in the Sahara desert."
60
- >>> image = pipe(prompt).images[0]
 
 
 
 
 
 
61
  ```
62
  """
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
65
  def retrieve_timesteps(
66
  scheduler,
67
  num_inference_steps: Optional[int] = None,
68
  device: Optional[Union[str, torch.device]] = None,
69
  timesteps: Optional[List[int]] = None,
 
70
  **kwargs,
71
  ):
72
  """
@@ -77,19 +129,23 @@ def retrieve_timesteps(
77
  scheduler (`SchedulerMixin`):
78
  The scheduler to get timesteps from.
79
  num_inference_steps (`int`):
80
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
81
- `timesteps` must be `None`.
82
  device (`str` or `torch.device`, *optional*):
83
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
84
  timesteps (`List[int]`, *optional*):
85
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
86
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
87
- must be `None`.
 
 
88
 
89
  Returns:
90
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
91
  second element is the number of inference steps.
92
  """
 
 
93
  if timesteps is not None:
94
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
95
  if not accepts_timesteps:
@@ -100,86 +156,113 @@ def retrieve_timesteps(
100
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
101
  timesteps = scheduler.timesteps
102
  num_inference_steps = len(timesteps)
 
 
 
 
 
 
 
 
 
 
103
  else:
104
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
105
  timesteps = scheduler.timesteps
106
  return timesteps, num_inference_steps
107
 
108
- @dataclass
109
- class EasyAnimatePipelineOutput(BaseOutput):
110
- videos: Union[torch.Tensor, np.ndarray]
111
 
112
  class EasyAnimatePipeline(DiffusionPipeline):
113
  r"""
114
- Pipeline for text-to-image generation using PixArt-Alpha.
115
 
116
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
117
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
118
 
 
 
 
 
119
  Args:
120
- vae ([`AutoencoderKL`]):
121
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
122
- text_encoder ([`T5EncoderModel`]):
123
- Frozen text-encoder. PixArt-Alpha uses
124
- [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
125
- [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
126
- tokenizer (`T5Tokenizer`):
127
- Tokenizer of class
128
- [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
129
- transformer ([`Transformer3DModel`]):
130
- A text conditioned `Transformer3DModel` to denoise the encoded image latents.
131
- scheduler ([`SchedulerMixin`]):
132
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
 
 
 
133
  """
134
- bad_punct_regex = re.compile(
135
- r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
136
- ) # noqa
137
 
138
- _optional_components = ["tokenizer", "text_encoder"]
139
- model_cpu_offload_seq = "text_encoder->transformer->vae"
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def __init__(
142
  self,
143
- tokenizer: T5Tokenizer,
144
- text_encoder: T5EncoderModel,
145
- vae: AutoencoderKL,
146
- transformer: Transformer3DModel,
147
- scheduler: DPMSolverMultistepScheduler,
 
 
148
  ):
149
  super().__init__()
150
 
151
  self.register_modules(
152
- tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
 
 
 
 
 
 
153
  )
154
 
155
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
156
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
157
- self.enable_autocast_float8_transformer_flag = False
158
-
159
- # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
160
- def mask_text_embeddings(self, emb, mask):
161
- if emb.shape[0] == 1:
162
- keep_index = mask.sum().item()
163
- return emb[:, :, :keep_index, :], keep_index
164
- else:
165
- masked_feature = emb * mask[:, None, :, None]
166
- return masked_feature, emb.shape[2]
167
 
168
- # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
 
 
 
 
 
 
169
  def encode_prompt(
170
  self,
171
- prompt: Union[str, List[str]],
172
- do_classifier_free_guidance: bool = True,
173
- negative_prompt: str = "",
174
  num_images_per_prompt: int = 1,
175
- device: Optional[torch.device] = None,
176
- prompt_embeds: Optional[torch.FloatTensor] = None,
177
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
178
- prompt_attention_mask: Optional[torch.FloatTensor] = None,
179
- negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
180
- clean_caption: bool = False,
181
- max_sequence_length: int = 120,
182
- **kwargs,
 
183
  ):
184
  r"""
185
  Encodes the prompt into text encoder hidden states.
@@ -187,33 +270,46 @@ class EasyAnimatePipeline(DiffusionPipeline):
187
  Args:
188
  prompt (`str` or `List[str]`, *optional*):
189
  prompt to be encoded
190
- negative_prompt (`str` or `List[str]`, *optional*):
191
- The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
192
- instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
193
- PixArt-Alpha, this should be "".
194
- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
195
- whether to use classifier free guidance or not
196
- num_images_per_prompt (`int`, *optional*, defaults to 1):
197
  number of images that should be generated per prompt
198
- device: (`torch.device`, *optional*):
199
- torch device to place the resulting embeddings on
200
- prompt_embeds (`torch.FloatTensor`, *optional*):
 
 
 
 
201
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
202
  provided, text embeddings will be generated from `prompt` input argument.
203
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
204
- Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
205
- string.
206
- clean_caption (`bool`, defaults to `False`):
207
- If `True`, the function will preprocess and clean the provided caption before encoding.
208
- max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
 
 
 
 
 
209
  """
 
 
210
 
211
- if "mask_feature" in kwargs:
212
- deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
213
- deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
214
 
215
- if device is None:
216
- device = self._execution_device
 
 
 
 
 
217
 
218
  if prompt is not None and isinstance(prompt, str):
219
  batch_size = 1
@@ -222,74 +318,199 @@ class EasyAnimatePipeline(DiffusionPipeline):
222
  else:
223
  batch_size = prompt_embeds.shape[0]
224
 
225
- # See Section 3.1. of the paper.
226
- max_length = max_sequence_length
227
-
228
  if prompt_embeds is None:
229
- prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
230
- text_inputs = self.tokenizer(
231
- prompt,
232
- padding="max_length",
233
- max_length=max_length,
234
- truncation=True,
235
- add_special_tokens=True,
236
- return_tensors="pt",
237
- )
238
- text_input_ids = text_inputs.input_ids
239
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
240
-
241
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
242
- text_input_ids, untruncated_ids
243
- ):
244
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
245
- logger.warning(
246
- "The following part of your input was truncated because CLIP can only handle sequences up to"
247
- f" {max_length} tokens: {removed_text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  )
249
 
250
- prompt_attention_mask = text_inputs.attention_mask
251
- prompt_attention_mask = prompt_attention_mask.to(device)
252
-
253
- prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
254
- prompt_embeds = prompt_embeds[0]
255
-
256
- if self.text_encoder is not None:
257
- dtype = self.text_encoder.dtype
258
- elif self.transformer is not None:
259
- dtype = self.transformer.dtype
260
- else:
261
- dtype = None
262
-
 
 
 
 
 
 
 
 
 
 
263
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
264
 
265
  bs_embed, seq_len, _ = prompt_embeds.shape
266
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
267
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
268
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
269
- prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
270
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
271
 
272
  # get unconditional embeddings for classifier free guidance
273
  if do_classifier_free_guidance and negative_prompt_embeds is None:
274
- uncond_tokens = [negative_prompt] * batch_size
275
- uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
276
- max_length = prompt_embeds.shape[1]
277
- uncond_input = self.tokenizer(
278
- uncond_tokens,
279
- padding="max_length",
280
- max_length=max_length,
281
- truncation=True,
282
- return_attention_mask=True,
283
- add_special_tokens=True,
284
- return_tensors="pt",
285
- )
286
- negative_prompt_attention_mask = uncond_input.attention_mask
287
- negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- negative_prompt_embeds = self.text_encoder(
290
- uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
291
- )
292
- negative_prompt_embeds = negative_prompt_embeds[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  if do_classifier_free_guidance:
295
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -299,14 +520,9 @@ class EasyAnimatePipeline(DiffusionPipeline):
299
 
300
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
301
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
302
-
303
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
304
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
305
- else:
306
- negative_prompt_embeds = None
307
- negative_prompt_attention_mask = None
308
-
309
- return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
310
 
311
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
312
  def prepare_extra_step_kwargs(self, generator, eta):
@@ -331,20 +547,25 @@ class EasyAnimatePipeline(DiffusionPipeline):
331
  prompt,
332
  height,
333
  width,
334
- negative_prompt,
335
- callback_steps,
336
  prompt_embeds=None,
337
  negative_prompt_embeds=None,
 
 
 
 
 
 
 
338
  ):
339
- if height % 8 != 0 or width % 8 != 0:
340
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
341
 
342
- if (callback_steps is None) or (
343
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
344
  ):
345
  raise ValueError(
346
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
347
- f" {type(callback_steps)}."
348
  )
349
 
350
  if prompt is not None and prompt_embeds is not None:
@@ -356,14 +577,18 @@ class EasyAnimatePipeline(DiffusionPipeline):
356
  raise ValueError(
357
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
358
  )
 
 
 
 
359
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
360
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
361
 
362
- if prompt is not None and negative_prompt_embeds is not None:
363
- raise ValueError(
364
- f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
365
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
366
- )
367
 
368
  if negative_prompt is not None and negative_prompt_embeds is not None:
369
  raise ValueError(
@@ -371,6 +596,13 @@ class EasyAnimatePipeline(DiffusionPipeline):
371
  f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
372
  )
373
 
 
 
 
 
 
 
 
374
  if prompt_embeds is not None and negative_prompt_embeds is not None:
375
  if prompt_embeds.shape != negative_prompt_embeds.shape:
376
  raise ValueError(
@@ -378,153 +610,25 @@ class EasyAnimatePipeline(DiffusionPipeline):
378
  f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
379
  f" {negative_prompt_embeds.shape}."
380
  )
381
-
382
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
383
- def _text_preprocessing(self, text, clean_caption=False):
384
- if clean_caption and not is_bs4_available():
385
- logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
386
- logger.warn("Setting `clean_caption` to False...")
387
- clean_caption = False
388
-
389
- if clean_caption and not is_ftfy_available():
390
- logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
391
- logger.warn("Setting `clean_caption` to False...")
392
- clean_caption = False
393
-
394
- if not isinstance(text, (tuple, list)):
395
- text = [text]
396
-
397
- def process(text: str):
398
- if clean_caption:
399
- text = self._clean_caption(text)
400
- text = self._clean_caption(text)
401
- else:
402
- text = text.lower().strip()
403
- return text
404
-
405
- return [process(t) for t in text]
406
-
407
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
408
- def _clean_caption(self, caption):
409
- caption = str(caption)
410
- caption = ul.unquote_plus(caption)
411
- caption = caption.strip().lower()
412
- caption = re.sub("<person>", "person", caption)
413
- # urls:
414
- caption = re.sub(
415
- r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
416
- "",
417
- caption,
418
- ) # regex for urls
419
- caption = re.sub(
420
- r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
421
- "",
422
- caption,
423
- ) # regex for urls
424
- # html:
425
- caption = BeautifulSoup(caption, features="html.parser").text
426
-
427
- # @<nickname>
428
- caption = re.sub(r"@[\w\d]+\b", "", caption)
429
-
430
- # 31C0—31EF CJK Strokes
431
- # 31F0—31FF Katakana Phonetic Extensions
432
- # 3200—32FF Enclosed CJK Letters and Months
433
- # 3300—33FF CJK Compatibility
434
- # 3400—4DBF CJK Unified Ideographs Extension A
435
- # 4DC0—4DFF Yijing Hexagram Symbols
436
- # 4E00—9FFF CJK Unified Ideographs
437
- caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
438
- caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
439
- caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
440
- caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
441
- caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
442
- caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
443
- caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
444
- #######################################################
445
-
446
- # все виды тире / all types of dash --> "-"
447
- caption = re.sub(
448
- r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
449
- "-",
450
- caption,
451
- )
452
-
453
- # кавычки к одному стандарту
454
- caption = re.sub(r"[`´«»“”¨]", '"', caption)
455
- caption = re.sub(r"[‘’]", "'", caption)
456
-
457
- # &quot;
458
- caption = re.sub(r"&quot;?", "", caption)
459
- # &amp
460
- caption = re.sub(r"&amp", "", caption)
461
-
462
- # ip adresses:
463
- caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
464
-
465
- # article ids:
466
- caption = re.sub(r"\d:\d\d\s+$", "", caption)
467
-
468
- # \n
469
- caption = re.sub(r"\\n", " ", caption)
470
-
471
- # "#123"
472
- caption = re.sub(r"#\d{1,3}\b", "", caption)
473
- # "#12345.."
474
- caption = re.sub(r"#\d{5,}\b", "", caption)
475
- # "123456.."
476
- caption = re.sub(r"\b\d{6,}\b", "", caption)
477
- # filenames:
478
- caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
479
-
480
- #
481
- caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
482
- caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
483
-
484
- caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
485
- caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
486
-
487
- # this-is-my-cute-cat / this_is_my_cute_cat
488
- regex2 = re.compile(r"(?:\-|\_)")
489
- if len(re.findall(regex2, caption)) > 3:
490
- caption = re.sub(regex2, " ", caption)
491
-
492
- caption = ftfy.fix_text(caption)
493
- caption = html.unescape(html.unescape(caption))
494
-
495
- caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
496
- caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
497
- caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
498
-
499
- caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
500
- caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
501
- caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
502
- caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
503
- caption = re.sub(r"\bpage\s+\d+\b", "", caption)
504
-
505
- caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
506
-
507
- caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
508
-
509
- caption = re.sub(r"\b\s+\:\s+", r": ", caption)
510
- caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
511
- caption = re.sub(r"\s+", " ", caption)
512
-
513
- caption.strip()
514
-
515
- caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
516
- caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
517
- caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
518
- caption = re.sub(r"^\.\S+$", "", caption)
519
-
520
- return caption.strip()
521
 
522
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
523
  def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
524
- if self.vae.quant_conv.weight.ndim==5:
525
- mini_batch_encoder = self.vae.mini_batch_encoder
526
- mini_batch_decoder = self.vae.mini_batch_decoder
527
- shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
 
 
 
 
528
  else:
529
  shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
530
 
@@ -538,11 +642,12 @@ class EasyAnimatePipeline(DiffusionPipeline):
538
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
539
  else:
540
  latents = latents.to(device)
541
-
542
  # scale the initial noise by the standard deviation required by the scheduler
543
- latents = latents * self.scheduler.init_noise_sigma
 
544
  return latents
545
-
546
  def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
547
  if video.size()[2] <= mini_batch_encoder:
548
  return video
@@ -558,16 +663,17 @@ class EasyAnimatePipeline(DiffusionPipeline):
558
 
559
  video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
560
  return video
561
-
562
  def decode_latents(self, latents):
563
  video_length = latents.shape[2]
564
  latents = 1 / self.vae.config.scaling_factor * latents
565
- if self.vae.quant_conv.weight.ndim==5:
566
  mini_batch_encoder = self.vae.mini_batch_encoder
567
  mini_batch_decoder = self.vae.mini_batch_decoder
568
  video = self.vae.decode(latents)[0]
569
  video = video.clamp(-1, 1)
570
- video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
 
571
  else:
572
  latents = rearrange(latents, "b c f h w -> (b f) c h w")
573
  video = []
@@ -580,8 +686,28 @@ class EasyAnimatePipeline(DiffusionPipeline):
580
  video = video.cpu().float().numpy()
581
  return video
582
 
583
- def enable_autocast_float8_transformer(self):
584
- self.enable_autocast_float8_transformer_flag = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
  @torch.no_grad()
587
  @replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -589,103 +715,131 @@ class EasyAnimatePipeline(DiffusionPipeline):
589
  self,
590
  prompt: Union[str, List[str]] = None,
591
  video_length: Optional[int] = None,
592
- negative_prompt: str = "",
593
- num_inference_steps: int = 20,
594
- timesteps: List[int] = None,
595
- guidance_scale: float = 4.5,
596
- num_images_per_prompt: Optional[int] = 1,
597
  height: Optional[int] = None,
598
  width: Optional[int] = None,
599
- eta: float = 0.0,
 
 
 
 
600
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
601
- latents: Optional[torch.FloatTensor] = None,
602
- prompt_embeds: Optional[torch.FloatTensor] = None,
603
- prompt_attention_mask: Optional[torch.FloatTensor] = None,
604
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
605
- negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
 
 
 
 
606
  output_type: Optional[str] = "latent",
607
  return_dict: bool = True,
608
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
609
- callback_steps: int = 1,
610
- clean_caption: bool = True,
611
- max_sequence_length: int = 120,
 
 
 
 
612
  comfyui_progressbar: bool = False,
613
- **kwargs,
614
- ) -> Union[EasyAnimatePipelineOutput, Tuple]:
615
- """
616
- Function invoked when calling the pipeline for generation.
617
-
618
- Args:
619
- prompt (`str` or `List[str]`, *optional*):
620
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
621
- instead.
622
- negative_prompt (`str` or `List[str]`, *optional*):
623
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
624
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
625
- less than `1`).
626
- num_inference_steps (`int`, *optional*, defaults to 100):
627
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
628
- expense of slower inference.
629
- timesteps (`List[int]`, *optional*):
630
- Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
631
- timesteps are used. Must be in descending order.
632
- guidance_scale (`float`, *optional*, defaults to 7.0):
633
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
634
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
635
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
636
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
637
- usually at the expense of lower image quality.
638
- num_images_per_prompt (`int`, *optional*, defaults to 1):
639
- The number of images to generate per prompt.
640
- height (`int`, *optional*, defaults to self.unet.config.sample_size):
641
- The height in pixels of the generated image.
642
- width (`int`, *optional*, defaults to self.unet.config.sample_size):
643
- The width in pixels of the generated image.
644
- eta (`float`, *optional*, defaults to 0.0):
645
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
646
- [`schedulers.DDIMScheduler`], will be ignored for others.
647
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
648
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
649
- to make generation deterministic.
650
- latents (`torch.FloatTensor`, *optional*):
651
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
652
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
653
- tensor will ge generated by sampling using the supplied random `generator`.
654
- prompt_embeds (`torch.FloatTensor`, *optional*):
655
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
656
- provided, text embeddings will be generated from `prompt` input argument.
657
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
658
- Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
659
- provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
660
- output_type (`str`, *optional*, defaults to `"pil"`):
661
- The output format of the generate image. Choose between
662
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
663
- return_dict (`bool`, *optional*, defaults to `True`):
664
- Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
665
- callback (`Callable`, *optional*):
666
- A function that will be called every `callback_steps` steps during inference. The function will be
667
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
668
- callback_steps (`int`, *optional*, defaults to 1):
669
- The frequency at which the `callback` function will be called. If not specified, the callback will be
670
- called at every step.
671
- clean_caption (`bool`, *optional*, defaults to `True`):
672
- Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
673
- be installed. If the dependencies are not installed, the embeddings will be created from the raw
674
- prompt.
675
- mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
676
 
677
  Examples:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
  Returns:
680
- [`~pipelines.ImagePipelineOutput`] or `tuple`:
681
- If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
682
- returned where the first element is a list with the generated images
 
 
683
  """
 
 
 
 
 
 
 
 
684
  # 1. Check inputs. Raise error if not correct
685
- height = height or self.transformer.config.sample_size * self.vae_scale_factor
686
- width = width or self.transformer.config.sample_size * self.vae_scale_factor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687
 
688
- # 2. Default height and width to transformer
689
  if prompt is not None and isinstance(prompt, str):
690
  batch_size = 1
691
  elif prompt is not None and isinstance(prompt, list):
@@ -694,136 +848,223 @@ class EasyAnimatePipeline(DiffusionPipeline):
694
  batch_size = prompt_embeds.shape[0]
695
 
696
  device = self._execution_device
697
-
698
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
699
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
700
- # corresponds to doing no classifier free guidance.
701
- do_classifier_free_guidance = guidance_scale > 1.0
 
702
 
703
  # 3. Encode input prompt
704
  (
705
  prompt_embeds,
706
- prompt_attention_mask,
707
  negative_prompt_embeds,
 
708
  negative_prompt_attention_mask,
709
  ) = self.encode_prompt(
710
- prompt,
711
- do_classifier_free_guidance,
712
- negative_prompt=negative_prompt,
713
- num_images_per_prompt=num_images_per_prompt,
714
  device=device,
 
 
 
 
715
  prompt_embeds=prompt_embeds,
716
  negative_prompt_embeds=negative_prompt_embeds,
717
  prompt_attention_mask=prompt_attention_mask,
718
  negative_prompt_attention_mask=negative_prompt_attention_mask,
719
- clean_caption=clean_caption,
720
- max_sequence_length=max_sequence_length,
721
  )
722
- if do_classifier_free_guidance:
723
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
724
- prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
 
726
  # 4. Prepare timesteps
727
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
 
 
 
 
 
 
728
 
729
- # 5. Prepare latents.
730
- latent_channels = self.transformer.config.in_channels
731
  latents = self.prepare_latents(
732
  batch_size * num_images_per_prompt,
733
- latent_channels,
734
  video_length,
735
  height,
736
  width,
737
- prompt_embeds.dtype,
738
  device,
739
  generator,
740
  latents,
741
  )
 
 
742
 
743
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
744
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
745
 
746
- # 6.1 Prepare micro-conditions.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
  added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
748
- if self.transformer.config.sample_size == 128:
749
  resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
750
  aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
751
- resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
752
- aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
753
 
754
- if do_classifier_free_guidance:
755
  resolution = torch.cat([resolution, resolution], dim=0)
756
  aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
757
 
758
  added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
759
 
760
- torch.cuda.empty_cache()
761
- if self.enable_autocast_float8_transformer_flag:
762
- origin_weight_dtype = self.transformer.dtype
763
- self.transformer = self.transformer.to(torch.float8_e4m3fn)
764
-
765
- # 7. Denoising loop
766
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
767
- if comfyui_progressbar:
768
- from comfy.utils import ProgressBar
769
- pbar = ProgressBar(num_inference_steps)
 
 
 
 
 
 
 
770
  with self.progress_bar(total=num_inference_steps) as progress_bar:
771
  for i, t in enumerate(timesteps):
772
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
773
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
774
-
775
- current_timestep = t
776
- if not torch.is_tensor(current_timestep):
777
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
778
- # This would be a good case for the `match` statement (Python 3.10+)
779
- is_mps = latent_model_input.device.type == "mps"
780
- if isinstance(current_timestep, float):
781
- dtype = torch.float32 if is_mps else torch.float64
782
- else:
783
- dtype = torch.int32 if is_mps else torch.int64
784
- current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
785
- elif len(current_timestep.shape) == 0:
786
- current_timestep = current_timestep[None].to(latent_model_input.device)
787
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
788
- current_timestep = current_timestep.expand(latent_model_input.shape[0])
789
-
790
- # predict noise model_output
791
  noise_pred = self.transformer(
792
  latent_model_input,
 
793
  encoder_hidden_states=prompt_embeds,
794
- encoder_attention_mask=prompt_attention_mask,
795
- timestep=current_timestep,
 
 
 
 
796
  added_cond_kwargs=added_cond_kwargs,
797
  return_dict=False,
798
  )[0]
 
 
 
799
 
800
  # perform guidance
801
- if do_classifier_free_guidance:
802
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
803
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
804
 
805
- # learned sigma
806
- if self.transformer.config.out_channels // 2 == latent_channels:
807
- noise_pred = noise_pred.chunk(2, dim=1)[0]
808
- else:
809
- noise_pred = noise_pred
810
 
811
- # compute previous image: x_t -> x_t-1
812
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
813
 
814
- # call the callback, if provided
 
 
 
 
 
 
 
 
 
 
 
 
 
815
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
816
  progress_bar.update()
817
- if callback is not None and i % callback_steps == 0:
818
- step_idx = i // getattr(self.scheduler, "order", 1)
819
- callback(step_idx, t, latents)
820
 
821
  if comfyui_progressbar:
822
  pbar.update(1)
823
 
824
- if self.enable_autocast_float8_transformer_flag:
825
- self.transformer = self.transformer.to("cpu", origin_weight_dtype)
826
-
827
  # Post-processing
828
  video = self.decode_latents(latents)
829
 
@@ -831,7 +1072,10 @@ class EasyAnimatePipeline(DiffusionPipeline):
831
  if output_type == "latent":
832
  video = torch.from_numpy(video)
833
 
 
 
 
834
  if not return_dict:
835
  return video
836
 
837
- return EasyAnimatePipelineOutput(videos=video)
 
1
+ # Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
15
  import inspect
 
 
16
  from dataclasses import dataclass
17
+ from typing import Callable, Dict, List, Optional, Tuple, Union
18
 
19
  import numpy as np
20
  import torch
21
+ import torch.nn.functional as F
22
+ from diffusers import DiffusionPipeline
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
  from diffusers.image_processor import VaeImageProcessor
25
+ from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
26
+ from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
27
+ get_3d_rotary_pos_embed)
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
30
+ from diffusers.pipelines.stable_diffusion.safety_checker import \
31
+ StableDiffusionSafetyChecker
32
+ from diffusers.schedulers import DDIMScheduler, FlowMatchEulerDiscreteScheduler
33
  from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
34
+ is_bs4_available, is_ftfy_available,
35
+ is_torch_xla_available, logging,
36
  replace_example_docstring)
37
  from diffusers.utils.torch_utils import randn_tensor
38
  from einops import rearrange
39
+ from PIL import Image
40
  from tqdm import tqdm
41
+ from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
42
+ Qwen2Tokenizer, Qwen2VLForConditionalGeneration,
43
+ T5EncoderModel, T5Tokenizer)
44
 
45
+ from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
46
+ from .pipeline_easyanimate_inpaint import EasyAnimatePipelineOutput
47
 
48
+ if is_torch_xla_available():
49
+ import torch_xla.core.xla_model as xm
50
 
51
+ XLA_AVAILABLE = True
52
+ else:
53
+ XLA_AVAILABLE = False
54
 
 
 
55
 
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
 
58
  EXAMPLE_DOC_STRING = """
59
  Examples:
60
+ ```python
61
  >>> import torch
62
  >>> from diffusers import EasyAnimatePipeline
63
+ >>> from diffusers.utils import export_to_video
64
+
65
+ >>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh" or "alibaba-pai/EasyAnimateV5.1-7b-zh"
66
+ >>> pipe = EasyAnimatePipeline.from_pretrained("alibaba-pai/EasyAnimateV5.1-7b-zh", torch_dtype=torch.float16).to("cuda")
67
+ >>> prompt = (
68
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
69
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
70
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
71
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
72
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
73
+ ... "atmosphere of this unique musical performance."
74
+ ... )
75
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).sample[0]
76
+ >>> export_to_video(video, "output.mp4", fps=8)
77
  ```
78
  """
79
 
80
+
81
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
82
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
83
+ tw = tgt_width
84
+ th = tgt_height
85
+ h, w = src
86
+ r = h / w
87
+ if r > (th / tw):
88
+ resize_height = th
89
+ resize_width = int(round(th / h * w))
90
+ else:
91
+ resize_width = tw
92
+ resize_height = int(round(tw / w * h))
93
+
94
+ crop_top = int(round((th - resize_height) / 2.0))
95
+ crop_left = int(round((tw - resize_width) / 2.0))
96
+
97
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
101
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
102
+ """
103
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
104
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
105
+ """
106
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
107
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
108
+ # rescale the results from guidance (fixes overexposure)
109
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
110
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
111
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
112
+ return noise_cfg
113
+
114
+
115
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
116
  def retrieve_timesteps(
117
  scheduler,
118
  num_inference_steps: Optional[int] = None,
119
  device: Optional[Union[str, torch.device]] = None,
120
  timesteps: Optional[List[int]] = None,
121
+ sigmas: Optional[List[float]] = None,
122
  **kwargs,
123
  ):
124
  """
 
129
  scheduler (`SchedulerMixin`):
130
  The scheduler to get timesteps from.
131
  num_inference_steps (`int`):
132
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
133
+ must be `None`.
134
  device (`str` or `torch.device`, *optional*):
135
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
136
  timesteps (`List[int]`, *optional*):
137
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
138
+ `num_inference_steps` and `sigmas` must be `None`.
139
+ sigmas (`List[float]`, *optional*):
140
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
141
+ `num_inference_steps` and `timesteps` must be `None`.
142
 
143
  Returns:
144
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
145
  second element is the number of inference steps.
146
  """
147
+ if timesteps is not None and sigmas is not None:
148
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
149
  if timesteps is not None:
150
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
151
  if not accepts_timesteps:
 
156
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
157
  timesteps = scheduler.timesteps
158
  num_inference_steps = len(timesteps)
159
+ elif sigmas is not None:
160
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
161
+ if not accept_sigmas:
162
+ raise ValueError(
163
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
164
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
165
+ )
166
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
167
+ timesteps = scheduler.timesteps
168
+ num_inference_steps = len(timesteps)
169
  else:
170
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
171
  timesteps = scheduler.timesteps
172
  return timesteps, num_inference_steps
173
 
 
 
 
174
 
175
  class EasyAnimatePipeline(DiffusionPipeline):
176
  r"""
177
+ Pipeline for text-to-video generation using EasyAnimate.
178
 
179
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
180
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
181
 
182
+ EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
183
+ EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
184
+ HunyuanDiT team) in V5.
185
+
186
  Args:
187
+ vae ([`AutoencoderKLMagvit`]):
188
+ Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
189
+ text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
190
+ EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
191
+ EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5.
192
+ tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
193
+ A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
194
+ transformer ([`EasyAnimateTransformer3DModel`]):
195
+ The EasyAnimate model designed by EasyAnimate Team.
196
+ text_encoder_2 (`T5EncoderModel`):
197
+ EasyAnimate does not use text_encoder_2 in V5.1.
198
+ EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5.
199
+ tokenizer_2 (`T5Tokenizer`):
200
+ The tokenizer for the mT5 embedder.
201
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
202
+ A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
203
  """
 
 
 
204
 
205
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
206
+ _optional_components = [
207
+ "text_encoder_2",
208
+ "tokenizer_2",
209
+ "text_encoder",
210
+ "tokenizer",
211
+ ]
212
+ _callback_tensor_inputs = [
213
+ "latents",
214
+ "prompt_embeds",
215
+ "negative_prompt_embeds",
216
+ "prompt_embeds_2",
217
+ "negative_prompt_embeds_2",
218
+ ]
219
 
220
  def __init__(
221
  self,
222
+ vae: AutoencoderKLMagvit,
223
+ text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
224
+ tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
225
+ text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]],
226
+ tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]],
227
+ transformer: EasyAnimateTransformer3DModel,
228
+ scheduler: FlowMatchEulerDiscreteScheduler,
229
  ):
230
  super().__init__()
231
 
232
  self.register_modules(
233
+ vae=vae,
234
+ text_encoder=text_encoder,
235
+ text_encoder_2=text_encoder_2,
236
+ tokenizer=tokenizer,
237
+ tokenizer_2=tokenizer_2,
238
+ transformer=transformer,
239
+ scheduler=scheduler,
240
  )
241
 
242
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
 
 
 
 
 
 
 
 
 
 
 
243
 
244
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
245
+ super().enable_sequential_cpu_offload(*args, **kwargs)
246
+ if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
247
+ import accelerate
248
+ accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
249
+ self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
250
+
251
  def encode_prompt(
252
  self,
253
+ prompt: str,
254
+ device: torch.device,
255
+ dtype: torch.dtype,
256
  num_images_per_prompt: int = 1,
257
+ do_classifier_free_guidance: bool = True,
258
+ negative_prompt: Optional[str] = None,
259
+ prompt_embeds: Optional[torch.Tensor] = None,
260
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
261
+ prompt_attention_mask: Optional[torch.Tensor] = None,
262
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
263
+ max_sequence_length: Optional[int] = None,
264
+ text_encoder_index: int = 0,
265
+ actual_max_sequence_length: int = 256
266
  ):
267
  r"""
268
  Encodes the prompt into text encoder hidden states.
 
270
  Args:
271
  prompt (`str` or `List[str]`, *optional*):
272
  prompt to be encoded
273
+ device: (`torch.device`):
274
+ torch device
275
+ dtype (`torch.dtype`):
276
+ torch dtype
277
+ num_images_per_prompt (`int`):
 
 
278
  number of images that should be generated per prompt
279
+ do_classifier_free_guidance (`bool`):
280
+ whether to use classifier free guidance or not
281
+ negative_prompt (`str` or `List[str]`, *optional*):
282
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
283
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
284
+ less than `1`).
285
+ prompt_embeds (`torch.Tensor`, *optional*):
286
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
287
  provided, text embeddings will be generated from `prompt` input argument.
288
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
289
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
290
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
291
+ argument.
292
+ prompt_attention_mask (`torch.Tensor`, *optional*):
293
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
294
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
295
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
296
+ max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
297
+ text_encoder_index (`int`, *optional*):
298
+ Index of the text encoder to use. `0` for clip and `1` for T5.
299
  """
300
+ tokenizers = [self.tokenizer, self.tokenizer_2]
301
+ text_encoders = [self.text_encoder, self.text_encoder_2]
302
 
303
+ tokenizer = tokenizers[text_encoder_index]
304
+ text_encoder = text_encoders[text_encoder_index]
 
305
 
306
+ if max_sequence_length is None:
307
+ if text_encoder_index == 0:
308
+ max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
309
+ if text_encoder_index == 1:
310
+ max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
311
+ else:
312
+ max_length = max_sequence_length
313
 
314
  if prompt is not None and isinstance(prompt, str):
315
  batch_size = 1
 
318
  else:
319
  batch_size = prompt_embeds.shape[0]
320
 
 
 
 
321
  if prompt_embeds is None:
322
+ if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
323
+ text_inputs = tokenizer(
324
+ prompt,
325
+ padding="max_length",
326
+ max_length=max_length,
327
+ truncation=True,
328
+ return_attention_mask=True,
329
+ return_tensors="pt",
330
+ )
331
+ text_input_ids = text_inputs.input_ids
332
+ if text_input_ids.shape[-1] > actual_max_sequence_length:
333
+ reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
334
+ text_inputs = tokenizer(
335
+ reprompt,
336
+ padding="max_length",
337
+ max_length=max_length,
338
+ truncation=True,
339
+ return_attention_mask=True,
340
+ return_tensors="pt",
341
+ )
342
+ text_input_ids = text_inputs.input_ids
343
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
344
+
345
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
346
+ text_input_ids, untruncated_ids
347
+ ):
348
+ _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
349
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
350
+ logger.warning(
351
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
352
+ f" {_actual_max_sequence_length} tokens: {removed_text}"
353
+ )
354
+
355
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
356
+
357
+ if self.transformer.config.enable_text_attention_mask:
358
+ prompt_embeds = text_encoder(
359
+ text_input_ids.to(device),
360
+ attention_mask=prompt_attention_mask,
361
+ )
362
+ else:
363
+ prompt_embeds = text_encoder(
364
+ text_input_ids.to(device)
365
+ )
366
+ prompt_embeds = prompt_embeds[0]
367
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
368
+ else:
369
+ if prompt is not None and isinstance(prompt, str):
370
+ messages = [
371
+ {
372
+ "role": "user",
373
+ "content": [{"type": "text", "text": prompt}],
374
+ }
375
+ ]
376
+ else:
377
+ messages = [
378
+ {
379
+ "role": "user",
380
+ "content": [{"type": "text", "text": _prompt}],
381
+ } for _prompt in prompt
382
+ ]
383
+ text = tokenizer.apply_chat_template(
384
+ messages, tokenize=False, add_generation_prompt=True
385
  )
386
 
387
+ text_inputs = tokenizer(
388
+ text=[text],
389
+ padding="max_length",
390
+ max_length=max_length,
391
+ truncation=True,
392
+ return_attention_mask=True,
393
+ padding_side="right",
394
+ return_tensors="pt",
395
+ )
396
+ text_inputs = text_inputs.to(text_encoder.device)
397
+
398
+ text_input_ids = text_inputs.input_ids
399
+ prompt_attention_mask = text_inputs.attention_mask
400
+ if self.transformer.config.enable_text_attention_mask:
401
+ # Inference: Generation of the output
402
+ prompt_embeds = text_encoder(
403
+ input_ids=text_input_ids,
404
+ attention_mask=prompt_attention_mask,
405
+ output_hidden_states=True).hidden_states[-2]
406
+ else:
407
+ raise ValueError("LLM needs attention_mask")
408
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
409
+
410
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
411
 
412
  bs_embed, seq_len, _ = prompt_embeds.shape
413
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
414
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
415
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
416
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
 
417
 
418
  # get unconditional embeddings for classifier free guidance
419
  if do_classifier_free_guidance and negative_prompt_embeds is None:
420
+ if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
421
+ uncond_tokens: List[str]
422
+ if negative_prompt is None:
423
+ uncond_tokens = [""] * batch_size
424
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
425
+ raise TypeError(
426
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
427
+ f" {type(prompt)}."
428
+ )
429
+ elif isinstance(negative_prompt, str):
430
+ uncond_tokens = [negative_prompt]
431
+ elif batch_size != len(negative_prompt):
432
+ raise ValueError(
433
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
434
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
435
+ " the batch size of `prompt`."
436
+ )
437
+ else:
438
+ uncond_tokens = negative_prompt
439
+
440
+ max_length = prompt_embeds.shape[1]
441
+ uncond_input = tokenizer(
442
+ uncond_tokens,
443
+ padding="max_length",
444
+ max_length=max_length,
445
+ truncation=True,
446
+ return_tensors="pt",
447
+ )
448
+ uncond_input_ids = uncond_input.input_ids
449
+ if uncond_input_ids.shape[-1] > actual_max_sequence_length:
450
+ reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
451
+ uncond_input = tokenizer(
452
+ reuncond_tokens,
453
+ padding="max_length",
454
+ max_length=max_length,
455
+ truncation=True,
456
+ return_attention_mask=True,
457
+ return_tensors="pt",
458
+ )
459
+ uncond_input_ids = uncond_input.input_ids
460
+
461
+ negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
462
+ if self.transformer.config.enable_text_attention_mask:
463
+ negative_prompt_embeds = text_encoder(
464
+ uncond_input.input_ids.to(device),
465
+ attention_mask=negative_prompt_attention_mask,
466
+ )
467
+ else:
468
+ negative_prompt_embeds = text_encoder(
469
+ uncond_input.input_ids.to(device)
470
+ )
471
+ negative_prompt_embeds = negative_prompt_embeds[0]
472
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
473
+ else:
474
+ if negative_prompt is not None and isinstance(negative_prompt, str):
475
+ messages = [
476
+ {
477
+ "role": "user",
478
+ "content": [{"type": "text", "text": negative_prompt}],
479
+ }
480
+ ]
481
+ else:
482
+ messages = [
483
+ {
484
+ "role": "user",
485
+ "content": [{"type": "text", "text": _negative_prompt}],
486
+ } for _negative_prompt in negative_prompt
487
+ ]
488
+ text = tokenizer.apply_chat_template(
489
+ messages, tokenize=False, add_generation_prompt=True
490
+ )
491
 
492
+ text_inputs = tokenizer(
493
+ text=[text],
494
+ padding="max_length",
495
+ max_length=max_length,
496
+ truncation=True,
497
+ return_attention_mask=True,
498
+ padding_side="right",
499
+ return_tensors="pt",
500
+ )
501
+ text_inputs = text_inputs.to(text_encoder.device)
502
+
503
+ text_input_ids = text_inputs.input_ids
504
+ negative_prompt_attention_mask = text_inputs.attention_mask
505
+ if self.transformer.config.enable_text_attention_mask:
506
+ # Inference: Generation of the output
507
+ negative_prompt_embeds = text_encoder(
508
+ input_ids=text_input_ids,
509
+ attention_mask=negative_prompt_attention_mask,
510
+ output_hidden_states=True).hidden_states[-2]
511
+ else:
512
+ raise ValueError("LLM needs attention_mask")
513
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
514
 
515
  if do_classifier_free_guidance:
516
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
 
520
 
521
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
522
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
523
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
524
+
525
+ return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
 
 
 
 
 
526
 
527
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
528
  def prepare_extra_step_kwargs(self, generator, eta):
 
547
  prompt,
548
  height,
549
  width,
550
+ negative_prompt=None,
 
551
  prompt_embeds=None,
552
  negative_prompt_embeds=None,
553
+ prompt_attention_mask=None,
554
+ negative_prompt_attention_mask=None,
555
+ prompt_embeds_2=None,
556
+ negative_prompt_embeds_2=None,
557
+ prompt_attention_mask_2=None,
558
+ negative_prompt_attention_mask_2=None,
559
+ callback_on_step_end_tensor_inputs=None,
560
  ):
561
+ if height % 16 != 0 or width % 16 != 0:
562
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
563
 
564
+ if callback_on_step_end_tensor_inputs is not None and not all(
565
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
566
  ):
567
  raise ValueError(
568
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
 
569
  )
570
 
571
  if prompt is not None and prompt_embeds is not None:
 
577
  raise ValueError(
578
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
579
  )
580
+ elif prompt is None and prompt_embeds_2 is None:
581
+ raise ValueError(
582
+ "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
583
+ )
584
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
585
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
586
 
587
+ if prompt_embeds is not None and prompt_attention_mask is None:
588
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
589
+
590
+ if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
591
+ raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
592
 
593
  if negative_prompt is not None and negative_prompt_embeds is not None:
594
  raise ValueError(
 
596
  f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
597
  )
598
 
599
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
600
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
601
+
602
+ if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
603
+ raise ValueError(
604
+ "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
605
+ )
606
  if prompt_embeds is not None and negative_prompt_embeds is not None:
607
  if prompt_embeds.shape != negative_prompt_embeds.shape:
608
  raise ValueError(
 
610
  f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
611
  f" {negative_prompt_embeds.shape}."
612
  )
613
+ if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
614
+ if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
615
+ raise ValueError(
616
+ "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
617
+ f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
618
+ f" {negative_prompt_embeds_2.shape}."
619
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
622
  def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
623
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
624
+ if self.vae.cache_mag_vae:
625
+ mini_batch_encoder = self.vae.mini_batch_encoder
626
+ mini_batch_decoder = self.vae.mini_batch_decoder
627
+ shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
628
+ else:
629
+ mini_batch_encoder = self.vae.mini_batch_encoder
630
+ mini_batch_decoder = self.vae.mini_batch_decoder
631
+ shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
632
  else:
633
  shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
634
 
 
642
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
643
  else:
644
  latents = latents.to(device)
645
+
646
  # scale the initial noise by the standard deviation required by the scheduler
647
+ if hasattr(self.scheduler, "init_noise_sigma"):
648
+ latents = latents * self.scheduler.init_noise_sigma
649
  return latents
650
+
651
  def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
652
  if video.size()[2] <= mini_batch_encoder:
653
  return video
 
663
 
664
  video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
665
  return video
666
+
667
  def decode_latents(self, latents):
668
  video_length = latents.shape[2]
669
  latents = 1 / self.vae.config.scaling_factor * latents
670
+ if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
671
  mini_batch_encoder = self.vae.mini_batch_encoder
672
  mini_batch_decoder = self.vae.mini_batch_decoder
673
  video = self.vae.decode(latents)[0]
674
  video = video.clamp(-1, 1)
675
+ if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
676
+ video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
677
  else:
678
  latents = rearrange(latents, "b c f h w -> (b f) c h w")
679
  video = []
 
686
  video = video.cpu().float().numpy()
687
  return video
688
 
689
+ @property
690
+ def guidance_scale(self):
691
+ return self._guidance_scale
692
+
693
+ @property
694
+ def guidance_rescale(self):
695
+ return self._guidance_rescale
696
+
697
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
698
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
699
+ # corresponds to doing no classifier free guidance.
700
+ @property
701
+ def do_classifier_free_guidance(self):
702
+ return self._guidance_scale > 1
703
+
704
+ @property
705
+ def num_timesteps(self):
706
+ return self._num_timesteps
707
+
708
+ @property
709
+ def interrupt(self):
710
+ return self._interrupt
711
 
712
  @torch.no_grad()
713
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
715
  self,
716
  prompt: Union[str, List[str]] = None,
717
  video_length: Optional[int] = None,
 
 
 
 
 
718
  height: Optional[int] = None,
719
  width: Optional[int] = None,
720
+ num_inference_steps: Optional[int] = 50,
721
+ guidance_scale: Optional[float] = 5.0,
722
+ negative_prompt: Optional[Union[str, List[str]]] = None,
723
+ num_images_per_prompt: Optional[int] = 1,
724
+ eta: Optional[float] = 0.0,
725
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
726
+ latents: Optional[torch.Tensor] = None,
727
+ prompt_embeds: Optional[torch.Tensor] = None,
728
+ prompt_embeds_2: Optional[torch.Tensor] = None,
729
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
730
+ negative_prompt_embeds_2: Optional[torch.Tensor] = None,
731
+ prompt_attention_mask: Optional[torch.Tensor] = None,
732
+ prompt_attention_mask_2: Optional[torch.Tensor] = None,
733
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
734
+ negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
735
  output_type: Optional[str] = "latent",
736
  return_dict: bool = True,
737
+ callback_on_step_end: Optional[
738
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
739
+ ] = None,
740
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
741
+ guidance_rescale: float = 0.0,
742
+ original_size: Optional[Tuple[int, int]] = (1024, 1024),
743
+ target_size: Optional[Tuple[int, int]] = None,
744
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
745
  comfyui_progressbar: bool = False,
746
+ timesteps: Optional[List[int]] = None,
747
+ ):
748
+ r"""
749
+ Generates images or video using the EasyAnimate pipeline based on the provided prompts.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
 
751
  Examples:
752
+ prompt (`str` or `List[str]`, *optional*):
753
+ Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
754
+ video_length (`int`, *optional*):
755
+ Length of the generated video (in frames).
756
+ height (`int`, *optional*):
757
+ Height of the generated image in pixels.
758
+ width (`int`, *optional*):
759
+ Width of the generated image in pixels.
760
+ num_inference_steps (`int`, *optional*, defaults to 50):
761
+ Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
762
+ guidance_scale (`float`, *optional*, defaults to 5.0):
763
+ Encourages the model to align outputs with prompts. A higher value may decrease image quality.
764
+ negative_prompt (`str` or `List[str]`, *optional*):
765
+ Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
766
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
767
+ Number of images to generate for each prompt.
768
+ eta (`float`, *optional*, defaults to 0.0):
769
+ Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
770
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
771
+ A generator to ensure reproducibility in image generation.
772
+ latents (`torch.Tensor`, *optional*):
773
+ Predefined latent tensors to condition generation.
774
+ prompt_embeds (`torch.Tensor`, *optional*):
775
+ Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
776
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
777
+ Secondary text embeddings to supplement or replace the initial prompt embeddings.
778
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
779
+ Embeddings for negative prompts. Overrides string inputs if defined.
780
+ negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
781
+ Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
782
+ prompt_attention_mask (`torch.Tensor`, *optional*):
783
+ Attention mask for the primary prompt embeddings.
784
+ prompt_attention_mask_2 (`torch.Tensor`, *optional*):
785
+ Attention mask for the secondary prompt embeddings.
786
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
787
+ Attention mask for negative prompt embeddings.
788
+ negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
789
+ Attention mask for secondary negative prompt embeddings.
790
+ output_type (`str`, *optional*, defaults to "latent"):
791
+ Format of the generated output, either as a PIL image or as a NumPy array.
792
+ return_dict (`bool`, *optional*, defaults to `True`):
793
+ If `True`, returns a structured output. Otherwise returns a simple tuple.
794
+ callback_on_step_end (`Callable`, *optional*):
795
+ Functions called at the end of each denoising step.
796
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
797
+ Tensor names to be included in callback function calls.
798
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
799
+ Adjusts noise levels based on guidance scale.
800
+ original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
801
+ Original dimensions of the output.
802
+ target_size (`Tuple[int, int]`, *optional*):
803
+ Desired output dimensions for calculations.
804
+ crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
805
+ Coordinates for cropping.
806
 
807
  Returns:
808
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
809
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
810
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
811
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
812
+ "not-safe-for-work" (nsfw) content.
813
  """
814
+
815
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
816
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
817
+
818
+ # 0. default height and width
819
+ height = int((height // 16) * 16)
820
+ width = int((width // 16) * 16)
821
+
822
  # 1. Check inputs. Raise error if not correct
823
+ self.check_inputs(
824
+ prompt,
825
+ height,
826
+ width,
827
+ negative_prompt,
828
+ prompt_embeds,
829
+ negative_prompt_embeds,
830
+ prompt_attention_mask,
831
+ negative_prompt_attention_mask,
832
+ prompt_embeds_2,
833
+ negative_prompt_embeds_2,
834
+ prompt_attention_mask_2,
835
+ negative_prompt_attention_mask_2,
836
+ callback_on_step_end_tensor_inputs,
837
+ )
838
+ self._guidance_scale = guidance_scale
839
+ self._guidance_rescale = guidance_rescale
840
+ self._interrupt = False
841
 
842
+ # 2. Define call parameters
843
  if prompt is not None and isinstance(prompt, str):
844
  batch_size = 1
845
  elif prompt is not None and isinstance(prompt, list):
 
848
  batch_size = prompt_embeds.shape[0]
849
 
850
  device = self._execution_device
851
+ if self.text_encoder is not None:
852
+ dtype = self.text_encoder.dtype
853
+ elif self.text_encoder_2 is not None:
854
+ dtype = self.text_encoder_2.dtype
855
+ else:
856
+ dtype = self.transformer.dtype
857
 
858
  # 3. Encode input prompt
859
  (
860
  prompt_embeds,
 
861
  negative_prompt_embeds,
862
+ prompt_attention_mask,
863
  negative_prompt_attention_mask,
864
  ) = self.encode_prompt(
865
+ prompt=prompt,
 
 
 
866
  device=device,
867
+ dtype=dtype,
868
+ num_images_per_prompt=num_images_per_prompt,
869
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
870
+ negative_prompt=negative_prompt,
871
  prompt_embeds=prompt_embeds,
872
  negative_prompt_embeds=negative_prompt_embeds,
873
  prompt_attention_mask=prompt_attention_mask,
874
  negative_prompt_attention_mask=negative_prompt_attention_mask,
875
+ text_encoder_index=0,
 
876
  )
877
+ if self.tokenizer_2 is not None:
878
+ (
879
+ prompt_embeds_2,
880
+ negative_prompt_embeds_2,
881
+ prompt_attention_mask_2,
882
+ negative_prompt_attention_mask_2,
883
+ ) = self.encode_prompt(
884
+ prompt=prompt,
885
+ device=device,
886
+ dtype=dtype,
887
+ num_images_per_prompt=num_images_per_prompt,
888
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
889
+ negative_prompt=negative_prompt,
890
+ prompt_embeds=prompt_embeds_2,
891
+ negative_prompt_embeds=negative_prompt_embeds_2,
892
+ prompt_attention_mask=prompt_attention_mask_2,
893
+ negative_prompt_attention_mask=negative_prompt_attention_mask_2,
894
+ text_encoder_index=1,
895
+ )
896
+ else:
897
+ prompt_embeds_2 = None
898
+ negative_prompt_embeds_2 = None
899
+ prompt_attention_mask_2 = None
900
+ negative_prompt_attention_mask_2 = None
901
 
902
  # 4. Prepare timesteps
903
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
904
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
905
+ else:
906
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
907
+ if comfyui_progressbar:
908
+ from comfy.utils import ProgressBar
909
+ pbar = ProgressBar(num_inference_steps + 1)
910
 
911
+ # 5. Prepare latent variables
912
+ num_channels_latents = self.transformer.config.in_channels
913
  latents = self.prepare_latents(
914
  batch_size * num_images_per_prompt,
915
+ num_channels_latents,
916
  video_length,
917
  height,
918
  width,
919
+ dtype,
920
  device,
921
  generator,
922
  latents,
923
  )
924
+ if comfyui_progressbar:
925
+ pbar.update(1)
926
 
927
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
928
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
929
 
930
+ # 7 create image_rotary_emb, style embedding & time ids
931
+ grid_height = height // 8 // self.transformer.config.patch_size
932
+ grid_width = width // 8 // self.transformer.config.patch_size
933
+ if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
934
+ base_size_width = 720 // 8 // self.transformer.config.patch_size
935
+ base_size_height = 480 // 8 // self.transformer.config.patch_size
936
+
937
+ grid_crops_coords = get_resize_crop_region_for_grid(
938
+ (grid_height, grid_width), base_size_width, base_size_height
939
+ )
940
+ image_rotary_emb = get_3d_rotary_pos_embed(
941
+ self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
942
+ temporal_size=latents.size(2), use_real=True,
943
+ )
944
+ else:
945
+ base_size = 512 // 8 // self.transformer.config.patch_size
946
+ grid_crops_coords = get_resize_crop_region_for_grid(
947
+ (grid_height, grid_width), base_size, base_size
948
+ )
949
+ image_rotary_emb = get_2d_rotary_pos_embed(
950
+ self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
951
+ )
952
+
953
+ # Get other hunyuan params
954
+ target_size = target_size or (height, width)
955
+ add_time_ids = list(original_size + target_size + crops_coords_top_left)
956
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
957
+ style = torch.tensor([0], device=device)
958
+
959
+ if self.do_classifier_free_guidance:
960
+ add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
961
+ style = torch.cat([style] * 2, dim=0)
962
+
963
+ # To latents.device
964
+ add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat(
965
+ batch_size * num_images_per_prompt, 1
966
+ )
967
+ style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
968
+
969
+ # Get other pixart params
970
  added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
971
+ if self.transformer.config.get("sample_size", 64) == 128:
972
  resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
973
  aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
974
+ resolution = resolution.to(dtype=dtype, device=device)
975
+ aspect_ratio = aspect_ratio.to(dtype=dtype, device=device)
976
 
977
+ if self.do_classifier_free_guidance:
978
  resolution = torch.cat([resolution, resolution], dim=0)
979
  aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
980
 
981
  added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
982
 
983
+ if self.do_classifier_free_guidance:
984
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
985
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
986
+ if prompt_embeds_2 is not None:
987
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
988
+ prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
989
+
990
+ # To latents.device
991
+ prompt_embeds = prompt_embeds.to(device=device)
992
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
993
+ if prompt_embeds_2 is not None:
994
+ prompt_embeds_2 = prompt_embeds_2.to(device=device)
995
+ prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
996
+
997
+ # 8. Denoising loop
998
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
999
+ self._num_timesteps = len(timesteps)
1000
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1001
  for i, t in enumerate(timesteps):
1002
+ if self.interrupt:
1003
+ continue
1004
+
1005
+ # expand the latents if we are doing classifier free guidance
1006
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1007
+ if hasattr(self.scheduler, "scale_model_input"):
1008
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1009
+
1010
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
1011
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
1012
+ dtype=latent_model_input.dtype
1013
+ )
1014
+
1015
+ # predict the noise residual
 
 
 
 
 
1016
  noise_pred = self.transformer(
1017
  latent_model_input,
1018
+ t_expand,
1019
  encoder_hidden_states=prompt_embeds,
1020
+ text_embedding_mask=prompt_attention_mask,
1021
+ encoder_hidden_states_t5=prompt_embeds_2,
1022
+ text_embedding_mask_t5=prompt_attention_mask_2,
1023
+ image_meta_size=add_time_ids,
1024
+ style=style,
1025
+ image_rotary_emb=image_rotary_emb,
1026
  added_cond_kwargs=added_cond_kwargs,
1027
  return_dict=False,
1028
  )[0]
1029
+
1030
+ if noise_pred.size()[1] != self.vae.config.latent_channels:
1031
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
1032
 
1033
  # perform guidance
1034
+ if self.do_classifier_free_guidance:
1035
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1036
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1037
 
1038
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
1039
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1040
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
 
 
1041
 
1042
+ # compute the previous noisy sample x_t -> x_t-1
1043
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1044
 
1045
+ if callback_on_step_end is not None:
1046
+ callback_kwargs = {}
1047
+ for k in callback_on_step_end_tensor_inputs:
1048
+ callback_kwargs[k] = locals()[k]
1049
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1050
+
1051
+ latents = callback_outputs.pop("latents", latents)
1052
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1053
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1054
+ prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
1055
+ negative_prompt_embeds_2 = callback_outputs.pop(
1056
+ "negative_prompt_embeds_2", negative_prompt_embeds_2
1057
+ )
1058
+
1059
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1060
  progress_bar.update()
1061
+
1062
+ if XLA_AVAILABLE:
1063
+ xm.mark_step()
1064
 
1065
  if comfyui_progressbar:
1066
  pbar.update(1)
1067
 
 
 
 
1068
  # Post-processing
1069
  video = self.decode_latents(latents)
1070
 
 
1072
  if output_type == "latent":
1073
  video = torch.from_numpy(video)
1074
 
1075
+ # Offload all models
1076
+ self.maybe_free_model_hooks()
1077
+
1078
  if not return_dict:
1079
  return video
1080
 
1081
+ return EasyAnimatePipelineOutput(frames=video)
easyanimate/pipeline/{pipeline_easyanimate_multi_text_encoder_control.py → pipeline_easyanimate_control.py} RENAMED
@@ -31,7 +31,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
32
  from diffusers.pipelines.stable_diffusion.safety_checker import \
33
  StableDiffusionSafetyChecker
34
- from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler
 
35
  from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
36
  is_bs4_available, is_ftfy_available,
37
  is_torch_xla_available, logging,
@@ -41,11 +42,12 @@ from einops import rearrange
41
  from PIL import Image
42
  from tqdm import tqdm
43
  from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
44
- CLIPVisionModelWithProjection,
45
- T5EncoderModel, T5Tokenizer)
 
46
 
47
  from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
48
- from .pipeline_easyanimate import EasyAnimatePipelineOutput
49
 
50
  if is_torch_xla_available():
51
  import torch_xla.core.xla_model as xm
@@ -64,6 +66,7 @@ EXAMPLE_DOC_STRING = """
64
  ```
65
  """
66
 
 
67
  def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
68
  tw = tgt_width
69
  th = tgt_height
@@ -97,44 +100,140 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
97
  return noise_cfg
98
 
99
 
100
- class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  r"""
102
  Pipeline for text-to-video generation using EasyAnimate.
103
 
104
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
105
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
106
 
 
107
  EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
108
- HunyuanDiT team)
109
 
110
  Args:
111
  vae ([`AutoencoderKLMagvit`]):
112
  Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
113
- text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
114
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
115
- EasyAnimate uses a fine-tuned [bilingual CLIP].
116
- tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
117
- A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
118
  transformer ([`EasyAnimateTransformer3DModel`]):
119
- The EasyAnimate model designed by Tencent Hunyuan.
120
  text_encoder_2 (`T5EncoderModel`):
121
- The mT5 embedder.
 
122
  tokenizer_2 (`T5Tokenizer`):
123
  The tokenizer for the mT5 embedder.
124
- scheduler ([`DDIMScheduler`]):
125
  A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
126
  """
127
 
128
  model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
129
  _optional_components = [
130
- "safety_checker",
131
- "feature_extractor",
132
  "text_encoder_2",
133
  "tokenizer_2",
134
  "text_encoder",
135
  "tokenizer",
136
  ]
137
- _exclude_from_cpu_offload = ["safety_checker"]
138
  _callback_tensor_inputs = [
139
  "latents",
140
  "prompt_embeds",
@@ -146,53 +245,30 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
146
  def __init__(
147
  self,
148
  vae: AutoencoderKLMagvit,
149
- text_encoder: BertModel,
150
- tokenizer: BertTokenizer,
151
- text_encoder_2: T5EncoderModel,
152
- tokenizer_2: T5Tokenizer,
153
  transformer: EasyAnimateTransformer3DModel,
154
- scheduler: DDIMScheduler,
155
- safety_checker: StableDiffusionSafetyChecker,
156
- feature_extractor: CLIPImageProcessor,
157
- requires_safety_checker: bool = True
158
  ):
159
  super().__init__()
160
 
161
  self.register_modules(
162
  vae=vae,
163
  text_encoder=text_encoder,
 
164
  tokenizer=tokenizer,
165
  tokenizer_2=tokenizer_2,
166
  transformer=transformer,
167
  scheduler=scheduler,
168
- safety_checker=safety_checker,
169
- feature_extractor=feature_extractor,
170
- text_encoder_2=text_encoder_2
171
  )
172
 
173
- if safety_checker is None and requires_safety_checker:
174
- logger.warning(
175
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
176
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
177
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
178
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
179
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
180
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
181
- )
182
-
183
- if safety_checker is not None and feature_extractor is None:
184
- raise ValueError(
185
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
186
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
187
- )
188
-
189
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
190
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
191
  self.mask_processor = VaeImageProcessor(
192
  vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
193
  )
194
- self.enable_autocast_float8_transformer_flag = False
195
- self.register_to_config(requires_safety_checker=requires_safety_checker)
196
 
197
  def enable_sequential_cpu_offload(self, *args, **kwargs):
198
  super().enable_sequential_cpu_offload(*args, **kwargs)
@@ -272,19 +348,9 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
272
  batch_size = prompt_embeds.shape[0]
273
 
274
  if prompt_embeds is None:
275
- text_inputs = tokenizer(
276
- prompt,
277
- padding="max_length",
278
- max_length=max_length,
279
- truncation=True,
280
- return_attention_mask=True,
281
- return_tensors="pt",
282
- )
283
- text_input_ids = text_inputs.input_ids
284
- if text_input_ids.shape[-1] > actual_max_sequence_length:
285
- reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
286
  text_inputs = tokenizer(
287
- reprompt,
288
  padding="max_length",
289
  max_length=max_length,
290
  truncation=True,
@@ -292,91 +358,188 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
292
  return_tensors="pt",
293
  )
294
  text_input_ids = text_inputs.input_ids
295
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
296
-
297
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
298
- text_input_ids, untruncated_ids
299
- ):
300
- _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
301
- removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
302
- logger.warning(
303
- "The following part of your input was truncated because CLIP can only handle sequences up to"
304
- f" {_actual_max_sequence_length} tokens: {removed_text}"
305
- )
306
- prompt_attention_mask = text_inputs.attention_mask.to(device)
307
- if self.transformer.config.enable_text_attention_mask:
308
- prompt_embeds = text_encoder(
309
- text_input_ids.to(device),
310
- attention_mask=prompt_attention_mask,
311
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  else:
313
- prompt_embeds = text_encoder(
314
- text_input_ids.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  )
316
- prompt_embeds = prompt_embeds[0]
317
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
320
 
321
  bs_embed, seq_len, _ = prompt_embeds.shape
322
  # duplicate text embeddings for each generation per prompt, using mps friendly method
323
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
324
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
325
 
326
  # get unconditional embeddings for classifier free guidance
327
  if do_classifier_free_guidance and negative_prompt_embeds is None:
328
- uncond_tokens: List[str]
329
- if negative_prompt is None:
330
- uncond_tokens = [""] * batch_size
331
- elif prompt is not None and type(prompt) is not type(negative_prompt):
332
- raise TypeError(
333
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
334
- f" {type(prompt)}."
335
- )
336
- elif isinstance(negative_prompt, str):
337
- uncond_tokens = [negative_prompt]
338
- elif batch_size != len(negative_prompt):
339
- raise ValueError(
340
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
341
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
342
- " the batch size of `prompt`."
343
- )
344
- else:
345
- uncond_tokens = negative_prompt
346
-
347
- max_length = prompt_embeds.shape[1]
348
- uncond_input = tokenizer(
349
- uncond_tokens,
350
- padding="max_length",
351
- max_length=max_length,
352
- truncation=True,
353
- return_tensors="pt",
354
- )
355
- uncond_input_ids = uncond_input.input_ids
356
- if uncond_input_ids.shape[-1] > actual_max_sequence_length:
357
- reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
358
  uncond_input = tokenizer(
359
- reuncond_tokens,
360
  padding="max_length",
361
  max_length=max_length,
362
  truncation=True,
363
- return_attention_mask=True,
364
  return_tensors="pt",
365
  )
366
  uncond_input_ids = uncond_input.input_ids
 
 
 
 
 
 
 
 
 
 
 
367
 
368
- negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
369
- if self.transformer.config.enable_text_attention_mask:
370
- negative_prompt_embeds = text_encoder(
371
- uncond_input.input_ids.to(device),
372
- attention_mask=negative_prompt_attention_mask,
373
- )
 
 
 
 
 
 
374
  else:
375
- negative_prompt_embeds = text_encoder(
376
- uncond_input.input_ids.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  )
378
- negative_prompt_embeds = negative_prompt_embeds[0]
379
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
  if do_classifier_free_guidance:
382
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -386,24 +549,10 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
386
 
387
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
388
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
389
-
 
390
  return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
391
 
392
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
393
- def run_safety_checker(self, image, device, dtype):
394
- if self.safety_checker is None:
395
- has_nsfw_concept = None
396
- else:
397
- if torch.is_tensor(image):
398
- feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
399
- else:
400
- feature_extractor_input = self.image_processor.numpy_to_pil(image)
401
- safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
402
- image, has_nsfw_concept = self.safety_checker(
403
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
404
- )
405
- return image, has_nsfw_concept
406
-
407
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
408
  def prepare_extra_step_kwargs(self, generator, eta):
409
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -438,8 +587,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
438
  negative_prompt_attention_mask_2=None,
439
  callback_on_step_end_tensor_inputs=None,
440
  ):
441
- if height % 8 != 0 or width % 8 != 0:
442
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
443
 
444
  if callback_on_step_end_tensor_inputs is not None and not all(
445
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -524,43 +673,44 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
524
  latents = latents.to(device)
525
 
526
  # scale the initial noise by the standard deviation required by the scheduler
527
- latents = latents * self.scheduler.init_noise_sigma
 
528
  return latents
529
 
530
  def prepare_control_latents(
531
- self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
532
  ):
533
- # resize the mask to latents shape as we concatenate the mask to the latents
534
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
535
  # and half precision
536
 
537
- if mask is not None:
538
- mask = mask.to(device=device, dtype=self.vae.dtype)
539
  bs = 1
540
- new_mask = []
541
- for i in range(0, mask.shape[0], bs):
542
- mask_bs = mask[i : i + bs]
543
- mask_bs = self.vae.encode(mask_bs)[0]
544
- mask_bs = mask_bs.mode()
545
- new_mask.append(mask_bs)
546
- mask = torch.cat(new_mask, dim = 0)
547
- mask = mask * self.vae.config.scaling_factor
548
-
549
- if masked_image is not None:
550
- masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
551
  bs = 1
552
- new_mask_pixel_values = []
553
- for i in range(0, masked_image.shape[0], bs):
554
- mask_pixel_values_bs = masked_image[i : i + bs]
555
- mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
556
- mask_pixel_values_bs = mask_pixel_values_bs.mode()
557
- new_mask_pixel_values.append(mask_pixel_values_bs)
558
- masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
559
- masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
560
  else:
561
- masked_image_latents = None
562
 
563
- return mask, masked_image_latents
564
 
565
  def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
566
  if video.size()[2] <= mini_batch_encoder:
@@ -623,9 +773,6 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
623
  def interrupt(self):
624
  return self._interrupt
625
 
626
- def enable_autocast_float8_transformer(self):
627
- self.enable_autocast_float8_transformer_flag = True
628
-
629
  @torch.no_grad()
630
  @replace_example_docstring(EXAMPLE_DOC_STRING)
631
  def __call__(
@@ -635,6 +782,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
635
  height: Optional[int] = None,
636
  width: Optional[int] = None,
637
  control_video: Union[torch.FloatTensor] = None,
 
 
638
  num_inference_steps: Optional[int] = 50,
639
  guidance_scale: Optional[float] = 5.0,
640
  negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -661,6 +810,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
661
  target_size: Optional[Tuple[int, int]] = None,
662
  crops_coords_top_left: Tuple[int, int] = (0, 0),
663
  comfyui_progressbar: bool = False,
 
664
  ):
665
  r"""
666
  Generates images or video using the EasyAnimate pipeline based on the provided prompts.
@@ -765,6 +915,12 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
765
  batch_size = prompt_embeds.shape[0]
766
 
767
  device = self._execution_device
 
 
 
 
 
 
768
 
769
  # 3. Encode input prompt
770
  (
@@ -775,7 +931,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
775
  ) = self.encode_prompt(
776
  prompt=prompt,
777
  device=device,
778
- dtype=self.transformer.dtype,
779
  num_images_per_prompt=num_images_per_prompt,
780
  do_classifier_free_guidance=self.do_classifier_free_guidance,
781
  negative_prompt=negative_prompt,
@@ -785,28 +941,36 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
785
  negative_prompt_attention_mask=negative_prompt_attention_mask,
786
  text_encoder_index=0,
787
  )
788
- (
789
- prompt_embeds_2,
790
- negative_prompt_embeds_2,
791
- prompt_attention_mask_2,
792
- negative_prompt_attention_mask_2,
793
- ) = self.encode_prompt(
794
- prompt=prompt,
795
- device=device,
796
- dtype=self.transformer.dtype,
797
- num_images_per_prompt=num_images_per_prompt,
798
- do_classifier_free_guidance=self.do_classifier_free_guidance,
799
- negative_prompt=negative_prompt,
800
- prompt_embeds=prompt_embeds_2,
801
- negative_prompt_embeds=negative_prompt_embeds_2,
802
- prompt_attention_mask=prompt_attention_mask_2,
803
- negative_prompt_attention_mask=negative_prompt_attention_mask_2,
804
- text_encoder_index=1,
805
- )
806
- torch.cuda.empty_cache()
 
 
 
 
 
807
 
808
  # 4. Prepare timesteps
809
- self.scheduler.set_timesteps(num_inference_steps, device=device)
 
 
 
810
  timesteps = self.scheduler.timesteps
811
  if comfyui_progressbar:
812
  from comfy.utils import ProgressBar
@@ -820,7 +984,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
820
  video_length,
821
  height,
822
  width,
823
- prompt_embeds.dtype,
824
  device,
825
  generator,
826
  latents,
@@ -828,27 +992,69 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
828
  if comfyui_progressbar:
829
  pbar.update(1)
830
 
831
- if control_video is not None:
 
 
 
 
 
 
832
  video_length = control_video.shape[2]
833
  control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
834
  control_video = control_video.to(dtype=torch.float32)
835
  control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
  else:
837
- control_video = None
838
- control_video_latents = self.prepare_control_latents(
839
- None,
840
- control_video,
841
- batch_size,
842
- height,
843
- width,
844
- prompt_embeds.dtype,
845
- device,
846
- generator,
847
- self.do_classifier_free_guidance
848
- )[1]
849
- control_latents = (
850
- torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
851
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
 
853
  if comfyui_progressbar:
854
  pbar.update(1)
@@ -880,34 +1086,49 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
880
  )
881
 
882
  # Get other hunyuan params
883
- style = torch.tensor([0], device=device)
884
-
885
  target_size = target_size or (height, width)
886
  add_time_ids = list(original_size + target_size + crops_coords_top_left)
887
- add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
 
888
 
889
  if self.do_classifier_free_guidance:
890
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
891
- prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
892
- prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
893
- prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
894
  add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
895
  style = torch.cat([style] * 2, dim=0)
896
 
897
  # To latents.device
898
- prompt_embeds = prompt_embeds.to(device=device)
899
- prompt_attention_mask = prompt_attention_mask.to(device=device)
900
- prompt_embeds_2 = prompt_embeds_2.to(device=device)
901
- prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
902
- add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
903
  batch_size * num_images_per_prompt, 1
904
  )
905
  style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
906
 
907
- torch.cuda.empty_cache()
908
- if self.enable_autocast_float8_transformer_flag:
909
- origin_weight_dtype = self.transformer.dtype
910
- self.transformer = self.transformer.to(torch.float8_e4m3fn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  # 8. Denoising loop
912
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
913
  self._num_timesteps = len(timesteps)
@@ -918,7 +1139,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
918
 
919
  # expand the latents if we are doing classifier free guidance
920
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
921
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
922
 
923
  # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
924
  t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
@@ -935,8 +1157,9 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
935
  image_meta_size=add_time_ids,
936
  style=style,
937
  image_rotary_emb=image_rotary_emb,
938
- return_dict=False,
939
  control_latents=control_latents,
 
940
  )[0]
941
  if noise_pred.size()[1] != self.vae.config.latent_channels:
942
  noise_pred, _ = noise_pred.chunk(2, dim=1)
@@ -976,10 +1199,6 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
976
  if comfyui_progressbar:
977
  pbar.update(1)
978
 
979
- if self.enable_autocast_float8_transformer_flag:
980
- self.transformer = self.transformer.to("cpu", origin_weight_dtype)
981
-
982
- torch.cuda.empty_cache()
983
  # Post-processing
984
  video = self.decode_latents(latents)
985
 
@@ -993,4 +1212,4 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline):
993
  if not return_dict:
994
  return video
995
 
996
- return EasyAnimatePipelineOutput(videos=video)
 
31
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
32
  from diffusers.pipelines.stable_diffusion.safety_checker import \
33
  StableDiffusionSafetyChecker
34
+ from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
35
+ FlowMatchEulerDiscreteScheduler)
36
  from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
37
  is_bs4_available, is_ftfy_available,
38
  is_torch_xla_available, logging,
 
42
  from PIL import Image
43
  from tqdm import tqdm
44
  from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
45
+ CLIPVisionModelWithProjection, Qwen2Tokenizer,
46
+ Qwen2VLForConditionalGeneration, T5EncoderModel,
47
+ T5Tokenizer)
48
 
49
  from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
50
+ from .pipeline_easyanimate_inpaint import EasyAnimatePipelineOutput
51
 
52
  if is_torch_xla_available():
53
  import torch_xla.core.xla_model as xm
 
66
  ```
67
  """
68
 
69
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
70
  def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
71
  tw = tgt_width
72
  th = tgt_height
 
100
  return noise_cfg
101
 
102
 
103
+ # Resize mask information in magvit
104
+ def resize_mask(mask, latent, process_first_frame_only=True):
105
+ latent_size = latent.size()
106
+
107
+ if process_first_frame_only:
108
+ target_size = list(latent_size[2:])
109
+ target_size[0] = 1
110
+ first_frame_resized = F.interpolate(
111
+ mask[:, :, 0:1, :, :],
112
+ size=target_size,
113
+ mode='trilinear',
114
+ align_corners=False
115
+ )
116
+
117
+ target_size = list(latent_size[2:])
118
+ target_size[0] = target_size[0] - 1
119
+ if target_size[0] != 0:
120
+ remaining_frames_resized = F.interpolate(
121
+ mask[:, :, 1:, :, :],
122
+ size=target_size,
123
+ mode='trilinear',
124
+ align_corners=False
125
+ )
126
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
127
+ else:
128
+ resized_mask = first_frame_resized
129
+ else:
130
+ target_size = list(latent_size[2:])
131
+ resized_mask = F.interpolate(
132
+ mask,
133
+ size=target_size,
134
+ mode='trilinear',
135
+ align_corners=False
136
+ )
137
+ return resized_mask
138
+
139
+
140
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
141
+ def retrieve_timesteps(
142
+ scheduler,
143
+ num_inference_steps: Optional[int] = None,
144
+ device: Optional[Union[str, torch.device]] = None,
145
+ timesteps: Optional[List[int]] = None,
146
+ sigmas: Optional[List[float]] = None,
147
+ **kwargs,
148
+ ):
149
+ """
150
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
151
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
152
+
153
+ Args:
154
+ scheduler (`SchedulerMixin`):
155
+ The scheduler to get timesteps from.
156
+ num_inference_steps (`int`):
157
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
158
+ must be `None`.
159
+ device (`str` or `torch.device`, *optional*):
160
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
161
+ timesteps (`List[int]`, *optional*):
162
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
163
+ `num_inference_steps` and `sigmas` must be `None`.
164
+ sigmas (`List[float]`, *optional*):
165
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
166
+ `num_inference_steps` and `timesteps` must be `None`.
167
+
168
+ Returns:
169
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
170
+ second element is the number of inference steps.
171
+ """
172
+ if timesteps is not None and sigmas is not None:
173
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
174
+ if timesteps is not None:
175
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
176
+ if not accepts_timesteps:
177
+ raise ValueError(
178
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
179
+ f" timestep schedules. Please check whether you are using the correct scheduler."
180
+ )
181
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
182
+ timesteps = scheduler.timesteps
183
+ num_inference_steps = len(timesteps)
184
+ elif sigmas is not None:
185
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
186
+ if not accept_sigmas:
187
+ raise ValueError(
188
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
189
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
190
+ )
191
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
192
+ timesteps = scheduler.timesteps
193
+ num_inference_steps = len(timesteps)
194
+ else:
195
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
196
+ timesteps = scheduler.timesteps
197
+ return timesteps, num_inference_steps
198
+
199
+
200
+ class EasyAnimateControlPipeline(DiffusionPipeline):
201
  r"""
202
  Pipeline for text-to-video generation using EasyAnimate.
203
 
204
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
205
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
206
 
207
+ EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
208
  EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
209
+ HunyuanDiT team) in V5.
210
 
211
  Args:
212
  vae ([`AutoencoderKLMagvit`]):
213
  Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
214
+ text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
215
+ EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
216
+ EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5.
217
+ tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
218
+ A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
219
  transformer ([`EasyAnimateTransformer3DModel`]):
220
+ The EasyAnimate model designed by EasyAnimate Team.
221
  text_encoder_2 (`T5EncoderModel`):
222
+ EasyAnimate does not use text_encoder_2 in V5.1.
223
+ EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5.
224
  tokenizer_2 (`T5Tokenizer`):
225
  The tokenizer for the mT5 embedder.
226
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
227
  A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
228
  """
229
 
230
  model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
231
  _optional_components = [
 
 
232
  "text_encoder_2",
233
  "tokenizer_2",
234
  "text_encoder",
235
  "tokenizer",
236
  ]
 
237
  _callback_tensor_inputs = [
238
  "latents",
239
  "prompt_embeds",
 
245
  def __init__(
246
  self,
247
  vae: AutoencoderKLMagvit,
248
+ text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
249
+ tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
250
+ text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]],
251
+ tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]],
252
  transformer: EasyAnimateTransformer3DModel,
253
+ scheduler: FlowMatchEulerDiscreteScheduler,
 
 
 
254
  ):
255
  super().__init__()
256
 
257
  self.register_modules(
258
  vae=vae,
259
  text_encoder=text_encoder,
260
+ text_encoder_2=text_encoder_2,
261
  tokenizer=tokenizer,
262
  tokenizer_2=tokenizer_2,
263
  transformer=transformer,
264
  scheduler=scheduler,
 
 
 
265
  )
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
268
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
269
  self.mask_processor = VaeImageProcessor(
270
  vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
271
  )
 
 
272
 
273
  def enable_sequential_cpu_offload(self, *args, **kwargs):
274
  super().enable_sequential_cpu_offload(*args, **kwargs)
 
348
  batch_size = prompt_embeds.shape[0]
349
 
350
  if prompt_embeds is None:
351
+ if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
 
 
 
 
 
 
 
 
 
 
352
  text_inputs = tokenizer(
353
+ prompt,
354
  padding="max_length",
355
  max_length=max_length,
356
  truncation=True,
 
358
  return_tensors="pt",
359
  )
360
  text_input_ids = text_inputs.input_ids
361
+ if text_input_ids.shape[-1] > actual_max_sequence_length:
362
+ reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
363
+ text_inputs = tokenizer(
364
+ reprompt,
365
+ padding="max_length",
366
+ max_length=max_length,
367
+ truncation=True,
368
+ return_attention_mask=True,
369
+ return_tensors="pt",
370
+ )
371
+ text_input_ids = text_inputs.input_ids
372
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
373
+
374
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
375
+ text_input_ids, untruncated_ids
376
+ ):
377
+ _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
378
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
379
+ logger.warning(
380
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
381
+ f" {_actual_max_sequence_length} tokens: {removed_text}"
382
+ )
383
+
384
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
385
+
386
+ if self.transformer.config.enable_text_attention_mask:
387
+ prompt_embeds = text_encoder(
388
+ text_input_ids.to(device),
389
+ attention_mask=prompt_attention_mask,
390
+ )
391
+ else:
392
+ prompt_embeds = text_encoder(
393
+ text_input_ids.to(device)
394
+ )
395
+ prompt_embeds = prompt_embeds[0]
396
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
397
  else:
398
+ if prompt is not None and isinstance(prompt, str):
399
+ messages = [
400
+ {
401
+ "role": "user",
402
+ "content": [{"type": "text", "text": prompt}],
403
+ }
404
+ ]
405
+ else:
406
+ messages = [
407
+ {
408
+ "role": "user",
409
+ "content": [{"type": "text", "text": _prompt}],
410
+ } for _prompt in prompt
411
+ ]
412
+ text = tokenizer.apply_chat_template(
413
+ messages, tokenize=False, add_generation_prompt=True
414
  )
 
 
415
 
416
+ text_inputs = tokenizer(
417
+ text=[text],
418
+ padding="max_length",
419
+ max_length=max_length,
420
+ truncation=True,
421
+ return_attention_mask=True,
422
+ padding_side="right",
423
+ return_tensors="pt",
424
+ )
425
+ text_inputs = text_inputs.to(text_encoder.device)
426
+
427
+ text_input_ids = text_inputs.input_ids
428
+ prompt_attention_mask = text_inputs.attention_mask
429
+ if self.transformer.config.enable_text_attention_mask:
430
+ # Inference: Generation of the output
431
+ prompt_embeds = text_encoder(
432
+ input_ids=text_input_ids,
433
+ attention_mask=prompt_attention_mask,
434
+ output_hidden_states=True).hidden_states[-2]
435
+ else:
436
+ raise ValueError("LLM needs attention_mask")
437
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
438
+
439
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
440
 
441
  bs_embed, seq_len, _ = prompt_embeds.shape
442
  # duplicate text embeddings for each generation per prompt, using mps friendly method
443
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
444
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
445
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
446
 
447
  # get unconditional embeddings for classifier free guidance
448
  if do_classifier_free_guidance and negative_prompt_embeds is None:
449
+ if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
450
+ uncond_tokens: List[str]
451
+ if negative_prompt is None:
452
+ uncond_tokens = [""] * batch_size
453
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
454
+ raise TypeError(
455
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
456
+ f" {type(prompt)}."
457
+ )
458
+ elif isinstance(negative_prompt, str):
459
+ uncond_tokens = [negative_prompt]
460
+ elif batch_size != len(negative_prompt):
461
+ raise ValueError(
462
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
463
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
464
+ " the batch size of `prompt`."
465
+ )
466
+ else:
467
+ uncond_tokens = negative_prompt
468
+
469
+ max_length = prompt_embeds.shape[1]
 
 
 
 
 
 
 
 
 
470
  uncond_input = tokenizer(
471
+ uncond_tokens,
472
  padding="max_length",
473
  max_length=max_length,
474
  truncation=True,
 
475
  return_tensors="pt",
476
  )
477
  uncond_input_ids = uncond_input.input_ids
478
+ if uncond_input_ids.shape[-1] > actual_max_sequence_length:
479
+ reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
480
+ uncond_input = tokenizer(
481
+ reuncond_tokens,
482
+ padding="max_length",
483
+ max_length=max_length,
484
+ truncation=True,
485
+ return_attention_mask=True,
486
+ return_tensors="pt",
487
+ )
488
+ uncond_input_ids = uncond_input.input_ids
489
 
490
+ negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
491
+ if self.transformer.config.enable_text_attention_mask:
492
+ negative_prompt_embeds = text_encoder(
493
+ uncond_input.input_ids.to(device),
494
+ attention_mask=negative_prompt_attention_mask,
495
+ )
496
+ else:
497
+ negative_prompt_embeds = text_encoder(
498
+ uncond_input.input_ids.to(device)
499
+ )
500
+ negative_prompt_embeds = negative_prompt_embeds[0]
501
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
502
  else:
503
+ if negative_prompt is not None and isinstance(negative_prompt, str):
504
+ messages = [
505
+ {
506
+ "role": "user",
507
+ "content": [{"type": "text", "text": negative_prompt}],
508
+ }
509
+ ]
510
+ else:
511
+ messages = [
512
+ {
513
+ "role": "user",
514
+ "content": [{"type": "text", "text": _negative_prompt}],
515
+ } for _negative_prompt in negative_prompt
516
+ ]
517
+ text = tokenizer.apply_chat_template(
518
+ messages, tokenize=False, add_generation_prompt=True
519
  )
520
+
521
+ text_inputs = tokenizer(
522
+ text=[text],
523
+ padding="max_length",
524
+ max_length=max_length,
525
+ truncation=True,
526
+ return_attention_mask=True,
527
+ padding_side="right",
528
+ return_tensors="pt",
529
+ )
530
+ text_inputs = text_inputs.to(text_encoder.device)
531
+
532
+ text_input_ids = text_inputs.input_ids
533
+ negative_prompt_attention_mask = text_inputs.attention_mask
534
+ if self.transformer.config.enable_text_attention_mask:
535
+ # Inference: Generation of the output
536
+ negative_prompt_embeds = text_encoder(
537
+ input_ids=text_input_ids,
538
+ attention_mask=negative_prompt_attention_mask,
539
+ output_hidden_states=True).hidden_states[-2]
540
+ else:
541
+ raise ValueError("LLM needs attention_mask")
542
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
543
 
544
  if do_classifier_free_guidance:
545
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
 
549
 
550
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
551
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
552
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
553
+
554
  return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
557
  def prepare_extra_step_kwargs(self, generator, eta):
558
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
 
587
  negative_prompt_attention_mask_2=None,
588
  callback_on_step_end_tensor_inputs=None,
589
  ):
590
+ if height % 16 != 0 or width % 16 != 0:
591
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
592
 
593
  if callback_on_step_end_tensor_inputs is not None and not all(
594
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
 
673
  latents = latents.to(device)
674
 
675
  # scale the initial noise by the standard deviation required by the scheduler
676
+ if hasattr(self.scheduler, "init_noise_sigma"):
677
+ latents = latents * self.scheduler.init_noise_sigma
678
  return latents
679
 
680
  def prepare_control_latents(
681
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
682
  ):
683
+ # resize the control to latents shape as we concatenate the control to the latents
684
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
685
  # and half precision
686
 
687
+ if control is not None:
688
+ control = control.to(device=device, dtype=dtype)
689
  bs = 1
690
+ new_control = []
691
+ for i in range(0, control.shape[0], bs):
692
+ control_bs = control[i : i + bs]
693
+ control_bs = self.vae.encode(control_bs)[0]
694
+ control_bs = control_bs.mode()
695
+ new_control.append(control_bs)
696
+ control = torch.cat(new_control, dim = 0)
697
+ control = control * self.vae.config.scaling_factor
698
+
699
+ if control_image is not None:
700
+ control_image = control_image.to(device=device, dtype=dtype)
701
  bs = 1
702
+ new_control_pixel_values = []
703
+ for i in range(0, control_image.shape[0], bs):
704
+ control_pixel_values_bs = control_image[i : i + bs]
705
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
706
+ control_pixel_values_bs = control_pixel_values_bs.mode()
707
+ new_control_pixel_values.append(control_pixel_values_bs)
708
+ control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
709
+ control_image_latents = control_image_latents * self.vae.config.scaling_factor
710
  else:
711
+ control_image_latents = None
712
 
713
+ return control, control_image_latents
714
 
715
  def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
716
  if video.size()[2] <= mini_batch_encoder:
 
773
  def interrupt(self):
774
  return self._interrupt
775
 
 
 
 
776
  @torch.no_grad()
777
  @replace_example_docstring(EXAMPLE_DOC_STRING)
778
  def __call__(
 
782
  height: Optional[int] = None,
783
  width: Optional[int] = None,
784
  control_video: Union[torch.FloatTensor] = None,
785
+ control_camera_video: Union[torch.FloatTensor] = None,
786
+ ref_image: Union[torch.FloatTensor] = None,
787
  num_inference_steps: Optional[int] = 50,
788
  guidance_scale: Optional[float] = 5.0,
789
  negative_prompt: Optional[Union[str, List[str]]] = None,
 
810
  target_size: Optional[Tuple[int, int]] = None,
811
  crops_coords_top_left: Tuple[int, int] = (0, 0),
812
  comfyui_progressbar: bool = False,
813
+ timesteps: Optional[List[int]] = None,
814
  ):
815
  r"""
816
  Generates images or video using the EasyAnimate pipeline based on the provided prompts.
 
915
  batch_size = prompt_embeds.shape[0]
916
 
917
  device = self._execution_device
918
+ if self.text_encoder is not None:
919
+ dtype = self.text_encoder.dtype
920
+ elif self.text_encoder_2 is not None:
921
+ dtype = self.text_encoder_2.dtype
922
+ else:
923
+ dtype = self.transformer.dtype
924
 
925
  # 3. Encode input prompt
926
  (
 
931
  ) = self.encode_prompt(
932
  prompt=prompt,
933
  device=device,
934
+ dtype=dtype,
935
  num_images_per_prompt=num_images_per_prompt,
936
  do_classifier_free_guidance=self.do_classifier_free_guidance,
937
  negative_prompt=negative_prompt,
 
941
  negative_prompt_attention_mask=negative_prompt_attention_mask,
942
  text_encoder_index=0,
943
  )
944
+ if self.tokenizer_2 is not None:
945
+ (
946
+ prompt_embeds_2,
947
+ negative_prompt_embeds_2,
948
+ prompt_attention_mask_2,
949
+ negative_prompt_attention_mask_2,
950
+ ) = self.encode_prompt(
951
+ prompt=prompt,
952
+ device=device,
953
+ dtype=dtype,
954
+ num_images_per_prompt=num_images_per_prompt,
955
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
956
+ negative_prompt=negative_prompt,
957
+ prompt_embeds=prompt_embeds_2,
958
+ negative_prompt_embeds=negative_prompt_embeds_2,
959
+ prompt_attention_mask=prompt_attention_mask_2,
960
+ negative_prompt_attention_mask=negative_prompt_attention_mask_2,
961
+ text_encoder_index=1,
962
+ )
963
+ else:
964
+ prompt_embeds_2 = None
965
+ negative_prompt_embeds_2 = None
966
+ prompt_attention_mask_2 = None
967
+ negative_prompt_attention_mask_2 = None
968
 
969
  # 4. Prepare timesteps
970
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
971
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
972
+ else:
973
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
974
  timesteps = self.scheduler.timesteps
975
  if comfyui_progressbar:
976
  from comfy.utils import ProgressBar
 
984
  video_length,
985
  height,
986
  width,
987
+ dtype,
988
  device,
989
  generator,
990
  latents,
 
992
  if comfyui_progressbar:
993
  pbar.update(1)
994
 
995
+ if control_camera_video is not None:
996
+ control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True)
997
+ control_video_latents = control_video_latents * 6
998
+ control_latents = (
999
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
1000
+ ).to(device, dtype)
1001
+ elif control_video is not None:
1002
  video_length = control_video.shape[2]
1003
  control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
1004
  control_video = control_video.to(dtype=torch.float32)
1005
  control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
1006
+ control_video_latents = self.prepare_control_latents(
1007
+ None,
1008
+ control_video,
1009
+ batch_size,
1010
+ height,
1011
+ width,
1012
+ dtype,
1013
+ device,
1014
+ generator,
1015
+ self.do_classifier_free_guidance
1016
+ )[1]
1017
+ control_latents = (
1018
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
1019
+ ).to(device, dtype)
1020
  else:
1021
+ control_video_latents = torch.zeros_like(latents).to(device, dtype)
1022
+ control_latents = (
1023
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
1024
+ ).to(device, dtype)
1025
+
1026
+ if ref_image is not None:
1027
+ video_length = ref_image.shape[2]
1028
+ ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
1029
+ ref_image = ref_image.to(dtype=torch.float32)
1030
+ ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
1031
+
1032
+ ref_image_latentes = self.prepare_control_latents(
1033
+ None,
1034
+ ref_image,
1035
+ batch_size,
1036
+ height,
1037
+ width,
1038
+ prompt_embeds.dtype,
1039
+ device,
1040
+ generator,
1041
+ self.do_classifier_free_guidance
1042
+ )[1]
1043
+
1044
+ ref_image_latentes_conv_in = torch.zeros_like(latents)
1045
+ if latents.size()[2] != 1:
1046
+ ref_image_latentes_conv_in[:, :, :1] = ref_image_latentes
1047
+ ref_image_latentes_conv_in = (
1048
+ torch.cat([ref_image_latentes_conv_in] * 2) if self.do_classifier_free_guidance else ref_image_latentes_conv_in
1049
+ ).to(device, dtype)
1050
+ control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
1051
+ else:
1052
+ if self.transformer.config.get("add_ref_latent_in_control_model", False):
1053
+ ref_image_latentes_conv_in = torch.zeros_like(latents)
1054
+ ref_image_latentes_conv_in = (
1055
+ torch.cat([ref_image_latentes_conv_in] * 2) if self.do_classifier_free_guidance else ref_image_latentes_conv_in
1056
+ ).to(device, dtype)
1057
+ control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
1058
 
1059
  if comfyui_progressbar:
1060
  pbar.update(1)
 
1086
  )
1087
 
1088
  # Get other hunyuan params
 
 
1089
  target_size = target_size or (height, width)
1090
  add_time_ids = list(original_size + target_size + crops_coords_top_left)
1091
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
1092
+ style = torch.tensor([0], device=device)
1093
 
1094
  if self.do_classifier_free_guidance:
 
 
 
 
1095
  add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
1096
  style = torch.cat([style] * 2, dim=0)
1097
 
1098
  # To latents.device
1099
+ add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat(
 
 
 
 
1100
  batch_size * num_images_per_prompt, 1
1101
  )
1102
  style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
1103
 
1104
+ # Get other pixart params
1105
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
1106
+ if self.transformer.config.get("sample_size", 64) == 128:
1107
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
1108
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
1109
+ resolution = resolution.to(dtype=dtype, device=device)
1110
+ aspect_ratio = aspect_ratio.to(dtype=dtype, device=device)
1111
+
1112
+ if self.do_classifier_free_guidance:
1113
+ resolution = torch.cat([resolution, resolution], dim=0)
1114
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
1115
+
1116
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
1117
+
1118
+ if self.do_classifier_free_guidance:
1119
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1120
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
1121
+ if prompt_embeds_2 is not None:
1122
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
1123
+ prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
1124
+
1125
+ # To latents.device
1126
+ prompt_embeds = prompt_embeds.to(device=device)
1127
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
1128
+ if prompt_embeds_2 is not None:
1129
+ prompt_embeds_2 = prompt_embeds_2.to(device=device)
1130
+ prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
1131
+
1132
  # 8. Denoising loop
1133
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1134
  self._num_timesteps = len(timesteps)
 
1139
 
1140
  # expand the latents if we are doing classifier free guidance
1141
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1142
+ if hasattr(self.scheduler, "scale_model_input"):
1143
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1144
 
1145
  # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
1146
  t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
 
1157
  image_meta_size=add_time_ids,
1158
  style=style,
1159
  image_rotary_emb=image_rotary_emb,
1160
+ added_cond_kwargs=added_cond_kwargs,
1161
  control_latents=control_latents,
1162
+ return_dict=False,
1163
  )[0]
1164
  if noise_pred.size()[1] != self.vae.config.latent_channels:
1165
  noise_pred, _ = noise_pred.chunk(2, dim=1)
 
1199
  if comfyui_progressbar:
1200
  pbar.update(1)
1201
 
 
 
 
 
1202
  # Post-processing
1203
  video = self.decode_latents(latents)
1204
 
 
1212
  if not return_dict:
1213
  return video
1214
 
1215
+ return EasyAnimatePipelineOutput(frames=video)
easyanimate/pipeline/pipeline_easyanimate_inpaint.py CHANGED
The diff for this file is too large to render. See raw diff
 
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py DELETED
@@ -1,925 +0,0 @@
1
- # Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- from typing import Callable, Dict, List, Optional, Tuple, Union
17
-
18
- import numpy as np
19
- import torch
20
- from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
21
- from diffusers.image_processor import VaeImageProcessor
22
- from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
23
- get_3d_rotary_pos_embed)
24
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
26
- from diffusers.pipelines.stable_diffusion.safety_checker import \
27
- StableDiffusionSafetyChecker
28
- from diffusers.schedulers import DDIMScheduler
29
- from diffusers.utils import (is_torch_xla_available, logging,
30
- replace_example_docstring)
31
- from diffusers.utils.torch_utils import randn_tensor
32
- from einops import rearrange
33
- from tqdm import tqdm
34
- from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
35
- T5Tokenizer, T5EncoderModel)
36
-
37
- from .pipeline_easyanimate import EasyAnimatePipelineOutput
38
- from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
39
-
40
- if is_torch_xla_available():
41
- import torch_xla.core.xla_model as xm
42
-
43
- XLA_AVAILABLE = True
44
- else:
45
- XLA_AVAILABLE = False
46
-
47
-
48
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
-
50
- EXAMPLE_DOC_STRING = """
51
- Examples:
52
- ```py
53
- >>> pass
54
- ```
55
- """
56
-
57
-
58
- def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
59
- tw = tgt_width
60
- th = tgt_height
61
- h, w = src
62
- r = h / w
63
- if r > (th / tw):
64
- resize_height = th
65
- resize_width = int(round(th / h * w))
66
- else:
67
- resize_width = tw
68
- resize_height = int(round(tw / w * h))
69
-
70
- crop_top = int(round((th - resize_height) / 2.0))
71
- crop_left = int(round((tw - resize_width) / 2.0))
72
-
73
- return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
74
-
75
-
76
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
77
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
78
- """
79
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
80
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
81
- """
82
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
83
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
84
- # rescale the results from guidance (fixes overexposure)
85
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
86
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
87
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
88
- return noise_cfg
89
-
90
-
91
- class EasyAnimatePipeline_Multi_Text_Encoder(DiffusionPipeline):
92
- r"""
93
- Pipeline for text-to-video generation using EasyAnimate.
94
-
95
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
96
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
97
-
98
- EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
99
- HunyuanDiT team)
100
-
101
- Args:
102
- vae ([`AutoencoderKLMagvit`]):
103
- Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
104
- text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
105
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
106
- EasyAnimate uses a fine-tuned [bilingual CLIP].
107
- tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
108
- A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
109
- transformer ([`EasyAnimateTransformer3DModel`]):
110
- The EasyAnimate model designed by Tencent Hunyuan.
111
- text_encoder_2 (`T5EncoderModel`):
112
- The mT5 embedder.
113
- tokenizer_2 (`T5Tokenizer`):
114
- The tokenizer for the mT5 embedder.
115
- scheduler ([`DDIMScheduler`]):
116
- A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
117
- """
118
-
119
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
120
- _optional_components = [
121
- "safety_checker",
122
- "feature_extractor",
123
- "text_encoder_2",
124
- "tokenizer_2",
125
- "text_encoder",
126
- "tokenizer",
127
- ]
128
- _exclude_from_cpu_offload = ["safety_checker"]
129
- _callback_tensor_inputs = [
130
- "latents",
131
- "prompt_embeds",
132
- "negative_prompt_embeds",
133
- "prompt_embeds_2",
134
- "negative_prompt_embeds_2",
135
- ]
136
-
137
- def __init__(
138
- self,
139
- vae: AutoencoderKLMagvit,
140
- text_encoder: BertModel,
141
- tokenizer: BertTokenizer,
142
- text_encoder_2: T5EncoderModel,
143
- tokenizer_2: T5Tokenizer,
144
- transformer: EasyAnimateTransformer3DModel,
145
- scheduler: DDIMScheduler,
146
- safety_checker: StableDiffusionSafetyChecker,
147
- feature_extractor: CLIPImageProcessor,
148
- requires_safety_checker: bool = True,
149
- ):
150
- super().__init__()
151
-
152
- self.register_modules(
153
- vae=vae,
154
- text_encoder=text_encoder,
155
- tokenizer=tokenizer,
156
- tokenizer_2=tokenizer_2,
157
- transformer=transformer,
158
- scheduler=scheduler,
159
- safety_checker=safety_checker,
160
- feature_extractor=feature_extractor,
161
- text_encoder_2=text_encoder_2,
162
- )
163
-
164
- if safety_checker is None and requires_safety_checker:
165
- logger.warning(
166
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
167
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
168
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
169
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
170
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
171
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
172
- )
173
-
174
- if safety_checker is not None and feature_extractor is None:
175
- raise ValueError(
176
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
177
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
178
- )
179
-
180
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
181
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
182
- self.enable_autocast_float8_transformer_flag = False
183
- self.register_to_config(requires_safety_checker=requires_safety_checker)
184
-
185
- def enable_sequential_cpu_offload(self, *args, **kwargs):
186
- super().enable_sequential_cpu_offload(*args, **kwargs)
187
- if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
188
- import accelerate
189
- accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
190
- self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
191
-
192
- def encode_prompt(
193
- self,
194
- prompt: str,
195
- device: torch.device,
196
- dtype: torch.dtype,
197
- num_images_per_prompt: int = 1,
198
- do_classifier_free_guidance: bool = True,
199
- negative_prompt: Optional[str] = None,
200
- prompt_embeds: Optional[torch.Tensor] = None,
201
- negative_prompt_embeds: Optional[torch.Tensor] = None,
202
- prompt_attention_mask: Optional[torch.Tensor] = None,
203
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
204
- max_sequence_length: Optional[int] = None,
205
- text_encoder_index: int = 0,
206
- actual_max_sequence_length: int = 256
207
- ):
208
- r"""
209
- Encodes the prompt into text encoder hidden states.
210
-
211
- Args:
212
- prompt (`str` or `List[str]`, *optional*):
213
- prompt to be encoded
214
- device: (`torch.device`):
215
- torch device
216
- dtype (`torch.dtype`):
217
- torch dtype
218
- num_images_per_prompt (`int`):
219
- number of images that should be generated per prompt
220
- do_classifier_free_guidance (`bool`):
221
- whether to use classifier free guidance or not
222
- negative_prompt (`str` or `List[str]`, *optional*):
223
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
224
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
225
- less than `1`).
226
- prompt_embeds (`torch.Tensor`, *optional*):
227
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
228
- provided, text embeddings will be generated from `prompt` input argument.
229
- negative_prompt_embeds (`torch.Tensor`, *optional*):
230
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
231
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
232
- argument.
233
- prompt_attention_mask (`torch.Tensor`, *optional*):
234
- Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
235
- negative_prompt_attention_mask (`torch.Tensor`, *optional*):
236
- Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
237
- max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
238
- text_encoder_index (`int`, *optional*):
239
- Index of the text encoder to use. `0` for clip and `1` for T5.
240
- """
241
- tokenizers = [self.tokenizer, self.tokenizer_2]
242
- text_encoders = [self.text_encoder, self.text_encoder_2]
243
-
244
- tokenizer = tokenizers[text_encoder_index]
245
- text_encoder = text_encoders[text_encoder_index]
246
-
247
- if max_sequence_length is None:
248
- if text_encoder_index == 0:
249
- max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
250
- if text_encoder_index == 1:
251
- max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
252
- else:
253
- max_length = max_sequence_length
254
-
255
- if prompt is not None and isinstance(prompt, str):
256
- batch_size = 1
257
- elif prompt is not None and isinstance(prompt, list):
258
- batch_size = len(prompt)
259
- else:
260
- batch_size = prompt_embeds.shape[0]
261
-
262
- if prompt_embeds is None:
263
- text_inputs = tokenizer(
264
- prompt,
265
- padding="max_length",
266
- max_length=max_length,
267
- truncation=True,
268
- return_attention_mask=True,
269
- return_tensors="pt",
270
- )
271
- text_input_ids = text_inputs.input_ids
272
- if text_input_ids.shape[-1] > actual_max_sequence_length:
273
- reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
274
- text_inputs = tokenizer(
275
- reprompt,
276
- padding="max_length",
277
- max_length=max_length,
278
- truncation=True,
279
- return_attention_mask=True,
280
- return_tensors="pt",
281
- )
282
- text_input_ids = text_inputs.input_ids
283
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
284
-
285
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
286
- text_input_ids, untruncated_ids
287
- ):
288
- _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
289
- removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
290
- logger.warning(
291
- "The following part of your input was truncated because CLIP can only handle sequences up to"
292
- f" {_actual_max_sequence_length} tokens: {removed_text}"
293
- )
294
- prompt_attention_mask = text_inputs.attention_mask.to(device)
295
-
296
- if self.transformer.config.enable_text_attention_mask:
297
- prompt_embeds = text_encoder(
298
- text_input_ids.to(device),
299
- attention_mask=prompt_attention_mask,
300
- )
301
- else:
302
- prompt_embeds = text_encoder(
303
- text_input_ids.to(device)
304
- )
305
- prompt_embeds = prompt_embeds[0]
306
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
307
-
308
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
309
-
310
- bs_embed, seq_len, _ = prompt_embeds.shape
311
- # duplicate text embeddings for each generation per prompt, using mps friendly method
312
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
313
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
314
-
315
- # get unconditional embeddings for classifier free guidance
316
- if do_classifier_free_guidance and negative_prompt_embeds is None:
317
- uncond_tokens: List[str]
318
- if negative_prompt is None:
319
- uncond_tokens = [""] * batch_size
320
- elif prompt is not None and type(prompt) is not type(negative_prompt):
321
- raise TypeError(
322
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
323
- f" {type(prompt)}."
324
- )
325
- elif isinstance(negative_prompt, str):
326
- uncond_tokens = [negative_prompt]
327
- elif batch_size != len(negative_prompt):
328
- raise ValueError(
329
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
330
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
331
- " the batch size of `prompt`."
332
- )
333
- else:
334
- uncond_tokens = negative_prompt
335
-
336
- max_length = prompt_embeds.shape[1]
337
- uncond_input = tokenizer(
338
- uncond_tokens,
339
- padding="max_length",
340
- max_length=max_length,
341
- truncation=True,
342
- return_tensors="pt",
343
- )
344
- uncond_input_ids = uncond_input.input_ids
345
- if uncond_input_ids.shape[-1] > actual_max_sequence_length:
346
- reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
347
- uncond_input = tokenizer(
348
- reuncond_tokens,
349
- padding="max_length",
350
- max_length=max_length,
351
- truncation=True,
352
- return_attention_mask=True,
353
- return_tensors="pt",
354
- )
355
- uncond_input_ids = uncond_input.input_ids
356
-
357
- negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
358
- if self.transformer.config.enable_text_attention_mask:
359
- negative_prompt_embeds = text_encoder(
360
- uncond_input.input_ids.to(device),
361
- attention_mask=negative_prompt_attention_mask,
362
- )
363
- else:
364
- negative_prompt_embeds = text_encoder(
365
- uncond_input.input_ids.to(device)
366
- )
367
- negative_prompt_embeds = negative_prompt_embeds[0]
368
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
369
-
370
- if do_classifier_free_guidance:
371
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
372
- seq_len = negative_prompt_embeds.shape[1]
373
-
374
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
375
-
376
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
377
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
378
-
379
- return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
380
-
381
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
382
- def run_safety_checker(self, image, device, dtype):
383
- if self.safety_checker is None:
384
- has_nsfw_concept = None
385
- else:
386
- if torch.is_tensor(image):
387
- feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
388
- else:
389
- feature_extractor_input = self.image_processor.numpy_to_pil(image)
390
- safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
391
- image, has_nsfw_concept = self.safety_checker(
392
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
393
- )
394
- return image, has_nsfw_concept
395
-
396
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
397
- def prepare_extra_step_kwargs(self, generator, eta):
398
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
399
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
400
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
401
- # and should be between [0, 1]
402
-
403
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
404
- extra_step_kwargs = {}
405
- if accepts_eta:
406
- extra_step_kwargs["eta"] = eta
407
-
408
- # check if the scheduler accepts generator
409
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
410
- if accepts_generator:
411
- extra_step_kwargs["generator"] = generator
412
- return extra_step_kwargs
413
-
414
- def check_inputs(
415
- self,
416
- prompt,
417
- height,
418
- width,
419
- negative_prompt=None,
420
- prompt_embeds=None,
421
- negative_prompt_embeds=None,
422
- prompt_attention_mask=None,
423
- negative_prompt_attention_mask=None,
424
- prompt_embeds_2=None,
425
- negative_prompt_embeds_2=None,
426
- prompt_attention_mask_2=None,
427
- negative_prompt_attention_mask_2=None,
428
- callback_on_step_end_tensor_inputs=None,
429
- ):
430
- if height % 8 != 0 or width % 8 != 0:
431
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
432
-
433
- if callback_on_step_end_tensor_inputs is not None and not all(
434
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
435
- ):
436
- raise ValueError(
437
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
438
- )
439
-
440
- if prompt is not None and prompt_embeds is not None:
441
- raise ValueError(
442
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
443
- " only forward one of the two."
444
- )
445
- elif prompt is None and prompt_embeds is None:
446
- raise ValueError(
447
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
448
- )
449
- elif prompt is None and prompt_embeds_2 is None:
450
- raise ValueError(
451
- "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
452
- )
453
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
454
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
455
-
456
- if prompt_embeds is not None and prompt_attention_mask is None:
457
- raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
458
-
459
- if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
460
- raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
461
-
462
- if negative_prompt is not None and negative_prompt_embeds is not None:
463
- raise ValueError(
464
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
465
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
466
- )
467
-
468
- if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
469
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
470
-
471
- if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
472
- raise ValueError(
473
- "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
474
- )
475
- if prompt_embeds is not None and negative_prompt_embeds is not None:
476
- if prompt_embeds.shape != negative_prompt_embeds.shape:
477
- raise ValueError(
478
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
479
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
480
- f" {negative_prompt_embeds.shape}."
481
- )
482
- if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
483
- if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
484
- raise ValueError(
485
- "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
486
- f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
487
- f" {negative_prompt_embeds_2.shape}."
488
- )
489
-
490
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
491
- def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
492
- if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
493
- if self.vae.cache_mag_vae:
494
- mini_batch_encoder = self.vae.mini_batch_encoder
495
- mini_batch_decoder = self.vae.mini_batch_decoder
496
- shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
497
- else:
498
- mini_batch_encoder = self.vae.mini_batch_encoder
499
- mini_batch_decoder = self.vae.mini_batch_decoder
500
- shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
501
- else:
502
- shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
503
-
504
- if isinstance(generator, list) and len(generator) != batch_size:
505
- raise ValueError(
506
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
507
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
508
- )
509
-
510
- if latents is None:
511
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
512
- else:
513
- latents = latents.to(device)
514
-
515
- # scale the initial noise by the standard deviation required by the scheduler
516
- latents = latents * self.scheduler.init_noise_sigma
517
- return latents
518
-
519
- def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
520
- if video.size()[2] <= mini_batch_encoder:
521
- return video
522
- prefix_index_before = mini_batch_encoder // 2
523
- prefix_index_after = mini_batch_encoder - prefix_index_before
524
- pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
525
-
526
- # Encode middle videos
527
- latents = self.vae.encode(pixel_values)[0]
528
- latents = latents.mode()
529
- # Decode middle videos
530
- middle_video = self.vae.decode(latents)[0]
531
-
532
- video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
533
- return video
534
-
535
- def decode_latents(self, latents):
536
- video_length = latents.shape[2]
537
- latents = 1 / self.vae.config.scaling_factor * latents
538
- if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
539
- mini_batch_encoder = self.vae.mini_batch_encoder
540
- mini_batch_decoder = self.vae.mini_batch_decoder
541
- video = self.vae.decode(latents)[0]
542
- video = video.clamp(-1, 1)
543
- if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
544
- video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
545
- else:
546
- latents = rearrange(latents, "b c f h w -> (b f) c h w")
547
- video = []
548
- for frame_idx in tqdm(range(latents.shape[0])):
549
- video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
550
- video = torch.cat(video)
551
- video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
552
- video = (video / 2 + 0.5).clamp(0, 1)
553
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
554
- video = video.cpu().float().numpy()
555
- return video
556
-
557
- @property
558
- def guidance_scale(self):
559
- return self._guidance_scale
560
-
561
- @property
562
- def guidance_rescale(self):
563
- return self._guidance_rescale
564
-
565
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
566
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
567
- # corresponds to doing no classifier free guidance.
568
- @property
569
- def do_classifier_free_guidance(self):
570
- return self._guidance_scale > 1
571
-
572
- @property
573
- def num_timesteps(self):
574
- return self._num_timesteps
575
-
576
- @property
577
- def interrupt(self):
578
- return self._interrupt
579
-
580
- def enable_autocast_float8_transformer(self):
581
- self.enable_autocast_float8_transformer_flag = True
582
-
583
- @torch.no_grad()
584
- @replace_example_docstring(EXAMPLE_DOC_STRING)
585
- def __call__(
586
- self,
587
- prompt: Union[str, List[str]] = None,
588
- video_length: Optional[int] = None,
589
- height: Optional[int] = None,
590
- width: Optional[int] = None,
591
- num_inference_steps: Optional[int] = 50,
592
- guidance_scale: Optional[float] = 5.0,
593
- negative_prompt: Optional[Union[str, List[str]]] = None,
594
- num_images_per_prompt: Optional[int] = 1,
595
- eta: Optional[float] = 0.0,
596
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
597
- latents: Optional[torch.Tensor] = None,
598
- prompt_embeds: Optional[torch.Tensor] = None,
599
- prompt_embeds_2: Optional[torch.Tensor] = None,
600
- negative_prompt_embeds: Optional[torch.Tensor] = None,
601
- negative_prompt_embeds_2: Optional[torch.Tensor] = None,
602
- prompt_attention_mask: Optional[torch.Tensor] = None,
603
- prompt_attention_mask_2: Optional[torch.Tensor] = None,
604
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
605
- negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
606
- output_type: Optional[str] = "latent",
607
- return_dict: bool = True,
608
- callback_on_step_end: Optional[
609
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
610
- ] = None,
611
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
612
- guidance_rescale: float = 0.0,
613
- original_size: Optional[Tuple[int, int]] = (1024, 1024),
614
- target_size: Optional[Tuple[int, int]] = None,
615
- crops_coords_top_left: Tuple[int, int] = (0, 0),
616
- comfyui_progressbar: bool = False,
617
- ):
618
- r"""
619
- Generates images or video using the EasyAnimate pipeline based on the provided prompts.
620
-
621
- Examples:
622
- prompt (`str` or `List[str]`, *optional*):
623
- Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
624
- video_length (`int`, *optional*):
625
- Length of the generated video (in frames).
626
- height (`int`, *optional*):
627
- Height of the generated image in pixels.
628
- width (`int`, *optional*):
629
- Width of the generated image in pixels.
630
- num_inference_steps (`int`, *optional*, defaults to 50):
631
- Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
632
- guidance_scale (`float`, *optional*, defaults to 5.0):
633
- Encourages the model to align outputs with prompts. A higher value may decrease image quality.
634
- negative_prompt (`str` or `List[str]`, *optional*):
635
- Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
636
- num_images_per_prompt (`int`, *optional*, defaults to 1):
637
- Number of images to generate for each prompt.
638
- eta (`float`, *optional*, defaults to 0.0):
639
- Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
640
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
641
- A generator to ensure reproducibility in image generation.
642
- latents (`torch.Tensor`, *optional*):
643
- Predefined latent tensors to condition generation.
644
- prompt_embeds (`torch.Tensor`, *optional*):
645
- Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
646
- prompt_embeds_2 (`torch.Tensor`, *optional*):
647
- Secondary text embeddings to supplement or replace the initial prompt embeddings.
648
- negative_prompt_embeds (`torch.Tensor`, *optional*):
649
- Embeddings for negative prompts. Overrides string inputs if defined.
650
- negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
651
- Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
652
- prompt_attention_mask (`torch.Tensor`, *optional*):
653
- Attention mask for the primary prompt embeddings.
654
- prompt_attention_mask_2 (`torch.Tensor`, *optional*):
655
- Attention mask for the secondary prompt embeddings.
656
- negative_prompt_attention_mask (`torch.Tensor`, *optional*):
657
- Attention mask for negative prompt embeddings.
658
- negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
659
- Attention mask for secondary negative prompt embeddings.
660
- output_type (`str`, *optional*, defaults to "latent"):
661
- Format of the generated output, either as a PIL image or as a NumPy array.
662
- return_dict (`bool`, *optional*, defaults to `True`):
663
- If `True`, returns a structured output. Otherwise returns a simple tuple.
664
- callback_on_step_end (`Callable`, *optional*):
665
- Functions called at the end of each denoising step.
666
- callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
667
- Tensor names to be included in callback function calls.
668
- guidance_rescale (`float`, *optional*, defaults to 0.0):
669
- Adjusts noise levels based on guidance scale.
670
- original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
671
- Original dimensions of the output.
672
- target_size (`Tuple[int, int]`, *optional*):
673
- Desired output dimensions for calculations.
674
- crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
675
- Coordinates for cropping.
676
-
677
- Returns:
678
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
679
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
680
- otherwise a `tuple` is returned where the first element is a list with the generated images and the
681
- second element is a list of `bool`s indicating whether the corresponding generated image contains
682
- "not-safe-for-work" (nsfw) content.
683
- """
684
-
685
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
686
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
687
-
688
- # 0. default height and width
689
- height = int((height // 16) * 16)
690
- width = int((width // 16) * 16)
691
-
692
- # 1. Check inputs. Raise error if not correct
693
- self.check_inputs(
694
- prompt,
695
- height,
696
- width,
697
- negative_prompt,
698
- prompt_embeds,
699
- negative_prompt_embeds,
700
- prompt_attention_mask,
701
- negative_prompt_attention_mask,
702
- prompt_embeds_2,
703
- negative_prompt_embeds_2,
704
- prompt_attention_mask_2,
705
- negative_prompt_attention_mask_2,
706
- callback_on_step_end_tensor_inputs,
707
- )
708
- self._guidance_scale = guidance_scale
709
- self._guidance_rescale = guidance_rescale
710
- self._interrupt = False
711
-
712
- # 2. Define call parameters
713
- if prompt is not None and isinstance(prompt, str):
714
- batch_size = 1
715
- elif prompt is not None and isinstance(prompt, list):
716
- batch_size = len(prompt)
717
- else:
718
- batch_size = prompt_embeds.shape[0]
719
-
720
- device = self._execution_device
721
-
722
- # 3. Encode input prompt
723
- (
724
- prompt_embeds,
725
- negative_prompt_embeds,
726
- prompt_attention_mask,
727
- negative_prompt_attention_mask,
728
- ) = self.encode_prompt(
729
- prompt=prompt,
730
- device=device,
731
- dtype=self.transformer.dtype,
732
- num_images_per_prompt=num_images_per_prompt,
733
- do_classifier_free_guidance=self.do_classifier_free_guidance,
734
- negative_prompt=negative_prompt,
735
- prompt_embeds=prompt_embeds,
736
- negative_prompt_embeds=negative_prompt_embeds,
737
- prompt_attention_mask=prompt_attention_mask,
738
- negative_prompt_attention_mask=negative_prompt_attention_mask,
739
- text_encoder_index=0,
740
- )
741
- (
742
- prompt_embeds_2,
743
- negative_prompt_embeds_2,
744
- prompt_attention_mask_2,
745
- negative_prompt_attention_mask_2,
746
- ) = self.encode_prompt(
747
- prompt=prompt,
748
- device=device,
749
- dtype=self.transformer.dtype,
750
- num_images_per_prompt=num_images_per_prompt,
751
- do_classifier_free_guidance=self.do_classifier_free_guidance,
752
- negative_prompt=negative_prompt,
753
- prompt_embeds=prompt_embeds_2,
754
- negative_prompt_embeds=negative_prompt_embeds_2,
755
- prompt_attention_mask=prompt_attention_mask_2,
756
- negative_prompt_attention_mask=negative_prompt_attention_mask_2,
757
- text_encoder_index=1,
758
- )
759
- torch.cuda.empty_cache()
760
-
761
- # 4. Prepare timesteps
762
- self.scheduler.set_timesteps(num_inference_steps, device=device)
763
- timesteps = self.scheduler.timesteps
764
- if comfyui_progressbar:
765
- from comfy.utils import ProgressBar
766
- pbar = ProgressBar(num_inference_steps + 1)
767
-
768
- # 5. Prepare latent variables
769
- num_channels_latents = self.transformer.config.in_channels
770
- latents = self.prepare_latents(
771
- batch_size * num_images_per_prompt,
772
- num_channels_latents,
773
- video_length,
774
- height,
775
- width,
776
- prompt_embeds.dtype,
777
- device,
778
- generator,
779
- latents,
780
- )
781
- if comfyui_progressbar:
782
- pbar.update(1)
783
-
784
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
785
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
786
-
787
- # 7 create image_rotary_emb, style embedding & time ids
788
- grid_height = height // 8 // self.transformer.config.patch_size
789
- grid_width = width // 8 // self.transformer.config.patch_size
790
- if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
791
- base_size_width = 720 // 8 // self.transformer.config.patch_size
792
- base_size_height = 480 // 8 // self.transformer.config.patch_size
793
-
794
- grid_crops_coords = get_resize_crop_region_for_grid(
795
- (grid_height, grid_width), base_size_width, base_size_height
796
- )
797
- image_rotary_emb = get_3d_rotary_pos_embed(
798
- self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
799
- temporal_size=latents.size(2), use_real=True,
800
- )
801
- else:
802
- base_size = 512 // 8 // self.transformer.config.patch_size
803
- grid_crops_coords = get_resize_crop_region_for_grid(
804
- (grid_height, grid_width), base_size, base_size
805
- )
806
- image_rotary_emb = get_2d_rotary_pos_embed(
807
- self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
808
- )
809
-
810
- # Get other hunyuan params
811
- style = torch.tensor([0], device=device)
812
-
813
- target_size = target_size or (height, width)
814
- add_time_ids = list(original_size + target_size + crops_coords_top_left)
815
- add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
816
-
817
- if self.do_classifier_free_guidance:
818
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
819
- prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
820
- prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
821
- prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
822
- add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
823
- style = torch.cat([style] * 2, dim=0)
824
-
825
- # To latents.device
826
- prompt_embeds = prompt_embeds.to(device=device)
827
- prompt_attention_mask = prompt_attention_mask.to(device=device)
828
- prompt_embeds_2 = prompt_embeds_2.to(device=device)
829
- prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
830
- add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
831
- batch_size * num_images_per_prompt, 1
832
- )
833
- style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
834
-
835
- torch.cuda.empty_cache()
836
- if self.enable_autocast_float8_transformer_flag:
837
- origin_weight_dtype = self.transformer.dtype
838
- self.transformer = self.transformer.to(torch.float8_e4m3fn)
839
- # 8. Denoising loop
840
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
841
- self._num_timesteps = len(timesteps)
842
- with self.progress_bar(total=num_inference_steps) as progress_bar:
843
- for i, t in enumerate(timesteps):
844
- if self.interrupt:
845
- continue
846
-
847
- # expand the latents if we are doing classifier free guidance
848
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
849
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
850
-
851
- # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
852
- t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
853
- dtype=latent_model_input.dtype
854
- )
855
-
856
- # predict the noise residual
857
- noise_pred = self.transformer(
858
- latent_model_input,
859
- t_expand,
860
- encoder_hidden_states=prompt_embeds,
861
- text_embedding_mask=prompt_attention_mask,
862
- encoder_hidden_states_t5=prompt_embeds_2,
863
- text_embedding_mask_t5=prompt_attention_mask_2,
864
- image_meta_size=add_time_ids,
865
- style=style,
866
- image_rotary_emb=image_rotary_emb,
867
- return_dict=False,
868
- )[0]
869
-
870
- if noise_pred.size()[1] != self.vae.config.latent_channels:
871
- noise_pred, _ = noise_pred.chunk(2, dim=1)
872
-
873
- # perform guidance
874
- if self.do_classifier_free_guidance:
875
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
876
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
877
-
878
- if self.do_classifier_free_guidance and guidance_rescale > 0.0:
879
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
880
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
881
-
882
- # compute the previous noisy sample x_t -> x_t-1
883
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
884
-
885
- if callback_on_step_end is not None:
886
- callback_kwargs = {}
887
- for k in callback_on_step_end_tensor_inputs:
888
- callback_kwargs[k] = locals()[k]
889
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
890
-
891
- latents = callback_outputs.pop("latents", latents)
892
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
893
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
894
- prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
895
- negative_prompt_embeds_2 = callback_outputs.pop(
896
- "negative_prompt_embeds_2", negative_prompt_embeds_2
897
- )
898
-
899
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
900
- progress_bar.update()
901
-
902
- if XLA_AVAILABLE:
903
- xm.mark_step()
904
-
905
- if comfyui_progressbar:
906
- pbar.update(1)
907
-
908
- if self.enable_autocast_float8_transformer_flag:
909
- self.transformer = self.transformer.to("cpu", origin_weight_dtype)
910
-
911
- torch.cuda.empty_cache()
912
- # Post-processing
913
- video = self.decode_latents(latents)
914
-
915
- # Convert to tensor
916
- if output_type == "latent":
917
- video = torch.from_numpy(video)
918
-
919
- # Offload all models
920
- self.maybe_free_model_hooks()
921
-
922
- if not return_dict:
923
- return video
924
-
925
- return EasyAnimatePipelineOutput(videos=video)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py DELETED
@@ -1,1334 +0,0 @@
1
- # Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- from typing import Callable, Dict, List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn.functional as F
20
- from diffusers import DiffusionPipeline
21
- from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
22
- from diffusers.image_processor import VaeImageProcessor
23
- from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
24
- from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
25
- get_3d_rotary_pos_embed)
26
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
27
- from diffusers.pipelines.stable_diffusion.safety_checker import \
28
- StableDiffusionSafetyChecker
29
- from diffusers.schedulers import DDIMScheduler
30
- from diffusers.utils import (is_torch_xla_available, logging,
31
- replace_example_docstring)
32
- from diffusers.utils.torch_utils import randn_tensor
33
- from einops import rearrange
34
- from PIL import Image
35
- from tqdm import tqdm
36
- from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
37
- CLIPVisionModelWithProjection, T5Tokenizer,
38
- T5EncoderModel)
39
-
40
- from .pipeline_easyanimate import EasyAnimatePipelineOutput
41
- from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
42
-
43
- if is_torch_xla_available():
44
- import torch_xla.core.xla_model as xm
45
-
46
- XLA_AVAILABLE = True
47
- else:
48
- XLA_AVAILABLE = False
49
-
50
-
51
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
-
53
- EXAMPLE_DOC_STRING = """
54
- Examples:
55
- ```py
56
- >>> pass
57
- ```
58
- """
59
-
60
-
61
- def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
62
- tw = tgt_width
63
- th = tgt_height
64
- h, w = src
65
- r = h / w
66
- if r > (th / tw):
67
- resize_height = th
68
- resize_width = int(round(th / h * w))
69
- else:
70
- resize_width = tw
71
- resize_height = int(round(tw / w * h))
72
-
73
- crop_top = int(round((th - resize_height) / 2.0))
74
- crop_left = int(round((tw - resize_width) / 2.0))
75
-
76
- return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
77
-
78
-
79
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
80
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
81
- """
82
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
83
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
84
- """
85
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
86
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
87
- # rescale the results from guidance (fixes overexposure)
88
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
89
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
90
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
91
- return noise_cfg
92
-
93
-
94
- def resize_mask(mask, latent, process_first_frame_only=True):
95
- latent_size = latent.size()
96
-
97
- if process_first_frame_only:
98
- target_size = list(latent_size[2:])
99
- target_size[0] = 1
100
- first_frame_resized = F.interpolate(
101
- mask[:, :, 0:1, :, :],
102
- size=target_size,
103
- mode='trilinear',
104
- align_corners=False
105
- )
106
-
107
- target_size = list(latent_size[2:])
108
- target_size[0] = target_size[0] - 1
109
- if target_size[0] != 0:
110
- remaining_frames_resized = F.interpolate(
111
- mask[:, :, 1:, :, :],
112
- size=target_size,
113
- mode='trilinear',
114
- align_corners=False
115
- )
116
- resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
117
- else:
118
- resized_mask = first_frame_resized
119
- else:
120
- target_size = list(latent_size[2:])
121
- resized_mask = F.interpolate(
122
- mask,
123
- size=target_size,
124
- mode='trilinear',
125
- align_corners=False
126
- )
127
- return resized_mask
128
-
129
-
130
- def add_noise_to_reference_video(image, ratio=None):
131
- if ratio is None:
132
- sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
133
- sigma = torch.exp(sigma).to(image.dtype)
134
- else:
135
- sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
136
-
137
- image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
138
- image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
139
- image = image + image_noise
140
- return image
141
-
142
-
143
- class EasyAnimatePipeline_Multi_Text_Encoder_Inpaint(DiffusionPipeline):
144
- r"""
145
- Pipeline for text-to-video generation using EasyAnimate.
146
-
147
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
148
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
149
-
150
- EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
151
- HunyuanDiT team)
152
-
153
- Args:
154
- vae ([`AutoencoderKLMagvit`]):
155
- Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
156
- text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
157
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
158
- EasyAnimate uses a fine-tuned [bilingual CLIP].
159
- tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
160
- A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
161
- transformer ([`EasyAnimateTransformer3DModel`]):
162
- The EasyAnimate model designed by Tencent Hunyuan.
163
- text_encoder_2 (`T5EncoderModel`):
164
- The mT5 embedder.
165
- tokenizer_2 (`T5Tokenizer`):
166
- The tokenizer for the mT5 embedder.
167
- scheduler ([`DDIMScheduler`]):
168
- A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
169
- clip_image_processor (`CLIPImageProcessor`):
170
- The CLIP image embedder.
171
- clip_image_encoder (`CLIPVisionModelWithProjection`):
172
- The image processor for the CLIP image embedder.
173
- """
174
-
175
- model_cpu_offload_seq = "text_encoder->text_encoder_2->clip_image_encoder->transformer->vae"
176
- _optional_components = [
177
- "safety_checker",
178
- "feature_extractor",
179
- "text_encoder_2",
180
- "tokenizer_2",
181
- "text_encoder",
182
- "tokenizer",
183
- "clip_image_encoder",
184
- ]
185
- _exclude_from_cpu_offload = ["safety_checker"]
186
- _callback_tensor_inputs = [
187
- "latents",
188
- "prompt_embeds",
189
- "negative_prompt_embeds",
190
- "prompt_embeds_2",
191
- "negative_prompt_embeds_2",
192
- ]
193
-
194
- def __init__(
195
- self,
196
- vae: AutoencoderKLMagvit,
197
- text_encoder: BertModel,
198
- tokenizer: BertTokenizer,
199
- text_encoder_2: T5EncoderModel,
200
- tokenizer_2: T5Tokenizer,
201
- transformer: EasyAnimateTransformer3DModel,
202
- scheduler: DDIMScheduler,
203
- safety_checker: StableDiffusionSafetyChecker,
204
- feature_extractor: CLIPImageProcessor,
205
- requires_safety_checker: bool = True,
206
- clip_image_processor: CLIPImageProcessor = None,
207
- clip_image_encoder: CLIPVisionModelWithProjection = None,
208
- ):
209
- super().__init__()
210
-
211
- self.register_modules(
212
- vae=vae,
213
- text_encoder=text_encoder,
214
- tokenizer=tokenizer,
215
- tokenizer_2=tokenizer_2,
216
- transformer=transformer,
217
- scheduler=scheduler,
218
- safety_checker=safety_checker,
219
- feature_extractor=feature_extractor,
220
- text_encoder_2=text_encoder_2,
221
- clip_image_processor=clip_image_processor,
222
- clip_image_encoder=clip_image_encoder,
223
- )
224
-
225
- if safety_checker is None and requires_safety_checker:
226
- logger.warning(
227
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
228
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
229
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
230
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
231
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
232
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
233
- )
234
-
235
- if safety_checker is not None and feature_extractor is None:
236
- raise ValueError(
237
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
238
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
239
- )
240
-
241
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
242
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
243
- self.mask_processor = VaeImageProcessor(
244
- vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
245
- )
246
- self.enable_autocast_float8_transformer_flag = False
247
- self.register_to_config(requires_safety_checker=requires_safety_checker)
248
-
249
- def enable_sequential_cpu_offload(self, *args, **kwargs):
250
- super().enable_sequential_cpu_offload(*args, **kwargs)
251
- if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
252
- import accelerate
253
- accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
254
- self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
255
-
256
- def encode_prompt(
257
- self,
258
- prompt: str,
259
- device: torch.device,
260
- dtype: torch.dtype,
261
- num_images_per_prompt: int = 1,
262
- do_classifier_free_guidance: bool = True,
263
- negative_prompt: Optional[str] = None,
264
- prompt_embeds: Optional[torch.Tensor] = None,
265
- negative_prompt_embeds: Optional[torch.Tensor] = None,
266
- prompt_attention_mask: Optional[torch.Tensor] = None,
267
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
268
- max_sequence_length: Optional[int] = None,
269
- text_encoder_index: int = 0,
270
- actual_max_sequence_length: int = 256
271
- ):
272
- r"""
273
- Encodes the prompt into text encoder hidden states.
274
-
275
- Args:
276
- prompt (`str` or `List[str]`, *optional*):
277
- prompt to be encoded
278
- device: (`torch.device`):
279
- torch device
280
- dtype (`torch.dtype`):
281
- torch dtype
282
- num_images_per_prompt (`int`):
283
- number of images that should be generated per prompt
284
- do_classifier_free_guidance (`bool`):
285
- whether to use classifier free guidance or not
286
- negative_prompt (`str` or `List[str]`, *optional*):
287
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
288
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
289
- less than `1`).
290
- prompt_embeds (`torch.Tensor`, *optional*):
291
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
292
- provided, text embeddings will be generated from `prompt` input argument.
293
- negative_prompt_embeds (`torch.Tensor`, *optional*):
294
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
295
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
296
- argument.
297
- prompt_attention_mask (`torch.Tensor`, *optional*):
298
- Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
299
- negative_prompt_attention_mask (`torch.Tensor`, *optional*):
300
- Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
301
- max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
302
- text_encoder_index (`int`, *optional*):
303
- Index of the text encoder to use. `0` for clip and `1` for T5.
304
- """
305
- tokenizers = [self.tokenizer, self.tokenizer_2]
306
- text_encoders = [self.text_encoder, self.text_encoder_2]
307
-
308
- tokenizer = tokenizers[text_encoder_index]
309
- text_encoder = text_encoders[text_encoder_index]
310
-
311
- if max_sequence_length is None:
312
- if text_encoder_index == 0:
313
- max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
314
- if text_encoder_index == 1:
315
- max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
316
- else:
317
- max_length = max_sequence_length
318
-
319
- if prompt is not None and isinstance(prompt, str):
320
- batch_size = 1
321
- elif prompt is not None and isinstance(prompt, list):
322
- batch_size = len(prompt)
323
- else:
324
- batch_size = prompt_embeds.shape[0]
325
-
326
- if prompt_embeds is None:
327
- text_inputs = tokenizer(
328
- prompt,
329
- padding="max_length",
330
- max_length=max_length,
331
- truncation=True,
332
- return_attention_mask=True,
333
- return_tensors="pt",
334
- )
335
- text_input_ids = text_inputs.input_ids
336
- if text_input_ids.shape[-1] > actual_max_sequence_length:
337
- reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
338
- text_inputs = tokenizer(
339
- reprompt,
340
- padding="max_length",
341
- max_length=max_length,
342
- truncation=True,
343
- return_attention_mask=True,
344
- return_tensors="pt",
345
- )
346
- text_input_ids = text_inputs.input_ids
347
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
348
-
349
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
350
- text_input_ids, untruncated_ids
351
- ):
352
- _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
353
- removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
354
- logger.warning(
355
- "The following part of your input was truncated because CLIP can only handle sequences up to"
356
- f" {_actual_max_sequence_length} tokens: {removed_text}"
357
- )
358
- prompt_attention_mask = text_inputs.attention_mask.to(device)
359
- if self.transformer.config.enable_text_attention_mask:
360
- prompt_embeds = text_encoder(
361
- text_input_ids.to(device),
362
- attention_mask=prompt_attention_mask,
363
- )
364
- else:
365
- prompt_embeds = text_encoder(
366
- text_input_ids.to(device)
367
- )
368
- prompt_embeds = prompt_embeds[0]
369
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
370
-
371
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
372
-
373
- bs_embed, seq_len, _ = prompt_embeds.shape
374
- # duplicate text embeddings for each generation per prompt, using mps friendly method
375
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
376
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
377
-
378
- # get unconditional embeddings for classifier free guidance
379
- if do_classifier_free_guidance and negative_prompt_embeds is None:
380
- uncond_tokens: List[str]
381
- if negative_prompt is None:
382
- uncond_tokens = [""] * batch_size
383
- elif prompt is not None and type(prompt) is not type(negative_prompt):
384
- raise TypeError(
385
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
386
- f" {type(prompt)}."
387
- )
388
- elif isinstance(negative_prompt, str):
389
- uncond_tokens = [negative_prompt]
390
- elif batch_size != len(negative_prompt):
391
- raise ValueError(
392
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
393
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
394
- " the batch size of `prompt`."
395
- )
396
- else:
397
- uncond_tokens = negative_prompt
398
-
399
- max_length = prompt_embeds.shape[1]
400
- uncond_input = tokenizer(
401
- uncond_tokens,
402
- padding="max_length",
403
- max_length=max_length,
404
- truncation=True,
405
- return_tensors="pt",
406
- )
407
- uncond_input_ids = uncond_input.input_ids
408
- if uncond_input_ids.shape[-1] > actual_max_sequence_length:
409
- reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
410
- uncond_input = tokenizer(
411
- reuncond_tokens,
412
- padding="max_length",
413
- max_length=max_length,
414
- truncation=True,
415
- return_attention_mask=True,
416
- return_tensors="pt",
417
- )
418
- uncond_input_ids = uncond_input.input_ids
419
-
420
- negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
421
- if self.transformer.config.enable_text_attention_mask:
422
- negative_prompt_embeds = text_encoder(
423
- uncond_input.input_ids.to(device),
424
- attention_mask=negative_prompt_attention_mask,
425
- )
426
- else:
427
- negative_prompt_embeds = text_encoder(
428
- uncond_input.input_ids.to(device)
429
- )
430
- negative_prompt_embeds = negative_prompt_embeds[0]
431
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
432
-
433
- if do_classifier_free_guidance:
434
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
435
- seq_len = negative_prompt_embeds.shape[1]
436
-
437
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
438
-
439
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
440
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
441
-
442
- return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
443
-
444
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
445
- def run_safety_checker(self, image, device, dtype):
446
- if self.safety_checker is None:
447
- has_nsfw_concept = None
448
- else:
449
- if torch.is_tensor(image):
450
- feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
451
- else:
452
- feature_extractor_input = self.image_processor.numpy_to_pil(image)
453
- safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
454
- image, has_nsfw_concept = self.safety_checker(
455
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
456
- )
457
- return image, has_nsfw_concept
458
-
459
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
460
- def prepare_extra_step_kwargs(self, generator, eta):
461
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
462
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
463
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
464
- # and should be between [0, 1]
465
-
466
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
467
- extra_step_kwargs = {}
468
- if accepts_eta:
469
- extra_step_kwargs["eta"] = eta
470
-
471
- # check if the scheduler accepts generator
472
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
473
- if accepts_generator:
474
- extra_step_kwargs["generator"] = generator
475
- return extra_step_kwargs
476
-
477
- def check_inputs(
478
- self,
479
- prompt,
480
- height,
481
- width,
482
- negative_prompt=None,
483
- prompt_embeds=None,
484
- negative_prompt_embeds=None,
485
- prompt_attention_mask=None,
486
- negative_prompt_attention_mask=None,
487
- prompt_embeds_2=None,
488
- negative_prompt_embeds_2=None,
489
- prompt_attention_mask_2=None,
490
- negative_prompt_attention_mask_2=None,
491
- callback_on_step_end_tensor_inputs=None,
492
- ):
493
- if height % 8 != 0 or width % 8 != 0:
494
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
495
-
496
- if callback_on_step_end_tensor_inputs is not None and not all(
497
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
498
- ):
499
- raise ValueError(
500
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
501
- )
502
-
503
- if prompt is not None and prompt_embeds is not None:
504
- raise ValueError(
505
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
506
- " only forward one of the two."
507
- )
508
- elif prompt is None and prompt_embeds is None:
509
- raise ValueError(
510
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
511
- )
512
- elif prompt is None and prompt_embeds_2 is None:
513
- raise ValueError(
514
- "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
515
- )
516
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
517
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
518
-
519
- if prompt_embeds is not None and prompt_attention_mask is None:
520
- raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
521
-
522
- if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
523
- raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
524
-
525
- if negative_prompt is not None and negative_prompt_embeds is not None:
526
- raise ValueError(
527
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
528
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
529
- )
530
-
531
- if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
532
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
533
-
534
- if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
535
- raise ValueError(
536
- "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
537
- )
538
- if prompt_embeds is not None and negative_prompt_embeds is not None:
539
- if prompt_embeds.shape != negative_prompt_embeds.shape:
540
- raise ValueError(
541
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
542
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
543
- f" {negative_prompt_embeds.shape}."
544
- )
545
- if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
546
- if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
547
- raise ValueError(
548
- "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
549
- f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
550
- f" {negative_prompt_embeds_2.shape}."
551
- )
552
-
553
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
554
- def get_timesteps(self, num_inference_steps, strength, device):
555
- # get the original timestep using init_timestep
556
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
557
-
558
- t_start = max(num_inference_steps - init_timestep, 0)
559
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
560
-
561
- return timesteps, num_inference_steps - t_start
562
-
563
- def prepare_mask_latents(
564
- self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
565
- ):
566
- # resize the mask to latents shape as we concatenate the mask to the latents
567
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
568
- # and half precision
569
- if mask is not None:
570
- mask = mask.to(device=device, dtype=self.vae.dtype)
571
- if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
572
- bs = 1
573
- new_mask = []
574
- for i in range(0, mask.shape[0], bs):
575
- mask_bs = mask[i : i + bs]
576
- mask_bs = self.vae.encode(mask_bs)[0]
577
- mask_bs = mask_bs.mode()
578
- new_mask.append(mask_bs)
579
- mask = torch.cat(new_mask, dim = 0)
580
- mask = mask * self.vae.config.scaling_factor
581
-
582
- else:
583
- if mask.shape[1] == 4:
584
- mask = mask
585
- else:
586
- video_length = mask.shape[2]
587
- mask = rearrange(mask, "b c f h w -> (b f) c h w")
588
- mask = self._encode_vae_image(mask, generator=generator)
589
- mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
590
-
591
- if masked_image is not None:
592
- masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
593
- if self.transformer.config.add_noise_in_inpaint_model:
594
- masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
595
- if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
596
- bs = 1
597
- new_mask_pixel_values = []
598
- for i in range(0, masked_image.shape[0], bs):
599
- mask_pixel_values_bs = masked_image[i : i + bs]
600
- mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
601
- mask_pixel_values_bs = mask_pixel_values_bs.mode()
602
- new_mask_pixel_values.append(mask_pixel_values_bs)
603
- masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
604
- masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
605
-
606
- else:
607
- if masked_image.shape[1] == 4:
608
- masked_image_latents = masked_image
609
- else:
610
- video_length = masked_image.shape[2]
611
- masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
612
- masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
613
- masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
614
-
615
- # aligning device to prevent device errors when concating it with the latent model input
616
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
617
- else:
618
- masked_image_latents = None
619
-
620
- return mask, masked_image_latents
621
-
622
- def prepare_latents(
623
- self,
624
- batch_size,
625
- num_channels_latents,
626
- height,
627
- width,
628
- video_length,
629
- dtype,
630
- device,
631
- generator,
632
- latents=None,
633
- video=None,
634
- timestep=None,
635
- is_strength_max=True,
636
- return_noise=False,
637
- return_video_latents=False,
638
- ):
639
- if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
640
- if self.vae.cache_mag_vae:
641
- mini_batch_encoder = self.vae.mini_batch_encoder
642
- mini_batch_decoder = self.vae.mini_batch_decoder
643
- shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
644
- else:
645
- mini_batch_encoder = self.vae.mini_batch_encoder
646
- mini_batch_decoder = self.vae.mini_batch_decoder
647
- shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
648
- else:
649
- shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
650
-
651
- if isinstance(generator, list) and len(generator) != batch_size:
652
- raise ValueError(
653
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
654
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
655
- )
656
-
657
- if return_video_latents or (latents is None and not is_strength_max):
658
- video = video.to(device=device, dtype=self.vae.dtype)
659
- if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
660
- bs = 1
661
- new_video = []
662
- for i in range(0, video.shape[0], bs):
663
- video_bs = video[i : i + bs]
664
- video_bs = self.vae.encode(video_bs)[0]
665
- video_bs = video_bs.sample()
666
- new_video.append(video_bs)
667
- video = torch.cat(new_video, dim = 0)
668
- video = video * self.vae.config.scaling_factor
669
-
670
- else:
671
- if video.shape[1] == 4:
672
- video = video
673
- else:
674
- video_length = video.shape[2]
675
- video = rearrange(video, "b c f h w -> (b f) c h w")
676
- video = self._encode_vae_image(video, generator=generator)
677
- video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
678
- video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
679
- video_latents = video_latents.to(device=device, dtype=dtype)
680
-
681
- if latents is None:
682
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
683
- # if strength is 1. then initialise the latents to noise, else initial to image + noise
684
- latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
685
- # if pure noise then scale the initial latents by the Scheduler's init sigma
686
- latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
687
- else:
688
- noise = latents.to(device)
689
- latents = noise * self.scheduler.init_noise_sigma
690
-
691
- # scale the initial noise by the standard deviation required by the scheduler
692
- outputs = (latents,)
693
-
694
- if return_noise:
695
- outputs += (noise,)
696
-
697
- if return_video_latents:
698
- outputs += (video_latents,)
699
-
700
- return outputs
701
-
702
- def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
703
- if video.size()[2] <= mini_batch_encoder:
704
- return video
705
- prefix_index_before = mini_batch_encoder // 2
706
- prefix_index_after = mini_batch_encoder - prefix_index_before
707
- pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
708
-
709
- # Encode middle videos
710
- latents = self.vae.encode(pixel_values)[0]
711
- latents = latents.mode()
712
- # Decode middle videos
713
- middle_video = self.vae.decode(latents)[0]
714
-
715
- video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
716
- return video
717
-
718
- def decode_latents(self, latents):
719
- video_length = latents.shape[2]
720
- latents = 1 / self.vae.config.scaling_factor * latents
721
- if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
722
- mini_batch_encoder = self.vae.mini_batch_encoder
723
- mini_batch_decoder = self.vae.mini_batch_decoder
724
- video = self.vae.decode(latents)[0]
725
- video = video.clamp(-1, 1)
726
- if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
727
- video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
728
- else:
729
- latents = rearrange(latents, "b c f h w -> (b f) c h w")
730
- video = []
731
- for frame_idx in tqdm(range(latents.shape[0])):
732
- video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
733
- video = torch.cat(video)
734
- video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
735
- video = (video / 2 + 0.5).clamp(0, 1)
736
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
737
- video = video.cpu().float().numpy()
738
- return video
739
-
740
- @property
741
- def guidance_scale(self):
742
- return self._guidance_scale
743
-
744
- @property
745
- def guidance_rescale(self):
746
- return self._guidance_rescale
747
-
748
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
749
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
750
- # corresponds to doing no classifier free guidance.
751
- @property
752
- def do_classifier_free_guidance(self):
753
- return self._guidance_scale > 1
754
-
755
- @property
756
- def num_timesteps(self):
757
- return self._num_timesteps
758
-
759
- @property
760
- def interrupt(self):
761
- return self._interrupt
762
-
763
- def enable_autocast_float8_transformer(self):
764
- self.enable_autocast_float8_transformer_flag = True
765
-
766
- @torch.no_grad()
767
- @replace_example_docstring(EXAMPLE_DOC_STRING)
768
- def __call__(
769
- self,
770
- prompt: Union[str, List[str]] = None,
771
- video_length: Optional[int] = None,
772
- video: Union[torch.FloatTensor] = None,
773
- mask_video: Union[torch.FloatTensor] = None,
774
- masked_video_latents: Union[torch.FloatTensor] = None,
775
- height: Optional[int] = None,
776
- width: Optional[int] = None,
777
- num_inference_steps: Optional[int] = 50,
778
- guidance_scale: Optional[float] = 5.0,
779
- negative_prompt: Optional[Union[str, List[str]]] = None,
780
- num_images_per_prompt: Optional[int] = 1,
781
- eta: Optional[float] = 0.0,
782
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
783
- latents: Optional[torch.Tensor] = None,
784
- prompt_embeds: Optional[torch.Tensor] = None,
785
- prompt_embeds_2: Optional[torch.Tensor] = None,
786
- negative_prompt_embeds: Optional[torch.Tensor] = None,
787
- negative_prompt_embeds_2: Optional[torch.Tensor] = None,
788
- prompt_attention_mask: Optional[torch.Tensor] = None,
789
- prompt_attention_mask_2: Optional[torch.Tensor] = None,
790
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
791
- negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
792
- output_type: Optional[str] = "latent",
793
- return_dict: bool = True,
794
- callback_on_step_end: Optional[
795
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
796
- ] = None,
797
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
798
- guidance_rescale: float = 0.0,
799
- original_size: Optional[Tuple[int, int]] = (1024, 1024),
800
- target_size: Optional[Tuple[int, int]] = None,
801
- crops_coords_top_left: Tuple[int, int] = (0, 0),
802
- clip_image: Image = None,
803
- clip_apply_ratio: float = 0.40,
804
- strength: float = 1.0,
805
- noise_aug_strength: float = 0.0563,
806
- comfyui_progressbar: bool = False,
807
- ):
808
- r"""
809
- The call function to the pipeline for generation with HunyuanDiT.
810
-
811
- Examples:
812
- prompt (`str` or `List[str]`, *optional*):
813
- The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
814
- video_length (`int`, *optional*):
815
- Length of the video to be generated in seconds. This parameter influences the number of frames and
816
- continuity of generated content.
817
- video (`torch.FloatTensor`, *optional*):
818
- A tensor representing an input video, which can be modified depending on the prompts provided.
819
- mask_video (`torch.FloatTensor`, *optional*):
820
- A tensor to specify areas of the video to be masked (omitted from generation).
821
- masked_video_latents (`torch.FloatTensor`, *optional*):
822
- Latents from masked portions of the video, utilized during image generation.
823
- height (`int`, *optional*):
824
- The height in pixels of the generated image or video frames.
825
- width (`int`, *optional*):
826
- The width in pixels of the generated image or video frames.
827
- num_inference_steps (`int`, *optional*, defaults to 50):
828
- The number of denoising steps. More denoising steps usually lead to a higher quality image but slower
829
- inference time. This parameter is modulated by `strength`.
830
- guidance_scale (`float`, *optional*, defaults to 5.0):
831
- A higher guidance scale value encourages the model to generate images closely linked to the text
832
- `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`.
833
- negative_prompt (`str` or `List[str]`, *optional*):
834
- The prompt or prompts to guide what to exclude in image generation. If not defined, you need to
835
- provide `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`).
836
- num_images_per_prompt (`int`, *optional*, defaults to 1):
837
- The number of images to generate per prompt.
838
- eta (`float`, *optional*, defaults to 0.0):
839
- A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the
840
- [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the
841
- inference process.
842
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
843
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting
844
- random seeds which helps in making generation deterministic.
845
- latents (`torch.Tensor`, *optional*):
846
- A pre-computed latent representation which can be used to guide the generation process.
847
- prompt_embeds (`torch.Tensor`, *optional*):
848
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
849
- provided, embeddings are generated from the `prompt` input argument.
850
- prompt_embeds_2 (`torch.Tensor`, *optional*):
851
- Secondary set of pre-generated text embeddings, useful for advanced prompt weighting.
852
- negative_prompt_embeds (`torch.Tensor`, *optional*):
853
- Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the outputs.
854
- If not provided, embeddings are generated from the `negative_prompt` argument.
855
- negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
856
- Secondary set of pre-generated negative text embeddings for further control.
857
- prompt_attention_mask (`torch.Tensor`, *optional*):
858
- Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using
859
- `prompt_embeds`.
860
- prompt_attention_mask_2 (`torch.Tensor`, *optional*):
861
- Attention mask for the secondary prompt embedding.
862
- negative_prompt_attention_mask (`torch.Tensor`, *optional*):
863
- Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used.
864
- negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
865
- Attention mask for the secondary negative prompt embedding.
866
- output_type (`str`, *optional*, defaults to `"latent"`):
867
- The output format of the generated image. Choose between `PIL.Image` and `np.array` to define
868
- how you want the results to be formatted.
869
- return_dict (`bool`, *optional*, defaults to `True`):
870
- If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned;
871
- otherwise, a tuple containing the generated images and safety flags will be returned.
872
- callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
873
- A callback function (or a list of them) that will be executed at the end of each denoising step,
874
- allowing for custom processing during generation.
875
- callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
876
- Specifies which tensor inputs should be included in the callback function. If not defined, all tensor
877
- inputs will be passed, facilitating enhanced logging or monitoring of the generation process.
878
- guidance_rescale (`float`, *optional*, defaults to 0.0):
879
- Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from
880
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
881
- original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
882
- The original dimensions of the image. Used to compute time ids during the generation process.
883
- target_size (`Tuple[int, int]`, *optional*):
884
- The targeted dimensions of the generated image, also utilized in the time id calculations.
885
- crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
886
- Coordinates defining the top left corner of any cropping, utilized while calculating the time ids.
887
- clip_image (`Image`, *optional*):
888
- An optional image to assist in the generation process. It may be used as an additional visual cue.
889
- clip_apply_ratio (`float`, *optional*, defaults to 0.40):
890
- Ratio indicating how much influence the clip image should exert over the generated content.
891
- strength (`float`, *optional*, defaults to 1.0):
892
- Affects the overall styling or quality of the generated output. Values closer to 1 usually provide direct
893
- adherence to prompts.
894
- comfyui_progressbar (`bool`, *optional*, defaults to `False`):
895
- Enables a progress bar in ComfyUI, providing visual feedback during the generation process.
896
-
897
- Examples:
898
- # Example usage of the function for generating images based on prompts.
899
-
900
- Returns:
901
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
902
- Returns either a structured output containing generated images and their metadata when `return_dict` is
903
- `True`, or a simpler tuple, where the first element is a list of generated images and the second
904
- element indicates if any of them contain "not-safe-for-work" (NSFW) content.
905
- """
906
-
907
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
908
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
909
-
910
- # 0. default height and width
911
- height = int(height // 16 * 16)
912
- width = int(width // 16 * 16)
913
-
914
- # 1. Check inputs. Raise error if not correct
915
- self.check_inputs(
916
- prompt,
917
- height,
918
- width,
919
- negative_prompt,
920
- prompt_embeds,
921
- negative_prompt_embeds,
922
- prompt_attention_mask,
923
- negative_prompt_attention_mask,
924
- prompt_embeds_2,
925
- negative_prompt_embeds_2,
926
- prompt_attention_mask_2,
927
- negative_prompt_attention_mask_2,
928
- callback_on_step_end_tensor_inputs,
929
- )
930
- self._guidance_scale = guidance_scale
931
- self._guidance_rescale = guidance_rescale
932
- self._interrupt = False
933
-
934
- # 2. Define call parameters
935
- if prompt is not None and isinstance(prompt, str):
936
- batch_size = 1
937
- elif prompt is not None and isinstance(prompt, list):
938
- batch_size = len(prompt)
939
- else:
940
- batch_size = prompt_embeds.shape[0]
941
-
942
- device = self._execution_device
943
-
944
- # 3. Encode input prompt
945
- (
946
- prompt_embeds,
947
- negative_prompt_embeds,
948
- prompt_attention_mask,
949
- negative_prompt_attention_mask,
950
- ) = self.encode_prompt(
951
- prompt=prompt,
952
- device=device,
953
- dtype=self.transformer.dtype,
954
- num_images_per_prompt=num_images_per_prompt,
955
- do_classifier_free_guidance=self.do_classifier_free_guidance,
956
- negative_prompt=negative_prompt,
957
- prompt_embeds=prompt_embeds,
958
- negative_prompt_embeds=negative_prompt_embeds,
959
- prompt_attention_mask=prompt_attention_mask,
960
- negative_prompt_attention_mask=negative_prompt_attention_mask,
961
- text_encoder_index=0,
962
- )
963
- (
964
- prompt_embeds_2,
965
- negative_prompt_embeds_2,
966
- prompt_attention_mask_2,
967
- negative_prompt_attention_mask_2,
968
- ) = self.encode_prompt(
969
- prompt=prompt,
970
- device=device,
971
- dtype=self.transformer.dtype,
972
- num_images_per_prompt=num_images_per_prompt,
973
- do_classifier_free_guidance=self.do_classifier_free_guidance,
974
- negative_prompt=negative_prompt,
975
- prompt_embeds=prompt_embeds_2,
976
- negative_prompt_embeds=negative_prompt_embeds_2,
977
- prompt_attention_mask=prompt_attention_mask_2,
978
- negative_prompt_attention_mask=negative_prompt_attention_mask_2,
979
- text_encoder_index=1,
980
- )
981
- torch.cuda.empty_cache()
982
-
983
- # 4. set timesteps
984
- self.scheduler.set_timesteps(num_inference_steps, device=device)
985
- timesteps, num_inference_steps = self.get_timesteps(
986
- num_inference_steps=num_inference_steps, strength=strength, device=device
987
- )
988
- if comfyui_progressbar:
989
- from comfy.utils import ProgressBar
990
- pbar = ProgressBar(num_inference_steps + 3)
991
- # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
992
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
993
- # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
994
- is_strength_max = strength == 1.0
995
-
996
- if video is not None:
997
- video_length = video.shape[2]
998
- init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
999
- init_video = init_video.to(dtype=torch.float32)
1000
- init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
1001
- else:
1002
- init_video = None
1003
-
1004
- # Prepare latent variables
1005
- num_channels_latents = self.vae.config.latent_channels
1006
- num_channels_transformer = self.transformer.config.in_channels
1007
- return_image_latents = num_channels_transformer == num_channels_latents
1008
-
1009
- # 5. Prepare latents.
1010
- latents_outputs = self.prepare_latents(
1011
- batch_size * num_images_per_prompt,
1012
- num_channels_latents,
1013
- height,
1014
- width,
1015
- video_length,
1016
- prompt_embeds.dtype,
1017
- device,
1018
- generator,
1019
- latents,
1020
- video=init_video,
1021
- timestep=latent_timestep,
1022
- is_strength_max=is_strength_max,
1023
- return_noise=True,
1024
- return_video_latents=return_image_latents,
1025
- )
1026
- if return_image_latents:
1027
- latents, noise, image_latents = latents_outputs
1028
- else:
1029
- latents, noise = latents_outputs
1030
-
1031
- if comfyui_progressbar:
1032
- pbar.update(1)
1033
-
1034
- # 6. Prepare clip latents if it needs.
1035
- if clip_image is not None and self.transformer.enable_clip_in_inpaint:
1036
- inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
1037
- inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
1038
- clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:]
1039
- clip_encoder_hidden_states_neg = torch.zeros(
1040
- [
1041
- batch_size,
1042
- int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
1043
- int(self.clip_image_encoder.config.hidden_size)
1044
- ]
1045
- ).to(latents.device, dtype=latents.dtype)
1046
-
1047
- clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
1048
- clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
1049
-
1050
- clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states
1051
- clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask
1052
-
1053
- elif clip_image is None and num_channels_transformer != num_channels_latents and self.transformer.enable_clip_in_inpaint:
1054
- clip_encoder_hidden_states = torch.zeros(
1055
- [
1056
- batch_size,
1057
- int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
1058
- int(self.clip_image_encoder.config.hidden_size)
1059
- ]
1060
- ).to(latents.device, dtype=latents.dtype)
1061
-
1062
- clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query])
1063
- clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
1064
-
1065
- clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states
1066
- clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask
1067
-
1068
- else:
1069
- clip_encoder_hidden_states_input = None
1070
- clip_attention_mask_input = None
1071
- if comfyui_progressbar:
1072
- pbar.update(1)
1073
-
1074
- # 7. Prepare inpaint latents if it needs.
1075
- if mask_video is not None:
1076
- if (mask_video == 255).all():
1077
- # Use zero latents if we want to t2v.
1078
- if self.transformer.resize_inpaint_mask_directly:
1079
- mask_latents = torch.zeros_like(latents)[:, :1].to(latents.device, latents.dtype)
1080
- else:
1081
- mask_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1082
- masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1083
-
1084
- mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
1085
- masked_video_latents_input = (
1086
- torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
1087
- )
1088
- inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1089
- else:
1090
- # Prepare mask latent variables
1091
- video_length = video.shape[2]
1092
- mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
1093
- mask_condition = mask_condition.to(dtype=torch.float32)
1094
- mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
1095
-
1096
- if num_channels_transformer != num_channels_latents:
1097
- mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
1098
- if masked_video_latents is None:
1099
- masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
1100
- else:
1101
- masked_video = masked_video_latents
1102
-
1103
- if self.transformer.resize_inpaint_mask_directly:
1104
- _, masked_video_latents = self.prepare_mask_latents(
1105
- None,
1106
- masked_video,
1107
- batch_size,
1108
- height,
1109
- width,
1110
- prompt_embeds.dtype,
1111
- device,
1112
- generator,
1113
- self.do_classifier_free_guidance,
1114
- noise_aug_strength=noise_aug_strength,
1115
- )
1116
- mask_latents = resize_mask(1 - mask_condition, masked_video_latents, self.vae.cache_mag_vae)
1117
- mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
1118
- else:
1119
- mask_latents, masked_video_latents = self.prepare_mask_latents(
1120
- mask_condition_tile,
1121
- masked_video,
1122
- batch_size,
1123
- height,
1124
- width,
1125
- prompt_embeds.dtype,
1126
- device,
1127
- generator,
1128
- self.do_classifier_free_guidance,
1129
- noise_aug_strength=noise_aug_strength,
1130
- )
1131
-
1132
- mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
1133
- masked_video_latents_input = (
1134
- torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
1135
- )
1136
- inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1137
- else:
1138
- inpaint_latents = None
1139
-
1140
- mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1141
- mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1142
- else:
1143
- if num_channels_transformer != num_channels_latents:
1144
- mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
1145
- masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1146
-
1147
- mask_input = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask
1148
- masked_video_latents_input = (
1149
- torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
1150
- )
1151
- inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1152
- else:
1153
- mask = torch.zeros_like(init_video[:, :1])
1154
- mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
1155
- mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1156
-
1157
- inpaint_latents = None
1158
- if comfyui_progressbar:
1159
- pbar.update(1)
1160
-
1161
- # Check that sizes of mask, masked image and latents match
1162
- if num_channels_transformer != num_channels_latents:
1163
- num_channels_mask = mask_latents.shape[1]
1164
- num_channels_masked_image = masked_video_latents.shape[1]
1165
- if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
1166
- raise ValueError(
1167
- f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
1168
- f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1169
- f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1170
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1171
- " `pipeline.transformer` or your `mask_image` or `image` input."
1172
- )
1173
-
1174
- # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1175
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1176
-
1177
- # 9 create image_rotary_emb, style embedding & time ids
1178
- grid_height = height // 8 // self.transformer.config.patch_size
1179
- grid_width = width // 8 // self.transformer.config.patch_size
1180
- if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
1181
- base_size_width = 720 // 8 // self.transformer.config.patch_size
1182
- base_size_height = 480 // 8 // self.transformer.config.patch_size
1183
-
1184
- grid_crops_coords = get_resize_crop_region_for_grid(
1185
- (grid_height, grid_width), base_size_width, base_size_height
1186
- )
1187
- image_rotary_emb = get_3d_rotary_pos_embed(
1188
- self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
1189
- temporal_size=latents.size(2), use_real=True,
1190
- )
1191
- else:
1192
- base_size = 512 // 8 // self.transformer.config.patch_size
1193
- grid_crops_coords = get_resize_crop_region_for_grid(
1194
- (grid_height, grid_width), base_size, base_size
1195
- )
1196
- image_rotary_emb = get_2d_rotary_pos_embed(
1197
- self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
1198
- )
1199
-
1200
- # Get other hunyuan params
1201
- style = torch.tensor([0], device=device)
1202
-
1203
- target_size = target_size or (height, width)
1204
- add_time_ids = list(original_size + target_size + crops_coords_top_left)
1205
- add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
1206
-
1207
- if self.do_classifier_free_guidance:
1208
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1209
- prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
1210
- prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
1211
- prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
1212
- add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
1213
- style = torch.cat([style] * 2, dim=0)
1214
-
1215
- prompt_embeds = prompt_embeds.to(device=device)
1216
- prompt_attention_mask = prompt_attention_mask.to(device=device)
1217
- prompt_embeds_2 = prompt_embeds_2.to(device=device)
1218
- prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
1219
- add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
1220
- batch_size * num_images_per_prompt, 1
1221
- )
1222
- style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
1223
-
1224
- torch.cuda.empty_cache()
1225
- if self.enable_autocast_float8_transformer_flag:
1226
- origin_weight_dtype = self.transformer.dtype
1227
- self.transformer = self.transformer.to(torch.float8_e4m3fn)
1228
- # 10. Denoising loop
1229
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1230
- self._num_timesteps = len(timesteps)
1231
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1232
- for i, t in enumerate(timesteps):
1233
- if self.interrupt:
1234
- continue
1235
-
1236
- # expand the latents if we are doing classifier free guidance
1237
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1238
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1239
-
1240
- if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
1241
- clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
1242
- clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input)
1243
- else:
1244
- clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
1245
- clip_attention_mask_actual_input = clip_attention_mask_input
1246
-
1247
- # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
1248
- t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
1249
- dtype=latent_model_input.dtype
1250
- )
1251
-
1252
- # predict the noise residual
1253
- noise_pred = self.transformer(
1254
- latent_model_input,
1255
- t_expand,
1256
- encoder_hidden_states=prompt_embeds,
1257
- text_embedding_mask=prompt_attention_mask,
1258
- encoder_hidden_states_t5=prompt_embeds_2,
1259
- text_embedding_mask_t5=prompt_attention_mask_2,
1260
- image_meta_size=add_time_ids,
1261
- style=style,
1262
- image_rotary_emb=image_rotary_emb,
1263
- inpaint_latents=inpaint_latents,
1264
- clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
1265
- clip_attention_mask=clip_attention_mask_actual_input,
1266
- return_dict=False,
1267
- )[0]
1268
- if noise_pred.size()[1] != self.vae.config.latent_channels:
1269
- noise_pred, _ = noise_pred.chunk(2, dim=1)
1270
-
1271
- # perform guidance
1272
- if self.do_classifier_free_guidance:
1273
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1274
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1275
-
1276
- if self.do_classifier_free_guidance and guidance_rescale > 0.0:
1277
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1278
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1279
-
1280
- # compute the previous noisy sample x_t -> x_t-1
1281
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1282
-
1283
- if num_channels_transformer == 4:
1284
- init_latents_proper = image_latents
1285
- init_mask = mask
1286
- if i < len(timesteps) - 1:
1287
- noise_timestep = timesteps[i + 1]
1288
- init_latents_proper = self.scheduler.add_noise(
1289
- init_latents_proper, noise, torch.tensor([noise_timestep])
1290
- )
1291
-
1292
- latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1293
-
1294
- if callback_on_step_end is not None:
1295
- callback_kwargs = {}
1296
- for k in callback_on_step_end_tensor_inputs:
1297
- callback_kwargs[k] = locals()[k]
1298
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1299
-
1300
- latents = callback_outputs.pop("latents", latents)
1301
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1302
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1303
- prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
1304
- negative_prompt_embeds_2 = callback_outputs.pop(
1305
- "negative_prompt_embeds_2", negative_prompt_embeds_2
1306
- )
1307
-
1308
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1309
- progress_bar.update()
1310
-
1311
- if XLA_AVAILABLE:
1312
- xm.mark_step()
1313
-
1314
- if comfyui_progressbar:
1315
- pbar.update(1)
1316
-
1317
- if self.enable_autocast_float8_transformer_flag:
1318
- self.transformer = self.transformer.to("cpu", origin_weight_dtype)
1319
-
1320
- torch.cuda.empty_cache()
1321
- # Post-processing
1322
- video = self.decode_latents(latents)
1323
-
1324
- # Convert to tensor
1325
- if output_type == "latent":
1326
- video = torch.from_numpy(video)
1327
-
1328
- # Offload all models
1329
- self.maybe_free_model_hooks()
1330
-
1331
- if not return_dict:
1332
- return video
1333
-
1334
- return EasyAnimatePipelineOutput(videos=video)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
easyanimate/ui/ui.py CHANGED
@@ -17,41 +17,42 @@ import torch
17
  from diffusers import (AutoencoderKL, DDIMScheduler,
18
  DPMSolverMultistepScheduler,
19
  EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
20
- PNDMScheduler)
21
  from diffusers.utils.import_utils import is_xformers_available
22
  from omegaconf import OmegaConf
23
  from PIL import Image
24
  from safetensors import safe_open
25
  from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
26
- CLIPVisionModelWithProjection, T5Tokenizer,
27
- T5EncoderModel, T5Tokenizer)
 
28
 
29
- from easyanimate.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
30
- from easyanimate.models import (name_to_autoencoder_magvit,
31
  name_to_transformer3d)
32
- from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
33
- from easyanimate.models.transformer3d import (HunyuanTransformer3DModel,
34
- Transformer3DModel)
35
- from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
36
- from easyanimate.pipeline.pipeline_easyanimate_inpaint import \
37
  EasyAnimateInpaintPipeline
38
- from easyanimate.pipeline.pipeline_easyanimate_multi_text_encoder import \
39
- EasyAnimatePipeline_Multi_Text_Encoder
40
- from easyanimate.pipeline.pipeline_easyanimate_multi_text_encoder_inpaint import \
41
- EasyAnimatePipeline_Multi_Text_Encoder_Inpaint
42
- from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
43
- from easyanimate.utils.utils import (
44
  get_image_to_video_latent, get_video_to_video_latent,
45
  get_width_and_height_from_image_and_base_resolution, save_videos_grid)
46
- from easyanimate.utils.fp8_optimization import convert_weight_dtype_wrapper
47
 
48
- scheduler_dict = {
49
  "Euler": EulerDiscreteScheduler,
50
  "Euler A": EulerAncestralDiscreteScheduler,
51
  "DPM++": DPMSolverMultistepScheduler,
52
  "PNDM": PNDMScheduler,
53
  "DDIM": DDIMScheduler,
54
  }
 
 
 
 
55
 
56
  gradio_version = pkg_resources.get_distribution("gradio").version
57
  gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
@@ -98,8 +99,8 @@ class EasyAnimateController:
98
  self.GPU_memory_mode = GPU_memory_mode
99
 
100
  self.weight_dtype = weight_dtype
101
- self.edition = "v5"
102
- self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml"))
103
 
104
  def refresh_diffusion_transformer(self):
105
  self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
@@ -121,26 +122,37 @@ class EasyAnimateController:
121
  if edition == "v1":
122
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v1_motion_module.yaml"))
123
  return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
 
124
  gr.update(value=512, minimum=384, maximum=704, step=32), \
125
  gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
126
  elif edition == "v2":
127
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v2_magvit_motion_module.yaml"))
128
  return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
 
129
  gr.update(value=672, minimum=128, maximum=1344, step=16), \
130
  gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
131
  elif edition == "v3":
132
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v3_slicevae_motion_module.yaml"))
133
  return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
 
134
  gr.update(value=672, minimum=128, maximum=1344, step=16), \
135
  gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
136
  elif edition == "v4":
137
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v4_slicevae_multi_text_encoder.yaml"))
138
  return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
 
139
  gr.update(value=672, minimum=128, maximum=1344, step=16), \
140
  gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
141
  elif edition == "v5":
142
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml"))
143
  return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
 
 
 
 
 
 
 
144
  gr.update(value=672, minimum=128, maximum=1344, step=16), \
145
  gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4)
146
 
@@ -170,33 +182,55 @@ class EasyAnimateController:
170
  self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
171
  diffusion_transformer_dropdown,
172
  subfolder="transformer",
173
- transformer_additional_kwargs=transformer_additional_kwargs
174
- ).to(self.weight_dtype)
 
 
175
 
176
  if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
177
  tokenizer = BertTokenizer.from_pretrained(
178
  diffusion_transformer_dropdown, subfolder="tokenizer"
179
  )
180
- tokenizer_2 = T5Tokenizer.from_pretrained(
181
- diffusion_transformer_dropdown, subfolder="tokenizer_2"
182
- )
 
 
 
 
 
183
  else:
184
- tokenizer = T5Tokenizer.from_pretrained(
185
- diffusion_transformer_dropdown, subfolder="tokenizer"
186
- )
 
 
 
 
 
187
  tokenizer_2 = None
188
 
189
  if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
190
  text_encoder = BertModel.from_pretrained(
191
  diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
192
  )
193
- text_encoder_2 = T5EncoderModel.from_pretrained(
194
- diffusion_transformer_dropdown, subfolder="text_encoder_2", torch_dtype=self.weight_dtype
195
- )
196
- else:
197
- text_encoder = T5EncoderModel.from_pretrained(
198
- diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
199
- )
 
 
 
 
 
 
 
 
 
 
200
  text_encoder_2 = None
201
 
202
  # Get pipeline
@@ -212,23 +246,18 @@ class EasyAnimateController:
212
  clip_image_processor = None
213
 
214
  # Get Scheduler
215
- Choosen_Scheduler = scheduler_dict = {
216
- "Euler": EulerDiscreteScheduler,
217
- "Euler A": EulerAncestralDiscreteScheduler,
218
- "DPM++": DPMSolverMultistepScheduler,
219
- "PNDM": PNDMScheduler,
220
- "DDIM": DDIMScheduler,
221
- }["Euler"]
222
-
223
  scheduler = Choosen_Scheduler.from_pretrained(
224
  diffusion_transformer_dropdown,
225
  subfolder="scheduler"
226
  )
227
 
228
- if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
229
  if self.transformer.config.in_channels != self.vae.config.latent_channels:
230
- self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder_Inpaint.from_pretrained(
231
- diffusion_transformer_dropdown,
232
  text_encoder=text_encoder,
233
  text_encoder_2=text_encoder_2,
234
  tokenizer=tokenizer,
@@ -236,13 +265,11 @@ class EasyAnimateController:
236
  vae=self.vae,
237
  transformer=self.transformer,
238
  scheduler=scheduler,
239
- torch_dtype=self.weight_dtype,
240
  clip_image_encoder=clip_image_encoder,
241
  clip_image_processor=clip_image_processor,
242
- )
243
  else:
244
- self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder.from_pretrained(
245
- diffusion_transformer_dropdown,
246
  text_encoder=text_encoder,
247
  text_encoder_2=text_encoder_2,
248
  tokenizer=tokenizer,
@@ -250,40 +277,25 @@ class EasyAnimateController:
250
  vae=self.vae,
251
  transformer=self.transformer,
252
  scheduler=scheduler,
253
- torch_dtype=self.weight_dtype
254
- )
255
  else:
256
- if self.transformer.config.in_channels != self.vae.config.latent_channels:
257
- self.pipeline = EasyAnimateInpaintPipeline(
258
- diffusion_transformer_dropdown,
259
- text_encoder=text_encoder,
260
- tokenizer=tokenizer,
261
- vae=self.vae,
262
- transformer=self.transformer,
263
- scheduler=scheduler,
264
- torch_dtype=self.weight_dtype,
265
- clip_image_encoder=clip_image_encoder,
266
- clip_image_processor=clip_image_processor,
267
- )
268
- else:
269
- self.pipeline = EasyAnimatePipeline(
270
- diffusion_transformer_dropdown,
271
- text_encoder=text_encoder,
272
- tokenizer=tokenizer,
273
- vae=self.vae,
274
- transformer=self.transformer,
275
- scheduler=scheduler,
276
- torch_dtype=self.weight_dtype
277
- )
278
 
279
  if self.GPU_memory_mode == "sequential_cpu_offload":
280
  self.pipeline.enable_sequential_cpu_offload()
281
  elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
282
  self.pipeline.enable_model_cpu_offload()
283
- self.pipeline.enable_autocast_float8_transformer()
284
  convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
285
  else:
286
- self.GPU_memory_mode.enable_model_cpu_offload()
287
  print("Update diffusion transformer done")
288
  return gr.update()
289
 
@@ -374,8 +386,10 @@ class EasyAnimateController:
374
  if self.base_model_path != base_model_dropdown:
375
  self.update_base_model(base_model_dropdown)
376
 
 
 
 
377
  if self.lora_model_path != lora_model_dropdown:
378
- print("Update lora model")
379
  self.update_lora_model(lora_model_dropdown)
380
 
381
  if control_video is not None and self.model_type == "Inpaint":
@@ -426,19 +440,21 @@ class EasyAnimateController:
426
  else:
427
  raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
428
 
429
- fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition]
430
  is_image = True if generation_method == "Image Generation" else False
431
 
432
- if is_xformers_available() and not self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): self.transformer.enable_xformers_memory_efficient_attention()
 
 
433
 
434
- self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
 
 
 
 
435
  if self.lora_model_path != "none":
436
  # lora part
437
  self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
438
-
439
- if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
440
- else: seed_textbox = np.random.randint(0, 1e10)
441
- generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
442
 
443
  try:
444
  if self.model_type == "Inpaint":
@@ -480,7 +496,7 @@ class EasyAnimateController:
480
  video = input_video,
481
  mask_video = input_video_mask,
482
  strength = 1,
483
- ).videos
484
 
485
  if init_frames != 0:
486
  mix_ratio = torch.from_numpy(
@@ -531,7 +547,7 @@ class EasyAnimateController:
531
  video = input_video,
532
  mask_video = input_video_mask,
533
  strength = strength,
534
- ).videos
535
  else:
536
  if self.vae.cache_mag_vae:
537
  length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
@@ -547,7 +563,7 @@ class EasyAnimateController:
547
  height = height_slider,
548
  video_length = length_slider if not is_image else 1,
549
  generator = generator
550
- ).videos
551
  else:
552
  if self.vae.cache_mag_vae:
553
  length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
@@ -566,7 +582,7 @@ class EasyAnimateController:
566
  generator = generator,
567
 
568
  control_video = input_video,
569
- ).videos
570
  except Exception as e:
571
  gc.collect()
572
  torch.cuda.empty_cache()
@@ -676,8 +692,8 @@ def ui(GPU_memory_mode, weight_dtype):
676
  with gr.Row():
677
  easyanimate_edition_dropdown = gr.Dropdown(
678
  label="The config of EasyAnimate Edition (EasyAnimate版本配置)",
679
- choices=["v1", "v2", "v3", "v4", "v5"],
680
- value="v5",
681
  interactive=True,
682
  )
683
  gr.Markdown(
@@ -751,13 +767,22 @@ def ui(GPU_memory_mode, weight_dtype):
751
  """
752
  )
753
 
754
- prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
755
- negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Blurring, mutation, deformation, distortion, dark and solid, comics." )
 
 
 
 
 
 
756
 
757
  with gr.Row():
758
  with gr.Column():
759
  with gr.Row():
760
- sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
 
 
 
761
  sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1)
762
 
763
  resize_method = gr.Radio(
@@ -794,11 +819,11 @@ def ui(GPU_memory_mode, weight_dtype):
794
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
795
  def select_template(evt: gr.SelectData):
796
  text = {
797
- "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
798
- "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
799
- "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
800
- "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
801
- "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
802
  }[template_gallery_path[evt.index]]
803
  return template_gallery_path[evt.index], text
804
 
@@ -838,6 +863,7 @@ def ui(GPU_memory_mode, weight_dtype):
838
  gr.Markdown(
839
  """
840
  Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
 
841
  """
842
  )
843
  control_video = gr.Video(
@@ -927,6 +953,7 @@ def ui(GPU_memory_mode, weight_dtype):
927
  diffusion_transformer_dropdown,
928
  motion_module_dropdown,
929
  motion_module_refresh_button,
 
930
  width_slider,
931
  height_slider,
932
  length_slider,
@@ -1003,33 +1030,55 @@ class EasyAnimateController_Modelscope:
1003
  self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
1004
  model_name,
1005
  subfolder="transformer",
1006
- transformer_additional_kwargs=transformer_additional_kwargs
1007
- ).to(self.weight_dtype)
 
 
1008
 
1009
  if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
1010
  tokenizer = BertTokenizer.from_pretrained(
1011
  model_name, subfolder="tokenizer"
1012
  )
1013
- tokenizer_2 = T5Tokenizer.from_pretrained(
1014
- model_name, subfolder="tokenizer_2"
1015
- )
 
 
 
 
 
1016
  else:
1017
- tokenizer = T5Tokenizer.from_pretrained(
1018
- model_name, subfolder="tokenizer"
1019
- )
 
 
 
 
 
1020
  tokenizer_2 = None
1021
 
1022
  if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
1023
  text_encoder = BertModel.from_pretrained(
1024
  model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
1025
  )
1026
- text_encoder_2 = T5EncoderModel.from_pretrained(
1027
- model_name, subfolder="text_encoder_2", torch_dtype=self.weight_dtype
1028
- )
1029
- else:
1030
- text_encoder = T5EncoderModel.from_pretrained(
1031
- model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
1032
- )
 
 
 
 
 
 
 
 
 
 
1033
  text_encoder_2 = None
1034
 
1035
  # Get pipeline
@@ -1045,23 +1094,18 @@ class EasyAnimateController_Modelscope:
1045
  clip_image_processor = None
1046
 
1047
  # Get Scheduler
1048
- Choosen_Scheduler = scheduler_dict = {
1049
- "Euler": EulerDiscreteScheduler,
1050
- "Euler A": EulerAncestralDiscreteScheduler,
1051
- "DPM++": DPMSolverMultistepScheduler,
1052
- "PNDM": PNDMScheduler,
1053
- "DDIM": DDIMScheduler,
1054
- }["Euler"]
1055
-
1056
  scheduler = Choosen_Scheduler.from_pretrained(
1057
  model_name,
1058
  subfolder="scheduler"
1059
  )
1060
 
1061
- if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
1062
  if self.transformer.config.in_channels != self.vae.config.latent_channels:
1063
- self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder_Inpaint.from_pretrained(
1064
- model_name,
1065
  text_encoder=text_encoder,
1066
  text_encoder_2=text_encoder_2,
1067
  tokenizer=tokenizer,
@@ -1069,51 +1113,34 @@ class EasyAnimateController_Modelscope:
1069
  vae=self.vae,
1070
  transformer=self.transformer,
1071
  scheduler=scheduler,
1072
- torch_dtype=self.weight_dtype,
1073
  clip_image_encoder=clip_image_encoder,
1074
  clip_image_processor=clip_image_processor,
1075
- )
1076
  else:
1077
- self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder.from_pretrained(
1078
- model_name,
1079
  text_encoder=text_encoder,
1080
  text_encoder_2=text_encoder_2,
1081
  tokenizer=tokenizer,
1082
  tokenizer_2=tokenizer_2,
1083
  vae=self.vae,
1084
  transformer=self.transformer,
1085
- scheduler=scheduler,
1086
- torch_dtype=self.weight_dtype
1087
- )
1088
  else:
1089
- if self.transformer.config.in_channels != self.vae.config.latent_channels:
1090
- self.pipeline = EasyAnimateInpaintPipeline(
1091
- model_name,
1092
- text_encoder=text_encoder,
1093
- tokenizer=tokenizer,
1094
- vae=self.vae,
1095
- transformer=self.transformer,
1096
- scheduler=scheduler,
1097
- torch_dtype=self.weight_dtype,
1098
- clip_image_encoder=clip_image_encoder,
1099
- clip_image_processor=clip_image_processor,
1100
- )
1101
- else:
1102
- self.pipeline = EasyAnimatePipeline(
1103
- model_name,
1104
- text_encoder=text_encoder,
1105
- tokenizer=tokenizer,
1106
- vae=self.vae,
1107
- transformer=self.transformer,
1108
- scheduler=scheduler,
1109
- torch_dtype=self.weight_dtype
1110
- )
1111
 
1112
  if GPU_memory_mode == "sequential_cpu_offload":
1113
  self.pipeline.enable_sequential_cpu_offload()
1114
  elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
1115
  self.pipeline.enable_model_cpu_offload()
1116
- self.pipeline.enable_autocast_float8_transformer()
1117
  convert_weight_dtype_wrapper(self.pipeline.transformer, weight_dtype)
1118
  else:
1119
  GPU_memory_mode.enable_model_cpu_offload()
@@ -1214,17 +1241,17 @@ class EasyAnimateController_Modelscope:
1214
  else:
1215
  raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
1216
 
1217
- fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition]
1218
  is_image = True if generation_method == "Image Generation" else False
1219
 
1220
- self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
1221
- if self.lora_model_path != "none":
1222
- # lora part
1223
- self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
1224
-
1225
  if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
1226
  else: seed_textbox = np.random.randint(0, 1e10)
1227
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
 
 
 
 
 
1228
 
1229
  try:
1230
  if self.model_type == "Inpaint":
@@ -1254,7 +1281,7 @@ class EasyAnimateController_Modelscope:
1254
  video = input_video,
1255
  mask_video = input_video_mask,
1256
  strength = strength,
1257
- ).videos
1258
  else:
1259
  sample = self.pipeline(
1260
  prompt_textbox,
@@ -1265,7 +1292,7 @@ class EasyAnimateController_Modelscope:
1265
  height = height_slider,
1266
  video_length = length_slider if not is_image else 1,
1267
  generator = generator
1268
- ).videos
1269
  else:
1270
  if self.vae.cache_mag_vae:
1271
  length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
@@ -1285,7 +1312,7 @@ class EasyAnimateController_Modelscope:
1285
  generator = generator,
1286
 
1287
  control_video = input_video,
1288
- ).videos
1289
  except Exception as e:
1290
  gc.collect()
1291
  torch.cuda.empty_cache()
@@ -1406,13 +1433,28 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample,
1406
  """
1407
  )
1408
 
1409
- prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1410
- negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Blurring, mutation, deformation, distortion, dark and solid, comics." )
 
 
 
 
 
 
1411
 
1412
  with gr.Row():
1413
  with gr.Column():
1414
  with gr.Row():
1415
- sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
 
 
 
 
 
 
 
 
 
1416
  sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False)
1417
 
1418
  if edition == "v1":
@@ -1466,11 +1508,11 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample,
1466
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1467
  def select_template(evt: gr.SelectData):
1468
  text = {
1469
- "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1470
- "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1471
- "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1472
- "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1473
- "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1474
  }[template_gallery_path[evt.index]]
1475
  return template_gallery_path[evt.index], text
1476
 
@@ -1510,6 +1552,7 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample,
1510
  gr.Markdown(
1511
  """
1512
  Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
 
1513
  """
1514
  )
1515
  control_video = gr.Video(
@@ -1820,13 +1863,28 @@ def ui_eas(edition, config_path, model_name, savedir_sample):
1820
  """
1821
  )
1822
 
1823
- prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1824
- negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="Blurring, mutation, deformation, distortion, dark and solid, comics." )
 
 
 
 
 
 
1825
 
1826
  with gr.Row():
1827
  with gr.Column():
1828
  with gr.Row():
1829
- sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
 
 
 
 
 
 
 
 
 
1830
  sample_step_slider = gr.Slider(label="Sampling steps", value=40, minimum=10, maximum=40, step=1, interactive=False)
1831
 
1832
  if edition == "v1":
@@ -1875,11 +1933,11 @@ def ui_eas(edition, config_path, model_name, savedir_sample):
1875
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1876
  def select_template(evt: gr.SelectData):
1877
  text = {
1878
- "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1879
- "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1880
- "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1881
- "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1882
- "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1883
  }[template_gallery_path[evt.index]]
1884
  return template_gallery_path[evt.index], text
1885
 
 
17
  from diffusers import (AutoencoderKL, DDIMScheduler,
18
  DPMSolverMultistepScheduler,
19
  EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
20
+ FlowMatchEulerDiscreteScheduler, PNDMScheduler)
21
  from diffusers.utils.import_utils import is_xformers_available
22
  from omegaconf import OmegaConf
23
  from PIL import Image
24
  from safetensors import safe_open
25
  from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
26
+ CLIPVisionModelWithProjection, Qwen2Tokenizer,
27
+ Qwen2VLForConditionalGeneration, T5EncoderModel,
28
+ T5Tokenizer)
29
 
30
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
31
+ from ..models import (name_to_autoencoder_magvit,
32
  name_to_transformer3d)
33
+ from ..pipeline.pipeline_easyanimate import \
34
+ EasyAnimatePipeline
35
+ from ..pipeline.pipeline_easyanimate_control import \
36
+ EasyAnimateControlPipeline
37
+ from ..pipeline.pipeline_easyanimate_inpaint import \
38
  EasyAnimateInpaintPipeline
39
+ from ..utils.fp8_optimization import convert_weight_dtype_wrapper
40
+ from ..utils.lora_utils import merge_lora, unmerge_lora
41
+ from ..utils.utils import (
 
 
 
42
  get_image_to_video_latent, get_video_to_video_latent,
43
  get_width_and_height_from_image_and_base_resolution, save_videos_grid)
 
44
 
45
+ ddpm_scheduler_dict = {
46
  "Euler": EulerDiscreteScheduler,
47
  "Euler A": EulerAncestralDiscreteScheduler,
48
  "DPM++": DPMSolverMultistepScheduler,
49
  "PNDM": PNDMScheduler,
50
  "DDIM": DDIMScheduler,
51
  }
52
+ flow_scheduler_dict = {
53
+ "Flow": FlowMatchEulerDiscreteScheduler,
54
+ }
55
+ all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
56
 
57
  gradio_version = pkg_resources.get_distribution("gradio").version
58
  gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
 
99
  self.GPU_memory_mode = GPU_memory_mode
100
 
101
  self.weight_dtype = weight_dtype
102
+ self.edition = "v5.1"
103
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5.1_magvit_qwen.yaml"))
104
 
105
  def refresh_diffusion_transformer(self):
106
  self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
 
122
  if edition == "v1":
123
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v1_motion_module.yaml"))
124
  return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
125
+ gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
126
  gr.update(value=512, minimum=384, maximum=704, step=32), \
127
  gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
128
  elif edition == "v2":
129
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v2_magvit_motion_module.yaml"))
130
  return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
131
+ gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
132
  gr.update(value=672, minimum=128, maximum=1344, step=16), \
133
  gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
134
  elif edition == "v3":
135
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v3_slicevae_motion_module.yaml"))
136
  return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
137
+ gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
138
  gr.update(value=672, minimum=128, maximum=1344, step=16), \
139
  gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
140
  elif edition == "v4":
141
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v4_slicevae_multi_text_encoder.yaml"))
142
  return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
143
+ gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
144
  gr.update(value=672, minimum=128, maximum=1344, step=16), \
145
  gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
146
  elif edition == "v5":
147
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml"))
148
  return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
149
+ gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
150
+ gr.update(value=672, minimum=128, maximum=1344, step=16), \
151
+ gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4)
152
+ elif edition == "v5.1":
153
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5.1_magvit_qwen.yaml"))
154
+ return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
155
+ gr.update(choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]), \
156
  gr.update(value=672, minimum=128, maximum=1344, step=16), \
157
  gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4)
158
 
 
182
  self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
183
  diffusion_transformer_dropdown,
184
  subfolder="transformer",
185
+ transformer_additional_kwargs=transformer_additional_kwargs,
186
+ torch_dtype=torch.float8_e4m3fn if self.GPU_memory_mode == "model_cpu_offload_and_qfloat8" else self.weight_dtype,
187
+ low_cpu_mem_usage=True,
188
+ )
189
 
190
  if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
191
  tokenizer = BertTokenizer.from_pretrained(
192
  diffusion_transformer_dropdown, subfolder="tokenizer"
193
  )
194
+ if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
195
+ tokenizer_2 = Qwen2Tokenizer.from_pretrained(
196
+ os.path.join(diffusion_transformer_dropdown, "tokenizer_2")
197
+ )
198
+ else:
199
+ tokenizer_2 = T5Tokenizer.from_pretrained(
200
+ diffusion_transformer_dropdown, subfolder="tokenizer_2"
201
+ )
202
  else:
203
+ if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
204
+ tokenizer = Qwen2Tokenizer.from_pretrained(
205
+ os.path.join(diffusion_transformer_dropdown, "tokenizer")
206
+ )
207
+ else:
208
+ tokenizer = T5Tokenizer.from_pretrained(
209
+ diffusion_transformer_dropdown, subfolder="tokenizer"
210
+ )
211
  tokenizer_2 = None
212
 
213
  if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
214
  text_encoder = BertModel.from_pretrained(
215
  diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
216
  )
217
+ if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
218
+ text_encoder_2 = Qwen2VLForConditionalGeneration.from_pretrained(
219
+ os.path.join(diffusion_transformer_dropdown, "text_encoder_2")
220
+ )
221
+ else:
222
+ text_encoder_2 = T5EncoderModel.from_pretrained(
223
+ diffusion_transformer_dropdown, subfolder="text_encoder_2", torch_dtype=self.weight_dtype
224
+ )
225
+ else:
226
+ if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
227
+ text_encoder = Qwen2VLForConditionalGeneration.from_pretrained(
228
+ os.path.join(diffusion_transformer_dropdown, "text_encoder")
229
+ )
230
+ else:
231
+ text_encoder = T5EncoderModel.from_pretrained(
232
+ diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
233
+ )
234
  text_encoder_2 = None
235
 
236
  # Get pipeline
 
246
  clip_image_processor = None
247
 
248
  # Get Scheduler
249
+ if self.edition in ["v5.1"]:
250
+ Choosen_Scheduler = all_cheduler_dict["Flow"]
251
+ else:
252
+ Choosen_Scheduler = all_cheduler_dict["Euler"]
 
 
 
 
253
  scheduler = Choosen_Scheduler.from_pretrained(
254
  diffusion_transformer_dropdown,
255
  subfolder="scheduler"
256
  )
257
 
258
+ if self.model_type == "Inpaint":
259
  if self.transformer.config.in_channels != self.vae.config.latent_channels:
260
+ self.pipeline = EasyAnimateInpaintPipeline(
 
261
  text_encoder=text_encoder,
262
  text_encoder_2=text_encoder_2,
263
  tokenizer=tokenizer,
 
265
  vae=self.vae,
266
  transformer=self.transformer,
267
  scheduler=scheduler,
 
268
  clip_image_encoder=clip_image_encoder,
269
  clip_image_processor=clip_image_processor,
270
+ ).to(self.weight_dtype)
271
  else:
272
+ self.pipeline = EasyAnimatePipeline(
 
273
  text_encoder=text_encoder,
274
  text_encoder_2=text_encoder_2,
275
  tokenizer=tokenizer,
 
277
  vae=self.vae,
278
  transformer=self.transformer,
279
  scheduler=scheduler,
280
+ ).to(self.weight_dtype)
 
281
  else:
282
+ self.pipeline = EasyAnimateControlPipeline(
283
+ text_encoder=text_encoder,
284
+ text_encoder_2=text_encoder_2,
285
+ tokenizer=tokenizer,
286
+ tokenizer_2=tokenizer_2,
287
+ vae=self.vae,
288
+ transformer=self.transformer,
289
+ scheduler=scheduler,
290
+ ).to(self.weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  if self.GPU_memory_mode == "sequential_cpu_offload":
293
  self.pipeline.enable_sequential_cpu_offload()
294
  elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
295
  self.pipeline.enable_model_cpu_offload()
 
296
  convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
297
  else:
298
+ self.pipeline.enable_model_cpu_offload()
299
  print("Update diffusion transformer done")
300
  return gr.update()
301
 
 
386
  if self.base_model_path != base_model_dropdown:
387
  self.update_base_model(base_model_dropdown)
388
 
389
+ if self.motion_module_path != motion_module_dropdown:
390
+ self.update_motion_module(motion_module_dropdown)
391
+
392
  if self.lora_model_path != lora_model_dropdown:
 
393
  self.update_lora_model(lora_model_dropdown)
394
 
395
  if control_video is not None and self.model_type == "Inpaint":
 
440
  else:
441
  raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
442
 
443
+ fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8, "v5.1": 8}[self.edition]
444
  is_image = True if generation_method == "Image Generation" else False
445
 
446
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
447
+ else: seed_textbox = np.random.randint(0, 1e10)
448
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
449
 
450
+ if is_xformers_available() \
451
+ and self.inference_config['transformer_additional_kwargs'].get('transformer_type', 'Transformer3DModel') == 'Transformer3DModel':
452
+ self.transformer.enable_xformers_memory_efficient_attention()
453
+
454
+ self.pipeline.scheduler = all_cheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
455
  if self.lora_model_path != "none":
456
  # lora part
457
  self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
 
 
 
 
458
 
459
  try:
460
  if self.model_type == "Inpaint":
 
496
  video = input_video,
497
  mask_video = input_video_mask,
498
  strength = 1,
499
+ ).frames
500
 
501
  if init_frames != 0:
502
  mix_ratio = torch.from_numpy(
 
547
  video = input_video,
548
  mask_video = input_video_mask,
549
  strength = strength,
550
+ ).frames
551
  else:
552
  if self.vae.cache_mag_vae:
553
  length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
 
563
  height = height_slider,
564
  video_length = length_slider if not is_image else 1,
565
  generator = generator
566
+ ).frames
567
  else:
568
  if self.vae.cache_mag_vae:
569
  length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
 
582
  generator = generator,
583
 
584
  control_video = input_video,
585
+ ).frames
586
  except Exception as e:
587
  gc.collect()
588
  torch.cuda.empty_cache()
 
692
  with gr.Row():
693
  easyanimate_edition_dropdown = gr.Dropdown(
694
  label="The config of EasyAnimate Edition (EasyAnimate版本配置)",
695
+ choices=["v1", "v2", "v3", "v4", "v5", "v5.1"],
696
+ value="v5.1",
697
  interactive=True,
698
  )
699
  gr.Markdown(
 
767
  """
768
  )
769
 
770
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.")
771
+ gr.Markdown(
772
+ """
773
+ Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism.
774
+ 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。
775
+ """
776
+ )
777
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." )
778
 
779
  with gr.Row():
780
  with gr.Column():
781
  with gr.Row():
782
+ sampler_dropdown = gr.Dropdown(
783
+ label="Sampling method (采样器种类)",
784
+ choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]
785
+ )
786
  sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1)
787
 
788
  resize_method = gr.Radio(
 
819
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
820
  def select_template(evt: gr.SelectData):
821
  text = {
822
+ "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
823
+ "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
824
+ "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
825
+ "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
826
+ "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
827
  }[template_gallery_path[evt.index]]
828
  return template_gallery_path[evt.index], text
829
 
 
863
  gr.Markdown(
864
  """
865
  Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
866
+ Only normal controls are supported in app.py; trajectory control and camera control need ComfyUI, as shown in https://github.com/aigc-apps/EasyAnimate/tree/main/comfyui.
867
  """
868
  )
869
  control_video = gr.Video(
 
953
  diffusion_transformer_dropdown,
954
  motion_module_dropdown,
955
  motion_module_refresh_button,
956
+ sampler_dropdown,
957
  width_slider,
958
  height_slider,
959
  length_slider,
 
1030
  self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
1031
  model_name,
1032
  subfolder="transformer",
1033
+ transformer_additional_kwargs=transformer_additional_kwargs,
1034
+ torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype,
1035
+ low_cpu_mem_usage=True,
1036
+ )
1037
 
1038
  if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
1039
  tokenizer = BertTokenizer.from_pretrained(
1040
  model_name, subfolder="tokenizer"
1041
  )
1042
+ if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
1043
+ tokenizer_2 = Qwen2Tokenizer.from_pretrained(
1044
+ os.path.join(model_name, "tokenizer_2")
1045
+ )
1046
+ else:
1047
+ tokenizer_2 = T5Tokenizer.from_pretrained(
1048
+ model_name, subfolder="tokenizer_2"
1049
+ )
1050
  else:
1051
+ if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
1052
+ tokenizer = Qwen2Tokenizer.from_pretrained(
1053
+ os.path.join(model_name, "tokenizer")
1054
+ )
1055
+ else:
1056
+ tokenizer = T5Tokenizer.from_pretrained(
1057
+ model_name, subfolder="tokenizer"
1058
+ )
1059
  tokenizer_2 = None
1060
 
1061
  if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
1062
  text_encoder = BertModel.from_pretrained(
1063
  model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
1064
  )
1065
+ if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
1066
+ text_encoder_2 = Qwen2VLForConditionalGeneration.from_pretrained(
1067
+ os.path.join(model_name, "text_encoder_2"), torch_dtype=self.weight_dtype
1068
+ )
1069
+ else:
1070
+ text_encoder_2 = T5EncoderModel.from_pretrained(
1071
+ model_name, subfolder="text_encoder_2", torch_dtype=self.weight_dtype
1072
+ )
1073
+ else:
1074
+ if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
1075
+ text_encoder = Qwen2VLForConditionalGeneration.from_pretrained(
1076
+ os.path.join(model_name, "text_encoder"), torch_dtype=self.weight_dtype
1077
+ )
1078
+ else:
1079
+ text_encoder = T5EncoderModel.from_pretrained(
1080
+ model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
1081
+ )
1082
  text_encoder_2 = None
1083
 
1084
  # Get pipeline
 
1094
  clip_image_processor = None
1095
 
1096
  # Get Scheduler
1097
+ if self.edition in ["v5.1"]:
1098
+ Choosen_Scheduler = all_cheduler_dict["Flow"]
1099
+ else:
1100
+ Choosen_Scheduler = all_cheduler_dict["Euler"]
 
 
 
 
1101
  scheduler = Choosen_Scheduler.from_pretrained(
1102
  model_name,
1103
  subfolder="scheduler"
1104
  )
1105
 
1106
+ if model_type == "Inpaint":
1107
  if self.transformer.config.in_channels != self.vae.config.latent_channels:
1108
+ self.pipeline = EasyAnimateInpaintPipeline(
 
1109
  text_encoder=text_encoder,
1110
  text_encoder_2=text_encoder_2,
1111
  tokenizer=tokenizer,
 
1113
  vae=self.vae,
1114
  transformer=self.transformer,
1115
  scheduler=scheduler,
 
1116
  clip_image_encoder=clip_image_encoder,
1117
  clip_image_processor=clip_image_processor,
1118
+ ).to(weight_dtype)
1119
  else:
1120
+ self.pipeline = EasyAnimatePipeline(
 
1121
  text_encoder=text_encoder,
1122
  text_encoder_2=text_encoder_2,
1123
  tokenizer=tokenizer,
1124
  tokenizer_2=tokenizer_2,
1125
  vae=self.vae,
1126
  transformer=self.transformer,
1127
+ scheduler=scheduler
1128
+ ).to(weight_dtype)
 
1129
  else:
1130
+ self.pipeline = EasyAnimateControlPipeline(
1131
+ text_encoder=text_encoder,
1132
+ text_encoder_2=text_encoder_2,
1133
+ tokenizer=tokenizer,
1134
+ tokenizer_2=tokenizer_2,
1135
+ vae=self.vae,
1136
+ transformer=self.transformer,
1137
+ scheduler=scheduler,
1138
+ ).to(weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
1139
 
1140
  if GPU_memory_mode == "sequential_cpu_offload":
1141
  self.pipeline.enable_sequential_cpu_offload()
1142
  elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
1143
  self.pipeline.enable_model_cpu_offload()
 
1144
  convert_weight_dtype_wrapper(self.pipeline.transformer, weight_dtype)
1145
  else:
1146
  GPU_memory_mode.enable_model_cpu_offload()
 
1241
  else:
1242
  raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
1243
 
1244
+ fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8, "v5.1": 8}[self.edition]
1245
  is_image = True if generation_method == "Image Generation" else False
1246
 
 
 
 
 
 
1247
  if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
1248
  else: seed_textbox = np.random.randint(0, 1e10)
1249
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
1250
+
1251
+ self.pipeline.scheduler = all_cheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
1252
+ if self.lora_model_path != "none":
1253
+ # lora part
1254
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
1255
 
1256
  try:
1257
  if self.model_type == "Inpaint":
 
1281
  video = input_video,
1282
  mask_video = input_video_mask,
1283
  strength = strength,
1284
+ ).frames
1285
  else:
1286
  sample = self.pipeline(
1287
  prompt_textbox,
 
1292
  height = height_slider,
1293
  video_length = length_slider if not is_image else 1,
1294
  generator = generator
1295
+ ).frames
1296
  else:
1297
  if self.vae.cache_mag_vae:
1298
  length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
 
1312
  generator = generator,
1313
 
1314
  control_video = input_video,
1315
+ ).frames
1316
  except Exception as e:
1317
  gc.collect()
1318
  torch.cuda.empty_cache()
 
1433
  """
1434
  )
1435
 
1436
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.")
1437
+ gr.Markdown(
1438
+ """
1439
+ Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism.
1440
+ 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。
1441
+ """
1442
+ )
1443
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." )
1444
 
1445
  with gr.Row():
1446
  with gr.Column():
1447
  with gr.Row():
1448
+ if edition in ["v5.1"]:
1449
+ sampler_dropdown = gr.Dropdown(
1450
+ label="Sampling method (采样器种类)",
1451
+ choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]
1452
+ )
1453
+ else:
1454
+ sampler_dropdown = gr.Dropdown(
1455
+ label="Sampling method (采样器种类)",
1456
+ choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]
1457
+ )
1458
  sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False)
1459
 
1460
  if edition == "v1":
 
1508
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1509
  def select_template(evt: gr.SelectData):
1510
  text = {
1511
+ "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
1512
+ "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
1513
+ "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
1514
+ "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
1515
+ "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
1516
  }[template_gallery_path[evt.index]]
1517
  return template_gallery_path[evt.index], text
1518
 
 
1552
  gr.Markdown(
1553
  """
1554
  Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
1555
+ Only normal controls are supported in app.py; trajectory control and camera control need ComfyUI, as shown in https://github.com/aigc-apps/EasyAnimate/tree/main/comfyui.
1556
  """
1557
  )
1558
  control_video = gr.Video(
 
1863
  """
1864
  )
1865
 
1866
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.")
1867
+ gr.Markdown(
1868
+ """
1869
+ Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism.
1870
+ 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。
1871
+ """
1872
+ )
1873
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." )
1874
 
1875
  with gr.Row():
1876
  with gr.Column():
1877
  with gr.Row():
1878
+ if edition in ["v5.1"]:
1879
+ sampler_dropdown = gr.Dropdown(
1880
+ label="Sampling method (采样器种类)",
1881
+ choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]
1882
+ )
1883
+ else:
1884
+ sampler_dropdown = gr.Dropdown(
1885
+ label="Sampling method (采样器种类)",
1886
+ choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]
1887
+ )
1888
  sample_step_slider = gr.Slider(label="Sampling steps", value=40, minimum=10, maximum=40, step=1, interactive=False)
1889
 
1890
  if edition == "v1":
 
1933
  template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1934
  def select_template(evt: gr.SelectData):
1935
  text = {
1936
+ "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
1937
+ "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
1938
+ "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
1939
+ "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
1940
+ "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
1941
  }[template_gallery_path[evt.index]]
1942
  return template_gallery_path[evt.index], text
1943
 
easyanimate/utils/lora_utils.py CHANGED
@@ -369,7 +369,6 @@ def create_network(
369
  def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
370
  LORA_PREFIX_TRANSFORMER = "lora_unet"
371
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
372
- SPECIAL_LAYER_NAME = ["text_proj_t5"]
373
  if state_dict is None:
374
  state_dict = load_file(lora_path, device=device)
375
  else:
@@ -410,20 +409,25 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
410
  else:
411
  temp_name = layer_infos.pop(0)
412
 
413
- weight_up = elems['lora_up.weight'].to(dtype)
414
- weight_down = elems['lora_down.weight'].to(dtype)
 
 
 
 
 
415
  if 'alpha' in elems.keys():
416
  alpha = elems['alpha'].item() / weight_up.shape[1]
417
  else:
418
  alpha = 1.0
419
 
420
- curr_layer.weight.data = curr_layer.weight.data.to(device)
421
  if len(weight_up.shape) == 4:
422
- curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
423
- weight_down.squeeze(3).squeeze(2)).unsqueeze(
424
- 2).unsqueeze(3)
425
  else:
426
  curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
 
427
 
428
  return pipeline
429
 
@@ -448,35 +452,43 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
448
  layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
449
  curr_layer = pipeline.transformer
450
 
451
- temp_name = layer_infos.pop(0)
452
- print(layer, curr_layer)
453
- while len(layer_infos) > -1:
454
- try:
455
- curr_layer = curr_layer.__getattr__(temp_name)
456
- if len(layer_infos) > 0:
457
- temp_name = layer_infos.pop(0)
458
- elif len(layer_infos) == 0:
459
- break
460
- except Exception:
461
- if len(layer_infos) == 0:
462
- print('Error loading layer')
463
- if len(temp_name) > 0:
464
- temp_name += "_" + layer_infos.pop(0)
465
- else:
466
- temp_name = layer_infos.pop(0)
467
-
468
- weight_up = elems['lora_up.weight'].to(dtype)
469
- weight_down = elems['lora_down.weight'].to(dtype)
 
 
 
 
 
 
 
470
  if 'alpha' in elems.keys():
471
  alpha = elems['alpha'].item() / weight_up.shape[1]
472
  else:
473
  alpha = 1.0
474
 
475
- curr_layer.weight.data = curr_layer.weight.data.to(device)
476
  if len(weight_up.shape) == 4:
477
- curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
478
- weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
 
479
  else:
480
  curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
 
481
 
482
- return pipeline
 
369
  def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
370
  LORA_PREFIX_TRANSFORMER = "lora_unet"
371
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
 
372
  if state_dict is None:
373
  state_dict = load_file(lora_path, device=device)
374
  else:
 
409
  else:
410
  temp_name = layer_infos.pop(0)
411
 
412
+ origin_dtype = curr_layer.weight.data.dtype
413
+ origin_device = curr_layer.weight.data.device
414
+
415
+ curr_layer = curr_layer.to(device, dtype)
416
+ weight_up = elems['lora_up.weight'].to(device, dtype)
417
+ weight_down = elems['lora_down.weight'].to(device, dtype)
418
+
419
  if 'alpha' in elems.keys():
420
  alpha = elems['alpha'].item() / weight_up.shape[1]
421
  else:
422
  alpha = 1.0
423
 
 
424
  if len(weight_up.shape) == 4:
425
+ curr_layer.weight.data += multiplier * alpha * torch.mm(
426
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
427
+ ).unsqueeze(2).unsqueeze(3)
428
  else:
429
  curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
430
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
431
 
432
  return pipeline
433
 
 
452
  layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
453
  curr_layer = pipeline.transformer
454
 
455
+ try:
456
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
457
+ except Exception:
458
+ temp_name = layer_infos.pop(0)
459
+ while len(layer_infos) > -1:
460
+ try:
461
+ curr_layer = curr_layer.__getattr__(temp_name)
462
+ if len(layer_infos) > 0:
463
+ temp_name = layer_infos.pop(0)
464
+ elif len(layer_infos) == 0:
465
+ break
466
+ except Exception:
467
+ if len(layer_infos) == 0:
468
+ print('Error loading layer')
469
+ if len(temp_name) > 0:
470
+ temp_name += "_" + layer_infos.pop(0)
471
+ else:
472
+ temp_name = layer_infos.pop(0)
473
+
474
+ origin_dtype = curr_layer.weight.data.dtype
475
+ origin_device = curr_layer.weight.data.device
476
+
477
+ curr_layer = curr_layer.to(device, dtype)
478
+ weight_up = elems['lora_up.weight'].to(device, dtype)
479
+ weight_down = elems['lora_down.weight'].to(device, dtype)
480
+
481
  if 'alpha' in elems.keys():
482
  alpha = elems['alpha'].item() / weight_up.shape[1]
483
  else:
484
  alpha = 1.0
485
 
 
486
  if len(weight_up.shape) == 4:
487
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(
488
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
489
+ ).unsqueeze(2).unsqueeze(3)
490
  else:
491
  curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
492
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
493
 
494
+ return pipeline
easyanimate/utils/utils.py CHANGED
@@ -169,47 +169,67 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide
169
  return input_video, input_video_mask, clip_image
170
 
171
  def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
172
- if isinstance(input_video_path, str):
173
- cap = cv2.VideoCapture(input_video_path)
174
- input_video = []
 
175
 
176
- original_fps = cap.get(cv2.CAP_PROP_FPS)
177
- frame_skip = 1 if fps is None else int(original_fps // fps)
178
 
179
- frame_count = 0
180
 
181
- while True:
182
- ret, frame = cap.read()
183
- if not ret:
184
- break
185
 
186
- if frame_count % frame_skip == 0:
187
- frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
188
- input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
189
 
190
- frame_count += 1
191
 
192
- cap.release()
193
- else:
194
- input_video = input_video_path
 
 
 
195
 
196
- input_video = torch.from_numpy(np.array(input_video))[:video_length]
197
- input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
 
 
 
 
 
 
 
 
 
 
198
 
199
  if ref_image is not None:
200
- ref_image = Image.open(ref_image)
201
- ref_image = torch.from_numpy(np.array(ref_image))
202
- ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
 
 
 
 
 
 
203
 
204
- if validation_video_mask is not None:
205
- validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
206
- input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
207
-
208
- input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
209
- input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
210
- input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
211
- else:
212
- input_video_mask = torch.zeros_like(input_video[:, :1])
213
- input_video_mask[:, :, :] = 255
214
 
215
- return input_video, input_video_mask, ref_image
 
169
  return input_video, input_video_mask, clip_image
170
 
171
  def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
172
+ if input_video_path is not None:
173
+ if isinstance(input_video_path, str):
174
+ cap = cv2.VideoCapture(input_video_path)
175
+ input_video = []
176
 
177
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
178
+ frame_skip = 1 if fps is None else int(original_fps // fps)
179
 
180
+ frame_count = 0
181
 
182
+ while True:
183
+ ret, frame = cap.read()
184
+ if not ret:
185
+ break
186
 
187
+ if frame_count % frame_skip == 0:
188
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
189
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
190
 
191
+ frame_count += 1
192
 
193
+ cap.release()
194
+ else:
195
+ input_video = input_video_path
196
+
197
+ input_video = torch.from_numpy(np.array(input_video))[:video_length]
198
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
199
 
200
+ if validation_video_mask is not None:
201
+ validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
202
+ input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
203
+
204
+ input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
205
+ input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
206
+ input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
207
+ else:
208
+ input_video_mask = torch.zeros_like(input_video[:, :1])
209
+ input_video_mask[:, :, :] = 255
210
+ else:
211
+ input_video, input_video_mask = None, None
212
 
213
  if ref_image is not None:
214
+ if isinstance(ref_image, str):
215
+ ref_image = Image.open(ref_image).convert("RGB")
216
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
217
+ ref_image = torch.from_numpy(np.array(ref_image))
218
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
219
+ else:
220
+ ref_image = torch.from_numpy(np.array(ref_image))
221
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
222
+ return input_video, input_video_mask, ref_image
223
 
224
+ def get_image_latent(ref_image=None, sample_size=None):
225
+ if ref_image is not None:
226
+ if isinstance(ref_image, str):
227
+ ref_image = Image.open(ref_image).convert("RGB")
228
+ ref_image = ref_image.resize((sample_size[1], sample_size[0]))
229
+ ref_image = torch.from_numpy(np.array(ref_image))
230
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
231
+ else:
232
+ ref_image = torch.from_numpy(np.array(ref_image))
233
+ ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
234
 
235
+ return ref_image
easyanimate/vae/ldm/models/autoencoder.py CHANGED
@@ -126,13 +126,13 @@ class AutoencoderKLMagvit(pl.LightningModule):
126
 
127
  def configure_optimizers(self):
128
  lr = self.learning_rate
129
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
130
  list(self.decoder.parameters())+
131
  list(self.quant_conv.parameters())+
132
  list(self.post_quant_conv.parameters()),
133
- lr=lr, betas=(0.5, 0.9))
134
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
135
- lr=lr, betas=(0.5, 0.9))
136
  return [opt_ae, opt_disc], []
137
 
138
  def get_last_layer(self):
 
126
 
127
  def configure_optimizers(self):
128
  lr = self.learning_rate
129
+ opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+
130
  list(self.decoder.parameters())+
131
  list(self.quant_conv.parameters())+
132
  list(self.post_quant_conv.parameters()),
133
+ lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
134
+ opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(),
135
+ lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
136
  return [opt_ae, opt_disc], []
137
 
138
  def get_last_layer(self):
easyanimate/vae/ldm/models/casual3dcnn.py CHANGED
@@ -279,13 +279,13 @@ class AutoencoderKL(pl.LightningModule):
279
 
280
  def configure_optimizers(self):
281
  lr = self.learning_rate
282
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
283
  list(self.decoder.parameters())+
284
  list(self.quant_conv.parameters())+
285
- list(self.post_quant_conv.parameters()),
286
- lr=lr, betas=(0.5, 0.9))
287
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
288
- lr=lr, betas=(0.5, 0.9))
289
  return [opt_ae, opt_disc], []
290
 
291
  def get_last_layer(self):
 
279
 
280
  def configure_optimizers(self):
281
  lr = self.learning_rate
282
+ opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+
283
  list(self.decoder.parameters())+
284
  list(self.quant_conv.parameters())+
285
+ list(self.post_quant_conv.parameters()), \
286
+ lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
287
+ opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(),
288
+ lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
289
  return [opt_ae, opt_disc], []
290
 
291
  def get_last_layer(self):
easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py CHANGED
@@ -277,23 +277,23 @@ class AutoencoderKLMagvit_CogVideoX(pl.LightningModule):
277
  training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
278
  else:
279
  training_list = list(self.decoder.parameters())
280
- opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
281
  elif self.train_encoder_only:
282
  if self.quant_conv is not None:
283
  training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
284
  else:
285
  training_list = list(self.encoder.parameters())
286
- opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
287
  else:
288
  training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
289
  if self.quant_conv is not None:
290
  training_list = training_list + list(self.quant_conv.parameters())
291
  if self.post_quant_conv is not None:
292
  training_list = training_list + list(self.post_quant_conv.parameters())
293
- opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
294
- opt_disc = torch.optim.Adam(
295
  list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
296
- lr=lr, betas=(0.5, 0.9)
297
  )
298
  return [opt_ae, opt_disc], []
299
 
 
277
  training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
278
  else:
279
  training_list = list(self.decoder.parameters())
280
+ opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
281
  elif self.train_encoder_only:
282
  if self.quant_conv is not None:
283
  training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
284
  else:
285
  training_list = list(self.encoder.parameters())
286
+ opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
287
  else:
288
  training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
289
  if self.quant_conv is not None:
290
  training_list = training_list + list(self.quant_conv.parameters())
291
  if self.post_quant_conv is not None:
292
  training_list = training_list + list(self.post_quant_conv.parameters())
293
+ opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
294
+ opt_disc = torch.optim.AdamW(
295
  list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
296
+ lr=lr, betas=(0.9, 0.999), weight_decay=5e-2
297
  )
298
  return [opt_ae, opt_disc], []
299
 
easyanimate/vae/ldm/models/omnigen_casual3dcnn.py CHANGED
@@ -95,6 +95,7 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
95
  out_channels: int = 3,
96
  ch = 128,
97
  ch_mult = [ 1,2,4,4 ],
 
98
  use_gc_blocks = None,
99
  down_block_types: tuple = None,
100
  up_block_types: tuple = None,
@@ -129,8 +130,9 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
129
  in_channels=in_channels,
130
  out_channels=latent_channels,
131
  down_block_types=down_block_types,
132
- ch = ch,
133
- ch_mult = ch_mult,
 
134
  use_gc_blocks=use_gc_blocks,
135
  mid_block_type=mid_block_type,
136
  mid_block_use_attention=mid_block_use_attention,
@@ -144,6 +146,7 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
144
  slice_mag_vae=slice_mag_vae,
145
  slice_compression_vae=slice_compression_vae,
146
  cache_compression_vae=cache_compression_vae,
 
147
  spatial_group_norm=spatial_group_norm,
148
  mini_batch_encoder=mini_batch_encoder,
149
  )
@@ -152,8 +155,9 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
152
  in_channels=latent_channels,
153
  out_channels=out_channels,
154
  up_block_types=up_block_types,
155
- ch = ch,
156
- ch_mult = ch_mult,
 
157
  use_gc_blocks=use_gc_blocks,
158
  mid_block_type=mid_block_type,
159
  mid_block_use_attention=mid_block_use_attention,
@@ -292,23 +296,23 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule):
292
  training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
293
  else:
294
  training_list = list(self.decoder.parameters())
295
- opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
296
  elif self.train_encoder_only:
297
  if self.quant_conv is not None:
298
  training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
299
  else:
300
  training_list = list(self.encoder.parameters())
301
- opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
302
  else:
303
  training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
304
  if self.quant_conv is not None:
305
  training_list = training_list + list(self.quant_conv.parameters())
306
  if self.post_quant_conv is not None:
307
  training_list = training_list + list(self.post_quant_conv.parameters())
308
- opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9))
309
- opt_disc = torch.optim.Adam(
310
  list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
311
- lr=lr, betas=(0.5, 0.9)
312
  )
313
  return [opt_ae, opt_disc], []
314
 
 
95
  out_channels: int = 3,
96
  ch = 128,
97
  ch_mult = [ 1,2,4,4 ],
98
+ block_out_channels = [128, 256, 512, 512],
99
  use_gc_blocks = None,
100
  down_block_types: tuple = None,
101
  up_block_types: tuple = None,
 
130
  in_channels=in_channels,
131
  out_channels=latent_channels,
132
  down_block_types=down_block_types,
133
+ ch=ch,
134
+ ch_mult=ch_mult,
135
+ block_out_channels=block_out_channels,
136
  use_gc_blocks=use_gc_blocks,
137
  mid_block_type=mid_block_type,
138
  mid_block_use_attention=mid_block_use_attention,
 
146
  slice_mag_vae=slice_mag_vae,
147
  slice_compression_vae=slice_compression_vae,
148
  cache_compression_vae=cache_compression_vae,
149
+ cache_mag_vae=cache_mag_vae,
150
  spatial_group_norm=spatial_group_norm,
151
  mini_batch_encoder=mini_batch_encoder,
152
  )
 
155
  in_channels=latent_channels,
156
  out_channels=out_channels,
157
  up_block_types=up_block_types,
158
+ ch=ch,
159
+ ch_mult=ch_mult,
160
+ block_out_channels=block_out_channels,
161
  use_gc_blocks=use_gc_blocks,
162
  mid_block_type=mid_block_type,
163
  mid_block_use_attention=mid_block_use_attention,
 
296
  training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
297
  else:
298
  training_list = list(self.decoder.parameters())
299
+ opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
300
  elif self.train_encoder_only:
301
  if self.quant_conv is not None:
302
  training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
303
  else:
304
  training_list = list(self.encoder.parameters())
305
+ opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
306
  else:
307
  training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
308
  if self.quant_conv is not None:
309
  training_list = training_list + list(self.quant_conv.parameters())
310
  if self.post_quant_conv is not None:
311
  training_list = training_list + list(self.post_quant_conv.parameters())
312
+ opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
313
+ opt_disc = torch.optim.AdamW(
314
  list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
315
+ lr=lr, betas=(0.9, 0.999), weight_decay=5e-2
316
  )
317
  return [opt_ae, opt_disc], []
318
 
easyanimate/vae/ldm/models/omnigen_enc_dec.py CHANGED
@@ -58,6 +58,7 @@ class Encoder(nn.Module):
58
  down_block_types = ("SpatialDownBlock3D",),
59
  ch = 128,
60
  ch_mult = [1,2,4,4,],
 
61
  use_gc_blocks = None,
62
  mid_block_type: str = "MidBlock3D",
63
  mid_block_use_attention: bool = True,
@@ -77,7 +78,8 @@ class Encoder(nn.Module):
77
  verbose = False,
78
  ):
79
  super().__init__()
80
- block_out_channels = [ch * i for i in ch_mult]
 
81
  assert len(down_block_types) == len(block_out_channels), (
82
  "Number of down block types must match number of block output channels."
83
  )
@@ -364,6 +366,7 @@ class Decoder(nn.Module):
364
  up_block_types = ("SpatialUpBlock3D",),
365
  ch = 128,
366
  ch_mult = [1,2,4,4,],
 
367
  use_gc_blocks = None,
368
  mid_block_type: str = "MidBlock3D",
369
  mid_block_use_attention: bool = True,
@@ -382,7 +385,8 @@ class Decoder(nn.Module):
382
  verbose = False,
383
  ):
384
  super().__init__()
385
- block_out_channels = [ch * i for i in ch_mult]
 
386
  assert len(up_block_types) == len(block_out_channels), (
387
  "Number of up block types must match number of block output channels."
388
  )
 
58
  down_block_types = ("SpatialDownBlock3D",),
59
  ch = 128,
60
  ch_mult = [1,2,4,4,],
61
+ block_out_channels = [128, 256, 512, 512],
62
  use_gc_blocks = None,
63
  mid_block_type: str = "MidBlock3D",
64
  mid_block_use_attention: bool = True,
 
78
  verbose = False,
79
  ):
80
  super().__init__()
81
+ if block_out_channels is None:
82
+ block_out_channels = [ch * i for i in ch_mult]
83
  assert len(down_block_types) == len(block_out_channels), (
84
  "Number of down block types must match number of block output channels."
85
  )
 
366
  up_block_types = ("SpatialUpBlock3D",),
367
  ch = 128,
368
  ch_mult = [1,2,4,4,],
369
+ block_out_channels = [128, 256, 512, 512],
370
  use_gc_blocks = None,
371
  mid_block_type: str = "MidBlock3D",
372
  mid_block_use_attention: bool = True,
 
385
  verbose = False,
386
  ):
387
  super().__init__()
388
+ if block_out_channels is None:
389
+ block_out_channels = [ch * i for i in ch_mult]
390
  assert len(up_block_types) == len(block_out_channels), (
391
  "Number of up block types must match number of block output channels."
392
  )
easyanimate/vae/ldm/modules/losses/contperceptual.py CHANGED
@@ -9,7 +9,8 @@ from ..vaemodules.discriminator import Discriminator3D
9
  class LPIPSWithDiscriminator(nn.Module):
10
  def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
11
  disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
12
- perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
 
13
  disc_loss="hinge", l2_loss_weight=0.0, l1_loss_weight=1.0):
14
 
15
  super().__init__()
@@ -34,6 +35,8 @@ class LPIPSWithDiscriminator(nn.Module):
34
  self.disc_factor = disc_factor
35
  self.discriminator_weight = disc_weight
36
  self.disc_conditional = disc_conditional
 
 
37
  self.l1_loss_weight = l1_loss_weight
38
  self.l2_loss_weight = l2_loss_weight
39
 
@@ -50,6 +53,18 @@ class LPIPSWithDiscriminator(nn.Module):
50
  d_weight = d_weight * self.discriminator_weight
51
  return d_weight
52
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
54
  global_step, last_layer=None, cond=None, split="train",
55
  weights=None):
@@ -86,6 +101,8 @@ class LPIPSWithDiscriminator(nn.Module):
86
  kl_loss = posteriors.kl()
87
  kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
88
 
 
 
89
  # now the GAN part
90
  if optimizer_idx == 0:
91
  # generator update
@@ -102,13 +119,13 @@ class LPIPSWithDiscriminator(nn.Module):
102
  try:
103
  d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
104
  except RuntimeError:
105
- assert not self.training
106
  d_weight = torch.tensor(0.0)
107
  else:
108
  d_weight = torch.tensor(0.0)
109
 
110
  disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
111
- loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
112
 
113
  log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
114
  "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
 
9
  class LPIPSWithDiscriminator(nn.Module):
10
  def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
11
  disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
12
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
13
+ outlier_penalty_loss_r=3.0, outlier_penalty_loss_weight=1e5,
14
  disc_loss="hinge", l2_loss_weight=0.0, l1_loss_weight=1.0):
15
 
16
  super().__init__()
 
35
  self.disc_factor = disc_factor
36
  self.discriminator_weight = disc_weight
37
  self.disc_conditional = disc_conditional
38
+ self.outlier_penalty_loss_r = outlier_penalty_loss_r
39
+ self.outlier_penalty_loss_weight = outlier_penalty_loss_weight
40
  self.l1_loss_weight = l1_loss_weight
41
  self.l2_loss_weight = l2_loss_weight
42
 
 
53
  d_weight = d_weight * self.discriminator_weight
54
  return d_weight
55
 
56
+ def outlier_penalty_loss(self, posteriors, r):
57
+ batch_size, channels, frames, height, width = posteriors.shape
58
+ mean_X = posteriors.mean(dim=(3, 4), keepdim=True)
59
+ std_X = posteriors.std(dim=(3, 4), keepdim=True)
60
+
61
+ diff = torch.abs(posteriors - mean_X)
62
+ penalty = torch.maximum(diff - r * std_X, torch.zeros_like(diff))
63
+
64
+ opl = penalty.sum(dim=(3, 4)) / (height * width)
65
+ opl_final = opl.mean(dim=(0, 1, 2))
66
+ return opl_final
67
+
68
  def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
69
  global_step, last_layer=None, cond=None, split="train",
70
  weights=None):
 
101
  kl_loss = posteriors.kl()
102
  kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
103
 
104
+ outlier_penalty_loss = self.outlier_penalty_loss(posteriors.mode(), self.outlier_penalty_loss_r) * self.outlier_penalty_loss_weight
105
+
106
  # now the GAN part
107
  if optimizer_idx == 0:
108
  # generator update
 
119
  try:
120
  d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
121
  except RuntimeError:
122
+ # assert not self.training
123
  d_weight = torch.tensor(0.0)
124
  else:
125
  d_weight = torch.tensor(0.0)
126
 
127
  disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
128
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + outlier_penalty_loss
129
 
130
  log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
131
  "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
easyanimate/vae/ldm/modules/vaemodules/__init__.py CHANGED
File without changes
easyanimate/vae/ldm/modules/vaemodules/activations.py CHANGED
File without changes
easyanimate/vae/ldm/modules/vaemodules/common.py CHANGED
@@ -8,6 +8,17 @@ from einops import rearrange, repeat
8
  from .activations import get_activation
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
11
  def cast_tuple(t, length = 1):
12
  return t if isinstance(t, tuple) else ((t,) * length)
13
 
@@ -66,10 +77,15 @@ class CausalConv3d(nn.Conv3d):
66
  **kwargs,
67
  )
68
 
 
 
 
 
69
  def forward(self, x: torch.Tensor) -> torch.Tensor:
70
  # x: (B, C, T, H, W)
71
  dtype = x.dtype
72
- x = x.float()
 
73
  if self.padding_flag == 0:
74
  x = F.pad(
75
  x,
@@ -85,7 +101,11 @@ class CausalConv3d(nn.Conv3d):
85
  mode="replicate", # TODO: check if this is necessary
86
  )
87
  x = x.to(dtype=dtype)
88
- self.prev_features = x[:, :, -self.temporal_padding:]
 
 
 
 
89
 
90
  b, c, f, h, w = x.size()
91
  outputs = []
@@ -105,7 +125,11 @@ class CausalConv3d(nn.Conv3d):
105
  [self.prev_features, x], dim = 2
106
  )
107
  x = x.to(dtype=dtype)
108
- self.prev_features = x[:, :, -self.temporal_padding:]
 
 
 
 
109
 
110
  b, c, f, h, w = x.size()
111
  outputs = []
@@ -122,7 +146,12 @@ class CausalConv3d(nn.Conv3d):
122
  mode="replicate", # TODO: check if this is necessary
123
  )
124
  x = x.to(dtype=dtype)
125
- self.prev_features = x[:, :, -self.temporal_padding:]
 
 
 
 
 
126
  return super().forward(x)
127
  elif self.padding_flag == 6:
128
  if self.t_stride == 2:
@@ -133,7 +162,12 @@ class CausalConv3d(nn.Conv3d):
133
  x = torch.concat(
134
  [self.prev_features, x], dim = 2
135
  )
136
- self.prev_features = x[:, :, -self.temporal_padding:]
 
 
 
 
 
137
  x = x.to(dtype=dtype)
138
  return super().forward(x)
139
  else:
 
8
  from .activations import get_activation
9
 
10
 
11
+ try:
12
+ current_version = torch.__version__
13
+ version_numbers = [int(x) for x in current_version.split('.')[:2]]
14
+ if version_numbers[0] < 2 or (version_numbers[0] == 2 and version_numbers[1] < 2):
15
+ need_to_float = True
16
+ else:
17
+ need_to_float = False
18
+ except Exception as e:
19
+ print("Encountered an error with Torch version. Set the data type to float in the VAE. ")
20
+ need_to_float = False
21
+
22
  def cast_tuple(t, length = 1):
23
  return t if isinstance(t, tuple) else ((t,) * length)
24
 
 
77
  **kwargs,
78
  )
79
 
80
+ def _clear_conv_cache(self):
81
+ del self.prev_features
82
+ self.prev_features = None
83
+
84
  def forward(self, x: torch.Tensor) -> torch.Tensor:
85
  # x: (B, C, T, H, W)
86
  dtype = x.dtype
87
+ if need_to_float:
88
+ x = x.float()
89
  if self.padding_flag == 0:
90
  x = F.pad(
91
  x,
 
101
  mode="replicate", # TODO: check if this is necessary
102
  )
103
  x = x.to(dtype=dtype)
104
+
105
+ # Clear cache before
106
+ self._clear_conv_cache()
107
+ # We could move these to the cpu for a lower VRAM
108
+ self.prev_features = x[:, :, -self.temporal_padding:].clone()
109
 
110
  b, c, f, h, w = x.size()
111
  outputs = []
 
125
  [self.prev_features, x], dim = 2
126
  )
127
  x = x.to(dtype=dtype)
128
+
129
+ # Clear cache before
130
+ self._clear_conv_cache()
131
+ # We could move these to the cpu for a lower VRAM
132
+ self.prev_features = x[:, :, -self.temporal_padding:].clone()
133
 
134
  b, c, f, h, w = x.size()
135
  outputs = []
 
146
  mode="replicate", # TODO: check if this is necessary
147
  )
148
  x = x.to(dtype=dtype)
149
+
150
+ # Clear cache before
151
+ self._clear_conv_cache()
152
+ # We could move these to the cpu for a lower VRAM
153
+ self.prev_features = x[:, :, -self.temporal_padding:].clone()
154
+
155
  return super().forward(x)
156
  elif self.padding_flag == 6:
157
  if self.t_stride == 2:
 
162
  x = torch.concat(
163
  [self.prev_features, x], dim = 2
164
  )
165
+
166
+ # Clear cache before
167
+ self._clear_conv_cache()
168
+ # We could move these to the cpu for a lower VRAM
169
+ self.prev_features = x[:, :, -self.temporal_padding:].clone()
170
+
171
  x = x.to(dtype=dtype)
172
  return super().forward(x)
173
  else:
easyanimate/vae/ldm/modules/vaemodules/down_blocks.py CHANGED
File without changes
easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py CHANGED
File without changes
easyanimate/vae/ldm/modules/vaemodules/up_blocks.py CHANGED
File without changes
requirements.txt CHANGED
@@ -6,7 +6,6 @@ tomesd
6
  torch>=2.1.2
7
  torchdiffeq
8
  torchsde
9
- xformers
10
  decord
11
  datasets
12
  numpy
@@ -21,8 +20,6 @@ tensorboard
21
  beautifulsoup4
22
  ftfy
23
  func_timeout
24
- deepspeed
25
  accelerate>=0.25.0
26
- gradio>=3.41.2
27
- diffusers>=0.30.1
28
- transformers>=4.37.2
 
6
  torch>=2.1.2
7
  torchdiffeq
8
  torchsde
 
9
  decord
10
  datasets
11
  numpy
 
20
  beautifulsoup4
21
  ftfy
22
  func_timeout
 
23
  accelerate>=0.25.0
24
+ diffusers==0.30.1
25
+ transformers==4.46.2