yanboding commited on
Commit
30a0a93
·
verified ·
1 Parent(s): 7b12bef

Upload 32 files

Browse files
__pycache__/draw_pose.cpython-313.pyc ADDED
Binary file (5.18 kB). View file
 
__pycache__/inference_engine.cpython-313.pyc ADDED
Binary file (7.56 kB). View file
 
__pycache__/motion_extractor.cpython-313.pyc ADDED
Binary file (2.97 kB). View file
 
__pycache__/utils.cpython-313.pyc ADDED
Binary file (4.8 kB). View file
 
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import cv2
5
+ from PIL import Image
6
+ from inference_engine import run_inference
7
+ from motion_extractor import extract_pkl_from_video
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ def full_pipeline(video_file, ref_image=None, width=512, height=512, steps=50, scale=3.0, seed=6666):
12
+ # 1. 提取 motion pkl
13
+ video_path = video_file.name
14
+ motion_pkl_path = extract_pkl_from_video(video_path)
15
+ gr.Info("⏳ Extract motion finished and begin animation...", visible=True)
16
+
17
+ # 2. 处理参考图像(可选)
18
+ if ref_image is not None:
19
+ ref_path = "temp_ref.png"
20
+ ref_image.save(ref_path)
21
+ else:
22
+ ref_path = ""
23
+
24
+ # 3. 推理
25
+ output_path = run_inference(
26
+ device,
27
+ motion_pkl_path,
28
+ ref_path,
29
+ dst_width=width,
30
+ dst_height=height,
31
+ num_inference_steps=steps,
32
+ guidance_scale=scale,
33
+ seed=seed,
34
+ )
35
+
36
+ return output_path
37
+
38
+
39
+ def run_pipeline_with_feedback(video_file, ref_image, width, height, steps, scale, seed):
40
+ try:
41
+ if video_file is None:
42
+ raise gr.Error("Please upload a dancing video (.mp4/.mov/.avi).")
43
+ # 添加进度提示
44
+ gr.Info("⏳ Processing... Please wait several minutes.", visible=True)
45
+ result = full_pipeline(video_file, ref_image, width, height, steps, scale, seed)
46
+ gr.Info("✅ Inference done, please enjoy it!", visible=True)
47
+ return result
48
+ except Exception as e:
49
+ traceback.print_exc()
50
+ gr.Warning("⚠️ Inference failed: " + str(e))
51
+ return None
52
+
53
+ # 构建 UI
54
+ with gr.Blocks(title="MTVCrafter Inference Demo") as demo:
55
+ gr.Markdown(
56
+ """
57
+ # 🎨💃 MTVCrafter Inference Demo
58
+
59
+ 💡 **Tip:** Upload a dancing video in **MP4/MOV/AVI** format, and optionally a reference image (e.g., PNG or JPG).
60
+ This demo will extract human motion from the input video and animate the reference image accordingly.
61
+ If no reference image is provided, the **first frame** of the video will be used as the reference.
62
+
63
+ 🎞️ **Note:** The generated output video will contain exactly **49 frames**.
64
+ """
65
+ )
66
+
67
+ with gr.Row():
68
+ with gr.Column(scale=1):
69
+ video_input = gr.File(label="📹 Input Video (Required)", file_types=[".mp4", ".mov", ".avi"])
70
+ video_preview = gr.Video(label="👀 Preview of Uploaded Video", height=280) # 固定高度,避免对齐错位
71
+
72
+ def show_video_preview(video_file):
73
+ return video_file.name if video_file else None
74
+
75
+ video_input.change(fn=show_video_preview, inputs=video_input, outputs=video_preview)
76
+
77
+ with gr.Column(scale=1):
78
+ ref_image = gr.Image(type="pil", label="🖼️ Reference Image (Optional)", height=538)
79
+
80
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
81
+ with gr.Row():
82
+ width = gr.Slider(384, 1024, value=512, step=16, label="Output Width")
83
+ height = gr.Slider(384, 1024, value=512, step=16, label="Output Height")
84
+ with gr.Row():
85
+ steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
86
+ scale = gr.Slider(0.0, 10.0, value=3.0, step=0.25, label="Guidance Scale")
87
+ seed = gr.Number(value=6666, label="Random Seed")
88
+
89
+ with gr.Row(scale=1):
90
+ output_video = gr.Video(label="🎬 Generated Video", interactive=False)
91
+
92
+ run_btn = gr.Button("🚀 Run MTVCrafter", variant="primary")
93
+
94
+ run_btn.click(
95
+ fn=run_pipeline_with_feedback,
96
+ inputs=[video_input, ref_image, width, height, steps, scale, seed],
97
+ outputs=output_video,
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"
102
+ os.environ["NO_PROXY"] = "localhost,127.0.0.1/8,::1"
103
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
data/mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ababeaabf5ac096ce7c7714ada14aa1de8355c0016de25695be611d51285141
3
+ size 416
data/std.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:650e46902a0878e6947be401e4e1995e54a8fd407f2be3ded0dda62bda99a9b3
3
+ size 416
draw_pose.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+
7
+ def intrinsic_matrix_from_field_of_view(imshape, fov_degrees:float =55 ): # nlf default fov_degrees 55
8
+ imshape = np.array(imshape)
9
+ fov_radians = fov_degrees * np.array(np.pi / 180)
10
+ larger_side = np.max(imshape)
11
+ focal_length = larger_side / (np.tan(fov_radians / 2) * 2)
12
+ # intrinsic_matrix 3*3
13
+ return np.array([
14
+ [focal_length, 0, imshape[1] / 2],
15
+ [0, focal_length, imshape[0] / 2],
16
+ [0, 0, 1],
17
+ ])
18
+
19
+
20
+ def p3d_to_p2d(point_3d, height, width): # point3d n*1024*3
21
+ camera_matrix = intrinsic_matrix_from_field_of_view((height,width))
22
+ camera_matrix = np.expand_dims(camera_matrix, axis=0)
23
+ camera_matrix = np.expand_dims(camera_matrix, axis=0) # 1*1*3*3
24
+ point_3d = np.expand_dims(point_3d,axis=-1) # n*1024*3*1
25
+ point_2d = (camera_matrix@point_3d).squeeze(-1)
26
+ point_2d[:,:,:2] = point_2d[:,:,:2]/point_2d[:,:,2:3]
27
+ return point_2d[:,:,:] # n*1024*2
28
+
29
+
30
+ def get_pose_images(smpl_data, offset):
31
+ pose_images = []
32
+ for data in smpl_data:
33
+ if isinstance(data, np.ndarray):
34
+ joints3d = data
35
+ else:
36
+ joints3d = data.numpy()
37
+ canvas = np.zeros(shape=(offset[0], offset[1], 3), dtype=np.uint8)
38
+ joints3d = p3d_to_p2d(joints3d, offset[0], offset[1])
39
+ canvas = draw_3d_points(canvas, joints3d[0], stickwidth=int(offset[1]/350))
40
+ pose_images.append(Image.fromarray(canvas))
41
+ return pose_images
42
+
43
+
44
+ def draw_3d_points(canvas, points, stickwidth=2, r=2, draw_line=True):
45
+ colors = [
46
+ [255, 0, 0], # 0
47
+ [0, 255, 0], # 1
48
+ [0, 0, 255], # 2
49
+ [255, 0, 255], # 3
50
+ [255, 255, 0], # 4
51
+ [85, 255, 0], # 5
52
+ [0, 75, 255], # 6
53
+ [0, 255, 85], # 7
54
+ [0, 255, 170], # 8
55
+ [170, 0, 255], # 9
56
+ [85, 0, 255], # 10
57
+ [0, 85, 255], # 11
58
+ [0, 255, 255], # 12
59
+ [85, 0, 255], # 13
60
+ [170, 0, 255], # 14
61
+ [255, 0, 255], # 15
62
+ [255, 0, 170], # 16
63
+ [255, 0, 85], # 17
64
+ ]
65
+ connetions = [
66
+ [15,12],[12, 16],[16, 18],[18, 20],[20, 22],
67
+ [12,17],[17,19],[19,21],
68
+ [21,23],[12,9],[9,6],
69
+ [6,3],[3,0],[0,1],
70
+ [1,4],[4,7],[7,10],[0,2],[2,5],[5,8],[8,11]
71
+ ]
72
+ connection_colors = [
73
+ [255, 0, 0], # 0
74
+ [0, 255, 0], # 1
75
+ [0, 0, 255], # 2
76
+ [255, 255, 0], # 3
77
+ [255, 0, 255], # 4
78
+ [0, 255, 0], # 5
79
+ [0, 85, 255], # 6
80
+ [255, 175, 0], # 7
81
+ [0, 0, 255], # 8
82
+ [255, 85, 0], # 9
83
+ [0, 255, 85], # 10
84
+ [255, 0, 255], # 11
85
+ [255, 0, 0], # 12
86
+ [0, 175, 255], # 13
87
+ [255, 255, 0], # 14
88
+ [0, 0, 255], # 15
89
+ [0, 255, 0], # 16
90
+ ]
91
+
92
+ # draw point
93
+ for i in range(len(points)):
94
+ x,y = points[i][0:2]
95
+ x,y = int(x),int(y)
96
+ if i==13 or i == 14:
97
+ continue
98
+ cv2.circle(canvas, (x, y), r, colors[i%17], thickness=-1)
99
+
100
+ # draw line
101
+ if draw_line:
102
+ for i in range(len(connetions)):
103
+ point1_idx,point2_idx = connetions[i][0:2]
104
+ point1 = points[point1_idx]
105
+ point2 = points[point2_idx]
106
+ Y = [point2[0],point1[0]]
107
+ X = [point2[1],point1[1]]
108
+ mX = int(np.mean(X))
109
+ mY = int(np.mean(Y))
110
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
111
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
112
+ polygon = cv2.ellipse2Poly((mY, mX), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
113
+ cv2.fillConvexPoly(canvas, polygon, connection_colors[i%17])
114
+
115
+ return canvas
inference_engine.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference_engine.py
2
+ import os
3
+ import torch
4
+ import decord
5
+ import imageio
6
+ from PIL import Image
7
+ from models import MTVCrafterPipeline, Encoder, VectorQuantizer, Decoder, SMPL_VQVAE
8
+ from torchvision.transforms import ToPILImage, transforms, InterpolationMode, functional as F
9
+ import numpy as np
10
+ import pickle
11
+ import copy
12
+ from draw_pose import get_pose_images
13
+ from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width
14
+
15
+ def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, dst_height=512, num_inference_steps=50, guidance_scale=3.0, seed=6666):
16
+ num_frames = 49
17
+ to_pil = ToPILImage()
18
+ normalize = transforms.Normalize([0.5], [0.5])
19
+ pretrained_model_path = "/gemini/space/human_guozz2/dyb/models/CogVideoX"
20
+ transformer_path = "/gemini/space/human_guozz2/dyb/models/MTVCrafter/MV-DiT/CogVideoX"
21
+ tokenizer_path = "/gemini/space/human_guozz2/dyb/models/MTVCrafter/4DMoT/mp_rank_00_model_states.pt"
22
+
23
+ with open(motion_data_path, 'rb') as f:
24
+ data_list = pickle.load(f)
25
+ if not isinstance(data_list, list):
26
+ data_list = [data_list]
27
+
28
+ pe_mean = np.load('data/mean.npy')
29
+ pe_std = np.load('data/std.npy')
30
+
31
+ pipe = MTVCrafterPipeline.from_pretrained(
32
+ model_path=pretrained_model_path,
33
+ transformer_model_path=transformer_path,
34
+ torch_dtype=torch.bfloat16,
35
+ scheduler_type='dpm',
36
+ ).to(device)
37
+ pipe.vae.enable_tiling()
38
+ pipe.vae.enable_slicing()
39
+
40
+ # load VQVAE
41
+ state_dict = torch.load(tokenizer_path, map_location="cpu")
42
+ motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1])
43
+ motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False)
44
+ motion_decoder = Decoder(in_channels=3072, mid_channels=[512, 128], out_channels=3, upsample_rate=2.0, frame_upsample_rate=[2.0, 2.0], joint_upsample_rate=[1.0, 1.0])
45
+ vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device)
46
+ vqvae.load_state_dict(state_dict['module'], strict=True)
47
+
48
+ # 这里只跑第一个样本
49
+ data = data_list[0]
50
+ new_height, new_width = get_new_height_width(data, dst_height, dst_width)
51
+ x1 = (new_width - dst_width) // 2
52
+ y1 = (new_height - dst_height) // 2
53
+
54
+ sample_indexes = get_sample_indexes(data['video_length'], num_frames, stride=1)
55
+ input_images = sample_video(decord.VideoReader(data['video_path']), sample_indexes)
56
+ input_images = torch.from_numpy(input_images).permute(0, 3, 1, 2).contiguous()
57
+ input_images = F.resize(input_images, (new_height, new_width), InterpolationMode.BILINEAR)
58
+ input_images = F.crop(input_images, y1, x1, dst_height, dst_width)
59
+
60
+ if ref_image_path != '':
61
+ ref_image = Image.open(ref_image_path).convert("RGB")
62
+ ref_image = torch.from_numpy(np.array(ref_image)).permute(2, 0, 1).contiguous()
63
+ ref_images = torch.stack([ref_image.clone() for _ in range(num_frames)])
64
+ ref_images = F.resize(ref_images, (new_height, new_width), InterpolationMode.BILINEAR)
65
+ ref_images = F.crop(ref_images, y1, x1, dst_height, dst_width)
66
+ else:
67
+ ref_images = copy.deepcopy(input_images)
68
+ frame0 = input_images[0]
69
+ ref_images[:, :, :, :] = frame0
70
+
71
+ try:
72
+ smpl_poses = np.array([pose[0][0].cpu().numpy() for pose in data['pose']['joints3d_nonparam']])
73
+ poses = smpl_poses[sample_indexes]
74
+ except:
75
+ poses = data['pose'][sample_indexes]
76
+ norm_poses = torch.tensor((poses - pe_mean) / pe_std)
77
+
78
+ offset = [data['video_height'], data['video_width'], 0]
79
+ pose_images_before = get_pose_images(copy.deepcopy(poses), offset)
80
+ pose_images_before = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_before]
81
+ input_smpl_joints = norm_poses.unsqueeze(0).to(device)
82
+ motion_tokens, vq_loss = vqvae(input_smpl_joints, return_vq=True)
83
+ output_motion, _ = vqvae(input_smpl_joints)
84
+ pose_images_after = get_pose_images(output_motion[0].cpu().detach() * pe_std + pe_mean, offset)
85
+ pose_images_after = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_after]
86
+
87
+ # normalize images
88
+ input_images = input_images / 255.0
89
+ ref_images = ref_images / 255.0
90
+ input_images = normalize(input_images)
91
+ ref_images = normalize(ref_images)
92
+
93
+ # infer
94
+ output_images = pipe(
95
+ height=dst_height,
96
+ width=dst_width,
97
+ num_frames=num_frames,
98
+ num_inference_steps=num_inference_steps,
99
+ guidance_scale=guidance_scale,
100
+ seed=seed,
101
+ ref_images=ref_images,
102
+ motion_embeds=motion_tokens,
103
+ joint_mean=pe_mean,
104
+ joint_std=pe_std,
105
+ ).frames[0]
106
+
107
+ # save result
108
+ vis_images = []
109
+ for k in range(len(output_images)):
110
+ vis_image = [to_pil(((input_images[k] + 1) * 127.5).clamp(0, 255).to(torch.uint8)), pose_images_before[k], pose_images_after[k], output_images[k]]
111
+ vis_image = concat_images_grid(vis_image, cols=len(vis_image), pad=2)
112
+ vis_images.append(vis_image)
113
+
114
+ output_path = "output.mp4"
115
+ imageio.mimsave(output_path, vis_images, fps=15)
116
+
117
+ return output_path
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .dit import *
2
+ from .motion4d import *
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (235 Bytes). View file
 
