Spaces:
Runtime error
Runtime error
Update modules/sadtalker_test.py
Browse files- modules/sadtalker_test.py +10 -5
modules/sadtalker_test.py
CHANGED
|
@@ -18,7 +18,7 @@ class SadTalker():
|
|
| 18 |
device = "cuda"
|
| 19 |
else:
|
| 20 |
device = "cpu"
|
| 21 |
-
|
| 22 |
current_code_path = sys.argv[0]
|
| 23 |
modules_path = os.path.split(current_code_path)[0]
|
| 24 |
|
|
@@ -53,7 +53,7 @@ class SadTalker():
|
|
| 53 |
facerender_yaml_path, device)
|
| 54 |
self.device = device
|
| 55 |
|
| 56 |
-
def test(self, source_image, driven_audio, result_dir):
|
| 57 |
|
| 58 |
time_tag = strftime("%Y_%m_%d_%H.%M.%S")
|
| 59 |
save_dir = os.path.join(result_dir, time_tag)
|
|
@@ -87,9 +87,14 @@ class SadTalker():
|
|
| 87 |
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
|
| 88 |
#coeff2video
|
| 89 |
batch_size = 4
|
| 90 |
-
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size)
|
| 91 |
-
self.animate_from_coeff.generate(data, save_dir)
|
| 92 |
video_name = data['video_name']
|
| 93 |
print(f'The generated video is named {video_name} in {save_dir}')
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
|
|
|
| 18 |
device = "cuda"
|
| 19 |
else:
|
| 20 |
device = "cpu"
|
| 21 |
+
|
| 22 |
current_code_path = sys.argv[0]
|
| 23 |
modules_path = os.path.split(current_code_path)[0]
|
| 24 |
|
|
|
|
| 53 |
facerender_yaml_path, device)
|
| 54 |
self.device = device
|
| 55 |
|
| 56 |
+
def test(self, source_image, driven_audio, still_mode, use_enhancer, result_dir):
|
| 57 |
|
| 58 |
time_tag = strftime("%Y_%m_%d_%H.%M.%S")
|
| 59 |
save_dir = os.path.join(result_dir, time_tag)
|
|
|
|
| 87 |
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
|
| 88 |
#coeff2video
|
| 89 |
batch_size = 4
|
| 90 |
+
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode)
|
| 91 |
+
self.animate_from_coeff.generate(data, save_dir, enhancer='gfpgan' if use_enhancer else None)
|
| 92 |
video_name = data['video_name']
|
| 93 |
print(f'The generated video is named {video_name} in {save_dir}')
|
| 94 |
+
|
| 95 |
+
if use_enhancer:
|
| 96 |
+
return os.path.join(save_dir, video_name+'_enhanced.mp4'), os.path.join(save_dir, video_name+'_enhanced.mp4')
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4')
|
| 100 |
|