yanboding commited on
Commit
9d9257a
·
verified ·
1 Parent(s): c3941ce

Update inference_engine.py

Browse files
Files changed (1) hide show
  1. inference_engine.py +7 -1
inference_engine.py CHANGED
@@ -9,6 +9,7 @@ from torchvision.transforms import ToPILImage, transforms, InterpolationMode, fu
9
  import numpy as np
10
  import pickle
11
  import copy
 
12
  from draw_pose import get_pose_images
13
  from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width
14
 
@@ -18,7 +19,7 @@ def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, ds
18
  normalize = transforms.Normalize([0.5], [0.5])
19
  pretrained_model_path = "THUDM/CogVideoX-5b"
20
  transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX"
21
- tokenizer_path = "mp_rank_00_model_states.pt"
22
 
23
  with open(motion_data_path, 'rb') as f:
24
  data_list = pickle.load(f)
@@ -38,6 +39,11 @@ def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, ds
38
  pipe.vae.enable_slicing()
39
 
40
  # load VQVAE
 
 
 
 
 
41
  state_dict = torch.load(tokenizer_path, map_location="cpu")
42
  motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1])
43
  motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False)
 
9
  import numpy as np
10
  import pickle
11
  import copy
12
+ from huggingface_hub import hf_hub_download
13
  from draw_pose import get_pose_images
14
  from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width
15
 
 
19
  normalize = transforms.Normalize([0.5], [0.5])
20
  pretrained_model_path = "THUDM/CogVideoX-5b"
21
  transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX"
22
+ tokenizer_path = "4DMoT/mp_rank_00_model_states.pt"
23
 
24
  with open(motion_data_path, 'rb') as f:
25
  data_list = pickle.load(f)
 
39
  pipe.vae.enable_slicing()
40
 
41
  # load VQVAE
42
+
43
+ vqvae_model_path = hf_hub_download(
44
+ repo_id="yanboding/MTVCrafter",
45
+ filename="4DMoT/mp_rank_00_model_states.pt"
46
+ )
47
  state_dict = torch.load(tokenizer_path, map_location="cpu")
48
  motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1])
49
  motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False)