Spaces:
Runtime error
Runtime error
| import pickle | |
| import time | |
| import numpy as np | |
| import scipy, cv2, os, sys, argparse | |
| from tqdm import tqdm | |
| import torch | |
| import librosa | |
| from networks import define_G | |
| from pcavs.config.AudioConfig import AudioConfig | |
| sys.path.append('spectre') | |
| from config import cfg as spectre_cfg | |
| from src.spectre import SPECTRE | |
| from audio2mesh_helper import * | |
| from pcavs.models import create_model, networks | |
| torch.manual_seed(0) | |
| from scipy.signal import savgol_filter | |
| class SimpleWrapperV2(nn.Module): | |
| def __init__(self, cfg, use_ref=True, exp_dim=53, noload=False) -> None: | |
| super().__init__() | |
| self.audio_encoder = networks.define_A_sync(cfg) | |
| self.mapping1 = nn.Linear(512+exp_dim, exp_dim) | |
| nn.init.constant_(self.mapping1.weight, 0.) | |
| nn.init.constant_(self.mapping1.bias, 0.) | |
| self.use_ref = use_ref | |
| def forward(self, x, ref, use_tanh=False): | |
| x = self.audio_encoder.forward_feature(x).view(x.size(0), -1) | |
| ref_reshape = ref.reshape(x.size(0), -1) #20, -1 | |
| y = self.mapping1(torch.cat([x, ref_reshape], dim=1)) | |
| if self.use_ref: | |
| out = y.reshape(ref.shape[0], ref.shape[1], -1) + ref # resudial | |
| else: | |
| out = y.reshape(ref.shape[0], ref.shape[1], -1) | |
| if use_tanh: | |
| out[:, :50] = torch.tanh(out[:, :50]) * 3 | |
| return out | |
| class Audio2Mesh(object): | |
| def __init__(self, args) -> None: | |
| self.args = args | |
| spectre_cfg.model.use_tex = True | |
| spectre_cfg.model.mask_type = args.mask_type | |
| spectre_cfg.debug = self.args.debug | |
| spectre_cfg.model.netA_sync = 'ressesync' | |
| spectre_cfg.model.gpu_ids = [0] | |
| self.spectre = SPECTRE(spectre_cfg) | |
| self.spectre.eval() | |
| self.face_tracker = None #FaceTrackerV2() # face landmark detection | |
| self.mel_step_size = 16 | |
| self.fps = args.fps | |
| self.Nw = args.tframes | |
| self.device = self.args.device | |
| self.image_size = self.args.image_size | |
| ### only audio | |
| args.netA_sync = 'ressesync' | |
| args.gpu_ids = [0] | |
| args.exp_dim = 53 | |
| args.use_tanh = False | |
| args.K = 20 | |
| self.audio2exp = 'pcavs' | |
| # | |
| self.avmodel = SimpleWrapperV2(args, exp_dim=args.exp_dim).cuda() | |
| self.avmodel.load_state_dict(torch.load('../packages/pretrained/audio2expression_v2_model.tar')['opt']) | |
| # 5, 160 = 25fps | |
| self.audio = AudioConfig(frame_rate=args.fps, num_frames_per_clip=5, hop_size=160) | |
| with open(os.path.join(args.source_dir, 'deca_infos.pkl'), 'rb') as f: # ? | |
| self.fitting_coeffs = pickle.load(f, encoding='bytes') | |
| self.coeffs_dict = { key: torch.Tensor(self.fitting_coeffs[key]).cuda().squeeze(1) for key in ['cam', 'pose', 'light', 'tex', 'shape', 'exp']} | |
| #### find the close month | |
| exp_tensors = torch.sum(self.coeffs_dict['exp'], dim=1) | |
| ssss, sorted_indices = torch.sort(exp_tensors) | |
| self.exp_id = sorted_indices[0].item() | |
| if '.ts' in args.render_path: | |
| self.render = torch.jit.load(args.render_path).cuda() | |
| self.trt = True | |
| else: | |
| self.render = define_G(self.Nw*6, 3, args.ngf, args.netR).eval().cuda() | |
| self.render.load_state_dict(torch.load(args.render_path)) | |
| self.trt = False | |
| print('loaded cached images...') | |
| def cg2real(self, rendedimages, start_frame=0): | |
| ## load original image and the mask | |
| self.source_images = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_frame'),\ | |
| resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:] | |
| self.source_masks = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_mask'),\ | |
| resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:] | |
| self.source_masks = torch.FloatTensor(np.transpose(self.source_masks,(0,3,1,2))/255.) | |
| self.padded_real_tensor = torch.FloatTensor(np.transpose(self.source_images,(0,3,1,2))/255.) | |
| ## padding the rended_imgs | |
| paded_tensor = torch.cat([rendedimages[0:1]]* (self.Nw // 2) + [rendedimages] + [rendedimages[-1:]]* (self.Nw // 2)).contiguous() | |
| paded_mask_tensor = torch.cat([self.source_masks[0:1]]* (self.Nw // 2) + [self.source_masks] + [self.source_masks[-1:]]* (self.Nw // 2)).contiguous() | |
| paded_real_tensor = torch.cat([self.padded_real_tensor[0:1]]* (self.Nw // 2) + [self.padded_real_tensor] + [self.padded_real_tensor[-1:]]* (self.Nw // 2)).contiguous() | |
| # paded_mask_tensor = maskErosion(paded_mask_tensor, offY=self.args.mask) | |
| padded_input = ((paded_real_tensor-0.5)*2 ) # *(1-paded_mask_tensor) | |
| padded_input = torch.nn.functional.interpolate(padded_input, (self.image_size, self.image_size), mode='bilinear', align_corners=False) | |
| paded_tensor = torch.nn.functional.interpolate(paded_tensor, (self.image_size, self.image_size), mode='bilinear', align_corners=False) | |
| paded_tensor = (paded_tensor-0.5)*2 | |
| result = [] | |
| for index in tqdm(range(0, len(rendedimages), self.args.renderbs), desc='CG2REAL:'): | |
| list_A = [] | |
| list_R = [] | |
| list_M = [] | |
| for i in range(self.args.renderbs): | |
| idx = index + i | |
| if idx+self.Nw > len(padded_input): | |
| list_A.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0)) | |
| list_R.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0)) | |
| list_M.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0)) | |
| else: | |
| list_A.append(padded_input[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0)) | |
| list_R.append(paded_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0)) | |
| list_M.append(paded_mask_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0)) | |
| list_A = torch.cat(list_A) | |
| list_R = torch.cat(list_R) | |
| list_M = torch.cat(list_M) | |
| idx = (self.Nw//2) * 3 | |
| mask = list_M[:, idx:idx+3] | |
| # list_A = padded_input | |
| mask = maskErosion(mask, offY=self.args.mask) | |
| list_A = list_A * (1 - mask[:,0:1]) | |
| A = torch.cat([list_A, list_R], 1) | |
| if self.trt: | |
| B = self.render(A.half().cuda()) | |
| elif self.args.netR == 'unet_256': | |
| # import pdb; pdb.set_trace() | |
| idx = (self.Nw//2) * 3 | |
| mask = list_M[:, idx:idx+3].cuda() | |
| mask = maskErosion(mask, offY=self.args.mask) | |
| B0 = list_A[:, idx:idx+3].cuda() | |
| B = self.render(A.cuda()) * mask[:,0:1] + (1 - mask[:,0:1]) * B0 | |
| elif self.args.netR == 's2am': | |
| # import pdb; pdb.set_trace() | |
| idx = (self.Nw//2) * 3 | |
| mask = list_M[:, idx:idx+3].cuda() | |
| mask = maskErosion(mask, offY=self.args.mask) | |
| B0 = list_A[:, idx:idx+3].cuda() | |
| B = self.render(A.cuda(), mask[:,0:1] ) * mask[:,0:1] + (1 - mask[:,0:1]) * B0 | |
| else: | |
| B = self.render(A.cuda()) | |
| result.append((B.cpu() + 1) * 0.5) # -1,1 -> 0,1 | |
| return torch.cat(result)[:len(rendedimages)] | |
| def coeffs_to_img(self, vertices, coeffs, zero_pose=False, XK = 20): | |
| xlen = vertices.shape[0] | |
| all_shape_images = [] | |
| landmark2d = [] | |
| #### find the most larger pose 51 in the coeffs. | |
| max_pose_51 = torch.max(self.coeffs_dict['pose'][..., 3:4].squeeze(-1)) | |
| for i in tqdm(range(0, xlen, XK)): | |
| if i + XK > xlen: | |
| XK = xlen - i | |
| codedictdecoder = {} | |
| codedictdecoder['shape'] = torch.zeros_like(self.coeffs_dict['shape'][i:i+XK].cuda()) | |
| codedictdecoder['tex'] = self.coeffs_dict['tex'][i:i+XK].cuda() | |
| codedictdecoder['exp'] = torch.zeros_like(self.coeffs_dict['exp'][i:i+XK].cuda()) # all_exps[i:i+XK, :50].cuda() # # # vid_exps[i:i+1].cuda() i:i+XK | |
| codedictdecoder['pose'] = self.coeffs_dict['pose'][i:i+XK] # vid_poses[i:i+1].cuda() | |
| codedictdecoder['cam'] = self.coeffs_dict['cam'][i:i+XK].cuda() # vid_poses[i:i+1].cuda() | |
| codedictdecoder['light'] = self.coeffs_dict['light'][i:i+XK].cuda() # vid_poses[i:i+1].cuda() | |
| codedictdecoder['images'] = torch.zeros((XK,3,256,256)).cuda() | |
| codedictdecoder['pose'][..., 3:4] = torch.clip(coeffs[i:i+XK, 50:51], 0, max_pose_51*0.9) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:]) | |
| codedictdecoder['pose'][..., 4:6] = 0 # coeffs[i:i+XK, 50:]*( - 0.25) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:]) | |
| sub_vertices = vertices[i:i+XK].cuda() | |
| opdict = self.spectre.decode_verts(codedictdecoder, sub_vertices, rendering=True, vis_lmk=False, return_vis=False) | |
| landmark2d.append(opdict['landmarks2d'].cpu()) | |
| all_shape_images.append(opdict['rendered_images'].cpu()) | |
| rendedimages = torch.cat(all_shape_images) | |
| lmk2d = torch.cat(landmark2d) | |
| return rendedimages, lmk2d | |
| def run_spectre_v3(self, wav=None, ds_features=None, L=20): | |
| wav = audio_normalize(wav) | |
| all_mel = self.audio.melspectrogram(wav).astype(np.float32).T | |
| frames_from_audio = np.arange(2, len(all_mel) // self.audio.num_bins_per_frame - 2) # 2,[]mmmmmmmmmmmmmmmmmmmmmmmmmmmm | |
| audio_inds = frame2audio_indexs(frames_from_audio, self.audio.num_frames_per_clip, self.audio.num_bins_per_frame) | |
| vid_exps = self.coeffs_dict['exp'][self.exp_id:self.exp_id+1] | |
| vid_poses = self.coeffs_dict['pose'][self.exp_id:self.exp_id+1] | |
| ref = torch.cat([vid_exps.view(1, 50), vid_poses[:, 3:].view(1, 3)], dim=-1) | |
| ref = ref[...,:self.args.exp_dim] | |
| K = 20 | |
| xlens = len(audio_inds) # len(self.coeffs_dict['exp']) | |
| exps = [] | |
| for i in tqdm(range(0, xlens, K), desc='S2 DECODER:'+ str(xlens) + ' '): | |
| mels = [] | |
| for j in range(K): | |
| if i + j < xlens: | |
| idx = i+j # //3 * 3 | |
| mel = load_spectrogram(all_mel, audio_inds[idx], self.audio.num_frames_per_clip * self.audio.num_bins_per_frame).cuda() | |
| mel = mel.view(-1, 1, 80, self.audio.num_frames_per_clip * self.audio.num_bins_per_frame) | |
| mels.append(mel) | |
| else: | |
| break | |
| mels = torch.cat(mels, dim=0) | |
| new_exp = self.avmodel(mels, ref.repeat(mels.shape[0], 1, 1).cuda(), self.args.use_tanh) # exp 53 | |
| exps+= [new_exp.view(-1, 53)] | |
| all_exps = torch.cat(exps,axis=0) | |
| return all_exps | |
| def test_model(self, wav_path): | |
| sys.path.append('../FaceFormer') | |
| from faceformer import Faceformer | |
| from transformers import Wav2Vec2FeatureExtractor,Wav2Vec2Processor | |
| from faceformer import PeriodicPositionalEncoding, init_biased_mask | |
| #build model | |
| self.args.train_subjects = " ".join(["A"]*8) # suitable for pre-trained faceformer checkpoint | |
| model = Faceformer(self.args) | |
| model.load_state_dict(torch.load('/apdcephfs/private_shadowcun/Avatar2dFF/medias/videos/c8/mask5000_l2/6_model.pth')) # ../packages/pretrained/28_ff_model.pth | |
| model = model.to(torch.device(self.device)) | |
| model.eval() | |
| # hacking for long audio generation | |
| model.PPE = PeriodicPositionalEncoding(self.args.feature_dim, period = self.args.period, max_seq_len=6000).cuda() | |
| model.biased_mask = init_biased_mask(n_head = 4, max_seq_len = 6000, period=self.args.period).cuda() | |
| train_subjects_list = ["A"] * 8 | |
| one_hot_labels = np.eye(len(train_subjects_list)) | |
| one_hot = one_hot_labels[0] | |
| one_hot = np.reshape(one_hot,(-1,one_hot.shape[0])) | |
| one_hot = torch.FloatTensor(one_hot).to(device=self.device) | |
| vertices_npy = np.load(self.args.source_dir + '/mesh_pose0.npy') | |
| vertices_npy = np.array(vertices_npy).reshape(-1, 5023*3) | |
| temp = vertices_npy[33] # 829 | |
| template = temp.reshape((-1)) | |
| template = np.reshape(template,(-1,template.shape[0])) | |
| template = torch.FloatTensor(template).to(device=self.device) | |
| speech_array, sampling_rate = librosa.load(os.path.join(wav_path), sr=16000) | |
| processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
| audio_feature = np.squeeze(processor(speech_array,sampling_rate=16000).input_values) | |
| audio_feature = np.reshape(audio_feature,(-1,audio_feature.shape[0])) | |
| audio_feature = torch.FloatTensor(audio_feature).to(device=self.device) | |
| prediction = model.predict(audio_feature, template, one_hot, 1.0) # (1, seq_len, V*3) | |
| return prediction.squeeze() | |
| def run(self, face, audio, start_frame=0): | |
| wav, sr = librosa.load(audio, sr=16000) # 16*80 ? 20*80 | |
| wav_tensor = torch.FloatTensor(wav).unsqueeze(0) if len(wav.shape) == 1 else torch.FloatTensor(wav) | |
| _, frames = parse_audio_length(wav_tensor.shape[1], 16000, self.args.fps) | |
| ##### audio-guided, only use the jaw movement | |
| all_exps = self.run_spectre_v3(wav) | |
| # #### temp. interpolation | |
| all_exps = torch.nn.functional.interpolate(all_exps.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear') | |
| all_exps = all_exps.permute([0,2,1]).squeeze(0) | |
| # run faceformer for face mesh generation | |
| predicted_vertices = self.test_model(audio) | |
| predicted_vertices = predicted_vertices.view(-1, 5023*3) | |
| #### temp. interpolation | |
| predicted_vertices = torch.nn.functional.interpolate(predicted_vertices.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear') | |
| predicted_vertices = predicted_vertices.permute([0,2,1]).squeeze(0).view(-1, 5023, 3) | |
| all_exps = torch.Tensor(savgol_filter(all_exps.cpu().numpy(), 5, 3, axis=0)).cpu() # smooth GT | |
| rendedimages, lm2d = self.coeffs_to_img(predicted_vertices, all_exps, zero_pose=True) | |
| debug_video_gen(rendedimages, self.args.result_dir+"/debug_before_ff.mp4", wav_tensor, self.args.fps, sr) | |
| # cg2real | |
| debug_video_gen(self.cg2real(rendedimages, start_frame=start_frame), self.args.result_dir+"/debug_cg2real_raw.mp4", wav_tensor, self.args.fps, sr) | |
| exit() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Stylization and Seamless Video Dubbing') | |
| parser.add_argument('--face', default='examples', type=str, help='') | |
| parser.add_argument('--audio', default='examples', type=str, help='') | |
| parser.add_argument('--source_dir', default='examples', type=str,help='TODO') | |
| parser.add_argument('--result_dir', default='examples', type=str,help='TODO') | |
| parser.add_argument('--backend', default='wav2lip', type=str,help='wav2lip or pcavs') | |
| parser.add_argument('--result_tag', default='result', type=str,help='TODO') | |
| parser.add_argument('--netR', default='unet_256', type=str,help='TODO') | |
| parser.add_argument('--render_path', default='', type=str,help='TODO') | |
| parser.add_argument('--ngf', default=16, type=int,help='TODO') | |
| parser.add_argument('--fps', default=20, type=int,help='TODO') | |
| parser.add_argument('--mask', default=100, type=int,help='TODO') | |
| parser.add_argument('--mask_type', default='v3', type=str,help='TODO') | |
| parser.add_argument('--image_size', default=256, type=int,help='TODO') | |
| parser.add_argument('--input_nc', default=21, type=int,help='TODO') | |
| parser.add_argument('--output_nc', default=3, type=int,help='TODO') | |
| parser.add_argument('--renderbs', default=16, type=int,help='TODO') | |
| parser.add_argument('--tframes', default=1, type=int,help='TODO') | |
| parser.add_argument('--debug', action='store_true') | |
| parser.add_argument('--enhance', action='store_true') | |
| parser.add_argument('--phone', action='store_true') | |
| #### faceformer | |
| parser.add_argument("--model_name", type=str, default="VOCA") | |
| parser.add_argument("--dataset", type=str, default="vocaset", help='vocaset or BIWI') | |
| parser.add_argument("--feature_dim", type=int, default=64, help='64 for vocaset; 128 for BIWI') | |
| parser.add_argument("--period", type=int, default=30, help='period in PPE - 30 for vocaset; 25 for BIWI') | |
| parser.add_argument("--vertice_dim", type=int, default=5023*3, help='number of vertices - 5023*3 for vocaset; 23370*3 for BIWI') | |
| parser.add_argument("--device", type=str, default="cuda") | |
| parser.add_argument("--train_subjects", type=str, default="FaceTalk_170728_03272_TA ") | |
| parser.add_argument("--test_subjects", type=str, default="FaceTalk_170809_00138_TA FaceTalk_170731_00024_TA") | |
| parser.add_argument("--condition", type=str, default="FaceTalk_170904_00128_TA", help='select a conditioning subject from train_subjects') | |
| parser.add_argument("--subject", type=str, default="FaceTalk_170731_00024_TA", help='select a subject from test_subjects or train_subjects') | |
| parser.add_argument("--background_black", type=bool, default=True, help='whether to use black background') | |
| parser.add_argument("--template_path", type=str, default="templates.pkl", help='path of the personalized templates') | |
| parser.add_argument("--render_template_path", type=str, default="templates", help='path of the mesh in BIWI/FLAME topology') | |
| opt = parser.parse_args() | |
| opt.img_size = 96 | |
| opt.static = True | |
| opt.device = torch.device("cuda") | |
| a2m = Audio2Mesh(opt) | |
| print('link start!') | |
| t = time.time() | |
| # 02780 | |
| a2m.run(opt.face, opt.audio, 0) | |
| print(time.time() - t) |