Spaces:
Runtime error
Runtime error
Upload 32 files
Browse files- __pycache__/draw_pose.cpython-313.pyc +0 -0
- __pycache__/inference_engine.cpython-313.pyc +0 -0
- __pycache__/motion_extractor.cpython-313.pyc +0 -0
- __pycache__/utils.cpython-313.pyc +0 -0
- app.py +103 -0
- data/mean.npy +3 -0
- data/std.npy +3 -0
- draw_pose.py +115 -0
- inference_engine.py +117 -0
- models/__init__.py +2 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/__init__.cpython-313.pyc +0 -0
- models/dit/__init__.py +2 -0
- models/dit/__pycache__/__init__.cpython-311.pyc +0 -0
- models/dit/__pycache__/__init__.cpython-313.pyc +0 -0
- models/dit/__pycache__/mvdit_transformer.cpython-311.pyc +0 -0
- models/dit/__pycache__/mvdit_transformer.cpython-313.pyc +0 -0
- models/dit/__pycache__/pipeline_mtvcrafter.cpython-311.pyc +0 -0
- models/dit/__pycache__/pipeline_mtvcrafter.cpython-313.pyc +0 -0
- models/dit/mvdit_transformer.py +758 -0
- models/dit/pipeline_mtvcrafter.py +728 -0
- models/motion4d/__init__.py +2 -0
- models/motion4d/__pycache__/__init__.cpython-311.pyc +0 -0
- models/motion4d/__pycache__/__init__.cpython-313.pyc +0 -0
- models/motion4d/__pycache__/loss.cpython-311.pyc +0 -0
- models/motion4d/__pycache__/loss.cpython-313.pyc +0 -0
- models/motion4d/__pycache__/vqvae.cpython-311.pyc +0 -0
- models/motion4d/__pycache__/vqvae.cpython-313.pyc +0 -0
- models/motion4d/loss.py +30 -0
- models/motion4d/vqvae.py +501 -0
- motion_extractor.py +62 -0
- utils.py +77 -0
__pycache__/draw_pose.cpython-313.pyc
ADDED
Binary file (5.18 kB). View file
|
|
__pycache__/inference_engine.cpython-313.pyc
ADDED
Binary file (7.56 kB). View file
|
|
__pycache__/motion_extractor.cpython-313.pyc
ADDED
Binary file (2.97 kB). View file
|
|
__pycache__/utils.cpython-313.pyc
ADDED
Binary file (4.8 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
from inference_engine import run_inference
|
7 |
+
from motion_extractor import extract_pkl_from_video
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
+
def full_pipeline(video_file, ref_image=None, width=512, height=512, steps=50, scale=3.0, seed=6666):
|
12 |
+
# 1. 提取 motion pkl
|
13 |
+
video_path = video_file.name
|
14 |
+
motion_pkl_path = extract_pkl_from_video(video_path)
|
15 |
+
gr.Info("⏳ Extract motion finished and begin animation...", visible=True)
|
16 |
+
|
17 |
+
# 2. 处理参考图像(可选)
|
18 |
+
if ref_image is not None:
|
19 |
+
ref_path = "temp_ref.png"
|
20 |
+
ref_image.save(ref_path)
|
21 |
+
else:
|
22 |
+
ref_path = ""
|
23 |
+
|
24 |
+
# 3. 推理
|
25 |
+
output_path = run_inference(
|
26 |
+
device,
|
27 |
+
motion_pkl_path,
|
28 |
+
ref_path,
|
29 |
+
dst_width=width,
|
30 |
+
dst_height=height,
|
31 |
+
num_inference_steps=steps,
|
32 |
+
guidance_scale=scale,
|
33 |
+
seed=seed,
|
34 |
+
)
|
35 |
+
|
36 |
+
return output_path
|
37 |
+
|
38 |
+
|
39 |
+
def run_pipeline_with_feedback(video_file, ref_image, width, height, steps, scale, seed):
|
40 |
+
try:
|
41 |
+
if video_file is None:
|
42 |
+
raise gr.Error("Please upload a dancing video (.mp4/.mov/.avi).")
|
43 |
+
# 添加进度提示
|
44 |
+
gr.Info("⏳ Processing... Please wait several minutes.", visible=True)
|
45 |
+
result = full_pipeline(video_file, ref_image, width, height, steps, scale, seed)
|
46 |
+
gr.Info("✅ Inference done, please enjoy it!", visible=True)
|
47 |
+
return result
|
48 |
+
except Exception as e:
|
49 |
+
traceback.print_exc()
|
50 |
+
gr.Warning("⚠️ Inference failed: " + str(e))
|
51 |
+
return None
|
52 |
+
|
53 |
+
# 构建 UI
|
54 |
+
with gr.Blocks(title="MTVCrafter Inference Demo") as demo:
|
55 |
+
gr.Markdown(
|
56 |
+
"""
|
57 |
+
# 🎨💃 MTVCrafter Inference Demo
|
58 |
+
|
59 |
+
💡 **Tip:** Upload a dancing video in **MP4/MOV/AVI** format, and optionally a reference image (e.g., PNG or JPG).
|
60 |
+
This demo will extract human motion from the input video and animate the reference image accordingly.
|
61 |
+
If no reference image is provided, the **first frame** of the video will be used as the reference.
|
62 |
+
|
63 |
+
🎞️ **Note:** The generated output video will contain exactly **49 frames**.
|
64 |
+
"""
|
65 |
+
)
|
66 |
+
|
67 |
+
with gr.Row():
|
68 |
+
with gr.Column(scale=1):
|
69 |
+
video_input = gr.File(label="📹 Input Video (Required)", file_types=[".mp4", ".mov", ".avi"])
|
70 |
+
video_preview = gr.Video(label="👀 Preview of Uploaded Video", height=280) # 固定高度,避免对齐错位
|
71 |
+
|
72 |
+
def show_video_preview(video_file):
|
73 |
+
return video_file.name if video_file else None
|
74 |
+
|
75 |
+
video_input.change(fn=show_video_preview, inputs=video_input, outputs=video_preview)
|
76 |
+
|
77 |
+
with gr.Column(scale=1):
|
78 |
+
ref_image = gr.Image(type="pil", label="🖼️ Reference Image (Optional)", height=538)
|
79 |
+
|
80 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
81 |
+
with gr.Row():
|
82 |
+
width = gr.Slider(384, 1024, value=512, step=16, label="Output Width")
|
83 |
+
height = gr.Slider(384, 1024, value=512, step=16, label="Output Height")
|
84 |
+
with gr.Row():
|
85 |
+
steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
|
86 |
+
scale = gr.Slider(0.0, 10.0, value=3.0, step=0.25, label="Guidance Scale")
|
87 |
+
seed = gr.Number(value=6666, label="Random Seed")
|
88 |
+
|
89 |
+
with gr.Row(scale=1):
|
90 |
+
output_video = gr.Video(label="🎬 Generated Video", interactive=False)
|
91 |
+
|
92 |
+
run_btn = gr.Button("🚀 Run MTVCrafter", variant="primary")
|
93 |
+
|
94 |
+
run_btn.click(
|
95 |
+
fn=run_pipeline_with_feedback,
|
96 |
+
inputs=[video_input, ref_image, width, height, steps, scale, seed],
|
97 |
+
outputs=output_video,
|
98 |
+
)
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"
|
102 |
+
os.environ["NO_PROXY"] = "localhost,127.0.0.1/8,::1"
|
103 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
data/mean.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ababeaabf5ac096ce7c7714ada14aa1de8355c0016de25695be611d51285141
|
3 |
+
size 416
|
data/std.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:650e46902a0878e6947be401e4e1995e54a8fd407f2be3ded0dda62bda99a9b3
|
3 |
+
size 416
|
draw_pose.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def intrinsic_matrix_from_field_of_view(imshape, fov_degrees:float =55 ): # nlf default fov_degrees 55
|
8 |
+
imshape = np.array(imshape)
|
9 |
+
fov_radians = fov_degrees * np.array(np.pi / 180)
|
10 |
+
larger_side = np.max(imshape)
|
11 |
+
focal_length = larger_side / (np.tan(fov_radians / 2) * 2)
|
12 |
+
# intrinsic_matrix 3*3
|
13 |
+
return np.array([
|
14 |
+
[focal_length, 0, imshape[1] / 2],
|
15 |
+
[0, focal_length, imshape[0] / 2],
|
16 |
+
[0, 0, 1],
|
17 |
+
])
|
18 |
+
|
19 |
+
|
20 |
+
def p3d_to_p2d(point_3d, height, width): # point3d n*1024*3
|
21 |
+
camera_matrix = intrinsic_matrix_from_field_of_view((height,width))
|
22 |
+
camera_matrix = np.expand_dims(camera_matrix, axis=0)
|
23 |
+
camera_matrix = np.expand_dims(camera_matrix, axis=0) # 1*1*3*3
|
24 |
+
point_3d = np.expand_dims(point_3d,axis=-1) # n*1024*3*1
|
25 |
+
point_2d = (camera_matrix@point_3d).squeeze(-1)
|
26 |
+
point_2d[:,:,:2] = point_2d[:,:,:2]/point_2d[:,:,2:3]
|
27 |
+
return point_2d[:,:,:] # n*1024*2
|
28 |
+
|
29 |
+
|
30 |
+
def get_pose_images(smpl_data, offset):
|
31 |
+
pose_images = []
|
32 |
+
for data in smpl_data:
|
33 |
+
if isinstance(data, np.ndarray):
|
34 |
+
joints3d = data
|
35 |
+
else:
|
36 |
+
joints3d = data.numpy()
|
37 |
+
canvas = np.zeros(shape=(offset[0], offset[1], 3), dtype=np.uint8)
|
38 |
+
joints3d = p3d_to_p2d(joints3d, offset[0], offset[1])
|
39 |
+
canvas = draw_3d_points(canvas, joints3d[0], stickwidth=int(offset[1]/350))
|
40 |
+
pose_images.append(Image.fromarray(canvas))
|
41 |
+
return pose_images
|
42 |
+
|
43 |
+
|
44 |
+
def draw_3d_points(canvas, points, stickwidth=2, r=2, draw_line=True):
|
45 |
+
colors = [
|
46 |
+
[255, 0, 0], # 0
|
47 |
+
[0, 255, 0], # 1
|
48 |
+
[0, 0, 255], # 2
|
49 |
+
[255, 0, 255], # 3
|
50 |
+
[255, 255, 0], # 4
|
51 |
+
[85, 255, 0], # 5
|
52 |
+
[0, 75, 255], # 6
|
53 |
+
[0, 255, 85], # 7
|
54 |
+
[0, 255, 170], # 8
|
55 |
+
[170, 0, 255], # 9
|
56 |
+
[85, 0, 255], # 10
|
57 |
+
[0, 85, 255], # 11
|
58 |
+
[0, 255, 255], # 12
|
59 |
+
[85, 0, 255], # 13
|
60 |
+
[170, 0, 255], # 14
|
61 |
+
[255, 0, 255], # 15
|
62 |
+
[255, 0, 170], # 16
|
63 |
+
[255, 0, 85], # 17
|
64 |
+
]
|
65 |
+
connetions = [
|
66 |
+
[15,12],[12, 16],[16, 18],[18, 20],[20, 22],
|
67 |
+
[12,17],[17,19],[19,21],
|
68 |
+
[21,23],[12,9],[9,6],
|
69 |
+
[6,3],[3,0],[0,1],
|
70 |
+
[1,4],[4,7],[7,10],[0,2],[2,5],[5,8],[8,11]
|
71 |
+
]
|
72 |
+
connection_colors = [
|
73 |
+
[255, 0, 0], # 0
|
74 |
+
[0, 255, 0], # 1
|
75 |
+
[0, 0, 255], # 2
|
76 |
+
[255, 255, 0], # 3
|
77 |
+
[255, 0, 255], # 4
|
78 |
+
[0, 255, 0], # 5
|
79 |
+
[0, 85, 255], # 6
|
80 |
+
[255, 175, 0], # 7
|
81 |
+
[0, 0, 255], # 8
|
82 |
+
[255, 85, 0], # 9
|
83 |
+
[0, 255, 85], # 10
|
84 |
+
[255, 0, 255], # 11
|
85 |
+
[255, 0, 0], # 12
|
86 |
+
[0, 175, 255], # 13
|
87 |
+
[255, 255, 0], # 14
|
88 |
+
[0, 0, 255], # 15
|
89 |
+
[0, 255, 0], # 16
|
90 |
+
]
|
91 |
+
|
92 |
+
# draw point
|
93 |
+
for i in range(len(points)):
|
94 |
+
x,y = points[i][0:2]
|
95 |
+
x,y = int(x),int(y)
|
96 |
+
if i==13 or i == 14:
|
97 |
+
continue
|
98 |
+
cv2.circle(canvas, (x, y), r, colors[i%17], thickness=-1)
|
99 |
+
|
100 |
+
# draw line
|
101 |
+
if draw_line:
|
102 |
+
for i in range(len(connetions)):
|
103 |
+
point1_idx,point2_idx = connetions[i][0:2]
|
104 |
+
point1 = points[point1_idx]
|
105 |
+
point2 = points[point2_idx]
|
106 |
+
Y = [point2[0],point1[0]]
|
107 |
+
X = [point2[1],point1[1]]
|
108 |
+
mX = int(np.mean(X))
|
109 |
+
mY = int(np.mean(Y))
|
110 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
111 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
112 |
+
polygon = cv2.ellipse2Poly((mY, mX), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
113 |
+
cv2.fillConvexPoly(canvas, polygon, connection_colors[i%17])
|
114 |
+
|
115 |
+
return canvas
|
inference_engine.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# inference_engine.py
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import decord
|
5 |
+
import imageio
|
6 |
+
from PIL import Image
|
7 |
+
from models import MTVCrafterPipeline, Encoder, VectorQuantizer, Decoder, SMPL_VQVAE
|
8 |
+
from torchvision.transforms import ToPILImage, transforms, InterpolationMode, functional as F
|
9 |
+
import numpy as np
|
10 |
+
import pickle
|
11 |
+
import copy
|
12 |
+
from draw_pose import get_pose_images
|
13 |
+
from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width
|
14 |
+
|
15 |
+
def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, dst_height=512, num_inference_steps=50, guidance_scale=3.0, seed=6666):
|
16 |
+
num_frames = 49
|
17 |
+
to_pil = ToPILImage()
|
18 |
+
normalize = transforms.Normalize([0.5], [0.5])
|
19 |
+
pretrained_model_path = "/gemini/space/human_guozz2/dyb/models/CogVideoX"
|
20 |
+
transformer_path = "/gemini/space/human_guozz2/dyb/models/MTVCrafter/MV-DiT/CogVideoX"
|
21 |
+
tokenizer_path = "/gemini/space/human_guozz2/dyb/models/MTVCrafter/4DMoT/mp_rank_00_model_states.pt"
|
22 |
+
|
23 |
+
with open(motion_data_path, 'rb') as f:
|
24 |
+
data_list = pickle.load(f)
|
25 |
+
if not isinstance(data_list, list):
|
26 |
+
data_list = [data_list]
|
27 |
+
|
28 |
+
pe_mean = np.load('data/mean.npy')
|
29 |
+
pe_std = np.load('data/std.npy')
|
30 |
+
|
31 |
+
pipe = MTVCrafterPipeline.from_pretrained(
|
32 |
+
model_path=pretrained_model_path,
|
33 |
+
transformer_model_path=transformer_path,
|
34 |
+
torch_dtype=torch.bfloat16,
|
35 |
+
scheduler_type='dpm',
|
36 |
+
).to(device)
|
37 |
+
pipe.vae.enable_tiling()
|
38 |
+
pipe.vae.enable_slicing()
|
39 |
+
|
40 |
+
# load VQVAE
|
41 |
+
state_dict = torch.load(tokenizer_path, map_location="cpu")
|
42 |
+
motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1])
|
43 |
+
motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False)
|
44 |
+
motion_decoder = Decoder(in_channels=3072, mid_channels=[512, 128], out_channels=3, upsample_rate=2.0, frame_upsample_rate=[2.0, 2.0], joint_upsample_rate=[1.0, 1.0])
|
45 |
+
vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device)
|
46 |
+
vqvae.load_state_dict(state_dict['module'], strict=True)
|
47 |
+
|
48 |
+
# 这里只跑第一个样本
|
49 |
+
data = data_list[0]
|
50 |
+
new_height, new_width = get_new_height_width(data, dst_height, dst_width)
|
51 |
+
x1 = (new_width - dst_width) // 2
|
52 |
+
y1 = (new_height - dst_height) // 2
|
53 |
+
|
54 |
+
sample_indexes = get_sample_indexes(data['video_length'], num_frames, stride=1)
|
55 |
+
input_images = sample_video(decord.VideoReader(data['video_path']), sample_indexes)
|
56 |
+
input_images = torch.from_numpy(input_images).permute(0, 3, 1, 2).contiguous()
|
57 |
+
input_images = F.resize(input_images, (new_height, new_width), InterpolationMode.BILINEAR)
|
58 |
+
input_images = F.crop(input_images, y1, x1, dst_height, dst_width)
|
59 |
+
|
60 |
+
if ref_image_path != '':
|
61 |
+
ref_image = Image.open(ref_image_path).convert("RGB")
|
62 |
+
ref_image = torch.from_numpy(np.array(ref_image)).permute(2, 0, 1).contiguous()
|
63 |
+
ref_images = torch.stack([ref_image.clone() for _ in range(num_frames)])
|
64 |
+
ref_images = F.resize(ref_images, (new_height, new_width), InterpolationMode.BILINEAR)
|
65 |
+
ref_images = F.crop(ref_images, y1, x1, dst_height, dst_width)
|
66 |
+
else:
|
67 |
+
ref_images = copy.deepcopy(input_images)
|
68 |
+
frame0 = input_images[0]
|
69 |
+
ref_images[:, :, :, :] = frame0
|
70 |
+
|
71 |
+
try:
|
72 |
+
smpl_poses = np.array([pose[0][0].cpu().numpy() for pose in data['pose']['joints3d_nonparam']])
|
73 |
+
poses = smpl_poses[sample_indexes]
|
74 |
+
except:
|
75 |
+
poses = data['pose'][sample_indexes]
|
76 |
+
norm_poses = torch.tensor((poses - pe_mean) / pe_std)
|
77 |
+
|
78 |
+
offset = [data['video_height'], data['video_width'], 0]
|
79 |
+
pose_images_before = get_pose_images(copy.deepcopy(poses), offset)
|
80 |
+
pose_images_before = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_before]
|
81 |
+
input_smpl_joints = norm_poses.unsqueeze(0).to(device)
|
82 |
+
motion_tokens, vq_loss = vqvae(input_smpl_joints, return_vq=True)
|
83 |
+
output_motion, _ = vqvae(input_smpl_joints)
|
84 |
+
pose_images_after = get_pose_images(output_motion[0].cpu().detach() * pe_std + pe_mean, offset)
|
85 |
+
pose_images_after = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_after]
|
86 |
+
|
87 |
+
# normalize images
|
88 |
+
input_images = input_images / 255.0
|
89 |
+
ref_images = ref_images / 255.0
|
90 |
+
input_images = normalize(input_images)
|
91 |
+
ref_images = normalize(ref_images)
|
92 |
+
|
93 |
+
# infer
|
94 |
+
output_images = pipe(
|
95 |
+
height=dst_height,
|
96 |
+
width=dst_width,
|
97 |
+
num_frames=num_frames,
|
98 |
+
num_inference_steps=num_inference_steps,
|
99 |
+
guidance_scale=guidance_scale,
|
100 |
+
seed=seed,
|
101 |
+
ref_images=ref_images,
|
102 |
+
motion_embeds=motion_tokens,
|
103 |
+
joint_mean=pe_mean,
|
104 |
+
joint_std=pe_std,
|
105 |
+
).frames[0]
|
106 |
+
|
107 |
+
# save result
|
108 |
+
vis_images = []
|
109 |
+
for k in range(len(output_images)):
|
110 |
+
vis_image = [to_pil(((input_images[k] + 1) * 127.5).clamp(0, 255).to(torch.uint8)), pose_images_before[k], pose_images_after[k], output_images[k]]
|
111 |
+
vis_image = concat_images_grid(vis_image, cols=len(vis_image), pad=2)
|
112 |
+
vis_images.append(vis_image)
|
113 |
+
|
114 |
+
output_path = "output.mp4"
|
115 |
+
imageio.mimsave(output_path, vis_images, fps=15)
|
116 |
+
|
117 |
+
return output_path
|
models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .dit import *
|
2 |
+
from .motion4d import *
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (235 Bytes). View file
|
|
models/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (205 Bytes). View file
|
|
models/dit/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .mvdit_transformer import Transformer3DModel
|
2 |
+
from .pipeline_mtvcrafter import MTVCrafterPipeline
|
models/dit/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (333 Bytes). View file
|
|
models/dit/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (287 Bytes). View file
|
|
models/dit/__pycache__/mvdit_transformer.cpython-311.pyc
ADDED
Binary file (38.1 kB). View file
|
|
models/dit/__pycache__/mvdit_transformer.cpython-313.pyc
ADDED
Binary file (35.1 kB). View file
|
|
models/dit/__pycache__/pipeline_mtvcrafter.cpython-311.pyc
ADDED
Binary file (39.5 kB). View file
|
|
models/dit/__pycache__/pipeline_mtvcrafter.cpython-313.pyc
ADDED
Binary file (37 kB). View file
|
|
models/dit/mvdit_transformer.py
ADDED
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
7 |
+
from diffusers.models.attention import Attention, FeedForward
|
8 |
+
from diffusers.models.attention_processor import AttentionProcessor, FusedCogVideoXAttnProcessor2_0
|
9 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
10 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
12 |
+
from diffusers.models.normalization import AdaLayerNorm
|
13 |
+
from diffusers.utils import is_torch_version, logging
|
14 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
15 |
+
|
16 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
17 |
+
|
18 |
+
|
19 |
+
def apply_rotary_emb(
|
20 |
+
x: torch.Tensor,
|
21 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
22 |
+
use_real: bool = True,
|
23 |
+
use_real_unbind_dim: int = -1,
|
24 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
25 |
+
"""
|
26 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
27 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
28 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
29 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
x (`torch.Tensor`):
|
33 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
34 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
38 |
+
"""
|
39 |
+
if use_real:
|
40 |
+
cos, sin = freqs_cis # [S, D]
|
41 |
+
cos = cos[None, None]
|
42 |
+
sin = sin[None, None]
|
43 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
44 |
+
|
45 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
46 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
47 |
+
|
48 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
49 |
+
|
50 |
+
return out
|
51 |
+
else:
|
52 |
+
# used for lumina
|
53 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
54 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
55 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
56 |
+
|
57 |
+
return x_out.type_as(x)
|
58 |
+
|
59 |
+
|
60 |
+
class CogVideoXLayerNormZero(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
conditioning_dim: int,
|
64 |
+
embedding_dim: int,
|
65 |
+
elementwise_affine: bool = True,
|
66 |
+
eps: float = 1e-5,
|
67 |
+
bias: bool = True,
|
68 |
+
) -> None:
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
self.silu = nn.SiLU()
|
72 |
+
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
73 |
+
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
74 |
+
|
75 |
+
def forward(
|
76 |
+
self, hidden_states: torch.Tensor, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
77 |
+
shift, scale, gate, _, _, _ = self.linear(self.silu(temb)).chunk(6, dim=1)
|
78 |
+
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
79 |
+
return hidden_states, gate[:, None, :]
|
80 |
+
|
81 |
+
|
82 |
+
class CogVideoXAttnProcessor1_0:
|
83 |
+
r"""Processor for implementing scaled dot-product attention for the
|
84 |
+
CogVideoX model.
|
85 |
+
|
86 |
+
It applies a rotary embedding on query and key vectors, but does not include spatial normalization.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self):
|
90 |
+
if not hasattr(F, 'scaled_dot_product_attention'):
|
91 |
+
raise ImportError('CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.')
|
92 |
+
|
93 |
+
def __call__(
|
94 |
+
self,
|
95 |
+
attn: Attention,
|
96 |
+
hidden_states: torch.Tensor,
|
97 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
98 |
+
attention_mask: Optional[torch.Tensor] = None,
|
99 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
100 |
+
motion_rotary_emb: Optional[torch.Tensor] = None,
|
101 |
+
) -> torch.Tensor:
|
102 |
+
|
103 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
104 |
+
|
105 |
+
if attention_mask is not None:
|
106 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
107 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
108 |
+
|
109 |
+
query = attn.to_q(hidden_states)
|
110 |
+
key = attn.to_k(encoder_hidden_states)
|
111 |
+
value = attn.to_v(encoder_hidden_states)
|
112 |
+
|
113 |
+
inner_dim = key.shape[-1]
|
114 |
+
head_dim = inner_dim // attn.heads
|
115 |
+
|
116 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [batch_size, heads, seq_len, dim]
|
117 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
118 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
119 |
+
|
120 |
+
if attn.norm_q is not None:
|
121 |
+
query = attn.norm_q(query)
|
122 |
+
if attn.norm_k is not None:
|
123 |
+
key = attn.norm_k(key)
|
124 |
+
|
125 |
+
# Apply RoPE if needed
|
126 |
+
if image_rotary_emb is not None:
|
127 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
128 |
+
if motion_rotary_emb is not None:
|
129 |
+
key = apply_rotary_emb(key, motion_rotary_emb)
|
130 |
+
|
131 |
+
hidden_states = F.scaled_dot_product_attention(
|
132 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
133 |
+
)
|
134 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
135 |
+
|
136 |
+
# linear proj
|
137 |
+
hidden_states = attn.to_out[0](hidden_states)
|
138 |
+
# dropout
|
139 |
+
hidden_states = attn.to_out[1](hidden_states)
|
140 |
+
|
141 |
+
return hidden_states
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
class CogVideoXAttnProcessor2_0:
|
146 |
+
r"""Processor for implementing scaled dot-product attention for the
|
147 |
+
CogVideoX model.
|
148 |
+
|
149 |
+
It applies a rotary embedding on query and key vectors, but does not include spatial normalization.
|
150 |
+
"""
|
151 |
+
|
152 |
+
def __init__(self):
|
153 |
+
if not hasattr(F, 'scaled_dot_product_attention'):
|
154 |
+
raise ImportError('CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.')
|
155 |
+
|
156 |
+
def __call__(
|
157 |
+
self,
|
158 |
+
attn: Attention,
|
159 |
+
hidden_states: torch.Tensor,
|
160 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
161 |
+
attention_mask: Optional[torch.Tensor] = None,
|
162 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
163 |
+
motion_rotary_emb: Optional[torch.Tensor] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
|
166 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
167 |
+
|
168 |
+
if attention_mask is not None:
|
169 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
170 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
171 |
+
|
172 |
+
query = attn.to_q(hidden_states)
|
173 |
+
key = attn.to_k(hidden_states)
|
174 |
+
value = attn.to_v(hidden_states)
|
175 |
+
|
176 |
+
inner_dim = key.shape[-1]
|
177 |
+
head_dim = inner_dim // attn.heads
|
178 |
+
|
179 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [batch_size, heads, seq_len, dim]
|
180 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
181 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
182 |
+
|
183 |
+
if attn.norm_q is not None:
|
184 |
+
query = attn.norm_q(query)
|
185 |
+
if attn.norm_k is not None:
|
186 |
+
key = attn.norm_k(key)
|
187 |
+
|
188 |
+
# Apply RoPE if needed
|
189 |
+
if image_rotary_emb is not None:
|
190 |
+
image_seq_length = image_rotary_emb[0].shape[0]
|
191 |
+
query[:, :, :image_seq_length] = apply_rotary_emb(query[:, :, :image_seq_length], image_rotary_emb)
|
192 |
+
if motion_rotary_emb is not None:
|
193 |
+
query[:, :, image_seq_length:] = apply_rotary_emb(query[:, :, image_seq_length:], motion_rotary_emb)
|
194 |
+
if not attn.is_cross_attention:
|
195 |
+
key[:, :, :image_seq_length] = apply_rotary_emb(key[:, :, :image_seq_length], image_rotary_emb)
|
196 |
+
if motion_rotary_emb is not None:
|
197 |
+
key[:, :, image_seq_length:] = apply_rotary_emb(key[:, :, image_seq_length:], motion_rotary_emb)
|
198 |
+
|
199 |
+
hidden_states = F.scaled_dot_product_attention(
|
200 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
201 |
+
)
|
202 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
203 |
+
|
204 |
+
# linear proj
|
205 |
+
hidden_states = attn.to_out[0](hidden_states)
|
206 |
+
# dropout
|
207 |
+
hidden_states = attn.to_out[1](hidden_states)
|
208 |
+
|
209 |
+
return hidden_states
|
210 |
+
|
211 |
+
|
212 |
+
class CogVideoXPatchEmbed(nn.Module):
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
patch_size: int = 2,
|
216 |
+
in_channels: int = 16,
|
217 |
+
embed_dim: int = 1920,
|
218 |
+
text_embed_dim: int = 4096,
|
219 |
+
bias: bool = True,
|
220 |
+
sample_width: int = 90,
|
221 |
+
sample_height: int = 60,
|
222 |
+
sample_frames: int = 49,
|
223 |
+
temporal_compression_ratio: int = 4,
|
224 |
+
max_text_seq_length: int = 226,
|
225 |
+
spatial_interpolation_scale: float = 1.875,
|
226 |
+
temporal_interpolation_scale: float = 1.0,
|
227 |
+
use_positional_embeddings: bool = True,
|
228 |
+
) -> None:
|
229 |
+
super().__init__()
|
230 |
+
|
231 |
+
self.patch_size = patch_size
|
232 |
+
self.embed_dim = embed_dim
|
233 |
+
self.sample_height = sample_height
|
234 |
+
self.sample_width = sample_width
|
235 |
+
self.sample_frames = sample_frames
|
236 |
+
self.temporal_compression_ratio = temporal_compression_ratio
|
237 |
+
self.max_text_seq_length = max_text_seq_length
|
238 |
+
self.spatial_interpolation_scale = spatial_interpolation_scale
|
239 |
+
self.temporal_interpolation_scale = temporal_interpolation_scale
|
240 |
+
self.use_positional_embeddings = use_positional_embeddings
|
241 |
+
|
242 |
+
self.proj = nn.Conv2d(
|
243 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
244 |
+
)
|
245 |
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
246 |
+
|
247 |
+
if use_positional_embeddings:
|
248 |
+
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
249 |
+
self.register_buffer('pos_embedding', pos_embedding, persistent=False)
|
250 |
+
|
251 |
+
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
252 |
+
post_patch_height = sample_height // self.patch_size
|
253 |
+
post_patch_width = sample_width // self.patch_size
|
254 |
+
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
255 |
+
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
256 |
+
|
257 |
+
pos_embedding = get_3d_sincos_pos_embed(
|
258 |
+
self.embed_dim,
|
259 |
+
(post_patch_width, post_patch_height),
|
260 |
+
post_time_compression_frames,
|
261 |
+
self.spatial_interpolation_scale,
|
262 |
+
self.temporal_interpolation_scale,
|
263 |
+
)
|
264 |
+
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
265 |
+
joint_pos_embedding = torch.zeros(
|
266 |
+
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
267 |
+
)
|
268 |
+
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
269 |
+
|
270 |
+
return joint_pos_embedding
|
271 |
+
|
272 |
+
def forward(self, image_embeds: torch.Tensor):
|
273 |
+
r"""
|
274 |
+
Args:
|
275 |
+
text_embeds (`torch.Tensor`):
|
276 |
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
277 |
+
image_embeds (`torch.Tensor`):
|
278 |
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
279 |
+
"""
|
280 |
+
batch, num_frames, channels, height, width = image_embeds.shape
|
281 |
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
282 |
+
image_embeds = self.proj(image_embeds) # [2*7, 3072, h/8/2, w/8/2]
|
283 |
+
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
284 |
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
285 |
+
image_embeds = image_embeds.flatten(1, 2).contiguous() # [batch, num_frames x height x width, channels]
|
286 |
+
|
287 |
+
if self.use_positional_embeddings:
|
288 |
+
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
289 |
+
if (
|
290 |
+
self.sample_height != height
|
291 |
+
or self.sample_width != width
|
292 |
+
or self.sample_frames != pre_time_compression_frames
|
293 |
+
):
|
294 |
+
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
|
295 |
+
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
|
296 |
+
else:
|
297 |
+
pos_embedding = self.pos_embedding
|
298 |
+
|
299 |
+
embeds = embeds + pos_embedding
|
300 |
+
|
301 |
+
return image_embeds
|
302 |
+
|
303 |
+
|
304 |
+
@maybe_allow_in_graph
|
305 |
+
class CogVideoXBlock(nn.Module):
|
306 |
+
r"""
|
307 |
+
Parameters:
|
308 |
+
dim (`int`):
|
309 |
+
The number of channels in the input and output.
|
310 |
+
num_attention_heads (`int`):
|
311 |
+
The number of heads to use for multi-head attention.
|
312 |
+
attention_head_dim (`int`):
|
313 |
+
The number of channels in each head.
|
314 |
+
time_embed_dim (`int`):
|
315 |
+
The number of channels in timestep embedding.
|
316 |
+
dropout (`float`, defaults to `0.0`):
|
317 |
+
The dropout probability to use.
|
318 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
319 |
+
Activation function to be used in feed-forward.
|
320 |
+
attention_bias (`bool`, defaults to `False`):
|
321 |
+
Whether or not to use bias in attention projection layers.
|
322 |
+
qk_norm (`bool`, defaults to `True`):
|
323 |
+
Whether or not to use normalization after query and key projections in Attention.
|
324 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
325 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
326 |
+
norm_eps (`float`, defaults to `1e-5`):
|
327 |
+
Epsilon value for normalization layers.
|
328 |
+
final_dropout (`bool` defaults to `False`):
|
329 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
330 |
+
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
331 |
+
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
332 |
+
ff_bias (`bool`, defaults to `True`):
|
333 |
+
Whether or not to use bias in Feed-forward layer.
|
334 |
+
attention_out_bias (`bool`, defaults to `True`):
|
335 |
+
Whether or not to use bias in Attention output projection layer.
|
336 |
+
"""
|
337 |
+
|
338 |
+
def __init__(
|
339 |
+
self,
|
340 |
+
dim: int,
|
341 |
+
num_attention_heads: int,
|
342 |
+
attention_head_dim: int,
|
343 |
+
time_embed_dim: int,
|
344 |
+
motion_dim: int,
|
345 |
+
dropout: float = 0.0,
|
346 |
+
activation_fn: str = 'gelu-approximate',
|
347 |
+
attention_bias: bool = False,
|
348 |
+
qk_norm: bool = True,
|
349 |
+
norm_elementwise_affine: bool = True,
|
350 |
+
norm_eps: float = 1e-5,
|
351 |
+
final_dropout: bool = True,
|
352 |
+
ff_inner_dim: Optional[int] = None,
|
353 |
+
ff_bias: bool = True,
|
354 |
+
attention_out_bias: bool = True,
|
355 |
+
cross_attention: bool = False,
|
356 |
+
):
|
357 |
+
super().__init__()
|
358 |
+
|
359 |
+
self.is_cross_attention = cross_attention
|
360 |
+
|
361 |
+
if self.is_cross_attention:
|
362 |
+
self.attn0 = Attention(
|
363 |
+
query_dim=dim,
|
364 |
+
cross_attention_dim=dim,
|
365 |
+
dim_head=attention_head_dim,
|
366 |
+
heads=num_attention_heads,
|
367 |
+
qk_norm='layer_norm' if qk_norm else None,
|
368 |
+
eps=1e-6,
|
369 |
+
bias=attention_bias,
|
370 |
+
out_bias=attention_out_bias,
|
371 |
+
processor=CogVideoXAttnProcessor1_0(),
|
372 |
+
)
|
373 |
+
|
374 |
+
# 1. Self Attention
|
375 |
+
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
376 |
+
|
377 |
+
self.attn1 = Attention(
|
378 |
+
query_dim=dim,
|
379 |
+
dim_head=attention_head_dim,
|
380 |
+
heads=num_attention_heads,
|
381 |
+
qk_norm='layer_norm' if qk_norm else None,
|
382 |
+
eps=1e-6,
|
383 |
+
bias=attention_bias,
|
384 |
+
out_bias=attention_out_bias,
|
385 |
+
processor=CogVideoXAttnProcessor2_0(),
|
386 |
+
)
|
387 |
+
|
388 |
+
# 2. Feed Forward
|
389 |
+
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
390 |
+
|
391 |
+
self.ff = FeedForward(
|
392 |
+
dim,
|
393 |
+
dropout=dropout,
|
394 |
+
activation_fn=activation_fn,
|
395 |
+
final_dropout=final_dropout,
|
396 |
+
inner_dim=ff_inner_dim,
|
397 |
+
bias=ff_bias,
|
398 |
+
)
|
399 |
+
|
400 |
+
|
401 |
+
def forward(
|
402 |
+
self,
|
403 |
+
hidden_states: torch.Tensor,
|
404 |
+
encoder_hidden_states: torch.Tensor,
|
405 |
+
temb: torch.Tensor,
|
406 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
407 |
+
motion_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
408 |
+
) -> torch.Tensor:
|
409 |
+
|
410 |
+
# norm & modulate
|
411 |
+
norm_hidden_states, gate_msa = self.norm1(hidden_states, temb)
|
412 |
+
|
413 |
+
# self attention
|
414 |
+
attn_hidden_states = self.attn1(
|
415 |
+
hidden_states=norm_hidden_states,
|
416 |
+
image_rotary_emb=image_rotary_emb,
|
417 |
+
)
|
418 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
419 |
+
|
420 |
+
if self.is_cross_attention:
|
421 |
+
cross_attn_hidden_states = self.attn0(
|
422 |
+
hidden_states=hidden_states,
|
423 |
+
encoder_hidden_states=encoder_hidden_states,
|
424 |
+
image_rotary_emb=image_rotary_emb,
|
425 |
+
motion_rotary_emb=motion_rotary_emb,
|
426 |
+
)
|
427 |
+
hidden_states = hidden_states + cross_attn_hidden_states
|
428 |
+
|
429 |
+
# norm & modulate
|
430 |
+
norm_hidden_states, gate_ff = self.norm2(hidden_states, temb)
|
431 |
+
|
432 |
+
# feed-forward
|
433 |
+
ff_output = self.ff(norm_hidden_states)
|
434 |
+
|
435 |
+
hidden_states = hidden_states + gate_ff * ff_output
|
436 |
+
|
437 |
+
return hidden_states
|
438 |
+
|
439 |
+
|
440 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
441 |
+
"""
|
442 |
+
Parameters:
|
443 |
+
num_attention_heads (`int`, defaults to `30`):
|
444 |
+
The number of heads to use for multi-head attention.
|
445 |
+
attention_head_dim (`int`, defaults to `64`):
|
446 |
+
The number of channels in each head.
|
447 |
+
in_channels (`int`, defaults to `16`):
|
448 |
+
The number of channels in the input.
|
449 |
+
out_channels (`int`, *optional*, defaults to `16`):
|
450 |
+
The number of channels in the output.
|
451 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
452 |
+
Whether to flip the sin to cos in the time embedding.
|
453 |
+
time_embed_dim (`int`, defaults to `512`):
|
454 |
+
Output dimension of timestep embeddings.
|
455 |
+
text_embed_dim (`int`, defaults to `4096`):
|
456 |
+
Input dimension of text embeddings from the text encoder.
|
457 |
+
num_layers (`int`, defaults to `30`):
|
458 |
+
The number of layers of Transformer blocks to use.
|
459 |
+
dropout (`float`, defaults to `0.0`):
|
460 |
+
The dropout probability to use.
|
461 |
+
attention_bias (`bool`, defaults to `True`):
|
462 |
+
Whether or not to use bias in the attention projection layers.
|
463 |
+
sample_width (`int`, defaults to `90`):
|
464 |
+
The width of the input latents.
|
465 |
+
sample_height (`int`, defaults to `60`):
|
466 |
+
The height of the input latents.
|
467 |
+
sample_frames (`int`, defaults to `49`):
|
468 |
+
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
469 |
+
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
470 |
+
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
471 |
+
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
472 |
+
patch_size (`int`, defaults to `2`):
|
473 |
+
The size of the patches to use in the patch embedding layer.
|
474 |
+
temporal_compression_ratio (`int`, defaults to `4`):
|
475 |
+
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
476 |
+
max_text_seq_length (`int`, defaults to `226`):
|
477 |
+
The maximum sequence length of the input text embeddings.
|
478 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
479 |
+
Activation function to use in feed-forward.
|
480 |
+
timestep_activation_fn (`str`, defaults to `"silu"`):
|
481 |
+
Activation function to use when generating the timestep embeddings.
|
482 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
483 |
+
Whether or not to use elementwise affine in normalization layers.
|
484 |
+
norm_eps (`float`, defaults to `1e-5`):
|
485 |
+
The epsilon value to use in normalization layers.
|
486 |
+
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
487 |
+
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
488 |
+
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
489 |
+
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
490 |
+
"""
|
491 |
+
|
492 |
+
_supports_gradient_checkpointing = True
|
493 |
+
|
494 |
+
@register_to_config
|
495 |
+
def __init__(
|
496 |
+
self,
|
497 |
+
num_attention_heads: int = 30,
|
498 |
+
attention_head_dim: int = 64,
|
499 |
+
in_channels: int = 16,
|
500 |
+
out_channels: Optional[int] = 16,
|
501 |
+
flip_sin_to_cos: bool = True,
|
502 |
+
freq_shift: int = 0,
|
503 |
+
time_embed_dim: int = 512,
|
504 |
+
text_embed_dim: int = 4096,
|
505 |
+
motion_dim: int = 168,
|
506 |
+
num_layers: int = 30,
|
507 |
+
dropout: float = 0.0,
|
508 |
+
attention_bias: bool = True,
|
509 |
+
sample_width: int = 90,
|
510 |
+
sample_height: int = 60,
|
511 |
+
sample_frames: int = 49,
|
512 |
+
patch_size: int = 2,
|
513 |
+
temporal_compression_ratio: int = 4,
|
514 |
+
max_text_seq_length: int = 226,
|
515 |
+
activation_fn: str = 'gelu-approximate',
|
516 |
+
timestep_activation_fn: str = 'silu',
|
517 |
+
norm_elementwise_affine: bool = True,
|
518 |
+
norm_eps: float = 1e-5,
|
519 |
+
spatial_interpolation_scale: float = 1.875,
|
520 |
+
temporal_interpolation_scale: float = 1.0,
|
521 |
+
use_rotary_positional_embeddings: bool = False,
|
522 |
+
):
|
523 |
+
super().__init__()
|
524 |
+
inner_dim = num_attention_heads * attention_head_dim # 48 * 64 = 3072
|
525 |
+
|
526 |
+
self.unconditional_motion_token = torch.nn.Parameter(torch.randn(312, 3072))
|
527 |
+
print(self.unconditional_motion_token[0])
|
528 |
+
|
529 |
+
# 1. Patch embedding
|
530 |
+
self.patch_embed = CogVideoXPatchEmbed(
|
531 |
+
patch_size=patch_size,
|
532 |
+
in_channels=in_channels,
|
533 |
+
embed_dim=inner_dim,
|
534 |
+
text_embed_dim=text_embed_dim,
|
535 |
+
bias=True,
|
536 |
+
sample_width=sample_width,
|
537 |
+
sample_height=sample_height,
|
538 |
+
sample_frames=sample_frames,
|
539 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
540 |
+
max_text_seq_length=max_text_seq_length,
|
541 |
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
542 |
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
543 |
+
use_positional_embeddings=not use_rotary_positional_embeddings,
|
544 |
+
)
|
545 |
+
self.embedding_dropout = nn.Dropout(dropout)
|
546 |
+
|
547 |
+
# 2. Time embeddings
|
548 |
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
549 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) # 3072 --> 512
|
550 |
+
|
551 |
+
self.transformer_blocks = nn.ModuleList(
|
552 |
+
[
|
553 |
+
CogVideoXBlock(
|
554 |
+
dim=inner_dim,
|
555 |
+
num_attention_heads=num_attention_heads,
|
556 |
+
attention_head_dim=attention_head_dim,
|
557 |
+
time_embed_dim=time_embed_dim,
|
558 |
+
motion_dim=motion_dim,
|
559 |
+
dropout=dropout,
|
560 |
+
activation_fn=activation_fn,
|
561 |
+
attention_bias=attention_bias,
|
562 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
563 |
+
norm_eps=norm_eps,
|
564 |
+
cross_attention=True,
|
565 |
+
)
|
566 |
+
for _ in range(num_layers)
|
567 |
+
]
|
568 |
+
)
|
569 |
+
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
570 |
+
|
571 |
+
# 4. Output blocks
|
572 |
+
self.norm_out = AdaLayerNorm(
|
573 |
+
embedding_dim=time_embed_dim,
|
574 |
+
output_dim=2 * inner_dim,
|
575 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
576 |
+
norm_eps=norm_eps,
|
577 |
+
chunk_dim=1,
|
578 |
+
)
|
579 |
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
580 |
+
|
581 |
+
self.gradient_checkpointing = False
|
582 |
+
|
583 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
584 |
+
self.gradient_checkpointing = value
|
585 |
+
|
586 |
+
@property
|
587 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
588 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
589 |
+
r"""
|
590 |
+
Returns:
|
591 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
592 |
+
indexed by its weight name.
|
593 |
+
"""
|
594 |
+
# set recursively
|
595 |
+
processors = {}
|
596 |
+
|
597 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
598 |
+
if hasattr(module, 'get_processor'):
|
599 |
+
processors[f'{name}.processor'] = module.get_processor()
|
600 |
+
|
601 |
+
for sub_name, child in module.named_children():
|
602 |
+
fn_recursive_add_processors(f'{name}.{sub_name}', child, processors)
|
603 |
+
|
604 |
+
return processors
|
605 |
+
|
606 |
+
for name, module in self.named_children():
|
607 |
+
fn_recursive_add_processors(name, module, processors)
|
608 |
+
|
609 |
+
return processors
|
610 |
+
|
611 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
612 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
613 |
+
r"""Sets the attention processor to use to compute attention.
|
614 |
+
|
615 |
+
Parameters:
|
616 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
617 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
618 |
+
for **all** `Attention` layers.
|
619 |
+
|
620 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
621 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
622 |
+
"""
|
623 |
+
count = len(self.attn_processors.keys())
|
624 |
+
|
625 |
+
if isinstance(processor, dict) and len(processor) != count:
|
626 |
+
raise ValueError(
|
627 |
+
f'A dict of processors was passed, but the number of processors {len(processor)} does not match the'
|
628 |
+
f' number of attention layers: {count}. Please make sure to pass {count} processor classes.'
|
629 |
+
)
|
630 |
+
|
631 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
632 |
+
if hasattr(module, 'set_processor'):
|
633 |
+
if not isinstance(processor, dict):
|
634 |
+
module.set_processor(processor)
|
635 |
+
else:
|
636 |
+
module.set_processor(processor.pop(f'{name}.processor'))
|
637 |
+
|
638 |
+
for sub_name, child in module.named_children():
|
639 |
+
fn_recursive_attn_processor(f'{name}.{sub_name}', child, processor)
|
640 |
+
|
641 |
+
for name, module in self.named_children():
|
642 |
+
fn_recursive_attn_processor(name, module, processor)
|
643 |
+
|
644 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with
|
645 |
+
# FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
646 |
+
def fuse_qkv_projections(self):
|
647 |
+
"""Enables fused QKV projections. For self-attention modules, all
|
648 |
+
projection matrices (i.e., query, key, value) are fused. For cross-
|
649 |
+
attention modules, key and value projection matrices are fused.
|
650 |
+
|
651 |
+
<Tip warning={true}>
|
652 |
+
|
653 |
+
This API is 🧪 experimental.
|
654 |
+
|
655 |
+
</Tip>
|
656 |
+
"""
|
657 |
+
self.original_attn_processors = None
|
658 |
+
|
659 |
+
for _, attn_processor in self.attn_processors.items():
|
660 |
+
if 'Added' in str(attn_processor.__class__.__name__):
|
661 |
+
raise ValueError('`fuse_qkv_projections()` is not supported for models having added KV projections.')
|
662 |
+
|
663 |
+
self.original_attn_processors = self.attn_processors
|
664 |
+
|
665 |
+
for module in self.modules():
|
666 |
+
if isinstance(module, Attention):
|
667 |
+
module.fuse_projections(fuse=True)
|
668 |
+
|
669 |
+
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
670 |
+
|
671 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
672 |
+
def unfuse_qkv_projections(self):
|
673 |
+
"""Disables the fused QKV projection if enabled.
|
674 |
+
|
675 |
+
<Tip warning={true}>
|
676 |
+
|
677 |
+
This API is 🧪 experimental.
|
678 |
+
|
679 |
+
</Tip>
|
680 |
+
"""
|
681 |
+
if self.original_attn_processors is not None:
|
682 |
+
self.set_attn_processor(self.original_attn_processors)
|
683 |
+
|
684 |
+
def forward(
|
685 |
+
self,
|
686 |
+
hidden_states: torch.Tensor,
|
687 |
+
timestep: Union[int, float, torch.LongTensor],
|
688 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
689 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
690 |
+
motion_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
691 |
+
motion_emb: Optional[torch.Tensor] = None,
|
692 |
+
camera_emb: Optional[torch.Tensor] = None,
|
693 |
+
need_broadcast: bool = True,
|
694 |
+
return_dict: bool = True,
|
695 |
+
):
|
696 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
697 |
+
|
698 |
+
# 1. Time embedding
|
699 |
+
timesteps = timestep
|
700 |
+
t_emb = self.time_proj(timesteps)
|
701 |
+
|
702 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
703 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
704 |
+
# there might be better ways to encapsulate this.
|
705 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype) # (2, 3072)
|
706 |
+
emb = self.time_embedding(t_emb, timestep_cond) # (2, 3072) --> (2, 512)
|
707 |
+
|
708 |
+
# 2. Patch embedding
|
709 |
+
hidden_states = self.patch_embed(hidden_states) # (2, 226+9450, dim=3072)
|
710 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
711 |
+
image_seq_length = image_rotary_emb[0].shape[0]
|
712 |
+
motion_seq_length = motion_emb.shape[1] # 168
|
713 |
+
# hidden_states = hidden_states[:, motion_seq_length:]
|
714 |
+
encoder_hidden_states = motion_emb
|
715 |
+
# encoder_hidden_states = self.motion_proj(motion_emb)
|
716 |
+
|
717 |
+
# 3. Transformer blocks
|
718 |
+
for i, block in enumerate(self.transformer_blocks):
|
719 |
+
if self.training and self.gradient_checkpointing: # train with gradient checkpointing to save memory
|
720 |
+
|
721 |
+
def create_custom_forward(module):
|
722 |
+
def custom_forward(*inputs):
|
723 |
+
return module(*inputs)
|
724 |
+
|
725 |
+
return custom_forward
|
726 |
+
|
727 |
+
ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
|
728 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
729 |
+
create_custom_forward(block),
|
730 |
+
hidden_states,
|
731 |
+
encoder_hidden_states,
|
732 |
+
emb,
|
733 |
+
image_rotary_emb,
|
734 |
+
motion_rotary_emb,
|
735 |
+
**ckpt_kwargs,
|
736 |
+
)
|
737 |
+
else:
|
738 |
+
hidden_states = block(
|
739 |
+
hidden_states=hidden_states,
|
740 |
+
encoder_hidden_states=encoder_hidden_states,
|
741 |
+
temb=emb,
|
742 |
+
image_rotary_emb=image_rotary_emb,
|
743 |
+
motion_rotary_emb=motion_rotary_emb,
|
744 |
+
)
|
745 |
+
|
746 |
+
# 4. Final block
|
747 |
+
hidden_states = self.norm_final(hidden_states)
|
748 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
749 |
+
hidden_states = self.proj_out(hidden_states)
|
750 |
+
|
751 |
+
# 5. Unpatchify
|
752 |
+
p = self.config.patch_size
|
753 |
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
754 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
755 |
+
|
756 |
+
if not return_dict:
|
757 |
+
return (output,)
|
758 |
+
return Transformer2DModelOutput(sample=output)
|
models/dit/pipeline_mtvcrafter.py
ADDED
@@ -0,0 +1,728 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import inspect
|
4 |
+
import numpy as np
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
10 |
+
from diffusers.models import AutoencoderKLCogVideoX
|
11 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
12 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
13 |
+
from diffusers.utils import BaseOutput, logging
|
14 |
+
from diffusers.utils.torch_utils import randn_tensor
|
15 |
+
from diffusers.video_processor import VideoProcessor
|
16 |
+
from einops import rearrange
|
17 |
+
from PIL import Image
|
18 |
+
from torchvision import transforms
|
19 |
+
|
20 |
+
from .mvdit_transformer import Transformer3DModel
|
21 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
22 |
+
|
23 |
+
|
24 |
+
def get_1d_rotary_pos_embed(
|
25 |
+
dim: int,
|
26 |
+
pos: Union[np.ndarray, int],
|
27 |
+
theta: float = 10000.0,
|
28 |
+
use_real=False,
|
29 |
+
linear_factor=1.0,
|
30 |
+
ntk_factor=1.0,
|
31 |
+
repeat_interleave_real=True,
|
32 |
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
36 |
+
|
37 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
38 |
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
39 |
+
data type.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
dim (`int`): Dimension of the frequency tensor.
|
43 |
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
44 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
45 |
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
46 |
+
use_real (`bool`, *optional*):
|
47 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
48 |
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
49 |
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
50 |
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
51 |
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
52 |
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
53 |
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
54 |
+
Otherwise, they are concateanted with themselves.
|
55 |
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
56 |
+
the dtype of the frequency tensor.
|
57 |
+
Returns:
|
58 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
59 |
+
"""
|
60 |
+
assert dim % 2 == 0
|
61 |
+
|
62 |
+
if isinstance(pos, int):
|
63 |
+
pos = torch.arange(pos)
|
64 |
+
if isinstance(pos, np.ndarray):
|
65 |
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
66 |
+
|
67 |
+
theta = theta * ntk_factor
|
68 |
+
freqs = (
|
69 |
+
1.0
|
70 |
+
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
71 |
+
/ linear_factor
|
72 |
+
) # [D/2]
|
73 |
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
74 |
+
if use_real and repeat_interleave_real:
|
75 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
76 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
77 |
+
return freqs_cos, freqs_sin
|
78 |
+
elif use_real:
|
79 |
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
80 |
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
81 |
+
return freqs_cos, freqs_sin
|
82 |
+
else:
|
83 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
84 |
+
return freqs_cis
|
85 |
+
|
86 |
+
|
87 |
+
def get_3d_rotary_pos_embed(
|
88 |
+
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
89 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
90 |
+
"""
|
91 |
+
RoPE for video tokens with 3D structure.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
embed_dim: (`int`):
|
95 |
+
The embedding dimension size, corresponding to hidden_size_head.
|
96 |
+
crops_coords (`Tuple[int]`):
|
97 |
+
The top-left and bottom-right coordinates of the crop.
|
98 |
+
grid_size (`Tuple[int]`):
|
99 |
+
The grid size of the spatial positional embedding (height, width).
|
100 |
+
temporal_size (`int`):
|
101 |
+
The size of the temporal dimension.
|
102 |
+
theta (`float`):
|
103 |
+
Scaling factor for frequency computation.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
107 |
+
"""
|
108 |
+
if use_real is not True:
|
109 |
+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
110 |
+
start, stop = crops_coords
|
111 |
+
grid_size_h, grid_size_w = grid_size
|
112 |
+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
113 |
+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
114 |
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
115 |
+
|
116 |
+
# Compute dimensions for each axis
|
117 |
+
dim_t = embed_dim // 4
|
118 |
+
dim_h = embed_dim // 8 * 3
|
119 |
+
dim_w = embed_dim // 8 * 3
|
120 |
+
|
121 |
+
# Temporal frequencies
|
122 |
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
123 |
+
# Spatial frequencies for height and width
|
124 |
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
125 |
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
126 |
+
|
127 |
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
128 |
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
129 |
+
freqs_t = freqs_t[:, None, None, :].expand(
|
130 |
+
-1, grid_size_h, grid_size_w, -1
|
131 |
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
132 |
+
freqs_h = freqs_h[None, :, None, :].expand(
|
133 |
+
temporal_size, -1, grid_size_w, -1
|
134 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
135 |
+
freqs_w = freqs_w[None, None, :, :].expand(
|
136 |
+
temporal_size, grid_size_h, -1, -1
|
137 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
138 |
+
|
139 |
+
freqs = torch.cat(
|
140 |
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
141 |
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
142 |
+
freqs = freqs.view(
|
143 |
+
temporal_size * grid_size_h * grid_size_w, -1
|
144 |
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
145 |
+
return freqs
|
146 |
+
|
147 |
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
148 |
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
149 |
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
150 |
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
151 |
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
152 |
+
return cos, sin
|
153 |
+
|
154 |
+
|
155 |
+
def get_3d_motion_spatial_embed(
|
156 |
+
embed_dim: int, num_joints: int, joints_mean: np.ndarray, joints_std: np.ndarray, theta: float = 10000.0
|
157 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
158 |
+
"""
|
159 |
+
"""
|
160 |
+
assert embed_dim % 2 == 0 and embed_dim % 3 == 0
|
161 |
+
|
162 |
+
def create_rope_pe(dim, pos, freqs_dtype=torch.float32):
|
163 |
+
if isinstance(pos, np.ndarray):
|
164 |
+
pos = torch.from_numpy(pos)
|
165 |
+
freqs = (
|
166 |
+
1.0
|
167 |
+
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
168 |
+
) # [D/2]
|
169 |
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
170 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
171 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
172 |
+
return freqs_cos, freqs_sin
|
173 |
+
|
174 |
+
# 为每个轴创建位置编码
|
175 |
+
# relative_pos_x = joints_mean[:, 0] - joints_mean[0, 0]
|
176 |
+
# relative_pos_y = joints_mean[:, 1] - joints_mean[0, 1]
|
177 |
+
# relative_pos_z = joints_mean[:, 2] - joints_mean[0, 2]
|
178 |
+
|
179 |
+
# normalized_pos_x = relative_pos_x / joints_std[:, 0].mean()
|
180 |
+
# normalized_pos_y = relative_pos_y / joints_std[:, 1].mean()
|
181 |
+
# normalized_pos_z = relative_pos_z / joints_std[:, 2].mean()
|
182 |
+
|
183 |
+
pos_x = joints_mean[:, 0]
|
184 |
+
pos_y = joints_mean[:, 1]
|
185 |
+
pos_z = joints_mean[:, 2]
|
186 |
+
|
187 |
+
normalized_pos_x = (pos_x - pos_x.mean())
|
188 |
+
normalized_pos_y = (pos_y - pos_y.mean())
|
189 |
+
normalized_pos_z = (pos_z - pos_z.mean())
|
190 |
+
|
191 |
+
freqs_cos_x, freqs_sin_x = create_rope_pe(embed_dim // 3, normalized_pos_x)
|
192 |
+
freqs_cos_y, freqs_sin_y = create_rope_pe(embed_dim // 3, normalized_pos_y)
|
193 |
+
freqs_cos_z, freqs_sin_z = create_rope_pe(embed_dim // 3, normalized_pos_z)
|
194 |
+
|
195 |
+
freqs_cos = torch.cat([freqs_cos_x, freqs_cos_y, freqs_cos_z], dim=-1)
|
196 |
+
freqs_sin = torch.cat([freqs_sin_x, freqs_sin_y, freqs_sin_z], dim=-1)
|
197 |
+
|
198 |
+
return freqs_cos, freqs_sin
|
199 |
+
|
200 |
+
|
201 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
202 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
203 |
+
tw = tgt_width
|
204 |
+
th = tgt_height
|
205 |
+
h, w = src
|
206 |
+
r = h / w
|
207 |
+
if r > (th / tw):
|
208 |
+
resize_height = th
|
209 |
+
resize_width = int(round(th / h * w))
|
210 |
+
else:
|
211 |
+
resize_width = tw
|
212 |
+
resize_height = int(round(tw / w * h))
|
213 |
+
|
214 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
215 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
216 |
+
|
217 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
218 |
+
|
219 |
+
|
220 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
221 |
+
def retrieve_timesteps(
|
222 |
+
scheduler,
|
223 |
+
num_inference_steps: Optional[int] = None,
|
224 |
+
device: Optional[Union[str, torch.device]] = None,
|
225 |
+
timesteps: Optional[List[int]] = None,
|
226 |
+
sigmas: Optional[List[float]] = None,
|
227 |
+
**kwargs,
|
228 |
+
):
|
229 |
+
"""
|
230 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
231 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
scheduler (`SchedulerMixin`):
|
235 |
+
The scheduler to get timesteps from.
|
236 |
+
num_inference_steps (`int`):
|
237 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
238 |
+
must be `None`.
|
239 |
+
device (`str` or `torch.device`, *optional*):
|
240 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
241 |
+
timesteps (`List[int]`, *optional*):
|
242 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
243 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
244 |
+
sigmas (`List[float]`, *optional*):
|
245 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
246 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
250 |
+
second element is the number of inference steps.
|
251 |
+
"""
|
252 |
+
if timesteps is not None and sigmas is not None:
|
253 |
+
raise ValueError('Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values')
|
254 |
+
if timesteps is not None:
|
255 |
+
accepts_timesteps = 'timesteps' in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
256 |
+
if not accepts_timesteps:
|
257 |
+
raise ValueError(
|
258 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
259 |
+
f' timestep schedules. Please check whether you are using the correct scheduler.'
|
260 |
+
)
|
261 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
262 |
+
timesteps = scheduler.timesteps
|
263 |
+
num_inference_steps = len(timesteps)
|
264 |
+
elif sigmas is not None:
|
265 |
+
accept_sigmas = 'sigmas' in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
266 |
+
if not accept_sigmas:
|
267 |
+
raise ValueError(
|
268 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
269 |
+
f' sigmas schedules. Please check whether you are using the correct scheduler.'
|
270 |
+
)
|
271 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
272 |
+
timesteps = scheduler.timesteps
|
273 |
+
num_inference_steps = len(timesteps)
|
274 |
+
else:
|
275 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
276 |
+
timesteps = scheduler.timesteps
|
277 |
+
return timesteps, num_inference_steps
|
278 |
+
|
279 |
+
|
280 |
+
@dataclass
|
281 |
+
class MTVCrafterPipelineOutput(BaseOutput):
|
282 |
+
r"""Output class for the MTVCrafter pipeline.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
286 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
287 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
288 |
+
`(batch_size, num_frames, channels, height, width)`.
|
289 |
+
"""
|
290 |
+
|
291 |
+
frames: torch.Tensor
|
292 |
+
|
293 |
+
|
294 |
+
class MTVCrafterPipeline(DiffusionPipeline):
|
295 |
+
r"""Pipeline for MTVCrafter.
|
296 |
+
|
297 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
298 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
299 |
+
|
300 |
+
Args:
|
301 |
+
vae ([`AutoencoderKL`]):
|
302 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
303 |
+
transformer ([`Transformer3DModel`]):
|
304 |
+
A image conditioned `Transformer3DModel` to denoise the encoded video latents.
|
305 |
+
scheduler ([`SchedulerMixin`]):
|
306 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
307 |
+
"""
|
308 |
+
|
309 |
+
_callback_tensor_inputs = [
|
310 |
+
'latents',
|
311 |
+
'prompt_embeds',
|
312 |
+
'negative_prompt_embeds',
|
313 |
+
]
|
314 |
+
|
315 |
+
def __init__(
|
316 |
+
self,
|
317 |
+
vae: AutoencoderKLCogVideoX,
|
318 |
+
transformer: Transformer3DModel,
|
319 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
320 |
+
):
|
321 |
+
super().__init__()
|
322 |
+
|
323 |
+
self.register_modules(
|
324 |
+
vae=vae,
|
325 |
+
transformer=transformer,
|
326 |
+
scheduler=scheduler,
|
327 |
+
)
|
328 |
+
self.vae_scale_factor_spatial = (
|
329 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, 'vae') and self.vae is not None else 8
|
330 |
+
)
|
331 |
+
self.vae_scale_factor_temporal = (
|
332 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, 'vae') and self.vae is not None else 4
|
333 |
+
)
|
334 |
+
|
335 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
336 |
+
self.normalize = transforms.Normalize([0.5], [0.5])
|
337 |
+
|
338 |
+
@classmethod
|
339 |
+
def from_pretrained(
|
340 |
+
cls,
|
341 |
+
model_path,
|
342 |
+
transformer_model_path=None,
|
343 |
+
scheduler_type='ddim',
|
344 |
+
torch_dtype=None,
|
345 |
+
**kwargs,
|
346 |
+
):
|
347 |
+
if transformer_model_path is None:
|
348 |
+
transformer_model_path = os.path.join(model_path, 'transformer')
|
349 |
+
transformer = Transformer3DModel.from_pretrained(
|
350 |
+
transformer_model_path, torch_dtype=torch_dtype, **kwargs
|
351 |
+
)
|
352 |
+
if scheduler_type == 'ddim':
|
353 |
+
scheduler = CogVideoXDDIMScheduler.from_pretrained(model_path, subfolder='scheduler')
|
354 |
+
elif scheduler_type == 'dpm':
|
355 |
+
scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder='scheduler')
|
356 |
+
else:
|
357 |
+
assert False
|
358 |
+
pipe = super().from_pretrained(
|
359 |
+
model_path, transformer=transformer, scheduler=scheduler, torch_dtype=torch_dtype, **kwargs
|
360 |
+
)
|
361 |
+
return pipe
|
362 |
+
|
363 |
+
|
364 |
+
def prepare_latents(
|
365 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
366 |
+
):
|
367 |
+
shape = (
|
368 |
+
batch_size,
|
369 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
370 |
+
num_channels_latents,
|
371 |
+
height // self.vae_scale_factor_spatial,
|
372 |
+
width // self.vae_scale_factor_spatial,
|
373 |
+
)
|
374 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
375 |
+
raise ValueError(
|
376 |
+
f'You have passed a list of generators of length {len(generator)}, but requested an effective batch'
|
377 |
+
f' size of {batch_size}. Make sure the batch size matches the length of the generators.'
|
378 |
+
)
|
379 |
+
|
380 |
+
if latents is None:
|
381 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
382 |
+
else:
|
383 |
+
latents = latents.to(device)
|
384 |
+
|
385 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
386 |
+
latents = latents * self.scheduler.init_noise_sigma
|
387 |
+
return latents
|
388 |
+
|
389 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
390 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
391 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
392 |
+
|
393 |
+
frames = self.vae.decode(latents).sample
|
394 |
+
return frames
|
395 |
+
|
396 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
397 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
398 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
399 |
+
# eta corresponds to η in DDIM paper and should be between [0, 1]
|
400 |
+
|
401 |
+
accepts_eta = 'eta' in set(inspect.signature(self.scheduler.step).parameters.keys())
|
402 |
+
extra_step_kwargs = {}
|
403 |
+
if accepts_eta:
|
404 |
+
extra_step_kwargs['eta'] = eta
|
405 |
+
|
406 |
+
# check if the scheduler accepts generator
|
407 |
+
accepts_generator = 'generator' in set(inspect.signature(self.scheduler.step).parameters.keys())
|
408 |
+
if accepts_generator:
|
409 |
+
extra_step_kwargs['generator'] = generator
|
410 |
+
return extra_step_kwargs
|
411 |
+
|
412 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
413 |
+
def check_inputs(
|
414 |
+
self,
|
415 |
+
height,
|
416 |
+
width,
|
417 |
+
callback_on_step_end_tensor_inputs,
|
418 |
+
):
|
419 |
+
if height % 8 != 0 or width % 8 != 0:
|
420 |
+
raise ValueError(f'`height` and `width` have to be divisible by 8 but are {height} and {width}.')
|
421 |
+
|
422 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
423 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
424 |
+
):
|
425 |
+
raise ValueError(
|
426 |
+
f'`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found '
|
427 |
+
f'{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}'
|
428 |
+
)
|
429 |
+
|
430 |
+
|
431 |
+
def _prepare_rotary_positional_embeddings(
|
432 |
+
self,
|
433 |
+
height: int,
|
434 |
+
width: int,
|
435 |
+
num_frames: int,
|
436 |
+
device: torch.device,
|
437 |
+
dtype: torch.dtype,
|
438 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
439 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
440 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
441 |
+
grid_crops_coords = ((0, 0), (grid_height, grid_width))
|
442 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
443 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
444 |
+
crops_coords=grid_crops_coords,
|
445 |
+
grid_size=(grid_height, grid_width),
|
446 |
+
temporal_size=num_frames,
|
447 |
+
)
|
448 |
+
|
449 |
+
freqs_cos = freqs_cos.to(device=device, dtype=dtype)
|
450 |
+
freqs_sin = freqs_sin.to(device=device, dtype=dtype)
|
451 |
+
return freqs_cos, freqs_sin
|
452 |
+
|
453 |
+
|
454 |
+
def _prepare_motion_embeddings(self, num_frames, num_joints, joints_mean, joints_std, device, dtype):
|
455 |
+
time_embed = get_1d_rotary_pos_embed(self.transformer.config.attention_head_dim // 4, num_frames, use_real=True)
|
456 |
+
time_embed_cos = time_embed[0][:, None, :].expand(-1, num_joints, -1).reshape(num_frames*num_joints, -1)
|
457 |
+
time_embed_sin = time_embed[1][:, None, :].expand(-1, num_joints, -1).reshape(num_frames*num_joints, -1)
|
458 |
+
spatial_motion_embed = get_3d_motion_spatial_embed(self.transformer.config.attention_head_dim // 4 * 3, num_joints, joints_mean, joints_std)
|
459 |
+
spatial_embed_cos = spatial_motion_embed[0][None, :, :].expand(num_frames, -1, -1).reshape(num_frames*num_joints, -1)
|
460 |
+
spatial_embed_sin = spatial_motion_embed[1][None, :, :].expand(num_frames, -1, -1).reshape(num_frames*num_joints, -1)
|
461 |
+
motion_embed_cos = torch.cat([time_embed_cos, spatial_embed_cos], dim=-1).to(device=device, dtype=dtype)
|
462 |
+
motion_embed_sin = torch.cat([time_embed_sin, spatial_embed_sin], dim=-1).to(device=device, dtype=dtype)
|
463 |
+
return motion_embed_cos, motion_embed_sin
|
464 |
+
|
465 |
+
@property
|
466 |
+
def guidance_scale(self):
|
467 |
+
return self._guidance_scale
|
468 |
+
|
469 |
+
@property
|
470 |
+
def num_timesteps(self):
|
471 |
+
return self._num_timesteps
|
472 |
+
|
473 |
+
@property
|
474 |
+
def interrupt(self):
|
475 |
+
return self._interrupt
|
476 |
+
|
477 |
+
@torch.no_grad()
|
478 |
+
def __call__(
|
479 |
+
self,
|
480 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
481 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
482 |
+
height: int = 480,
|
483 |
+
width: int = 720,
|
484 |
+
num_frames: int = 49,
|
485 |
+
num_inference_steps: int = 50,
|
486 |
+
timesteps: Optional[List[int]] = None,
|
487 |
+
guidance_scale: float = 6,
|
488 |
+
use_dynamic_cfg: bool = False,
|
489 |
+
num_videos_per_prompt: int = 1,
|
490 |
+
eta: float = 0.0,
|
491 |
+
seed: Optional[int] = -1,
|
492 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
493 |
+
latents: Optional[torch.FloatTensor] = None,
|
494 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
495 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
496 |
+
output_type: str = 'pil',
|
497 |
+
return_dict: bool = True,
|
498 |
+
callback_on_step_end: Optional[
|
499 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
500 |
+
] = None,
|
501 |
+
callback_on_step_end_tensor_inputs: List[str] = ['latents'],
|
502 |
+
max_sequence_length: int = 226,
|
503 |
+
ref_images: List[Image.Image] = None,
|
504 |
+
motion_embeds: Optional[torch.FloatTensor] = None,
|
505 |
+
joint_mean: Optional[np.ndarray] = None,
|
506 |
+
joint_std: Optional[np.ndarray] = None,
|
507 |
+
) -> Union[MTVCrafterPipelineOutput, Tuple]:
|
508 |
+
"""Function invoked when calling the pipeline for generation.
|
509 |
+
|
510 |
+
Args:
|
511 |
+
prompt (`str` or `List[str]`, *optional*):
|
512 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
513 |
+
instead.
|
514 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
515 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
516 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
517 |
+
less than `1`).
|
518 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
519 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
520 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
521 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
522 |
+
num_frames (`int`, defaults to `48`):
|
523 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
524 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
525 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
526 |
+
needs to be satisfied is that of divisibility mentioned above.
|
527 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
528 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
529 |
+
expense of slower inference.
|
530 |
+
timesteps (`List[int]`, *optional*):
|
531 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
532 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
533 |
+
passed will be used. Must be in descending order.
|
534 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
535 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance]. Guidance scale is enabled by setting `guidance_scale >
|
536 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
537 |
+
usually at the expense of lower image quality.
|
538 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
539 |
+
The number of videos to generate per prompt.
|
540 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
541 |
+
One or a list of [torch generator(s)]
|
542 |
+
to make generation deterministic.
|
543 |
+
latents (`torch.FloatTensor`, *optional*):
|
544 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
545 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
546 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
547 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
548 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
549 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
550 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
551 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
552 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
553 |
+
argument.
|
554 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
555 |
+
The output format of the generate image. Choose between `PIL.Image.Image` or `np.array`.
|
556 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
557 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
558 |
+
of a plain tuple.
|
559 |
+
callback_on_step_end (`Callable`, *optional*):
|
560 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
561 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
562 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
563 |
+
`callback_on_step_end_tensor_inputs`.
|
564 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
565 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
566 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
567 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
568 |
+
max_sequence_length (`int`, defaults to `226`):
|
569 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
570 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
571 |
+
"""
|
572 |
+
|
573 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
574 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
575 |
+
|
576 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
577 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
578 |
+
# 720 * 480
|
579 |
+
num_videos_per_prompt = 1
|
580 |
+
|
581 |
+
# 1. Check inputs. Raise error if not correct
|
582 |
+
self.check_inputs(
|
583 |
+
height,
|
584 |
+
width,
|
585 |
+
callback_on_step_end_tensor_inputs,
|
586 |
+
)
|
587 |
+
self._guidance_scale = guidance_scale
|
588 |
+
self._interrupt = False
|
589 |
+
|
590 |
+
# 2. Default call parameters
|
591 |
+
if prompt is not None and isinstance(prompt, str):
|
592 |
+
batch_size = 1
|
593 |
+
elif prompt is not None and isinstance(prompt, list):
|
594 |
+
batch_size = len(prompt)
|
595 |
+
elif prompt is None:
|
596 |
+
batch_size = 1
|
597 |
+
else:
|
598 |
+
batch_size = prompt_embeds.shape[0]
|
599 |
+
|
600 |
+
device = self._execution_device
|
601 |
+
|
602 |
+
if seed > 0:
|
603 |
+
generator = torch.Generator(device=device)
|
604 |
+
generator.manual_seed(seed)
|
605 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
606 |
+
|
607 |
+
# 3. Prepare timesteps
|
608 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
609 |
+
self._num_timesteps = len(timesteps)
|
610 |
+
|
611 |
+
# 4. Prepare latents.
|
612 |
+
latent_channels = self.vae.config.latent_channels
|
613 |
+
latents = self.prepare_latents(
|
614 |
+
batch_size * num_videos_per_prompt,
|
615 |
+
latent_channels,
|
616 |
+
num_frames,
|
617 |
+
height,
|
618 |
+
width,
|
619 |
+
self.vae.dtype,
|
620 |
+
device,
|
621 |
+
generator,
|
622 |
+
latents,
|
623 |
+
) # [1, x, 16, h/8, w/8]
|
624 |
+
|
625 |
+
if ref_images is not None:
|
626 |
+
ref_images = rearrange(ref_images.unsqueeze(0), 'b f c h w -> b c f h w')
|
627 |
+
ref_latents = self.vae.encode(
|
628 |
+
ref_images.to(dtype=self.vae.dtype, device=self.vae.device)
|
629 |
+
).latent_dist.sample()
|
630 |
+
ref_latents = rearrange(ref_latents, 'b c f h w -> b f c h w')
|
631 |
+
if do_classifier_free_guidance:
|
632 |
+
ref_latents = torch.cat([ref_latents, ref_latents], dim=0)
|
633 |
+
|
634 |
+
motion_embeds = motion_embeds.to(latents.dtype)
|
635 |
+
if motion_embeds is not None and do_classifier_free_guidance:
|
636 |
+
motion_embeds = torch.cat([self.transformer.unconditional_motion_token.unsqueeze(0), motion_embeds], dim=0)
|
637 |
+
|
638 |
+
# 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
639 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
640 |
+
|
641 |
+
# 6. Create rotary embeds if required
|
642 |
+
image_rotary_emb = (
|
643 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, dtype=latents.dtype)
|
644 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
645 |
+
else None
|
646 |
+
)
|
647 |
+
motion_rotary_emb = self._prepare_motion_embeddings(latents.size(1), 24, joint_mean, joint_std, device, dtype=latents.dtype)
|
648 |
+
|
649 |
+
# 7. Denoising loop
|
650 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
651 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
652 |
+
# for DPM-solver++
|
653 |
+
old_pred_original_sample = None
|
654 |
+
for i, t in enumerate(timesteps):
|
655 |
+
if self.interrupt:
|
656 |
+
continue
|
657 |
+
|
658 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
659 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
660 |
+
|
661 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
662 |
+
timestep = t.expand(latent_model_input.shape[0])
|
663 |
+
|
664 |
+
if ref_images is not None:
|
665 |
+
latent_model_input = torch.cat([latent_model_input, ref_latents], dim=2)
|
666 |
+
|
667 |
+
# predict noise model_output
|
668 |
+
noise_pred = self.transformer(
|
669 |
+
hidden_states=latent_model_input,
|
670 |
+
timestep=timestep.long(),
|
671 |
+
image_rotary_emb=image_rotary_emb,
|
672 |
+
motion_rotary_emb=motion_rotary_emb,
|
673 |
+
motion_emb=motion_embeds,
|
674 |
+
return_dict=False,
|
675 |
+
)[0]
|
676 |
+
noise_pred = noise_pred.float() # [b, f, c, h, w]
|
677 |
+
|
678 |
+
# perform guidance
|
679 |
+
if use_dynamic_cfg:
|
680 |
+
self._guidance_scale = 1 + guidance_scale * (
|
681 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
682 |
+
)
|
683 |
+
if do_classifier_free_guidance:
|
684 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
685 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
686 |
+
|
687 |
+
# compute the previous noisy sample x_t -> x_t-1
|
688 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
689 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
690 |
+
else:
|
691 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
692 |
+
noise_pred,
|
693 |
+
old_pred_original_sample,
|
694 |
+
t,
|
695 |
+
timesteps[i - 1] if i > 0 else None,
|
696 |
+
latents,
|
697 |
+
**extra_step_kwargs,
|
698 |
+
return_dict=False,
|
699 |
+
)
|
700 |
+
latents = latents.to(self.vae.dtype)
|
701 |
+
|
702 |
+
# call the callback, if provided
|
703 |
+
if callback_on_step_end is not None:
|
704 |
+
callback_kwargs = {}
|
705 |
+
for k in callback_on_step_end_tensor_inputs:
|
706 |
+
callback_kwargs[k] = locals()[k]
|
707 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
708 |
+
|
709 |
+
latents = callback_outputs.pop('latents', latents)
|
710 |
+
prompt_embeds = callback_outputs.pop('prompt_embeds', prompt_embeds)
|
711 |
+
negative_prompt_embeds = callback_outputs.pop('negative_prompt_embeds', negative_prompt_embeds)
|
712 |
+
|
713 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
714 |
+
progress_bar.update()
|
715 |
+
|
716 |
+
if not output_type == 'latent':
|
717 |
+
video = self.decode_latents(latents)
|
718 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
719 |
+
else:
|
720 |
+
video = latents
|
721 |
+
|
722 |
+
# Offload all models
|
723 |
+
self.maybe_free_model_hooks()
|
724 |
+
|
725 |
+
if not return_dict:
|
726 |
+
return (video,)
|
727 |
+
|
728 |
+
return MTVCrafterPipelineOutput(frames=video)
|
models/motion4d/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .vqvae import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder
|
2 |
+
from .loss import ReConsLoss
|
models/motion4d/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (375 Bytes). View file
|
|
models/motion4d/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (314 Bytes). View file
|
|
models/motion4d/__pycache__/loss.cpython-311.pyc
ADDED
Binary file (2.13 kB). View file
|
|
models/motion4d/__pycache__/loss.cpython-313.pyc
ADDED
Binary file (2 kB). View file
|
|
models/motion4d/__pycache__/vqvae.cpython-311.pyc
ADDED
Binary file (28.4 kB). View file
|
|
models/motion4d/__pycache__/vqvae.cpython-313.pyc
ADDED
Binary file (26.5 kB). View file
|
|
models/motion4d/loss.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class ReConsLoss(nn.Module):
|
5 |
+
def __init__(self, recons_loss, nb_joints):
|
6 |
+
super(ReConsLoss, self).__init__()
|
7 |
+
|
8 |
+
if recons_loss == 'l1':
|
9 |
+
self.Loss = torch.nn.L1Loss()
|
10 |
+
elif recons_loss == 'l2' :
|
11 |
+
self.Loss = torch.nn.MSELoss()
|
12 |
+
elif recons_loss == 'l1_smooth' :
|
13 |
+
self.Loss = torch.nn.SmoothL1Loss()
|
14 |
+
|
15 |
+
# 4 global motion associated to root
|
16 |
+
# 12 local motion (3 local xyz, 3 vel xyz, 6 rot6d)
|
17 |
+
# 3 global vel xyz
|
18 |
+
# 4 foot contact
|
19 |
+
self.nb_joints = nb_joints
|
20 |
+
self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4
|
21 |
+
|
22 |
+
def forward(self, motion_pred, motion_gt) :
|
23 |
+
loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim])
|
24 |
+
return loss
|
25 |
+
|
26 |
+
def forward_joint(self, motion_pred, motion_gt) :
|
27 |
+
loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4])
|
28 |
+
return loss
|
29 |
+
|
30 |
+
|
models/motion4d/vqvae.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
7 |
+
from diffusers.models.attention import Attention
|
8 |
+
|
9 |
+
|
10 |
+
class AttnProcessor:
|
11 |
+
r"""Processor for implementing scaled dot-product attention for the
|
12 |
+
CogVideoX model.
|
13 |
+
|
14 |
+
It applies a rotary embedding on query and key vectors, but does not include spatial normalization.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
if not hasattr(F, 'scaled_dot_product_attention'):
|
19 |
+
raise ImportError('AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.')
|
20 |
+
|
21 |
+
def __call__(
|
22 |
+
self,
|
23 |
+
attn: Attention,
|
24 |
+
hidden_states: torch.Tensor,
|
25 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
26 |
+
attention_mask: Optional[torch.Tensor] = None,
|
27 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
28 |
+
motion_rotary_emb: Optional[torch.Tensor] = None,
|
29 |
+
) -> torch.Tensor:
|
30 |
+
import pdb; pdb.set_trace()
|
31 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
32 |
+
|
33 |
+
if attention_mask is not None:
|
34 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
35 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
36 |
+
|
37 |
+
query = attn.to_q(hidden_states)
|
38 |
+
key = attn.to_k(hidden_states)
|
39 |
+
value = attn.to_v(hidden_states)
|
40 |
+
|
41 |
+
inner_dim = key.shape[-1]
|
42 |
+
head_dim = inner_dim // attn.heads
|
43 |
+
|
44 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [batch_size, heads, seq_len, dim]
|
45 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
46 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
47 |
+
|
48 |
+
if attn.norm_q is not None:
|
49 |
+
query = attn.norm_q(query)
|
50 |
+
if attn.norm_k is not None:
|
51 |
+
key = attn.norm_k(key)
|
52 |
+
|
53 |
+
sp_group = get_sequence_parallel_group()
|
54 |
+
if sp_group is not None:
|
55 |
+
sp_size = dist.get_world_size(sp_group)
|
56 |
+
query = _all_in_all_with_text(query, text_seq_length, sp_group, sp_size, mode=1)
|
57 |
+
key = _all_in_all_with_text(key, text_seq_length, sp_group, sp_size, mode=1)
|
58 |
+
value = _all_in_all_with_text(value, text_seq_length, sp_group, sp_size, mode=1)
|
59 |
+
text_seq_length *= sp_size
|
60 |
+
|
61 |
+
# Apply RoPE if needed
|
62 |
+
if image_rotary_emb is not None:
|
63 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
64 |
+
image_seq_length = image_rotary_emb[0].shape[0]
|
65 |
+
query[:, :, :image_seq_length] = apply_rotary_emb(query[:, :, :image_seq_length], image_rotary_emb)
|
66 |
+
if motion_rotary_emb is not None:
|
67 |
+
query[:, :, image_seq_length:] = apply_rotary_emb(query[:, :, image_seq_length:], motion_rotary_emb)
|
68 |
+
if not attn.is_cross_attention:
|
69 |
+
key[:, :, :image_seq_length] = apply_rotary_emb(key[:, :, :image_seq_length], image_rotary_emb)
|
70 |
+
if motion_rotary_emb is not None:
|
71 |
+
key[:, :, image_seq_length:] = apply_rotary_emb(key[:, :, image_seq_length:], motion_rotary_emb)
|
72 |
+
|
73 |
+
hidden_states = F.scaled_dot_product_attention(
|
74 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
75 |
+
)
|
76 |
+
|
77 |
+
if sp_group is not None:
|
78 |
+
hidden_states = _all_in_all_with_text(hidden_states, text_seq_length, sp_group, sp_size, mode=2)
|
79 |
+
text_seq_length = text_seq_length // sp_size
|
80 |
+
|
81 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
82 |
+
|
83 |
+
# linear proj
|
84 |
+
hidden_states = attn.to_out[0](hidden_states)
|
85 |
+
# dropout
|
86 |
+
hidden_states = attn.to_out[1](hidden_states)
|
87 |
+
|
88 |
+
return hidden_states
|
89 |
+
|
90 |
+
|
91 |
+
class Encoder(nn.Module):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
in_channels=3,
|
95 |
+
mid_channels=[128, 512],
|
96 |
+
out_channels=3072,
|
97 |
+
downsample_time=[1, 1],
|
98 |
+
downsample_joint=[1, 1],
|
99 |
+
num_attention_heads=8,
|
100 |
+
attention_head_dim=64,
|
101 |
+
dim=3072,
|
102 |
+
):
|
103 |
+
super(Encoder, self).__init__()
|
104 |
+
|
105 |
+
self.conv_in = nn.Conv2d(in_channels, mid_channels[0], kernel_size=3, stride=1, padding=1)
|
106 |
+
self.resnet1 = nn.ModuleList([ResBlock(mid_channels[0], mid_channels[0]) for _ in range(3)])
|
107 |
+
self.downsample1 = Downsample(mid_channels[0], mid_channels[0], downsample_time[0], downsample_joint[0])
|
108 |
+
self.resnet2 = ResBlock(mid_channels[0], mid_channels[1])
|
109 |
+
self.resnet3 = nn.ModuleList([ResBlock(mid_channels[1], mid_channels[1]) for _ in range(3)])
|
110 |
+
self.downsample2 = Downsample(mid_channels[1], mid_channels[1], downsample_time[1], downsample_joint[1])
|
111 |
+
# self.attn = Attention(
|
112 |
+
# query_dim=dim,
|
113 |
+
# dim_head=attention_head_dim,
|
114 |
+
# heads=num_attention_heads,
|
115 |
+
# qk_norm='layer_norm',
|
116 |
+
# eps=1e-6,
|
117 |
+
# bias=True,
|
118 |
+
# out_bias=True,
|
119 |
+
# processor=AttnProcessor(),
|
120 |
+
# )
|
121 |
+
self.conv_out = nn.Conv2d(mid_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
x = self.conv_in(x)
|
125 |
+
for resnet in self.resnet1:
|
126 |
+
x = resnet(x)
|
127 |
+
x = self.downsample1(x)
|
128 |
+
|
129 |
+
x = self.resnet2(x)
|
130 |
+
for resnet in self.resnet3:
|
131 |
+
x = resnet(x)
|
132 |
+
x = self.downsample2(x)
|
133 |
+
|
134 |
+
# x = x + self.attn(x)
|
135 |
+
x = self.conv_out(x)
|
136 |
+
|
137 |
+
return x
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
class VectorQuantizer(nn.Module):
|
142 |
+
def __init__(self, nb_code, code_dim, is_train=True):
|
143 |
+
super().__init__()
|
144 |
+
self.nb_code = nb_code
|
145 |
+
self.code_dim = code_dim
|
146 |
+
self.mu = 0.99
|
147 |
+
self.reset_codebook()
|
148 |
+
self.reset_count = 0
|
149 |
+
self.usage = torch.zeros((self.nb_code, 1))
|
150 |
+
self.is_train = is_train
|
151 |
+
|
152 |
+
def reset_codebook(self):
|
153 |
+
self.init = False
|
154 |
+
self.code_sum = None
|
155 |
+
self.code_count = None
|
156 |
+
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).cuda())
|
157 |
+
|
158 |
+
def _tile(self, x):
|
159 |
+
nb_code_x, code_dim = x.shape
|
160 |
+
if nb_code_x < self.nb_code:
|
161 |
+
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
|
162 |
+
std = 0.01 / np.sqrt(code_dim)
|
163 |
+
out = x.repeat(n_repeats, 1)
|
164 |
+
out = out + torch.randn_like(out) * std
|
165 |
+
else:
|
166 |
+
out = x
|
167 |
+
return out
|
168 |
+
|
169 |
+
def init_codebook(self, x):
|
170 |
+
if torch.all(self.codebook == 0):
|
171 |
+
out = self._tile(x)
|
172 |
+
self.codebook = out[:self.nb_code]
|
173 |
+
self.code_sum = self.codebook.clone()
|
174 |
+
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
|
175 |
+
if self.is_train:
|
176 |
+
self.init = True
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def update_codebook(self, x, code_idx):
|
180 |
+
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device)
|
181 |
+
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
|
182 |
+
|
183 |
+
code_sum = torch.matmul(code_onehot, x) # [nb_code, code_dim]
|
184 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
185 |
+
|
186 |
+
out = self._tile(x)
|
187 |
+
code_rand = out[torch.randperm(out.shape[0])[:self.nb_code]]
|
188 |
+
|
189 |
+
# Update centres
|
190 |
+
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum
|
191 |
+
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count
|
192 |
+
|
193 |
+
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
|
194 |
+
self.usage = self.usage.to(usage.device)
|
195 |
+
if self.reset_count >= 20: # reset codebook every 20 steps for stability
|
196 |
+
self.reset_count = 0
|
197 |
+
usage = (usage + self.usage >= 1.0).float()
|
198 |
+
else:
|
199 |
+
self.reset_count += 1
|
200 |
+
self.usage = (usage + self.usage >= 1.0).float()
|
201 |
+
usage = torch.ones_like(self.usage, device=x.device)
|
202 |
+
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
|
203 |
+
|
204 |
+
self.codebook = usage * code_update + (1 - usage) * code_rand
|
205 |
+
prob = code_count / torch.sum(code_count)
|
206 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
207 |
+
|
208 |
+
return perplexity
|
209 |
+
|
210 |
+
def preprocess(self, x):
|
211 |
+
# [bs, c, f, j] -> [bs * f * j, c]
|
212 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
213 |
+
x = x.view(-1, x.shape[-1])
|
214 |
+
return x
|
215 |
+
|
216 |
+
def quantize(self, x):
|
217 |
+
# [bs * f * j, dim=3072]
|
218 |
+
# Calculate latent code x_l
|
219 |
+
k_w = self.codebook.t()
|
220 |
+
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, keepdim=True)
|
221 |
+
_, code_idx = torch.min(distance, dim=-1)
|
222 |
+
return code_idx
|
223 |
+
|
224 |
+
def dequantize(self, code_idx):
|
225 |
+
x = F.embedding(code_idx, self.codebook) # indexing: [bs * f * j, 32]
|
226 |
+
return x
|
227 |
+
|
228 |
+
def forward(self, x, return_vq=False):
|
229 |
+
# import pdb; pdb.set_trace()
|
230 |
+
bs, c, f, j = x.shape # SMPL data frames: [bs, 3072, f, j]
|
231 |
+
|
232 |
+
# Preprocess
|
233 |
+
x = self.preprocess(x)
|
234 |
+
# return x.view(bs, f*j, c).contiguous(), None
|
235 |
+
assert x.shape[-1] == self.code_dim
|
236 |
+
|
237 |
+
# Init codebook if not inited
|
238 |
+
if not self.init and self.is_train:
|
239 |
+
self.init_codebook(x)
|
240 |
+
|
241 |
+
# quantize and dequantize through bottleneck
|
242 |
+
code_idx = self.quantize(x)
|
243 |
+
x_d = self.dequantize(code_idx)
|
244 |
+
|
245 |
+
# Update embeddings
|
246 |
+
if self.is_train:
|
247 |
+
perplexity = self.update_codebook(x, code_idx)
|
248 |
+
|
249 |
+
# Loss
|
250 |
+
commit_loss = F.mse_loss(x, x_d.detach())
|
251 |
+
|
252 |
+
# Passthrough
|
253 |
+
x_d = x + (x_d - x).detach()
|
254 |
+
|
255 |
+
if return_vq:
|
256 |
+
return x_d.view(bs, f*j, c).contiguous(), commit_loss
|
257 |
+
# return (x_d, x_d.view(bs, f, j, c).permute(0, 3, 1, 2).contiguous()), commit_loss, perplexity
|
258 |
+
|
259 |
+
# Postprocess
|
260 |
+
x_d = x_d.view(bs, f, j, c).permute(0, 3, 1, 2).contiguous()
|
261 |
+
|
262 |
+
if self.is_train:
|
263 |
+
return x_d, commit_loss, perplexity
|
264 |
+
else:
|
265 |
+
return x_d, commit_loss
|
266 |
+
|
267 |
+
|
268 |
+
class Decoder(nn.Module):
|
269 |
+
def __init__(
|
270 |
+
self,
|
271 |
+
in_channels=3072,
|
272 |
+
mid_channels=[512, 128],
|
273 |
+
out_channels=3,
|
274 |
+
upsample_rate=None,
|
275 |
+
frame_upsample_rate=[1.0, 1.0],
|
276 |
+
joint_upsample_rate=[1.0, 1.0],
|
277 |
+
dim=128,
|
278 |
+
attention_head_dim=64,
|
279 |
+
num_attention_heads=8,
|
280 |
+
):
|
281 |
+
super(Decoder, self).__init__()
|
282 |
+
|
283 |
+
self.conv_in = nn.Conv2d(in_channels, mid_channels[0], kernel_size=3, stride=1, padding=1)
|
284 |
+
self.resnet1 = nn.ModuleList([ResBlock(mid_channels[0], mid_channels[0]) for _ in range(3)])
|
285 |
+
self.upsample1 = Upsample(mid_channels[0], mid_channels[0], frame_upsample_rate=frame_upsample_rate[0], joint_upsample_rate=joint_upsample_rate[0])
|
286 |
+
self.resnet2 = ResBlock(mid_channels[0], mid_channels[1])
|
287 |
+
self.resnet3 = nn.ModuleList([ResBlock(mid_channels[1], mid_channels[1]) for _ in range(3)])
|
288 |
+
self.upsample2 = Upsample(mid_channels[1], mid_channels[1], frame_upsample_rate=frame_upsample_rate[1], joint_upsample_rate=joint_upsample_rate[1])
|
289 |
+
# self.attn = Attention(
|
290 |
+
# query_dim=dim,
|
291 |
+
# dim_head=attention_head_dim,
|
292 |
+
# heads=num_attention_heads,
|
293 |
+
# qk_norm='layer_norm',
|
294 |
+
# eps=1e-6,
|
295 |
+
# bias=True,
|
296 |
+
# out_bias=True,
|
297 |
+
# processor=AttnProcessor(),
|
298 |
+
# )
|
299 |
+
self.conv_out = nn.Conv2d(mid_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
x = self.conv_in(x)
|
303 |
+
for resnet in self.resnet1:
|
304 |
+
x = resnet(x)
|
305 |
+
x = self.upsample1(x)
|
306 |
+
|
307 |
+
x = self.resnet2(x)
|
308 |
+
for resnet in self.resnet3:
|
309 |
+
x = resnet(x)
|
310 |
+
x = self.upsample2(x)
|
311 |
+
|
312 |
+
# x = x + self.attn(x)
|
313 |
+
x = self.conv_out(x)
|
314 |
+
|
315 |
+
return x
|
316 |
+
|
317 |
+
|
318 |
+
class Upsample(nn.Module):
|
319 |
+
def __init__(
|
320 |
+
self,
|
321 |
+
in_channels,
|
322 |
+
out_channels,
|
323 |
+
upsample_rate=None,
|
324 |
+
frame_upsample_rate=None,
|
325 |
+
joint_upsample_rate=None,
|
326 |
+
):
|
327 |
+
super(Upsample, self).__init__()
|
328 |
+
|
329 |
+
self.upsampler = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
330 |
+
self.upsample_rate = upsample_rate
|
331 |
+
self.frame_upsample_rate = frame_upsample_rate
|
332 |
+
self.joint_upsample_rate = joint_upsample_rate
|
333 |
+
self.upsample_rate = upsample_rate
|
334 |
+
|
335 |
+
def forward(self, inputs):
|
336 |
+
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
337 |
+
# split first frame
|
338 |
+
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
339 |
+
|
340 |
+
if self.upsample_rate is not None:
|
341 |
+
# import pdb; pdb.set_trace()
|
342 |
+
x_first = F.interpolate(x_first, scale_factor=self.upsample_rate)
|
343 |
+
x_rest = F.interpolate(x_rest, scale_factor=self.upsample_rate)
|
344 |
+
else:
|
345 |
+
# import pdb; pdb.set_trace()
|
346 |
+
# x_first = F.interpolate(x_first, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="bilinear", align_corners=True)
|
347 |
+
x_rest = F.interpolate(x_rest, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="bilinear", align_corners=True)
|
348 |
+
x_first = x_first[:, :, None, :]
|
349 |
+
inputs = torch.cat([x_first, x_rest], dim=2)
|
350 |
+
elif inputs.shape[2] > 1:
|
351 |
+
if self.upsample_rate is not None:
|
352 |
+
inputs = F.interpolate(inputs, scale_factor=self.upsample_rate)
|
353 |
+
else:
|
354 |
+
inputs = F.interpolate(inputs, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="bilinear", align_corners=True)
|
355 |
+
else:
|
356 |
+
inputs = inputs.squeeze(2)
|
357 |
+
if self.upsample_rate is not None:
|
358 |
+
inputs = F.interpolate(inputs, scale_factor=self.upsample_rate)
|
359 |
+
else:
|
360 |
+
inputs = F.interpolate(inputs, scale_factor=(self.frame_upsample_rate, self.joint_upsample_rate), mode="linear", align_corners=True)
|
361 |
+
inputs = inputs[:, :, None, :, :]
|
362 |
+
|
363 |
+
b, c, t, j = inputs.shape
|
364 |
+
inputs = inputs.permute(0, 2, 1, 3).reshape(b * t, c, j)
|
365 |
+
inputs = self.upsampler(inputs)
|
366 |
+
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3)
|
367 |
+
|
368 |
+
return inputs
|
369 |
+
|
370 |
+
|
371 |
+
class Downsample(nn.Module):
|
372 |
+
def __init__(
|
373 |
+
self,
|
374 |
+
in_channels,
|
375 |
+
out_channels,
|
376 |
+
frame_downsample_rate,
|
377 |
+
joint_downsample_rate
|
378 |
+
):
|
379 |
+
super(Downsample, self).__init__()
|
380 |
+
|
381 |
+
self.frame_downsample_rate = frame_downsample_rate
|
382 |
+
self.joint_downsample_rate = joint_downsample_rate
|
383 |
+
self.joint_downsample = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=self.joint_downsample_rate, padding=1)
|
384 |
+
|
385 |
+
def forward(self, x):
|
386 |
+
# (batch_size, channels, frames, joints) -> (batch_size * joints, channels, frames)
|
387 |
+
if self.frame_downsample_rate > 1:
|
388 |
+
batch_size, channels, frames, joints = x.shape
|
389 |
+
x = x.permute(0, 3, 1, 2).reshape(batch_size * joints, channels, frames)
|
390 |
+
if x.shape[-1] % 2 == 1:
|
391 |
+
x_first, x_rest = x[..., 0], x[..., 1:]
|
392 |
+
if x_rest.shape[-1] > 0:
|
393 |
+
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
394 |
+
x_rest = F.avg_pool1d(x_rest, kernel_size=self.frame_downsample_rate, stride=self.frame_downsample_rate)
|
395 |
+
|
396 |
+
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
397 |
+
# (batch_size * joints, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, joints)
|
398 |
+
x = x.reshape(batch_size, joints, channels, x.shape[-1]).permute(0, 2, 3, 1)
|
399 |
+
else:
|
400 |
+
# (batch_size * joints, channels, frames) -> (batch_size * joints, channels, frames // 2)
|
401 |
+
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
402 |
+
# (batch_size * joints, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
403 |
+
x = x.reshape(batch_size, joints, channels, x.shape[-1]).permute(0, 2, 3, 1)
|
404 |
+
|
405 |
+
# Pad the tensor
|
406 |
+
# pad = (0, 1)
|
407 |
+
# x = F.pad(x, pad, mode="constant", value=0)
|
408 |
+
batch_size, channels, frames, joints = x.shape
|
409 |
+
# (batch_size, channels, frames, joints) -> (batch_size * frames, channels, joints)
|
410 |
+
x = x.permute(0, 2, 1, 3).reshape(batch_size * frames, channels, joints)
|
411 |
+
x = self.joint_downsample(x)
|
412 |
+
# (batch_size * frames, channels, joints) -> (batch_size, channels, frames, joints)
|
413 |
+
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2]).permute(0, 2, 1, 3)
|
414 |
+
return x
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
class ResBlock(nn.Module):
|
419 |
+
def __init__(self,
|
420 |
+
in_channels,
|
421 |
+
out_channels,
|
422 |
+
group_num=32,
|
423 |
+
max_channels=512):
|
424 |
+
super(ResBlock, self).__init__()
|
425 |
+
skip = max(1, max_channels // out_channels - 1)
|
426 |
+
self.block = nn.Sequential(
|
427 |
+
nn.GroupNorm(group_num, in_channels, eps=1e-06, affine=True),
|
428 |
+
nn.SiLU(),
|
429 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=skip, dilation=skip),
|
430 |
+
nn.GroupNorm(group_num, out_channels, eps=1e-06, affine=True),
|
431 |
+
nn.SiLU(),
|
432 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0),
|
433 |
+
)
|
434 |
+
self.conv_short = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) if in_channels != out_channels else nn.Identity()
|
435 |
+
|
436 |
+
def forward(self, x):
|
437 |
+
hidden_states = self.block(x)
|
438 |
+
if hidden_states.shape != x.shape:
|
439 |
+
x = self.conv_short(x)
|
440 |
+
x = x + hidden_states
|
441 |
+
return x
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
class SMPL_VQVAE(nn.Module):
|
446 |
+
def __init__(self, encoder, decoder, vq):
|
447 |
+
super(SMPL_VQVAE, self).__init__()
|
448 |
+
|
449 |
+
self.encoder = encoder
|
450 |
+
self.decoder = decoder
|
451 |
+
self.vq = vq
|
452 |
+
|
453 |
+
def to(self, device):
|
454 |
+
self.encoder = self.encoder.to(device)
|
455 |
+
self.decoder = self.decoder.to(device)
|
456 |
+
self.vq = self.vq.to(device)
|
457 |
+
self.device = device
|
458 |
+
return self
|
459 |
+
|
460 |
+
def encdec_slice_frames(self, x, frame_batch_size, encdec, return_vq):
|
461 |
+
num_frames = x.shape[2]
|
462 |
+
remaining_frames = num_frames % frame_batch_size
|
463 |
+
x_output = []
|
464 |
+
loss_output = []
|
465 |
+
perplexity_output = []
|
466 |
+
for i in range(num_frames // frame_batch_size):
|
467 |
+
remaining_frames = num_frames % frame_batch_size
|
468 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
469 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
470 |
+
x_intermediate = x[:, :, start_frame:end_frame]
|
471 |
+
x_intermediate = encdec(x_intermediate)
|
472 |
+
# if encdec == self.encoder and self.vq is not None:
|
473 |
+
# x_intermediate, loss, perplexity = self.vq(x_intermediate)
|
474 |
+
# x_output.append(x_intermediate)
|
475 |
+
# loss_output.append(loss)
|
476 |
+
# perplexity_output.append(perplexity)
|
477 |
+
# else:
|
478 |
+
# x_output.append(x_intermediate)
|
479 |
+
x_output.append(x_intermediate)
|
480 |
+
if encdec == self.encoder and self.vq is not None and not self.vq.is_train:
|
481 |
+
x_output, loss = self.vq(torch.cat(x_output, dim=2), return_vq=return_vq)
|
482 |
+
return x_output, loss
|
483 |
+
elif encdec == self.encoder and self.vq is not None and self.vq.is_train:
|
484 |
+
x_output, loss, preplexity = self.vq(torch.cat(x_output, dim=2))
|
485 |
+
return x_output, loss, preplexity
|
486 |
+
else:
|
487 |
+
return torch.cat(x_output, dim=2), None, None
|
488 |
+
|
489 |
+
def forward(self, x, return_vq=False):
|
490 |
+
x = x.permute(0, 3, 1, 2)
|
491 |
+
if not self.vq.is_train:
|
492 |
+
x, loss = self.encdec_slice_frames(x, frame_batch_size=8, encdec=self.encoder, return_vq=return_vq)
|
493 |
+
else:
|
494 |
+
x, loss, perplexity = self.encdec_slice_frames(x, frame_batch_size=8, encdec=self.encoder, return_vq=return_vq)
|
495 |
+
if return_vq:
|
496 |
+
return x, loss
|
497 |
+
x, _, _ = self.encdec_slice_frames(x, frame_batch_size=2, encdec=self.decoder, return_vq=return_vq)
|
498 |
+
x = x.permute(0, 2, 3, 1)
|
499 |
+
if self.vq.is_train:
|
500 |
+
return x, loss, perplexity
|
501 |
+
return x, loss
|
motion_extractor.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# motion_extractor.py
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import pickle
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
# Load the TorchScript model once at the top
|
10 |
+
model_path = '/gemini/space/human_guozz2/dyb/MTVCrafter-main/nlf_l_multi_0.3.2.torchscript'
|
11 |
+
assert os.path.exists(model_path), f"Model file not found at {model_path}"
|
12 |
+
model = torch.jit.load(model_path).cuda().eval()
|
13 |
+
|
14 |
+
def extract_pkl_from_video(video_path):
|
15 |
+
output_file = "temp_motion.pkl"
|
16 |
+
cap = cv2.VideoCapture(video_path)
|
17 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
18 |
+
video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
19 |
+
video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
20 |
+
|
21 |
+
pose_results = {
|
22 |
+
'joints3d_nonparam': [],
|
23 |
+
}
|
24 |
+
|
25 |
+
with torch.inference_mode(), torch.device('cuda'):
|
26 |
+
frame_idx = 0
|
27 |
+
while cap.isOpened():
|
28 |
+
ret, frame = cap.read()
|
29 |
+
if not ret:
|
30 |
+
break
|
31 |
+
|
32 |
+
# Convert frame to tensor
|
33 |
+
frame_tensor = torch.from_numpy(frame).cuda()
|
34 |
+
frame_batch = frame_tensor.unsqueeze(0).permute(0,3,1,2)
|
35 |
+
# Model inference
|
36 |
+
pred = model.detect_smpl_batched(frame_batch)
|
37 |
+
# Collect pose data
|
38 |
+
for key in pose_results.keys():
|
39 |
+
if key in pred:
|
40 |
+
#pose_results[key].append(pred[key].cpu().numpy())
|
41 |
+
pose_results[key].append(pred[key])
|
42 |
+
else:
|
43 |
+
pose_results[key].append(None)
|
44 |
+
|
45 |
+
frame_idx += 1
|
46 |
+
|
47 |
+
cap.release()
|
48 |
+
|
49 |
+
# Prepare output data
|
50 |
+
output_data = {
|
51 |
+
'video_path': video_path,
|
52 |
+
'video_length': frame_count,
|
53 |
+
'video_width': video_width,
|
54 |
+
'video_height': video_height,
|
55 |
+
'pose': pose_results
|
56 |
+
}
|
57 |
+
|
58 |
+
# Save to pkl file
|
59 |
+
with open(output_file, 'wb') as f:
|
60 |
+
pickle.dump(output_data, f)
|
61 |
+
|
62 |
+
return output_file
|
utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def concat_images(images, direction='horizontal', pad=0, pad_value=0):
|
8 |
+
if len(images) == 1:
|
9 |
+
return images[0]
|
10 |
+
is_pil = isinstance(images[0], Image.Image)
|
11 |
+
if is_pil:
|
12 |
+
images = [np.array(image) for image in images]
|
13 |
+
if direction == 'horizontal':
|
14 |
+
height = max([image.shape[0] for image in images])
|
15 |
+
width = sum([image.shape[1] for image in images]) + pad * (len(images) - 1)
|
16 |
+
new_image = np.full((height, width, images[0].shape[2]), pad_value, dtype=images[0].dtype)
|
17 |
+
begin = 0
|
18 |
+
for image in images:
|
19 |
+
end = begin + image.shape[1]
|
20 |
+
new_image[: image.shape[0], begin:end] = image
|
21 |
+
begin = end + pad
|
22 |
+
elif direction == 'vertical':
|
23 |
+
height = sum([image.shape[0] for image in images]) + pad * (len(images) - 1)
|
24 |
+
width = max([image.shape[1] for image in images])
|
25 |
+
new_image = np.full((height, width, images[0].shape[2]), pad_value, dtype=images[0].dtype)
|
26 |
+
begin = 0
|
27 |
+
for image in images:
|
28 |
+
end = begin + image.shape[0]
|
29 |
+
new_image[begin:end, : image.shape[1]] = image
|
30 |
+
begin = end + pad
|
31 |
+
else:
|
32 |
+
assert False
|
33 |
+
if is_pil:
|
34 |
+
new_image = Image.fromarray(new_image)
|
35 |
+
return new_image
|
36 |
+
|
37 |
+
def concat_images_grid(images, cols, pad=0, pad_value=0):
|
38 |
+
new_images = []
|
39 |
+
while len(images) > 0:
|
40 |
+
new_image = concat_images(images[:cols], pad=pad, pad_value=pad_value)
|
41 |
+
new_images.append(new_image)
|
42 |
+
images = images[cols:]
|
43 |
+
new_image = concat_images(new_images, direction='vertical', pad=pad, pad_value=pad_value)
|
44 |
+
return new_image
|
45 |
+
|
46 |
+
def sample_video(video, indexes, method=2):
|
47 |
+
if method == 1:
|
48 |
+
frames = video.get_batch(indexes)
|
49 |
+
frames = frames.numpy() if isinstance(frames, torch.Tensor) else frames.asnumpy()
|
50 |
+
elif method == 2:
|
51 |
+
max_idx = indexes.max() + 1
|
52 |
+
all_indexes = np.arange(max_idx, dtype=int)
|
53 |
+
frames = video.get_batch(all_indexes)
|
54 |
+
frames = frames.numpy() if isinstance(frames, torch.Tensor) else frames.asnumpy()
|
55 |
+
frames = frames[indexes]
|
56 |
+
else:
|
57 |
+
assert False
|
58 |
+
return frames
|
59 |
+
|
60 |
+
def get_sample_indexes(video_length, num_frames, stride):
|
61 |
+
assert num_frames * stride <= video_length
|
62 |
+
sample_length = min(video_length, (num_frames - 1) * stride + 1)
|
63 |
+
start_idx = 0 + random.randint(0, video_length - sample_length)
|
64 |
+
sample_indexes = np.linspace(start_idx, start_idx + sample_length - 1, num_frames, dtype=int)
|
65 |
+
return sample_indexes
|
66 |
+
|
67 |
+
def get_new_height_width(data_dict, dst_height, dst_width):
|
68 |
+
height = data_dict['video_height']
|
69 |
+
width = data_dict['video_width']
|
70 |
+
if float(dst_height) / height < float(dst_width) / width:
|
71 |
+
new_height = int(round(float(dst_width) / width * height))
|
72 |
+
new_width = dst_width
|
73 |
+
else:
|
74 |
+
new_height = dst_height
|
75 |
+
new_width = int(round(float(dst_height) / height * width))
|
76 |
+
assert dst_width <= new_width and dst_height <= new_height
|
77 |
+
return new_height, new_width
|