Spaces:
Running
Running
| import os | |
| from os.path import join as pjoin | |
| import torch | |
| import torch.nn.functional as F | |
| from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer | |
| from models.vq.model import RVQVAE, LengthEstimator | |
| from options.eval_option import EvalT2MOptions | |
| from utils.get_opt import get_opt | |
| from utils.fixseed import fixseed | |
| from visualization.joints2bvh import Joint2BVHConvertor | |
| from torch.distributions.categorical import Categorical | |
| from utils.motion_process import recover_from_ric | |
| from utils.plot_script import plot_3d_motion | |
| from utils.paramUtil import t2m_kinematic_chain | |
| import numpy as np | |
| clip_version = 'ViT-B/32' | |
| def load_vq_model(vq_opt): | |
| # opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt') | |
| vq_model = RVQVAE(vq_opt, | |
| vq_opt.dim_pose, | |
| vq_opt.nb_code, | |
| vq_opt.code_dim, | |
| vq_opt.output_emb_width, | |
| vq_opt.down_t, | |
| vq_opt.stride_t, | |
| vq_opt.width, | |
| vq_opt.depth, | |
| vq_opt.dilation_growth_rate, | |
| vq_opt.vq_act, | |
| vq_opt.vq_norm) | |
| ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'), | |
| map_location='cpu') | |
| model_key = 'vq_model' if 'vq_model' in ckpt else 'net' | |
| vq_model.load_state_dict(ckpt[model_key]) | |
| print(f'Loading VQ Model {vq_opt.name} Completed!') | |
| return vq_model, vq_opt | |
| def load_trans_model(model_opt, opt, which_model): | |
| t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim, | |
| cond_mode='text', | |
| latent_dim=model_opt.latent_dim, | |
| ff_size=model_opt.ff_size, | |
| num_layers=model_opt.n_layers, | |
| num_heads=model_opt.n_heads, | |
| dropout=model_opt.dropout, | |
| clip_dim=512, | |
| cond_drop_prob=model_opt.cond_drop_prob, | |
| clip_version=clip_version, | |
| opt=model_opt) | |
| ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model), | |
| map_location='cpu') | |
| model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans' | |
| # print(ckpt.keys()) | |
| missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False) | |
| assert len(unexpected_keys) == 0 | |
| assert all([k.startswith('clip_model.') for k in missing_keys]) | |
| print(f'Loading Transformer {opt.name} from epoch {ckpt["ep"]}!') | |
| return t2m_transformer | |
| def load_res_model(res_opt, vq_opt, opt): | |
| res_opt.num_quantizers = vq_opt.num_quantizers | |
| res_opt.num_tokens = vq_opt.nb_code | |
| res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim, | |
| cond_mode='text', | |
| latent_dim=res_opt.latent_dim, | |
| ff_size=res_opt.ff_size, | |
| num_layers=res_opt.n_layers, | |
| num_heads=res_opt.n_heads, | |
| dropout=res_opt.dropout, | |
| clip_dim=512, | |
| shared_codebook=vq_opt.shared_codebook, | |
| cond_drop_prob=res_opt.cond_drop_prob, | |
| # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None, | |
| share_weight=res_opt.share_weight, | |
| clip_version=clip_version, | |
| opt=res_opt) | |
| ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'), | |
| map_location=opt.device) | |
| missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False) | |
| assert len(unexpected_keys) == 0 | |
| assert all([k.startswith('clip_model.') for k in missing_keys]) | |
| print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!') | |
| return res_transformer | |
| def load_len_estimator(opt): | |
| model = LengthEstimator(512, 50) | |
| ckpt = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_estimator', 'model', 'finest.tar'), | |
| map_location=opt.device) | |
| model.load_state_dict(ckpt['estimator']) | |
| print(f'Loading Length Estimator from epoch {ckpt["epoch"]}!') | |
| return model | |
| if __name__ == '__main__': | |
| parser = EvalT2MOptions() | |
| opt = parser.parse() | |
| fixseed(opt.seed) | |
| opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id)) | |
| torch.autograd.set_detect_anomaly(True) | |
| dim_pose = 251 if opt.dataset_name == 'kit' else 263 | |
| # out_dir = pjoin(opt.check) | |
| root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) | |
| model_dir = pjoin(root_dir, 'model') | |
| result_dir = pjoin('./generation', opt.ext) | |
| joints_dir = pjoin(result_dir, 'joints') | |
| animation_dir = pjoin(result_dir, 'animations') | |
| os.makedirs(joints_dir, exist_ok=True) | |
| os.makedirs(animation_dir,exist_ok=True) | |
| model_opt_path = pjoin(root_dir, 'opt.txt') | |
| model_opt = get_opt(model_opt_path, device=opt.device) | |
| ####################### | |
| ######Loading RVQ###### | |
| ####################### | |
| vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt') | |
| vq_opt = get_opt(vq_opt_path, device=opt.device) | |
| vq_opt.dim_pose = dim_pose | |
| vq_model, vq_opt = load_vq_model(vq_opt) | |
| model_opt.num_tokens = vq_opt.nb_code | |
| model_opt.num_quantizers = vq_opt.num_quantizers | |
| model_opt.code_dim = vq_opt.code_dim | |
| ################################# | |
| ######Loading R-Transformer###### | |
| ################################# | |
| res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt') | |
| res_opt = get_opt(res_opt_path, device=opt.device) | |
| res_model = load_res_model(res_opt, vq_opt, opt) | |
| assert res_opt.vq_name == model_opt.vq_name | |
| ################################# | |
| ######Loading M-Transformer###### | |
| ################################# | |
| t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar') | |
| ################################## | |
| #####Loading Length Predictor##### | |
| ################################## | |
| length_estimator = load_len_estimator(model_opt) | |
| t2m_transformer.eval() | |
| vq_model.eval() | |
| res_model.eval() | |
| length_estimator.eval() | |
| res_model.to(opt.device) | |
| t2m_transformer.to(opt.device) | |
| vq_model.to(opt.device) | |
| length_estimator.to(opt.device) | |
| ##### ---- Dataloader ---- ##### | |
| opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22 | |
| mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy')) | |
| std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy')) | |
| def inv_transform(data): | |
| return data * std + mean | |
| prompt_list = [] | |
| length_list = [] | |
| est_length = False | |
| if opt.text_prompt != "": | |
| prompt_list.append(opt.text_prompt) | |
| if opt.motion_length == 0: | |
| est_length = True | |
| else: | |
| length_list.append(opt.motion_length) | |
| elif opt.text_path != "": | |
| with open(opt.text_path, 'r') as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| infos = line.split('#') | |
| prompt_list.append(infos[0]) | |
| if len(infos) == 1 or (not infos[1].isdigit()): | |
| est_length = True | |
| length_list = [] | |
| else: | |
| length_list.append(int(infos[-1])) | |
| else: | |
| raise "A text prompt, or a file a text prompts are required!!!" | |
| # print('loading checkpoint {}'.format(file)) | |
| if est_length: | |
| print("Since no motion length are specified, we will use estimated motion lengthes!!") | |
| text_embedding = t2m_transformer.encode_text(prompt_list) | |
| pred_dis = length_estimator(text_embedding) | |
| probs = F.softmax(pred_dis, dim=-1) # (b, ntoken) | |
| token_lens = Categorical(probs).sample() # (b, seqlen) | |
| # lengths = torch.multinomial() | |
| else: | |
| token_lens = torch.LongTensor(length_list) // 4 | |
| token_lens = token_lens.to(opt.device).long() | |
| m_length = token_lens * 4 | |
| captions = prompt_list | |
| sample = 0 | |
| kinematic_chain = t2m_kinematic_chain | |
| converter = Joint2BVHConvertor() | |
| for r in range(opt.repeat_times): | |
| print("-->Repeat %d"%r) | |
| with torch.no_grad(): | |
| mids = t2m_transformer.generate(captions, token_lens, | |
| timesteps=opt.time_steps, | |
| cond_scale=opt.cond_scale, | |
| temperature=opt.temperature, | |
| topk_filter_thres=opt.topkr, | |
| gsample=opt.gumbel_sample) | |
| # print(mids) | |
| # print(mids.shape) | |
| mids = res_model.generate(mids, captions, token_lens, temperature=1, cond_scale=5) | |
| pred_motions = vq_model.forward_decoder(mids) | |
| pred_motions = pred_motions.detach().cpu().numpy() | |
| data = inv_transform(pred_motions) | |
| for k, (caption, joint_data) in enumerate(zip(captions, data)): | |
| print("---->Sample %d: %s %d"%(k, caption, m_length[k])) | |
| animation_path = pjoin(animation_dir, str(k)) | |
| joint_path = pjoin(joints_dir, str(k)) | |
| os.makedirs(animation_path, exist_ok=True) | |
| os.makedirs(joint_path, exist_ok=True) | |
| joint_data = joint_data[:m_length[k]] | |
| joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy() | |
| bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k])) | |
| _, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100) | |
| bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k])) | |
| _, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False) | |
| save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k])) | |
| ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k])) | |
| plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=caption, fps=20) | |
| plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20) | |
| np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint) | |
| np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint) |