models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (205 Bytes). View file
 
models/dit/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .mvdit_transformer import Transformer3DModel
2
+ from .pipeline_mtvcrafter import MTVCrafterPipeline
models/dit/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (333 Bytes). View file
 
models/dit/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (287 Bytes). View file
 
models/dit/__pycache__/mvdit_transformer.cpython-311.pyc ADDED
Binary file (38.1 kB). View file
 
models/dit/__pycache__/mvdit_transformer.cpython-313.pyc ADDED
Binary file (35.1 kB). View file
 
models/dit/__pycache__/pipeline_mtvcrafter.cpython-311.pyc ADDED
Binary file (39.5 kB). View file
 
models/dit/__pycache__/pipeline_mtvcrafter.cpython-313.pyc ADDED
Binary file (37 kB). View file
 
models/dit/mvdit_transformer.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.attention import Attention, FeedForward
8
+ from diffusers.models.attention_processor import AttentionProcessor, FusedCogVideoXAttnProcessor2_0
9
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
10
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.models.normalization import AdaLayerNorm
13
+ from diffusers.utils import is_torch_version, logging
14
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ def apply_rotary_emb(
20
+ x: torch.Tensor,
21
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
22
+ use_real: bool = True,
23
+ use_real_unbind_dim: int = -1,
24
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
25
+ """
26
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
27
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
28
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
29
+ tensors contain rotary embeddings and are returned as real tensors.
30
+
31
+ Args:
32
+ x (`torch.Tensor`):
33
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
34
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
35
+
36
+ Returns:
37
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
38
+ """
39
+ if use_real:
40
+ cos, sin = freqs_cis # [S, D]
41
+ cos = cos[None, None]
42
+ sin = sin[None, None]
43
+ cos, sin = cos.to(x.device), sin.to(x.device)
44
+
45
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
46
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
47
+
48
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
49
+
50
+ return out
51
+ else:
52
+ # used for lumina
53
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
54
+ freqs_cis = freqs_cis.unsqueeze(2)
55
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
56
+
57
+ return x_out.type_as(x)
58
+
59
+
60
+ class CogVideoXLayerNormZero(nn.Module):
61
+ def __init__(
62
+ self,
63
+ conditioning_dim: int,
64
+ embedding_dim: int,
65
+ elementwise_affine: bool = True,
66
+ eps: float = 1e-5,
67
+ bias: bool = True,
68
+ ) -> None:
69
+ super().__init__()
70
+
71
+ self.silu = nn.SiLU()
72
+ self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
73
+ self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
74
+
75
+ def forward(
76
+ self, hidden_states: torch.Tensor, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ shift, scale, gate, _, _, _ = self.linear(self.silu(temb)).chunk(6, dim=1)
78
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
79
+ return hidden_states, gate[:, None, :]
80
+
81
+
82
+ class CogVideoXAttnProcessor1_0:
83
+ r"""Processor for implementing scaled dot-product attention for the
84
+ CogVideoX model.
85
+
86
+ It applies a rotary embedding on query and key vectors, but does not include spatial normalization.
87
+ """
88
+
89
+ def __init__(self):
90
+ if not hasattr(F, 'scaled_dot_product_attention'):
91
+ raise ImportError('CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.')
92
+
93
+ def __call__(
94
+ self,
95
+ attn: Attention,
96
+ hidden_states: torch.Tensor,
97
+ encoder_hidden_states: Optional[torch.Tensor] = None,
98
+ attention_mask: Optional[torch.Tensor] = None,
99
+ image_rotary_emb: Optional[torch.Tensor] = None,
100
+ motion_rotary_emb: Optional[torch.Tensor] = None,
101
+ ) -> torch.Tensor:
102
+
103
+ batch_size, sequence_length, _ = hidden_states.shape
104
+
105
+ if attention_mask is not None:
106
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
107
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
108
+
109
+ query = attn.to_q(hidden_states)
110
+ key = attn.to_k(encoder_hidden_states)
111
+ value = attn.to_v(encoder_hidden_states)
112
+
113
+ inner_dim = key.shape[-1]
114
+ head_dim = inner_dim // attn.heads
115
+
116
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [batch_size, heads, seq_len, dim]
117
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
118
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
119
+
120
+ if attn.norm_q is not None:
121
+ query = attn.norm_q(query)
122
+ if attn.norm_k is not None:
123
+ key = attn.norm_k(key)
124
+
125
+ # Apply RoPE if needed
126
+ if image_rotary_emb is not None:
127
+ query = apply_rotary_emb(query, image_rotary_emb)
128
+ if motion_rotary_emb is not None:
129
+ key = apply_rotary_emb(key, motion_rotary_emb)
130
+
131
+ hidden_states = F.scaled_dot_product_attention(
132
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
133
+ )
134
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
135
+
136
+ # linear proj
137
+ hidden_states = attn.to_out[0](hidden_states)
138
+ # dropout
139
+ hidden_states = attn.to_out[1](hidden_states)
140
+
141
+ return hidden_states
142
+
143
+
144
+
145
+ class CogVideoXAttnProcessor2_0:
146
+ r"""Processor for implementing scaled dot-product attention for the
147
+ CogVideoX model.
148
+
149
+ It applies a rotary embedding on query and key vectors, but does not include spatial normalization.
150
+ """
151
+
152
+ def __init__(self):
153
+ if not hasattr(F, 'scaled_dot_product_attention'):
154
+ raise ImportError('CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.')
155
+
156
+ def __call__(
157
+ self,
158
+ attn: Attention,
159
+ hidden_states: torch.Tensor,
160
+ encoder_hidden_states: Optional[torch.Tensor] = None,
161
+ attention_mask: Optional[torch.Tensor] = None,
162
+ image_rotary_emb: Optional[torch.Tensor] = None,
163
+ motion_rotary_emb: Optional[torch.Tensor] = None,
164
+ ) -> torch.Tensor:
165
+
166
+ batch_size, sequence_length, _ = hidden_states.shape
167
+
168
+ if attention_mask is not None:
169
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
170
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
171
+
172
+ query = attn.to_q(hidden_states)
173
+ key = attn.to_k(hidden_states)
174
+ value = attn.to_v(hidden_states)
175
+
176
+ inner_dim = key.shape[-1]
177
+ head_dim = inner_dim // attn.heads
178
+
179
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [batch_size, heads, seq_len, dim]
180
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
181
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
182
+
183
+ if attn.norm_q is not None:
184
+ query = attn.norm_q(query)
185
+ if attn.norm_k is not None:
186
+ key = attn.norm_k(key)
187
+
188
+ # Apply RoPE if needed
189
+ if image_rotary_emb is not None:
190
+ image_seq_length = image_rotary_emb[0].shape[0]
191
+ query[:, :, :image_seq_length] = apply_rotary_emb(query[:, :, :image_seq_length], image_rotary_emb)
192
+ if motion_rotary_emb is not None:
193
+ query[:, :, image_seq_length:] = apply_rotary_emb(query[:, :, image_seq_length:], motion_rotary_emb)
194
+ if not attn.is_cross_attention:
195
+ key[:, :, :image_seq_length] = apply_rotary_emb(key[:, :, :image_seq_length], image_rotary_emb)
196
+ if motion_rotary_emb is not None:
197
+ key[:, :, image_seq_length:] = apply_rotary_emb(key[:, :, image_seq_length:], motion_rotary_emb)
198
+
199
+ hidden_states = F.scaled_dot_product_attention(
200
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
201
+ )
202
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
203
+
204
+ # linear proj
205
+ hidden_states = attn.to_out[0](hidden_states)
206
+ # dropout
207
+ hidden_states = attn.to_out[1](hidden_states)
208
+
209
+ return hidden_states
210
+
211
+
212
+ class CogVideoXPatchEmbed(nn.Module):
213
+ def __init__(
214
+ self,
215
+ patch_size: int = 2,
216
+ in_channels: int = 16,
217
+ embed_dim: int = 1920,
218
+ text_embed_dim: int = 4096,
219
+ bias: bool = True,
220
+ sample_width: int = 90,
221
+ sample_height: int = 60,
222
+ sample_frames: int = 49,
223
+ temporal_compression_ratio: int = 4,
224
+ max_text_seq_length: int = 226,
225
+ spatial_interpolation_scale: float = 1.875,
226
+ temporal_interpolation_scale: float = 1.0,
227
+ use_positional_embeddings: bool = True,
228
+ ) -> None:
229
+ super().__init__()
230
+
231
+ self.patch_size = patch_size
232
+ self.embed_dim = embed_dim
233
+ self.sample_height = sample_height
234
+ self.sample_width = sample_width
235
+ self.sample_frames = sample_frames
236
+ self.temporal_compression_ratio = temporal_compression_ratio
237
+ self.max_text_seq_length = max_text_seq_length
238
+ self.spatial_interpolation_scale = spatial_interpolation_scale
239
+ self.temporal_interpolation_scale = temporal_interpolation_scale
240
+ self.use_positional_embeddings = use_positional_embeddings
241
+
242
+ self.proj = nn.Conv2d(
243
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
244
+ )
245
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
246
+
247
+ if use_positional_embeddings:
248
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
249
+ self.register_buffer('pos_embedding', pos_embedding, persistent=False)
250
+
251
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
252
+ post_patch_height = sample_height // self.patch_size
253
+ post_patch_width = sample_width // self.patch_size
254
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
255
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
256
+
257
+ pos_embedding = get_3d_sincos_pos_embed(
258
+ self.embed_dim,
259
+ (post_patch_width, post_patch_height),
260
+ post_time_compression_frames,
261
+ self.spatial_interpolation_scale,
262
+ self.temporal_interpolation_scale,
263
+ )
264
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
265
+ joint_pos_embedding = torch.zeros(
266
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
267
+ )
268
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
269
+
270
+ return joint_pos_embedding
271
+
272
+ def forward(self, image_embeds: torch.Tensor):
273
+ r"""
274
+ Args:
275
+ text_embeds (`torch.Tensor`):
276
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
277
+ image_embeds (`torch.Tensor`):
278
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
279
+ """
280
+ batch, num_frames, channels, height, width = image_embeds.shape
281
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
282
+ image_embeds = self.proj(image_embeds) # [2*7, 3072, h/8/2, w/8/2]
283
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
284
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
285
+ image_embeds = image_embeds.flatten(1, 2).contiguous() # [batch, num_frames x height x width, channels]
286
+
287
+ if self.use_positional_embeddings:
288
+ pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
289
+ if (
290
+ self.sample_height != height
291
+ or self.sample_width != width
292
+ or self.sample_frames != pre_time_compression_frames
293
+ ):
294
+ pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
295
+ pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
296
+ else:
297
+ pos_embedding = self.pos_embedding
298
+
299
+ embeds = embeds + pos_embedding
300
+
301
+ return image_embeds
302
+
303
+
304
+ @maybe_allow_in_graph
305
+ class CogVideoXBlock(nn.Module):
306
+ r"""
307
+ Parameters:
308
+ dim (`int`):
309
+ The number of channels in the input and output.
310
+ num_attention_heads (`int`):
311
+ The number of heads to use for multi-head attention.
312
+ attention_head_dim (`int`):
313
+ The number of channels in each head.
314
+ time_embed_dim (`int`):
315
+ The number of channels in timestep embedding.
316
+ dropout (`float`, defaults to `0.0`):
317
+ The dropout probability to use.
318
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
319
+ Activation function to be used in feed-forward.
320
+ attention_bias (`bool`, defaults to `False`):
321
+ Whether or not to use bias in attention projection layers.
322
+ qk_norm (`bool`, defaults to `True`):
323
+ Whether or not to use normalization after query and key projections in Attention.
324
+ norm_elementwise_affine (`bool`, defaults to `True`):
325
+ Whether to use learnable elementwise affine parameters for normalization.
326
+ norm_eps (`float`, defaults to `1e-5`):
327
+ Epsilon value for normalization layers.
328
+ final_dropout (`bool` defaults to `False`):
329
+ Whether to apply a final dropout after the last feed-forward layer.
330
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
331
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
332
+ ff_bias (`bool`, defaults to `True`):
333
+ Whether or not to use bias in Feed-forward layer.
334
+ attention_out_bias (`bool`, defaults to `True`):
335
+ Whether or not to use bias in Attention output projection layer.
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ dim: int,
341
+ num_attention_heads: int,
342
+ attention_head_dim: int,
343
+ time_embed_dim: int,
344
+ motion_dim: int,
345
+ dropout: float = 0.0,
346
+ activation_fn: str = 'gelu-approximate',
347
+ attention_bias: bool = False,
348
+ qk_norm: bool = True,
349
+ norm_elementwise_affine: bool = True,
350
+ norm_eps: float = 1e-5,
351
+ final_dropout: bool = True,
352
+ ff_inner_dim: Optional[int] = None,
353
+ ff_bias: bool = True,
354
+ attention_out_bias: bool = True,
355
+ cross_attention: bool = False,
356
+ ):
357
+ super().__init__()
358
+
359
+ self.is_cross_attention = cross_attention
360
+
361
+ if self.is_cross_attention:
362
+ self.attn0 = Attention(
363
+ query_dim=dim,
364
+ cross_attention_dim=dim,
365
+ dim_head=attention_head_dim,
366
+ heads=num_attention_heads,
367
+ qk_norm='layer_norm' if qk_norm else None,
368
+ eps=1e-6,
369
+ bias=attention_bias,
370
+ out_bias=attention_out_bias,
371
+ processor=CogVideoXAttnProcessor1_0(),
372
+ )
373
+
374
+ # 1. Self Attention
375
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
376
+
377
+ self.attn1 = Attention(
378
+ query_dim=dim,
379
+ dim_head=attention_head_dim,
380
+ heads=num_attention_heads,
381
+ qk_norm='layer_norm' if qk_norm else None,
382
+ eps=1e-6,
383
+ bias=attention_bias,
384
+ out_bias=attention_out_bias,
385
+ processor=CogVideoXAttnProcessor2_0(),
386
+ )
387
+
388
+ # 2. Feed Forward
389
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
390
+
391
+ self.ff = FeedForward(
392
+ dim,
393
+ dropout=dropout,
394
+ activation_fn=activation_fn,
395
+ final_dropout=final_dropout,
396
+ inner_dim=ff_inner_dim,
397
+ bias=ff_bias,
398
+ )
399
+
400
+
401
+ def forward(
402
+ self,
403
+ hidden_states: torch.Tensor,
404
+ encoder_hidden_states: torch.Tensor,
405
+ temb: torch.Tensor,
406
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
407
+ motion_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
408
+ ) -> torch.Tensor:
409
+
410
+ # norm & modulate
411
+ norm_hidden_states, gate_msa = self.norm1(hidden_states, temb)
412
+
413
+ # self attention
414
+ attn_hidden_states = self.attn1(
415
+ hidden_states=norm_hidden_states,
416
+ image_rotary_emb=image_rotary_emb,
417
+ )
418
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
419
+
420
+ if self.is_cross_attention:
421
+ cross_attn_hidden_states = self.attn0(
422
+ hidden_states=hidden_states,
423
+ encoder_hidden_states=encoder_hidden_states,
424
+ image_rotary_emb=image_rotary_emb,
425
+ motion_rotary_emb=motion_rotary_emb,
426
+ )
427
+ hidden_states = hidden_states + cross_attn_hidden_states
428
+
429
+ # norm & modulate
430
+ norm_hidden_states, gate_ff = self.norm2(hidden_states, temb)
431
+
432
+ # feed-forward
433
+ ff_output = self.ff(norm_hidden_states)
434
+
435
+ hidden_states = hidden_states + gate_ff * ff_output
436
+
437
+ return hidden_states
438
+
439
+
440
+ class Transformer3DModel(ModelMixin, ConfigMixin):
441
+ """
442
+ Parameters:
443
+ num_attention_heads (`int`, defaults to `30`):
444
+ The number of heads to use for multi-head attention.
445
+ attention_head_dim (`int`, defaults to `64`):
446
+ The number of channels in each head.
447
+ in_channels (`int`, defaults to `16`):
448
+ The number of channels in the input.
449
+ out_channels (`int`, *optional*, defaults to `16`):
450
+ The number of channels in the output.
451
+ flip_sin_to_cos (`bool`, defaults to `True`):
452
+ Whether to flip the sin to cos in the time embedding.
453
+ time_embed_dim (`int`, defaults to `512`):
454
+ Output dimension of timestep embeddings.
455
+ text_embed_dim (`int`, defaults to `4096`):
456
+ Input dimension of text embeddings from the text encoder.
457
+ num_layers (`int`, defaults to `30`):
458
+ The number of layers of Transformer blocks to use.
459
+ dropout (`float`, defaults to `0.0`):
460
+ The dropout probability to use.
461
+ attention_bias (`bool`, defaults to `True`):
462
+ Whether or not to use bias in the attention projection layers.
463
+ sample_width (`int`, defaults to `90`):
464
+ The width of the input latents.
465
+ sample_height (`int`, defaults to `60`):
466
+ The height of the input latents.
467
+ sample_frames (`int`, defaults to `49`):
468
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
469
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
470
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
471
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
472
+ patch_size (`int`, defaults to `2`):
473
+ The size of the patches to use in the patch embedding layer.
474
+ temporal_compression_ratio (`int`, defaults to `4`):
475
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
476
+ max_text_seq_length (`int`, defaults to `226`):
477
+ The maximum sequence length of the input text embeddings.
478
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
479
+ Activation function to use in feed-forward.
480
+ timestep_activation_fn (`str`, defaults to `"silu"`):
481
+ Activation function to use when generating the timestep embeddings.
482
+ norm_elementwise_affine (`bool`, defaults to `True`):
483
+ Whether or not to use elementwise affine in normalization layers.
484
+ norm_eps (`float`, defaults to `1e-5`):
485
+ The epsilon value to use in normalization layers.
486
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
487
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
488
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
489
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
490
+ """
491
+
492
+ _supports_gradient_checkpointing = True
493
+
494
+ @register_to_config
495
+ def __init__(
496
+ self,
497
+ num_attention_heads: int = 30,
498
+ attention_head_dim: int = 64,
499
+ in_channels: int = 16,
500
+ out_channels: Optional[int] = 16,
501
+ flip_sin_to_cos: bool = True,
502
+ freq_shift: int = 0,
503
+ time_embed_dim: int = 512,
504
+ text_embed_dim: int = 4096,
505
+ motion_dim: int = 168,
506
+ num_layers: int = 30,
507
+ dropout: float = 0.0,
508
+ attention_bias: bool = True,
509
+ sample_width: int = 90,
510
+ sample_height: int = 60,
511
+ sample_frames: int = 49,
512
+ patch_size: int = 2,
513
+ temporal_compression_ratio: int = 4,
514
+ max_text_seq_length: int = 226,
515
+ activation_fn: str = 'gelu-approximate',
516
+ timestep_activation_fn: str = 'silu',
517
+ norm_elementwise_affine: bool = True,
518
+ norm_eps: float = 1e-5,
519
+ spatial_interpolation_scale: float = 1.875,
520
+ temporal_interpolation_scale: float = 1.0,
521
+ use_rotary_positional_embeddings: bool = False,
522
+ ):
523
+ super().__init__()
524
+ inner_dim = num_attention_heads * attention_head_dim # 48 * 64 = 3072
525
+
526
+ self.unconditional_motion_token = torch.nn.Parameter(torch.randn(312, 3072))
527
+ print(self.unconditional_motion_token[0])
528
+
529
+ # 1. Patch embedding
530
+ self.patch_embed = CogVideoXPatchEmbed(
531
+ patch_size=patch_size,
532
+ in_channels=in_channels,
533
+ embed_dim=inner_dim,
534
+ text_embed_dim=text_embed_dim,
535
+ bias=True,
536
+ sample_width=sample_width,
537
+ sample_height=sample_height,
538
+ sample_frames=sample_frames,
539
+ temporal_compression_ratio=temporal_compression_ratio,
540
+ max_text_seq_length=max_text_seq_length,
541
+ spatial_interpolation_scale=spatial_interpolation_scale,
542
+ temporal_interpolation_scale=temporal_interpolation_scale,
543
+ use_positional_embeddings=not use_rotary_positional_embeddings,
544
+ )
545
+ self.embedding_dropout = nn.Dropout(dropout)
546
+
547
+ # 2. Time embeddings
548
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
549
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) # 3072 --> 512
550
+
551
+ self.transformer_blocks = nn.ModuleList(
552
+ [
553
+ CogVideoXBlock(
554
+ dim=inner_dim,
555
+ num_attention_heads=num_attention_heads,
556
+ attention_head_dim=attention_head_dim,
557
+ time_embed_dim=time_embed_dim,
558
+ motion_dim=motion_dim,
559
+ dropout=dropout,
560
+ activation_fn=activation_fn,
561
+ attention_bias=attention_bias,
562
+ norm_elementwise_affine=norm_elementwise_affine,
563
+ norm_eps=norm_eps,
564
+ cross_attention=True,
565
+ )
566
+ for _ in range(num_layers)
567
+ ]
568
+ )
569
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
570
+
571
+ # 4. Output blocks
572
+ self.norm_out = AdaLayerNorm(
573
+ embedding_dim=time_embed_dim,
574
+ output_dim=2 * inner_dim,
575
+ norm_elementwise_affine=norm_elementwise_affine,
576
+ norm_eps=norm_eps,
577
+ chunk_dim=1,
578
+ )
579
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
580
+
581
+ self.gradient_checkpointing = False
582
+
583
+ def _set_gradient_checkpointing(self, module, value=False):
584
+ self.gradient_checkpointing = value
585
+
586
+ @property
587
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
588
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
589
+ r"""
590
+ Returns:
591
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
592
+ indexed by its weight name.
593
+ """
594
+ # set recursively
595
+ processors = {}
596
+
597
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
598
+ if hasattr(module, 'get_processor'):
599
+ processors[f'{name}.processor'] = module.get_processor()
600
+
601
+ for sub_name, child in module.named_children():
602
+ fn_recursive_add_processors(f'{name}.{sub_name}', child, processors)
603
+
604
+ return processors
605
+
606
+ for name, module in self.named_children():
607
+ fn_recursive_add_processors(name, module, processors)
608
+
609
+ return processors
610
+
611
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
612
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
613
+ r"""Sets the attention processor to use to compute attention.
614
+
615
+ Parameters:
616
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
617
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
618
+ for **all** `Attention` layers.
619
+
620
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
621
+ processor. This is strongly recommended when setting trainable attention processors.
622
+ """
623
+ count = len(self.attn_processors.keys())
624
+
625
+ if isinstance(processor, dict) and len(processor) != count:
626
+ raise ValueError(
627
+ f'A dict of processors was passed, but the number of processors {len(processor)} does not match the'
628
+ f' number of attention layers: {count}. Please make sure to pass {count} processor classes.'
629
+ )
630
+
631
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
632
+ if hasattr(module, 'set_processor'):
633
+ if not isinstance(processor, dict):
634
+ module.set_processor(processor)
635
+ else:
636
+ module.set_processor(processor.pop(f'{name}.processor'))
637
+
638
+ for sub_name, child in module.named_children():
639
+ fn_recursive_attn_processor(f'{name}.{sub_name}', child, processor)
640
+
641
+ for name, module in self.named_children():
642
+ fn_recursive_attn_processor(name, module, processor)
643
+
644
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with
645
+ # FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
646
+ def fuse_qkv_projections(self):
647
+ """Enables fused QKV projections. For self-attention modules, all
648
+ projection matrices (i.e., query, key, value) are fused. For cross-
649
+ attention modules, key and value projection matrices are fused.
650
+
651
+ <Tip warning={true}>
652
+
653
+ This API is 🧪 experimental.
654
+
655
+ </Tip>
656
+ """
657
+ self.original_attn_processors = None
658
+
659
+ for _, attn_processor in self.attn_processors.items():
660
+ if 'Added' in str(attn_processor.__class__.__name__):
661
+ raise ValueError('`fuse_qkv_projections()` is not supported for models having added KV projections.')
662
+
663
+ self.original_attn_processors = self.attn_processors
664
+
665
+ for module in self.modules():
666
+ if isinstance(module, Attention):
667
+ module.fuse_projections(fuse=True)
668
+
669
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
670
+
671
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
672
+ def unfuse_qkv_projections(self):
673
+ """Disables the fused QKV projection if enabled.
674
+
675
+ <Tip warning={true}>
676
+
677
+ This API is 🧪 experimental.
678
+
679
+ </Tip>
680
+ """
681
+ if self.original_attn_processors is not None:
682
+ self.set_attn_processor(self.original_attn_processors)
683
+
684
+ def forward(
685
+ self,
686
+ hidden_states: torch.Tensor,
687
+ timestep: Union[int, float, torch.LongTensor],
688
+ timestep_cond: Optional[torch.Tensor] = None,
689
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
690
+ motion_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
691
+ motion_emb: Optional[torch.Tensor] = None,
692
+ camera_emb: Optional[torch.Tensor] = None,
693
+ need_broadcast: bool = True,
694
+ return_dict: bool = True,
695
+ ):
696
+ batch_size, num_frames, channels, height, width = hidden_states.shape
697
+
698
+ # 1. Time embedding
699
+ timesteps = timestep
700
+ t_emb = self.time_proj(timesteps)
701
+
702
+ # timesteps does not contain any weights and will always return f32 tensors
703
+ # but time_embedding might actually be running in fp16. so we need to cast here.
704
+ # there might be better ways to encapsulate this.
705
+ t_emb = t_emb.to(dtype=hidden_states.dtype) # (2, 3072)
706
+ emb = self.time_embedding(t_emb, timestep_cond) # (2, 3072) --> (2, 512)
707
+
708
+ # 2. Patch embedding
709
+ hidden_states = self.patch_embed(hidden_states) # (2, 226+9450, dim=3072)
710
+ hidden_states = self.embedding_dropout(hidden_states)
711
+ image_seq_length = image_rotary_emb[0].shape[0]
712
+ motion_seq_length = motion_emb.shape[1] # 168
713
+ # hidden_states = hidden_states[:, motion_seq_length:]
714
+ encoder_hidden_states = motion_emb
715
+ # encoder_hidden_states = self.motion_proj(motion_emb)
716
+
717
+ # 3. Transformer blocks
718
+ for i, block in enumerate(self.transformer_blocks):
719
+ if self.training and self.gradient_checkpointing: # train with gradient checkpointing to save memory
720
+
721
+ def create_custom_forward(module):
722
+ def custom_forward(*inputs):
723
+ return module(*inputs)
724
+
725
+ return custom_forward
726
+
727
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
728
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
729
+ create_custom_forward(block),
730
+ hidden_states,
731
+ encoder_hidden_states,
732
+ emb,
733
+ image_rotary_emb,
734
+ motion_rotary_emb,
735
+ **ckpt_kwargs,
736
+ )
737
+ else:
738
+ hidden_states = block(
739
+ hidden_states=hidden_states,
740
+ encoder_hidden_states=encoder_hidden_states,
741
+ temb=emb,
742
+ image_rotary_emb=image_rotary_emb,
743
+ motion_rotary_emb=motion_rotary_emb,
744
+ )
745
+
746
+ # 4. Final block
747
+ hidden_states = self.norm_final(hidden_states)
748
+ hidden_states = self.norm_out(hidden_states, temb=emb)
749
+ hidden_states = self.proj_out(hidden_states)
750
+
751
+ # 5. Unpatchify
752
+ p = self.config.patch_size
753
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
754
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
755
+
756
+ if not return_dict:
757
+ return (output,)
758
+ return Transformer2DModelOutput(sample=output)
models/dit/pipeline_mtvcrafter.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import inspect
4
+ import numpy as np
5
+ from dataclasses import dataclass
6
+ from typing import Callable, Dict, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
10
+ from diffusers.models import AutoencoderKLCogVideoX
11
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
13
+ from diffusers.utils import BaseOutput, logging
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from diffusers.video_processor import VideoProcessor
16
+ from einops import rearrange
17
+ from PIL import Image
18
+ from torchvision import transforms
19
+
20
+ from .mvdit_transformer import Transformer3DModel
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+
24
+ def get_1d_rotary_pos_embed(
25
+ dim: int,
26
+ pos: Union[np.ndarray, int],
27
+ theta: float = 10000.0,
28
+ use_real=False,
29
+ linear_factor=1.0,
30
+ ntk_factor=1.0,
31
+ repeat_interleave_real=True,
32
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
33
+ ):
34
+ """
35
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
36
+
37
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
38
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
39
+ data type.
40
+
41
+ Args:
42
+ dim (`int`): Dimension of the frequency tensor.
43
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
44
+ theta (`float`, *optional*, defaults to 10000.0):
45
+ Scaling factor for frequency computation. Defaults to 10000.0.
46
+ use_real (`bool`, *optional*):
47
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
48
+ linear_factor (`float`, *optional*, defaults to 1.0):
49
+ Scaling factor for the context extrapolation. Defaults to 1.0.
50
+ ntk_factor (`float`, *optional*, defaults to 1.0):
51
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
52
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
53
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
54
+ Otherwise, they are concateanted with themselves.
55
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
56
+ the dtype of the frequency tensor.
57
+ Returns:
58
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
59
+ """
60
+ assert dim % 2 == 0
61
+
62
+ if isinstance(pos, int):
63
+ pos = torch.arange(pos)
64
+ if isinstance(pos, np.ndarray):
65
+ pos = torch.from_numpy(pos) # type: ignore # [S]
66
+
67
+ theta = theta * ntk_factor
68
+ freqs = (
69
+ 1.0
70
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
71
+ / linear_factor
72
+ ) # [D/2]
73
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
74
+ if use_real and repeat_interleave_real:
75
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
76
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
77
+ return freqs_cos, freqs_sin
78
+ elif use_real:
79
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
80
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
81
+ return freqs_cos, freqs_sin
82
+ else:
83
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
84
+ return freqs_cis
85
+
86
+
87
+ def get_3d_rotary_pos_embed(
88
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
89
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
90
+ """
91
+ RoPE for video tokens with 3D structure.
92
+
93
+ Args:
94
+ embed_dim: (`int`):
95
+ The embedding dimension size, corresponding to hidden_size_head.
96
+ crops_coords (`Tuple[int]`):
97
+ The top-left and bottom-right coordinates of the crop.
98
+ grid_size (`Tuple[int]`):
99
+ The grid size of the spatial positional embedding (height, width).
100
+ temporal_size (`int`):
101
+ The size of the temporal dimension.
102
+ theta (`float`):
103
+ Scaling factor for frequency computation.
104
+
105
+ Returns:
106
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
107
+ """
108
+ if use_real is not True:
109
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
110
+ start, stop = crops_coords
111
+ grid_size_h, grid_size_w = grid_size
112
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
113
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
114
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
115
+
116
+ # Compute dimensions for each axis
117
+ dim_t = embed_dim // 4
118
+ dim_h = embed_dim // 8 * 3
119
+ dim_w = embed_dim // 8 * 3
120
+
121
+ # Temporal frequencies
122
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
123
+ # Spatial frequencies for height and width
124
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
125
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
126
+
127
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
128
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
129
+ freqs_t = freqs_t[:, None, None, :].expand(
130
+ -1, grid_size_h, grid_size_w, -1
131
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
132
+ freqs_h = freqs_h[None, :, None, :].expand(
133
+ temporal_size, -1, grid_size_w, -1
134
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
135
+ freqs_w = freqs_w[None, None, :, :].expand(
136
+ temporal_size, grid_size_h, -1, -1
137
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
138
+
139
+ freqs = torch.cat(
140
+ [freqs_t, freqs_h, freqs_w], dim=-1
141
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
142
+ freqs = freqs.view(
143
+ temporal_size * grid_size_h * grid_size_w, -1
144
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
145
+ return freqs
146
+
147
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
148
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
149
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
150
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
151
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
152
+ return cos, sin
153
+
154
+
155
+ def get_3d_motion_spatial_embed(
156
+ embed_dim: int, num_joints: int, joints_mean: np.ndarray, joints_std: np.ndarray, theta: float = 10000.0
157
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
158
+ """
159
+ """
160
+ assert embed_dim % 2 == 0 and embed_dim % 3 == 0
161
+
162
+ def create_rope_pe(dim, pos, freqs_dtype=torch.float32):
163
+ if isinstance(pos, np.ndarray):
164
+ pos = torch.from_numpy(pos)
165
+ freqs = (
166
+ 1.0
167
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
168
+ ) # [D/2]
169
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
170
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
171
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
172
+ return freqs_cos, freqs_sin
173
+
174
+ # 为每个轴创建位置编码
175
+ # relative_pos_x = joints_mean[:, 0] - joints_mean[0, 0]
176
+ # relative_pos_y = joints_mean[:, 1] - joints_mean[0, 1]
177
+ # relative_pos_z = joints_mean[:, 2] - joints_mean[0, 2]
178
+
179
+ # normalized_pos_x = relative_pos_x / joints_std[:, 0].mean()
180
+ # normalized_pos_y = relative_pos_y / joints_std[:, 1].mean()
181
+ # normalized_pos_z = relative_pos_z / joints_std[:, 2].mean()
182
+
183
+ pos_x = joints_mean[:, 0]
184
+ pos_y = joints_mean[:, 1]
185
+ pos_z = joints_mean[:, 2]
186
+
187
+ normalized_pos_x = (pos_x - pos_x.mean())
188
+ normalized_pos_y = (pos_y - pos_y.mean())
189
+ normalized_pos_z = (pos_z - pos_z.mean())
190
+
191
+ freqs_cos_x, freqs_sin_x = create_rope_pe(embed_dim // 3, normalized_pos_x)
192
+ freqs_cos_y, freqs_sin_y = create_rope_pe(embed_dim // 3, normalized_pos_y)
193
+ freqs_cos_z, freqs_sin_z = create_rope_pe(embed_dim // 3, normalized_pos_z)
194
+
195
+ freqs_cos = torch.cat([freqs_cos_x, freqs_cos_y, freqs_cos_z], dim=-1)
196
+ freqs_sin = torch.cat([freqs_sin_x, freqs_sin_y, freqs_sin_z], dim=-1)
197
+
198
+ return freqs_cos, freqs_sin
199
+
200
+
201
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
202
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
203
+ tw = tgt_width
204
+ th = tgt_height
205
+ h, w = src
206
+ r = h / w
207
+ if r > (th / tw):
208
+ resize_height = th
209
+ resize_width = int(round(th / h * w))
210
+ else:
211
+ resize_width = tw
212
+ resize_height = int(round(tw / w * h))
213
+
214
+ crop_top = int(round((th - resize_height) / 2.0))
215
+ crop_left = int(round((tw - resize_width) / 2.0))
216
+
217
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
218
+
219
+
220
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
221
+ def retrieve_timesteps(
222
+ scheduler,
223
+ num_inference_steps: Optional[int] = None,
224
+ device: Optional[Union[str, torch.device]] = None,
225
+ timesteps: Optional[List[int]] = None,
226
+ sigmas: Optional[List[float]] = None,
227
+ **kwargs,
228
+ ):
229
+ """
230
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
231
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
232
+
233
+ Args:
234
+ scheduler (`SchedulerMixin`):
235
+ The scheduler to get timesteps from.
236
+ num_inference_steps (`int`):
237
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
238
+ must be `None`.
239
+ device (`str` or `torch.device`, *optional*):
240
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
241
+ timesteps (`List[int]`, *optional*):
242
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
243
+ `num_inference_steps` and `sigmas` must be `None`.
244
+ sigmas (`List[float]`, *optional*):
245
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
246
+ `num_inference_steps` and `timesteps` must be `None`.
247
+
248
+ Returns:
249
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
250
+ second element is the number of inference steps.
251
+ """
252
+ if timesteps is not None and sigmas is not None:
253
+ raise ValueError('Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values')
254
+ if timesteps is not None:
255
+ accepts_timesteps = 'timesteps' in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
256
+ if not accepts_timesteps:
257
+ raise ValueError(
258
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
259
+ f' timestep schedules. Please check whether you are using the correct scheduler.'
260
+ )
261
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
262
+ timesteps = scheduler.timesteps
263
+ num_inference_steps = len(timesteps)
264
+ elif sigmas is not None:
265
+ accept_sigmas = 'sigmas' in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
266
+ if not accept_sigmas:
267
+ raise ValueError(
268
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
269
+ f' sigmas schedules. Please check whether you are using the correct scheduler.'
270
+ )
271
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
272
+ timesteps = scheduler.timesteps
273
+ num_inference_steps = len(timesteps)
274
+ else:
275
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
276
+ timesteps = scheduler.timesteps
277
+ return timesteps, num_inference_steps
278
+
279
+
280
+ @dataclass
281
+ class MTVCrafterPipelineOutput(BaseOutput):
282
+ r"""Output class for the MTVCrafter pipeline.
283
+
284
+ Args:
285
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
286
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
287
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
288
+ `(batch_size, num_frames, channels, height, width)`.
289
+ """
290
+
291
+ frames: torch.Tensor
292
+
293
+
294
+ class MTVCrafterPipeline(DiffusionPipeline):
295
+ r"""Pipeline for MTVCrafter.
296
+
297
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
298
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
299
+
300
+ Args:
301
+ vae ([`AutoencoderKL`]):
302
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
303
+ transformer ([`Transformer3DModel`]):
304
+ A image conditioned `Transformer3DModel` to denoise the encoded video latents.
305
+ scheduler ([`SchedulerMixin`]):
306
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
307
+ """
308
+
309
+ _callback_tensor_inputs = [
310
+ 'latents',
311
+ 'prompt_embeds',
312
+ 'negative_prompt_embeds',
313
+ ]
314
+
315
+ def __init__(
316
+ self,
317
+ vae: AutoencoderKLCogVideoX,
318
+ transformer: Transformer3DModel,
319
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
320
+ ):
321
+ super().__init__()
322
+
323
+ self.register_modules(
324
+ vae=vae,
325
+ transformer=transformer,
326
+ scheduler=scheduler,
327
+ )
328
+ self.vae_scale_factor_spatial = (
329
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, 'vae') and self.vae is not None else 8
330
+ )
331
+ self.vae_scale_factor_temporal = (
332
+ self.vae.config.temporal_compression_ratio if hasattr(self, 'vae') and self.vae is not None else 4
333
+ )
334
+
335
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
336
+ self.normalize = transforms.Normalize([0.5], [0.5])
337
+
338
+ @classmethod
339
+ def from_pretrained(
340
+ cls,
341
+ model_path,
342
+ transformer_model_path=None,
343
+ scheduler_type='ddim',
344
+ torch_dtype=None,
345
+ **kwargs,
346
+ ):
347
+ if transformer_model_path is None:
348
+ transformer_model_path = os.path.join(model_path, 'transformer')
349
+ transformer = Transformer3DModel.from_pretrained(
350
+ transformer_model_path, torch_dtype=torch_dtype, **kwargs
351
+ )
352
+ if scheduler_type == 'ddim':
353
+ scheduler = CogVideoXDDIMScheduler.from_pretrained(model_path, subfolder='scheduler')
354
+ elif scheduler_type == 'dpm':
355
+ scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder='scheduler')
356
+ else:
357
+ assert False
358
+ pipe = super().from_pretrained(
359
+ model_path, transformer=transformer, scheduler=scheduler, torch_dtype=torch_dtype, **kwargs
360
+ )
361
+ return pipe
362
+
363
+
364
+ def prepare_latents(
365
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
366
+ ):
367
+ shape = (
368
+ batch_size,
369
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
370
+ num_channels_latents,
371
+ height // self.vae_scale_factor_spatial,
372
+ width // self.vae_scale_factor_spatial,
373
+ )
374
+ if isinstance(generator, list) and len(generator) != batch_size:
375
+ raise ValueError(
376
+ f'You have passed a list of generators of length {len(generator)}, but requested an effective batch'
377
+ f' size of {batch_size}. Make sure the batch size matches the length of the generators.'
378
+ )
379
+
380
+ if latents is None:
381
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
382
+ else:
383
+ latents = latents.to(device)
384
+
385
+ # scale the initial noise by the standard deviation required by the scheduler
386
+ latents = latents * self.scheduler.init_noise_sigma
387
+ return latents
388
+
389
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
390
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
391
+ latents = 1 / self.vae.config.scaling_factor * latents
392
+
393
+ frames = self.vae.decode(latents).sample
394
+ return frames
395
+
396
+ def prepare_extra_step_kwargs(self, generator, eta):
397
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
398
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
399
+ # eta corresponds to η in DDIM paper and should be between [0, 1]
400
+
401
+ accepts_eta = 'eta' in set(inspect.signature(self.scheduler.step).parameters.keys())
402
+ extra_step_kwargs = {}
403
+ if accepts_eta:
404
+ extra_step_kwargs['eta'] = eta
405
+
406
+ # check if the scheduler accepts generator
407
+ accepts_generator = 'generator' in set(inspect.signature(self.scheduler.step).parameters.keys())
408
+ if accepts_generator:
409
+ extra_step_kwargs['generator'] = generator
410
+ return extra_step_kwargs
411
+
412
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
413
+ def check_inputs(
414
+ self,
415
+ height,
416
+ width,
417
+ callback_on_step_end_tensor_inputs,
418
+ ):
419
+ if height % 8 != 0 or width % 8 != 0:
420
+ raise ValueError(f'`height` and `width` have to be divisible by 8 but are {height} and {width}.')
421
+
422
+ if callback_on_step_end_tensor_inputs is not None and not all(
423
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
424
+ ):
425
+ raise ValueError(
426
+ f'`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found '
427
+ f'{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}'
428
+ )
429
+
430
+
431
+ def _prepare_rotary_positional_embeddings(
432
+ self,
433
+ height: int,
434
+ width: int,
435
+ num_frames: int,
436
+ device: torch.device,
437
+ dtype: torch.dtype,
438
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
439
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
440
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
441
+ grid_crops_coords = ((0, 0), (grid_height, grid_width))
442
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
443
+ embed_dim=self.transformer.config.attention_head_dim,
444
+ crops_coords=grid_crops_coords,
445
+ grid_size=(grid_height, grid_width),
446
+ temporal_size=num_frames,
447
+ )
448
+
449
+ freqs_cos = freqs_cos.to(device=device, dtype=dtype)
450
+ freqs_sin = freqs_sin.to(device=device, dtype=dtype)
451
+ return freqs_cos, freqs_sin
452
+
453
+
454
+ def _prepare_motion_embeddings(self, num_frames, num_joints, joints_mean, joints_std, device, dtype):
455
+ time_embed = get_1d_rotary_pos_embed(self.transformer.config.attention_head_dim // 4, num_frames, use_real=True)
456
+ time_embed_cos = time_embed[0][:, None, :].expand(-1, num_joints, -1).reshape(num_frames*num_joints, -1)
457
+ time_embed_sin = time_embed[1][:, None, :].expand(-1, num_joints, -1).reshape(num_frames*num_joints, -1)
458
+ spatial_motion_embed = get_3d_motion_spatial_embed(self.transformer.config.attention_head_dim // 4 * 3, num_joints, joints_mean, joints_std)
459
+ spatial_embed_cos = spatial_motion_embed[0][None, :, :].expand(num_frames, -1, -1).reshape(num_frames*num_joints, -1)
460
+ spatial_embed_sin = spatial_motion_embed[1][None, :, :].expand(num_frames, -1, -1).reshape(num_frames*num_joints, -1)
461
+ motion_embed_cos = torch.cat([time_embed_cos, spatial_embed_cos], dim=-1).to(device=device, dtype=dtype)
462
+ motion_embed_sin = torch.cat([time_embed_sin, spatial_embed_sin], dim=-1).to(device=device, dtype=dtype)
463
+ return motion_embed_cos, motion_embed_sin
464
+
465
+ @property
466
+ def guidance_scale(self):
467
+ return self._guidance_scale
468
+
469
+ @property
470
+ def num_timesteps(self):
471
+ return self._num_timesteps
472
+
473
+ @property
474
+ def interrupt(self):
475
+ return self._interrupt
476
+
477
+ @torch.no_grad()
478
+ def __call__(
479
+ self,
480
+ prompt: Optional[Union[str, List[str]]] = None,
481
+ negative_prompt: Optional[Union[str, List[str]]] = None,
482
+ height: int = 480,
483
+ width: int = 720,
484
+ num_frames: int = 49,
485
+ num_inference_steps: int = 50,
486
+ timesteps: Optional[List[int]] = None,
487
+ guidance_scale: float = 6,
488
+ use_dynamic_cfg: bool = False,
489
+ num_videos_per_prompt: int = 1,
490
+ eta: float = 0.0,
491
+ seed: Optional[int] = -1,
492
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
493
+ latents: Optional[torch.FloatTensor] = None,
494
+ prompt_embeds: Optional[torch.FloatTensor] = None,
495
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
496
+ output_type: str = 'pil',
497
+ return_dict: bool = True,
498
+ callback_on_step_end: Optional[
499
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
500
+ ] = None,
501
+ callback_on_step_end_tensor_inputs: List[str] = ['latents'],
502
+ max_sequence_length: int = 226,
503
+ ref_images: List[Image.Image] = None,
504
+ motion_embeds: Optional[torch.FloatTensor] = None,
505
+ joint_mean: Optional[np.ndarray] = None,
506
+ joint_std: Optional[np.ndarray] = None,
507
+ ) -> Union[MTVCrafterPipelineOutput, Tuple]:
508
+ """Function invoked when calling the pipeline for generation.
509
+
510
+ Args:
511
+ prompt (`str` or `List[str]`, *optional*):
512
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
513
+ instead.
514
+ negative_prompt (`str` or `List[str]`, *optional*):
515
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
516
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
517
+ less than `1`).
518
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
519
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
520
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
521
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
522
+ num_frames (`int`, defaults to `48`):
523
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
524
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
525
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
526
+ needs to be satisfied is that of divisibility mentioned above.
527
+ num_inference_steps (`int`, *optional*, defaults to 50):
528
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
529
+ expense of slower inference.
530
+ timesteps (`List[int]`, *optional*):
531
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
532
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
533
+ passed will be used. Must be in descending order.
534
+ guidance_scale (`float`, *optional*, defaults to 7.0):
535
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance]. Guidance scale is enabled by setting `guidance_scale >
536
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
537
+ usually at the expense of lower image quality.
538
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
539
+ The number of videos to generate per prompt.
540
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
541
+ One or a list of [torch generator(s)]
542
+ to make generation deterministic.
543
+ latents (`torch.FloatTensor`, *optional*):
544
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
545
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
546
+ tensor will ge generated by sampling using the supplied random `generator`.
547
+ prompt_embeds (`torch.FloatTensor`, *optional*):
548
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
549
+ provided, text embeddings will be generated from `prompt` input argument.
550
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
551
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
552
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
553
+ argument.
554
+ output_type (`str`, *optional*, defaults to `"pil"`):
555
+ The output format of the generate image. Choose between `PIL.Image.Image` or `np.array`.
556
+ return_dict (`bool`, *optional*, defaults to `True`):
557
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
558
+ of a plain tuple.
559
+ callback_on_step_end (`Callable`, *optional*):
560
+ A function that calls at the end of each denoising steps during the inference. The function is called
561
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
562
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
563
+ `callback_on_step_end_tensor_inputs`.
564
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
565
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
566
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
567
+ `._callback_tensor_inputs` attribute of your pipeline class.
568
+ max_sequence_length (`int`, defaults to `226`):
569
+ Maximum sequence length in encoded prompt. Must be consistent with
570
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
571
+ """
572
+
573
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
574
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
575
+
576
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
577
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
578
+ # 720 * 480
579
+ num_videos_per_prompt = 1
580
+
581
+ # 1. Check inputs. Raise error if not correct
582
+ self.check_inputs(
583
+ height,
584
+ width,
585
+ callback_on_step_end_tensor_inputs,
586
+ )
587
+ self._guidance_scale = guidance_scale
588
+ self._interrupt = False
589
+
590
+ # 2. Default call parameters
591
+ if prompt is not None and isinstance(prompt, str):
592
+ batch_size = 1
593
+ elif prompt is not None and isinstance(prompt, list):
594
+ batch_size = len(prompt)
595
+ elif prompt is None:
596
+ batch_size = 1
597
+ else:
598
+ batch_size = prompt_embeds.shape[0]
599
+
600
+ device = self._execution_device
601
+
602
+ if seed > 0:
603
+ generator = torch.Generator(device=device)
604
+ generator.manual_seed(seed)
605
+ do_classifier_free_guidance = guidance_scale > 1.0
606
+
607
+ # 3. Prepare timesteps
608
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
609
+ self._num_timesteps = len(timesteps)
610
+
611
+ # 4. Prepare latents.
612
+ latent_channels = self.vae.config.latent_channels
613
+ latents = self.prepare_latents(
614
+ batch_size * num_videos_per_prompt,
615
+ latent_channels,
616
+ num_frames,
617
+ height,
618
+ width,
619
+ self.vae.dtype,
620
+ device,
621
+ generator,
622
+ latents,
623
+ ) # [1, x, 16, h/8, w/8]
624
+
625
+ if ref_images is not None:
626
+ ref_images = rearrange(ref_images.unsqueeze(0), 'b f c h w -> b c f h w')
627
+ ref_latents = self.vae.encode(
628
+ ref_images.to(dtype=self.vae.dtype, device=self.vae.device)
629
+ ).latent_dist.sample()
630
+ ref_latents = rearrange(ref_latents, 'b c f h w -> b f c h w')
631
+ if do_classifier_free_guidance:
632
+ ref_latents = torch.cat([ref_latents, ref_latents], dim=0)
633
+
634
+ motion_embeds = motion_embeds.to(latents.dtype)
635
+ if motion_embeds is not None and do_classifier_free_guidance:
636
+ motion_embeds = torch.cat([self.transformer.unconditional_motion_token.unsqueeze(0), motion_embeds], dim=0)
637
+
638
+ # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
639
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
640
+
641
+ # 6. Create rotary embeds if required
642
+ image_rotary_emb = (
643
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, dtype=latents.dtype)
644
+ if self.transformer.config.use_rotary_positional_embeddings
645
+ else None
646
+ )
647
+ motion_rotary_emb = self._prepare_motion_embeddings(latents.size(1), 24, joint_mean, joint_std, device, dtype=latents.dtype)
648
+
649
+ # 7. Denoising loop
650
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
651
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
652
+ # for DPM-solver++
653
+ old_pred_original_sample = None
654
+ for i, t in enumerate(timesteps):
655
+ if self.interrupt:
656
+ continue
657
+
658
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
659
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
660
+
661
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
662
+ timestep = t.expand(latent_model_input.shape[0])
663
+
664
+ if ref_images is not None:
665
+ latent_model_input = torch.cat([latent_model_input, ref_latents], dim=2)
666
+
667
+ # predict noise model_output
668
+ noise_pred = self.transformer(
669
+ hidden_states=latent_model_input,
670
+ timestep=timestep.long(),
671
+ image_rotary_emb=image_rotary_emb,
672
+ motion_rotary_emb=motion_rotary_emb,
673
+ motion_emb=motion_embeds,
674
+ return_dict=False,
675
+ )[0]
676
+ noise_pred = noise_pred.float() # [b, f, c, h, w]
677
+
678
+ # perform guidance
679
+ if use_dynamic_cfg:
680
+ self._guidance_scale = 1 + guidance_scale * (
681
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
682
+ )
683
+ if do_classifier_free_guidance:
684
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
685
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
686
+
687
+ # compute the previous noisy sample x_t -> x_t-1
688
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
689
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
690
+ else:
691
+ latents, old_pred_original_sample = self.scheduler.step(
692
+ noise_pred,
693
+ old_pred_original_sample,
694
+ t,
695
+ timesteps[i - 1] if i > 0 else None,
696
+ latents,
697
+ **extra_step_kwargs,
698
+ return_dict=False,
699
+ )
700
+ latents = latents.to(self.vae.dtype)
701
+
702
+ # call the callback, if provided
703
+ if callback_on_step_end is not None:
704
+ callback_kwargs = {}
705
+ for k in callback_on_step_end_tensor_inputs:
706
+ callback_kwargs[k] = locals()[k]
707
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
708
+
709
+ latents = callback_outputs.pop('latents', latents)
710
+ prompt_embeds = callback_outputs.pop('prompt_embeds', prompt_embeds)
711
+ negative_prompt_embeds = callback_outputs.pop('negative_prompt_embeds', negative_prompt_embeds)
712
+
713
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
714
+ progress_bar.update()
715
+
716
+ if not output_type == 'latent':
717
+ video = self.decode_latents(latents)
718
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
719
+ else:
720
+ video = latents
721
+
722
+ # Offload all models
723
+ self.maybe_free_model_hooks()
724
+
725
+ if not return_dict:
726
+ return (video,)
727
+
728
+ return MTVCrafterPipelineOutput(frames=video)
models/motion4d/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .vqvae import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder
2
+ from .loss import ReConsLoss
models/motion4d/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (375 Bytes). View file
 
models/motion4d/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (314 Bytes). View file
 
models/motion4d/__pycache__/loss.cpython-311.pyc ADDED
Binary file (2.13 kB). View file
 
models/motion4d/__pycache__/loss.cpython-313.pyc ADDED
Binary file (2 kB). View file
 
models/motion4d/__pycache__/vqvae.cpython-311.pyc ADDED
Binary file (28.4 kB). View file
 
models/motion4d/__pycache__/vqvae.cpython-313.pyc ADDED
Binary file (26.5 kB). View file
 
models/motion4d/loss.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ReConsLoss(nn.Module):
5
+ def __init__(self, recons_loss, nb_joints):
6
+ super(ReConsLoss, self).__init__()
7
+
8
+ if recons_loss == 'l1':
9
+ self.Loss = torch.nn.L1Loss()
10
+ elif recons_loss == 'l2' :
11
+ self.Loss = torch.nn.MSELoss()
12
+ elif recons_loss == 'l1_smooth' :
13
+ self.Loss = torch.nn.SmoothL1Loss()
14
+
15
+ # 4 global motion associated to root
16
+ # 12 local motion (3 local xyz, 3 vel xyz, 6 rot6d)
17
+ # 3 global vel xyz
18
+ # 4 foot contact
19
+ self.nb_joints = nb_joints
20
+ self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4
21
+
22
+ def forward(self, motion_pred, motion_gt) :
23
+ loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim])
24
+ return loss
25
+
26
+ def forward_joint(self, motion_pred, motion_gt) :
27
+ loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4])
28
+ return loss
29
+
30
+
models/motion4d/vqvae.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from typing import Any, Dict, Optional, Tuple, Union
7
+ from diffusers.models.attention import Attention
8
+
9
+
10
+ class AttnProcessor:
11
+ r"""Processor for implementing scaled dot-product attention for the
12
+ CogVideoX model.
13
+
14
+ It applies a rotary embedding on query and key vectors, but does not include spatial normalization.
15
+ """
16
+
17
+ def __init__(self):
18
+ if not hasattr(F, 'scaled_dot_product_attention'):
19
+ raise ImportError('AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.')
20
+
21
+ def __call__(
22
+ self,
23
+ attn: Attention,
24
+ hidden_states: torch.Tensor,
25
+ encoder_hidden_states: Optional[torch.Tensor] = None,
26
+ attention_mask: Optional[torch.Tensor] = None,
27
+ image_rotary_emb: Optional[torch.Tensor] = None,
28
+ motion_rotary_emb: Optional[torch.Tensor] = None,
29
+ ) -> torch.Tensor:
30
+ import pdb; pdb.set_trace()
31
+ batch_size, sequence_length, _ = hidden_states.shape
32
+
33
+ if attention_mask is not None:
34
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
35
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
36
+
37
+ query = attn.to_q(hidden_states)
38
+ key = attn.to_k(hidden_states)
39
+ value = attn.to_v(hidden_states)
40
+
41
+ inner_dim = key.shape[-1]
42
+ head_dim = inner_dim // attn.heads
43
+
44
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [batch_size, heads, seq_len, dim]
45
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
46
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
47
+
48
+ if attn.norm_q is not None:
49
+ query = attn.norm_q(query)
50
+ if attn.norm_k is not None:
51
+ key = attn.norm_k(key)
52
+
53
+ sp_group = get_sequence_parallel_group()
54
+ if sp_group is not None:
55
+ sp_size = dist.get_world_size(sp_group)
56
+ query = _all_in_all_with_text(query, text_seq_length, sp_group, sp_size, mode=1)
57
+ key = _all_in_all_with_text(key, text_seq_length, sp_group, sp_size, mode=1)
58
+ value = _all_in_all_with_text(value, text_seq_length, sp_group, sp_size, mode=1)
59
+ text_seq_length *= sp_size
60
+
61
+ # Apply RoPE if needed
62
+ if image_rotary_emb is not None:
63
+ from diffusers.models.embeddings import apply_rotary_emb
64
+ image_seq_length = image_rotary_emb[0].shape[0]
65
+ query[:, :, :image_seq_length] = apply_rotary_emb(query[:, :, :image_seq_length], image_rotary_emb)
66
+ if motion_rotary_emb is not None:
67
+ query[:, :, image_seq_length:] = apply_rotary_emb(query[:, :, image_seq_length:], motion_rotary_emb)
68
+ if not attn.is_cross_attention:
69
+ key[:, :, :image_seq_length] = apply_rotary_emb(key[:, :, :image_seq_length], image_rotary_emb)
70
+ if motion_rotary_emb is not None:
71
+ key[:, :, image_seq_length:] = apply_rotary_emb(key[:, :, image_seq_length:], motion_rotary_emb)
72
+
73
+ hidden_states = F.scaled_dot_product_attention(
74
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
75
+ )
76
+
77
+ if sp_group is not None:
78
+ hidden_states = _all_in_all_with_text(hidden_states, text_seq_length, sp_group, sp_size, mode=2)
79
+ text_seq_length = text_seq_length // sp_size
80
+
81
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
82
+
83
+ # linear proj
84
+ hidden_states = attn.to_out[0](hidden_states)
85
+ # dropout
86
+ hidden_states = attn.to_out[1](hidden_states)
87
+
88
+ return hidden_states
89
+
90
+
91
+ class Encoder(nn.Module):
92
+ def __init__(
93
+ self,
94
+ in_channels=3,
95
+ mid_channels=[128, 512],
96
+ out_channels=3072,
97
+ downsample_time=[1, 1],
98
+ downsample_joint=[1, 1],
99
+ num_attention_heads=8,
100
+ attention_head_dim=64,
101
+ dim=3072,
102
+ ):
103
+ super(Encoder, self).__init__()
104
+
105
+ self.conv_in = nn.Conv2d(in_channels, mid_channels[0], kernel_size=3, stride=1, padding=1)
106
+ self.resnet1 = nn.ModuleList([ResBlock(mid_channels[0], mid_channels[0]) for _ in range(3)])
107
+ self.downsample1 = Downsample(mid_channels[0], mid_channels[0], downsample_time[0], downsample_joint[0])
108
+ self.resnet2 = ResBlock(mid_channels[0], mid_channels[1])
109
+ self.resnet3 = nn.ModuleList([ResBlock(mid_channels[1], mid_channels[1]) for _ in range(3)])
110
+ self.downsample2 = Downsample(mid_channels[1], mid_channels[1], downsample_time[1], downsample_joint[1])
111
+ # self.attn = Attention(
112
+ # query_dim=dim,
113
+ # dim_head=attention_head_dim,
114
+ # heads=num_attention_heads,
115
+ # qk_norm='layer_norm',
116
+ # eps=1e-6,
117
+ # bias=True,
118
+ # out_bias=True,
119
+ # processor=AttnProcessor(),
120
+ # )
121
+ self.conv_out = nn.Conv2d(mid_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
122
+
123
+ def forward(self, x):
124
+ x = self.conv_in(x)
125
+ for resnet in self.resnet1:
126
+ x = resnet(x)
127
+ x = self.downsample1(x)
128
+
129
+ x = self.resnet2(x)
130
+ for resnet in self.resnet3:
131
+ x = resnet(x)
132
+ x = self.downsample2(x)
133
+
134
+ # x = x + self.attn(x)
135
+ x = self.conv_out(x)
136
+
137
+ return x
138
+
139
+
140
+
141
+ class VectorQuantizer(nn.Module):
142
+ def __init__(self, nb_code, code_dim, is_train=True):
143
+ super().__init__()
144
+ self.nb_code = nb_code
145
+ self.code_dim = code_dim
146
+ self.mu = 0.99
147
+ self.reset_codebook()
148
+ self.reset_count = 0
149
+ self.usage = torch.zeros((self.nb_code, 1))
150
+ self.is_train = is_train
151
+
152
+ def reset_codebook(self):
153
+ self.init = False
154
+ self.code_sum = None
155
+ self.code_count = None
156
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda())
157
+
158
+ def _tile(self, x):
159
+ nb_code_x, code_dim = x.shape
160
+ if nb_code_x < self.nb_code:
161
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
162
+ std = 0.01 / np.sqrt(code_dim)
163
+ out = x.repeat(n_repeats, 1)
164
+ out = out + torch.randn_like(out) * std
165
+ else:
166
+ out = x
167
+ return out
168
+
169
+ def init_codebook(self, x):
170
+ if torch.all(self.codebook == 0):
171
+ out = self._tile(x)
172
+ self.codebook = out[:self.nb_code]
173
+ self.code_sum = self.codebook.clone()
174
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
175
+ if self.is_train:
176
+ self.init = True
177
+
178
+ @torch.no_grad()
179
+ def update_codebook(self, x, code_idx):
180
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device)
181
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
182
+
183
+ code_sum = torch.matmul(code_onehot, x) # [nb_code, code_dim]
184
+ code_count = code_onehot.sum(dim=-1) # nb_code
185
+
186
+ out = self._tile(x)
187
+ code_rand = out[torch.randperm(out.shape[0])[:self.nb_code]]
188
+
189
+ # Update centres
190
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum
191
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count
192
+
193
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
194
+ self.usage = self.usage.to(usage.device)
195
+ if self.reset_count >= 20: # reset codebook every 20 steps for stability
196
+ self.reset_count = 0
197
+ usage = (usage + self.usage >= 1.0).float()
198
+ else:
199
+ self.reset_count += 1
200
+ self.usage = (usage + self.usage >= 1.0).float()
201
+ usage = torch.ones_like(self.usage, device=x.device)
202
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
203
+
204
+ self.codebook = usage * code_update + (1 - usage) * code_rand
205
+ prob = code_count / torch.sum(code_count)
206
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
207
+
208
+ return perplexity
209
+
210
+ def preprocess(self, x):
211
+ # [bs, c, f, j] -> [bs * f * j, c]
212
+ x = x.permute(0, 2, 3, 1).contiguous()
213
+ x = x.view(-1, x.shape[-1])
214
+ return x
215
+
216
+ def quantize(self, x):
217
+ # [bs * f * j, dim=3072]
218
+ # Calculate latent code x_l
219
+ k_w = self.codebook.t()
220
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, keepdim=True)
221
+ _, code_idx = torch.min(distance, dim=-1)
222
+ return code_idx
223
+
224
+ def dequantize(self, code_idx):
225
+ x = F.embedding(code_idx, self.codebook) # indexing: [bs * f * j, 32]
226
+ return x
227
+
228
+ def forward(self, x, return_vq=False):
229
+ # import pdb; pdb.set_trace()
230
+ bs, c, f, j = x.shape # SMPL data frames: [bs, 3072, f, j]
231
+
232
+ # Preprocess
233
+ x = self.preprocess(x)
234
+ # return x.view(bs, f*j, c).contiguous(), None
235
+ assert x.shape[-1] == self.code_dim
236
+
237
+ # Init codebook if not inited
238
+ if not self.init and self.is_train:
239
+ self.init_codebook(x)
240
+
241
+ # quantize and dequantize through bottleneck
242
+ code_idx = self.quantize(x)
243
+ x_d = self.dequantize(code_idx)
244
+
245
+ # Update embeddings
246
+ if self.is_train:
247
+ perplexity = self.update_codebook(x, code_idx)
248
+
249
+ # Loss
250
+ commit_loss = F.mse_loss(x, x_d.detach())
251
+
252
+ # Passthrough
253
+ x_d = x + (x_d - x).detach()
254
+
255
+ if return_vq:
256
+ return x_d.view(bs, f*j, c).contiguous(), commit_loss
257
+ # return (x_d, x_d.view(bs, f, j, c).permute(0, 3, 1, 2).contiguous()), commit_loss, perplexity
258
+
259
+ # Postprocess
260
+ x_d = x_d.view(bs, f, j, c).permute(0, 3, 1, 2).contiguous()
261
+
262
+ if self.is_train:
263
+ return x_d, commit_loss, perplexity
264
+ else:
265
+ return x_d, commit_loss
266
+
267
+
268
+ class Decoder(nn.Module):
269
+ def __init__(
270
+ self,
271
+ in_channels=3072,
272
+ mid_channels=[512, 128],
273
+ out_channels=3,
274
+ upsample_rate=None,
275
+ frame_upsample_rate=[1.0, 1.0],
276
+ joint_upsample_rate=[1.0, 1.0],
277
+ dim=128,
278
+ attention_head_dim=64,
279
+ num_attention_heads=8,
280
+ ):
281
+ super(Decoder, self).__init__()
282
+
283
+ self.conv_in = nn.Conv2d(in_channels, mid_channels[0], kernel_size=3, stride=1, padding=1)
284
+ self.resnet1 = nn.ModuleList([ResBlock(mid_channels[0], mid_channels[0]) for _ in range(3)])
285
+ self.upsample1 = Upsample(mid_channels[0], mid_channels[0], frame_upsample_rate=frame_upsample_rate[0], joint_upsample_rate=joint_upsample_rate[0])
286
+ self.resnet2 = ResBlock(mid_channels[0], mid_channels[1])
287
+ self.resnet3 = nn.ModuleList([ResBlock(mid_channels[1], mid_channels[1]) for _ in range(3)])
288
+ self.upsample2 = Upsample(mid_channels[1], mid_channels[1], frame_upsample_rate=frame_upsample_rate[1], joint_upsample_rate=joint_upsample_rate[1])
289
+ # self.attn = Attention(
290
+ # query_dim=dim,
291
+ # dim_head=attention_head_dim,
292
+ # heads=num_attention_heads,
293
+ # qk_norm='layer_norm',
294
+ # eps=1e-6,
295
+ # bias=True,
296
+ # out_bias=True,
297
+ # processor=AttnProcessor(),
298
+ # )
299
+ self.conv_out = nn.Conv2d(mid_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
300
+
301
+ def forward(self, x):
302
+ x = self.conv_in(x)
303
+ for resnet in self.resnet1:
304
+ x = resnet(x)
305
+ x = self.upsample1(x)
306
+
307
+ x = self.resnet2(x)
308
+ for resnet in self.resnet3:
309
+ x = resnet(x)
310
+ x = self.upsample2(x)
311
+
312
+ # x = x + self.attn(x)
313
+ x = self.conv_out(x)
314
+
315
+ return x
316
+
317
+
318
+ class Upsample(nn.Module):
319
+ def __init__(
320
+ self,
321
+ in_channels,
322
+ out_channels,
323
+ upsample_rate=None,
324
+ frame_upsample_rate=None,
325
+ joint_upsample_rate=None,
326
+ ):
327
+ super(Upsample, self).__init__()
328
+
329
+ self.upsampler = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
330
+ self.upsample_rate = upsample_rate
331
+ self.frame_upsample_rate = frame_upsample_rate
332
+ self.joint_upsample_rate = joint_upsample_rate
333
+ self.upsample_rate = upsample_rate
334
+
335
+ def forward(self, inputs):
336
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
337
+ # split first frame
338
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
339
+
340
+ if self.upsample_rate is not None:
341
+ # import pdb; pdb.set_trace()
342
+ x_first = F.interpolate(x_first, scale_factor=self.upsample_rate)
343
+ x_rest = F.interpolate(x_rest, scale_factor=self.upsample_rate)
344
+ else:
345
+ # import pdb; pdb.set_trace()
346
+ # x_first = F.interpolate(x_first, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="bilinear", align_corners=True)
347
+ x_rest = F.interpolate(x_rest, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="bilinear", align_corners=True)
348
+ x_first = x_first[:, :, None, :]
349
+ inputs = torch.cat([x_first, x_rest], dim=2)
350
+ elif inputs.shape[2] > 1:
351
+ if self.upsample_rate is not None:
352
+ inputs = F.interpolate(inputs, scale_factor=self.upsample_rate)
353
+ else:
354
+ inputs = F.interpolate(inputs, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="bilinear", align_corners=True)
355
+ else:
356
+ inputs = inputs.squeeze(2)
357
+ if self.upsample_rate is not None:
358
+ inputs = F.interpolate(inputs, scale_factor=self.upsample_rate)
359
+ else:
360
+ inputs = F.interpolate(inputs, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="linear", align_corners=True)
361
+ inputs = inputs[:, :, None, :, :]
362
+
363
+ b, c, t, j = inputs.shape
364
+ inputs = inputs.permute(0, 2, 1, 3).reshape(b * t, c, j)
365
+ inputs = self.upsampler(inputs)
366
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3)
367
+
368
+ return inputs
369
+
370
+
371
+ class Downsample(nn.Module):
372
+ def __init__(
373
+ self,
374
+ in_channels,
375
+ out_channels,
376
+ frame_downsample_rate,
377
+ joint_downsample_rate
378
+ ):
379
+ super(Downsample, self).__init__()
380
+
381
+ self.frame_downsample_rate = frame_downsample_rate
382
+ self.joint_downsample_rate = joint_downsample_rate
383
+ self.joint_downsample = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=self.joint_downsample_rate, padding=1)
384
+
385
+ def forward(self, x):
386
+ # (batch_size, channels, frames, joints) -> (batch_size * joints, channels, frames)
387
+ if self.frame_downsample_rate > 1:
388
+ batch_size, channels, frames, joints = x.shape
389
+ x = x.permute(0, 3, 1, 2).reshape(batch_size * joints, channels, frames)
390
+ if x.shape[-1] % 2 == 1:
391
+ x_first, x_rest = x[..., 0], x[..., 1:]
392
+ if x_rest.shape[-1] > 0:
393
+ # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
394
+ x_rest = F.avg_pool1d(x_rest, kernel_size=self.frame_downsample_rate, stride=self.frame_downsample_rate)
395
+
396
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
397
+ # (batch_size * joints, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, joints)
398
+ x = x.reshape(batch_size, joints, channels, x.shape[-1]).permute(0, 2, 3, 1)
399
+ else:
400
+ # (batch_size * joints, channels, frames) -> (batch_size * joints, channels, frames // 2)
401
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
402
+ # (batch_size * joints, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
403
+ x = x.reshape(batch_size, joints, channels, x.shape[-1]).permute(0, 2, 3, 1)
404
+
405
+ # Pad the tensor
406
+ # pad = (0, 1)
407
+ # x = F.pad(x, pad, mode="constant", value=0)
408
+ batch_size, channels, frames, joints = x.shape
409
+ # (batch_size, channels, frames, joints) -> (batch_size * frames, channels, joints)
410
+ x = x.permute(0, 2, 1, 3).reshape(batch_size * frames, channels, joints)
411
+ x = self.joint_downsample(x)
412
+ # (batch_size * frames, channels, joints) -> (batch_size, channels, frames, joints)
413
+ x = x.reshape(batch_size, frames, x.shape[1], x.shape[2]).permute(0, 2, 1, 3)
414
+ return x
415
+
416
+
417
+
418
+ class ResBlock(nn.Module):
419
+ def __init__(self,
420
+ in_channels,
421
+ out_channels,
422
+ group_num=32,
423
+ max_channels=512):
424
+ super(ResBlock, self).__init__()
425
+ skip = max(1, max_channels // out_channels - 1)
426
+ self.block = nn.Sequential(
427
+ nn.GroupNorm(group_num, in_channels, eps=1e-06, affine=True),
428
+ nn.SiLU(),
429
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=skip, dilation=skip),
430
+ nn.GroupNorm(group_num, out_channels, eps=1e-06, affine=True),
431
+ nn.SiLU(),
432
+ nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0),
433
+ )
434
+ self.conv_short = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) if in_channels != out_channels else nn.Identity()
435
+
436
+ def forward(self, x):
437
+ hidden_states = self.block(x)
438
+ if hidden_states.shape != x.shape:
439
+ x = self.conv_short(x)
440
+ x = x + hidden_states
441
+ return x
442
+
443
+
444
+
445
+ class SMPL_VQVAE(nn.Module):
446
+ def __init__(self, encoder, decoder, vq):
447
+ super(SMPL_VQVAE, self).__init__()
448
+
449
+ self.encoder = encoder
450
+ self.decoder = decoder
451
+ self.vq = vq
452
+
453
+ def to(self, device):
454
+ self.encoder = self.encoder.to(device)
455
+ self.decoder = self.decoder.to(device)
456
+ self.vq = self.vq.to(device)
457
+ self.device = device
458
+ return self
459
+
460
+ def encdec_slice_frames(self, x, frame_batch_size, encdec, return_vq):
461
+ num_frames = x.shape[2]
462
+ remaining_frames = num_frames % frame_batch_size
463
+ x_output = []
464
+ loss_output = []
465
+ perplexity_output = []
466
+ for i in range(num_frames // frame_batch_size):
467
+ remaining_frames = num_frames % frame_batch_size
468
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
469
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
470
+ x_intermediate = x[:, :, start_frame:end_frame]
471
+ x_intermediate = encdec(x_intermediate)
472
+ # if encdec == self.encoder and self.vq is not None:
473
+ # x_intermediate, loss, perplexity = self.vq(x_intermediate)
474
+ # x_output.append(x_intermediate)
475
+ # loss_output.append(loss)
476
+ # perplexity_output.append(perplexity)
477
+ # else:
478
+ # x_output.append(x_intermediate)
479
+ x_output.append(x_intermediate)
480
+ if encdec == self.encoder and self.vq is not None and not self.vq.is_train:
481
+ x_output, loss = self.vq(torch.cat(x_output, dim=2), return_vq=return_vq)
482
+ return x_output, loss
483
+ elif encdec == self.encoder and self.vq is not None and self.vq.is_train:
484
+ x_output, loss, preplexity = self.vq(torch.cat(x_output, dim=2))
485
+ return x_output, loss, preplexity
486
+ else:
487
+ return torch.cat(x_output, dim=2), None, None
488
+
489
+ def forward(self, x, return_vq=False):
490
+ x = x.permute(0, 3, 1, 2)
491
+ if not self.vq.is_train:
492
+ x, loss = self.encdec_slice_frames(x, frame_batch_size=8, encdec=self.encoder, return_vq=return_vq)
493
+ else:
494
+ x, loss, perplexity = self.encdec_slice_frames(x, frame_batch_size=8, encdec=self.encoder, return_vq=return_vq)
495
+ if return_vq:
496
+ return x, loss
497
+ x, _, _ = self.encdec_slice_frames(x, frame_batch_size=2, encdec=self.decoder, return_vq=return_vq)
498
+ x = x.permute(0, 2, 3, 1)
499
+ if self.vq.is_train:
500
+ return x, loss, perplexity
501
+ return x, loss
motion_extractor.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # motion_extractor.py
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import torch
6
+ import pickle
7
+ import torchvision
8
+
9
+ # Load the TorchScript model once at the top
10
+ model_path = '/gemini/space/human_guozz2/dyb/MTVCrafter-main/nlf_l_multi_0.3.2.torchscript'
11
+ assert os.path.exists(model_path), f"Model file not found at {model_path}"
12
+ model = torch.jit.load(model_path).cuda().eval()
13
+
14
+ def extract_pkl_from_video(video_path):
15
+ output_file = "temp_motion.pkl"
16
+ cap = cv2.VideoCapture(video_path)
17
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
18
+ video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
19
+ video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
20
+
21
+ pose_results = {
22
+ 'joints3d_nonparam': [],
23
+ }
24
+
25
+ with torch.inference_mode(), torch.device('cuda'):
26
+ frame_idx = 0
27
+ while cap.isOpened():
28
+ ret, frame = cap.read()
29
+ if not ret:
30
+ break
31
+
32
+ # Convert frame to tensor
33
+ frame_tensor = torch.from_numpy(frame).cuda()
34
+ frame_batch = frame_tensor.unsqueeze(0).permute(0,3,1,2)
35
+ # Model inference
36
+ pred = model.detect_smpl_batched(frame_batch)
37
+ # Collect pose data
38
+ for key in pose_results.keys():
39
+ if key in pred:
40
+ #pose_results[key].append(pred[key].cpu().numpy())
41
+ pose_results[key].append(pred[key])
42
+ else:
43
+ pose_results[key].append(None)
44
+
45
+ frame_idx += 1
46
+
47
+ cap.release()
48
+
49
+ # Prepare output data
50
+ output_data = {
51
+ 'video_path': video_path,
52
+ 'video_length': frame_count,
53
+ 'video_width': video_width,
54
+ 'video_height': video_height,
55
+ 'pose': pose_results
56
+ }
57
+
58
+ # Save to pkl file
59
+ with open(output_file, 'wb') as f:
60
+ pickle.dump(output_data, f)
61
+
62
+ return output_file
utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+
7
+ def concat_images(images, direction='horizontal', pad=0, pad_value=0):
8
+ if len(images) == 1:
9
+ return images[0]
10
+ is_pil = isinstance(images[0], Image.Image)
11
+ if is_pil:
12
+ images = [np.array(image) for image in images]
13
+ if direction == 'horizontal':
14
+ height = max([image.shape[0] for image in images])
15
+ width = sum([image.shape[1] for image in images]) + pad * (len(images) - 1)
16
+ new_image = np.full((height, width, images[0].shape[2]), pad_value, dtype=images[0].dtype)
17
+ begin = 0
18
+ for image in images:
19
+ end = begin + image.shape[1]
20
+ new_image[: image.shape[0], begin:end] = image
21
+ begin = end + pad
22
+ elif direction == 'vertical':
23
+ height = sum([image.shape[0] for image in images]) + pad * (len(images) - 1)
24
+ width = max([image.shape[1] for image in images])
25
+ new_image = np.full((height, width, images[0].shape[2]), pad_value, dtype=images[0].dtype)
26
+ begin = 0
27
+ for image in images:
28
+ end = begin + image.shape[0]
29
+ new_image[begin:end, : image.shape[1]] = image
30
+ begin = end + pad
31
+ else:
32
+ assert False
33
+ if is_pil:
34
+ new_image = Image.fromarray(new_image)
35
+ return new_image
36
+
37
+ def concat_images_grid(images, cols, pad=0, pad_value=0):
38
+ new_images = []
39
+ while len(images) > 0:
40
+ new_image = concat_images(images[:cols], pad=pad, pad_value=pad_value)
41
+ new_images.append(new_image)
42
+ images = images[cols:]
43
+ new_image = concat_images(new_images, direction='vertical', pad=pad, pad_value=pad_value)
44
+ return new_image
45
+
46
+ def sample_video(video, indexes, method=2):
47
+ if method == 1:
48
+ frames = video.get_batch(indexes)
49
+ frames = frames.numpy() if isinstance(frames, torch.Tensor) else frames.asnumpy()
50
+ elif method == 2:
51
+ max_idx = indexes.max() + 1
52
+ all_indexes = np.arange(max_idx, dtype=int)
53
+ frames = video.get_batch(all_indexes)
54
+ frames = frames.numpy() if isinstance(frames, torch.Tensor) else frames.asnumpy()
55
+ frames = frames[indexes]
56
+ else:
57
+ assert False
58
+ return frames
59
+
60
+ def get_sample_indexes(video_length, num_frames, stride):
61
+ assert num_frames * stride <= video_length
62
+ sample_length = min(video_length, (num_frames - 1) * stride + 1)
63
+ start_idx = 0 + random.randint(0, video_length - sample_length)
64
+ sample_indexes = np.linspace(start_idx, start_idx + sample_length - 1, num_frames, dtype=int)
65
+ return sample_indexes
66
+
67
+ def get_new_height_width(data_dict, dst_height, dst_width):
68
+ height = data_dict['video_height']
69
+ width = data_dict['video_width']
70
+ if float(dst_height) / height < float(dst_width) / width:
71
+ new_height = int(round(float(dst_width) / width * height))
72
+ new_width = dst_width
73
+ else:
74
+ new_height = dst_height
75
+ new_width = int(round(float(dst_height) / height * width))
76
+ assert dst_width <= new_width and dst_height <= new_height
77
+ return new_height, new_width