LatentSync / eval /eval_sync_conf.py
fffiloni's picture
Migrated from GitHub
3650c12 verified
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import tqdm
from statistics import fmean
from eval.syncnet import SyncNetEval
from eval.syncnet_detect import SyncNetDetector
from latentsync.utils.util import red_text
import torch
def syncnet_eval(syncnet, syncnet_detector, video_path, temp_dir, detect_results_dir="detect_results"):
syncnet_detector(video_path=video_path, min_track=50)
crop_videos = os.listdir(os.path.join(detect_results_dir, "crop"))
if crop_videos == []:
raise Exception(red_text(f"Face not detected in {video_path}"))
av_offset_list = []
conf_list = []
for video in crop_videos:
av_offset, _, conf = syncnet.evaluate(
video_path=os.path.join(detect_results_dir, "crop", video), temp_dir=temp_dir
)
av_offset_list.append(av_offset)
conf_list.append(conf)
av_offset = int(fmean(av_offset_list))
conf = fmean(conf_list)
print(f"Input video: {video_path}\nSyncNet confidence: {conf:.2f}\nAV offset: {av_offset}")
return av_offset, conf
def main():
parser = argparse.ArgumentParser(description="SyncNet")
parser.add_argument("--initial_model", type=str, default="checkpoints/auxiliary/syncnet_v2.model", help="")
parser.add_argument("--video_path", type=str, default=None, help="")
parser.add_argument("--videos_dir", type=str, default="/root/processed")
parser.add_argument("--temp_dir", type=str, default="temp", help="")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
syncnet = SyncNetEval(device=device)
syncnet.loadParameters(args.initial_model)
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
if args.video_path is not None:
syncnet_eval(syncnet, syncnet_detector, args.video_path, args.temp_dir)
else:
sync_conf_list = []
video_names = sorted([f for f in os.listdir(args.videos_dir) if f.endswith(".mp4")])
for video_name in tqdm.tqdm(video_names):
try:
_, conf = syncnet_eval(
syncnet, syncnet_detector, os.path.join(args.videos_dir, video_name), args.temp_dir
)
sync_conf_list.append(conf)
except Exception as e:
print(e)
print(f"The average sync confidence is {fmean(sync_conf_list):.02f}")
if __name__ == "__main__":
main()