Spaces:
Runtime error
Runtime error
Update inference_engine.py
Browse files- inference_engine.py +7 -1
inference_engine.py
CHANGED
@@ -9,6 +9,7 @@ from torchvision.transforms import ToPILImage, transforms, InterpolationMode, fu
|
|
9 |
import numpy as np
|
10 |
import pickle
|
11 |
import copy
|
|
|
12 |
from draw_pose import get_pose_images
|
13 |
from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width
|
14 |
|
@@ -18,7 +19,7 @@ def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, ds
|
|
18 |
normalize = transforms.Normalize([0.5], [0.5])
|
19 |
pretrained_model_path = "THUDM/CogVideoX-5b"
|
20 |
transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX"
|
21 |
-
tokenizer_path = "mp_rank_00_model_states.pt"
|
22 |
|
23 |
with open(motion_data_path, 'rb') as f:
|
24 |
data_list = pickle.load(f)
|
@@ -38,6 +39,11 @@ def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, ds
|
|
38 |
pipe.vae.enable_slicing()
|
39 |
|
40 |
# load VQVAE
|
|
|
|
|
|
|
|
|
|
|
41 |
state_dict = torch.load(tokenizer_path, map_location="cpu")
|
42 |
motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1])
|
43 |
motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False)
|
|
|
9 |
import numpy as np
|
10 |
import pickle
|
11 |
import copy
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
from draw_pose import get_pose_images
|
14 |
from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width
|
15 |
|
|
|
19 |
normalize = transforms.Normalize([0.5], [0.5])
|
20 |
pretrained_model_path = "THUDM/CogVideoX-5b"
|
21 |
transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX"
|
22 |
+
tokenizer_path = "4DMoT/mp_rank_00_model_states.pt"
|
23 |
|
24 |
with open(motion_data_path, 'rb') as f:
|
25 |
data_list = pickle.load(f)
|
|
|
39 |
pipe.vae.enable_slicing()
|
40 |
|
41 |
# load VQVAE
|
42 |
+
|
43 |
+
vqvae_model_path = hf_hub_download(
|
44 |
+
repo_id="yanboding/MTVCrafter",
|
45 |
+
filename="4DMoT/mp_rank_00_model_states.pt"
|
46 |
+
)
|
47 |
state_dict = torch.load(tokenizer_path, map_location="cpu")
|
48 |
motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1])
|
49 |
motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False)
|