Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import os | |
| import glob | |
| from tqdm import tqdm | |
| import json | |
| import torch | |
| import time | |
| from models.svc.diffusion.diffusion_inference import DiffusionInference | |
| from models.svc.comosvc.comosvc_inference import ComoSVCInference | |
| from models.svc.transformer.transformer_inference import TransformerInference | |
| from utils.util import load_config | |
| from utils.audio_slicer import split_audio, merge_segments_encodec | |
| from processors import acoustic_extractor, content_extractor | |
| def build_inference(args, cfg, infer_type="from_dataset"): | |
| supported_inference = { | |
| "DiffWaveNetSVC": DiffusionInference, | |
| "DiffComoSVC": ComoSVCInference, | |
| "TransformerSVC": TransformerInference, | |
| } | |
| inference_class = supported_inference[cfg.model_type] | |
| return inference_class(args, cfg, infer_type) | |
| def prepare_for_audio_file(args, cfg, num_workers=1): | |
| preprocess_path = cfg.preprocess.processed_dir | |
| audio_name = cfg.inference.source_audio_name | |
| temp_audio_dir = os.path.join(preprocess_path, audio_name) | |
| ### eval file | |
| t = time.time() | |
| eval_file = prepare_source_eval_file(cfg, temp_audio_dir, audio_name) | |
| args.source = eval_file | |
| with open(eval_file, "r") as f: | |
| metadata = json.load(f) | |
| print("Prepare for meta eval data: {:.1f}s".format(time.time() - t)) | |
| ### acoustic features | |
| t = time.time() | |
| acoustic_extractor.extract_utt_acoustic_features_serial( | |
| metadata, temp_audio_dir, cfg | |
| ) | |
| acoustic_extractor.cal_mel_min_max( | |
| dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata | |
| ) | |
| acoustic_extractor.cal_pitch_statistics_svc( | |
| dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata | |
| ) | |
| print("Prepare for acoustic features: {:.1f}s".format(time.time() - t)) | |
| ### content features | |
| t = time.time() | |
| content_extractor.extract_utt_content_features_dataloader( | |
| cfg, metadata, num_workers | |
| ) | |
| print("Prepare for content features: {:.1f}s".format(time.time() - t)) | |
| return args, cfg, temp_audio_dir | |
| def merge_for_audio_segments(audio_files, args, cfg): | |
| audio_name = cfg.inference.source_audio_name | |
| target_singer_name = args.target_singer | |
| merge_segments_encodec( | |
| wav_files=audio_files, | |
| fs=cfg.preprocess.sample_rate, | |
| output_path=os.path.join( | |
| args.output_dir, "{}_{}.wav".format(audio_name, target_singer_name) | |
| ), | |
| overlap_duration=cfg.inference.segments_overlap_duration, | |
| ) | |
| for tmp_file in audio_files: | |
| os.remove(tmp_file) | |
| def prepare_source_eval_file(cfg, temp_audio_dir, audio_name): | |
| """ | |
| Prepare the eval file (json) for an audio | |
| """ | |
| audio_chunks_results = split_audio( | |
| wav_file=cfg.inference.source_audio_path, | |
| target_sr=cfg.preprocess.sample_rate, | |
| output_dir=os.path.join(temp_audio_dir, "wavs"), | |
| max_duration_of_segment=cfg.inference.segments_max_duration, | |
| overlap_duration=cfg.inference.segments_overlap_duration, | |
| ) | |
| metadata = [] | |
| for i, res in enumerate(audio_chunks_results): | |
| res["index"] = i | |
| res["Dataset"] = audio_name | |
| res["Singer"] = audio_name | |
| res["Uid"] = "{}_{}".format(audio_name, res["Uid"]) | |
| metadata.append(res) | |
| eval_file = os.path.join(temp_audio_dir, "eval.json") | |
| with open(eval_file, "w") as f: | |
| json.dump(metadata, f, indent=4, ensure_ascii=False, sort_keys=True) | |
| return eval_file | |
| def cuda_relevant(deterministic=False): | |
| torch.cuda.empty_cache() | |
| # TF32 on Ampere and above | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Deterministic | |
| torch.backends.cudnn.deterministic = deterministic | |
| torch.backends.cudnn.benchmark = not deterministic | |
| torch.use_deterministic_algorithms(deterministic) | |
| def infer(args, cfg, infer_type): | |
| # Build inference | |
| t = time.time() | |
| trainer = build_inference(args, cfg, infer_type) | |
| print("Model Init: {:.1f}s".format(time.time() - t)) | |
| # Run inference | |
| t = time.time() | |
| output_audio_files = trainer.inference() | |
| print("Model inference: {:.1f}s".format(time.time() - t)) | |
| return output_audio_files | |
| def build_parser(): | |
| r"""Build argument parser for inference.py. | |
| Anything else should be put in an extra config YAML file. | |
| """ | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| required=True, | |
| help="JSON/YAML file for configurations.", | |
| ) | |
| parser.add_argument( | |
| "--acoustics_dir", | |
| type=str, | |
| help="Acoustics model checkpoint directory. If a directory is given, " | |
| "search for the latest checkpoint dir in the directory. If a specific " | |
| "checkpoint dir is given, directly load the checkpoint.", | |
| ) | |
| parser.add_argument( | |
| "--vocoder_dir", | |
| type=str, | |
| required=True, | |
| help="Vocoder checkpoint directory. Searching behavior is the same as " | |
| "the acoustics one.", | |
| ) | |
| parser.add_argument( | |
| "--target_singer", | |
| type=str, | |
| required=True, | |
| help="convert to a specific singer (e.g. --target_singers singer_id).", | |
| ) | |
| parser.add_argument( | |
| "--trans_key", | |
| default=0, | |
| help="0: no pitch shift; autoshift: pitch shift; int: key shift.", | |
| ) | |
| parser.add_argument( | |
| "--source", | |
| type=str, | |
| default="source_audio", | |
| help="Source audio file or directory. If a JSON file is given, " | |
| "inference from dataset is applied. If a directory is given, " | |
| "inference from all wav/flac/mp3 audio files in the directory is applied. " | |
| "Default: inference from all wav/flac/mp3 audio files in ./source_audio", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="conversion_results", | |
| help="Output directory. Default: ./conversion_results", | |
| ) | |
| parser.add_argument( | |
| "--log_level", | |
| type=str, | |
| default="warning", | |
| help="Logging level. Default: warning", | |
| ) | |
| parser.add_argument( | |
| "--keep_cache", | |
| action="store_true", | |
| default=True, | |
| help="Keep cache files. Only applicable to inference from files.", | |
| ) | |
| parser.add_argument( | |
| "--diffusion_inference_steps", | |
| type=int, | |
| default=50, | |
| help="Number of inference steps. Only applicable to diffusion inference.", | |
| ) | |
| return parser | |
| def main(): | |
| ### Parse arguments and config | |
| args = build_parser().parse_args() | |
| cfg = load_config(args.config) | |
| # CUDA settings | |
| cuda_relevant() | |
| if os.path.isdir(args.source): | |
| ### Infer from file | |
| # Get all the source audio files (.wav, .flac, .mp3) | |
| source_audio_dir = args.source | |
| audio_list = [] | |
| for suffix in ["wav", "flac", "mp3"]: | |
| audio_list += glob.glob( | |
| os.path.join(source_audio_dir, "**/*.{}".format(suffix)), recursive=True | |
| ) | |
| print("There are {} source audios: ".format(len(audio_list))) | |
| # Infer for every file as dataset | |
| output_root_path = args.output_dir | |
| for audio_path in tqdm(audio_list): | |
| audio_name = audio_path.split("/")[-1].split(".")[0] | |
| args.output_dir = os.path.join(output_root_path, audio_name) | |
| print("\n{}\nConversion for {}...\n".format("*" * 10, audio_name)) | |
| cfg.inference.source_audio_path = audio_path | |
| cfg.inference.source_audio_name = audio_name | |
| cfg.inference.segments_max_duration = 10.0 | |
| cfg.inference.segments_overlap_duration = 1.0 | |
| # Prepare metadata and features | |
| args, cfg, cache_dir = prepare_for_audio_file(args, cfg) | |
| # Infer from file | |
| output_audio_files = infer(args, cfg, infer_type="from_file") | |
| # Merge the split segments | |
| merge_for_audio_segments(output_audio_files, args, cfg) | |
| # Keep or remove caches | |
| if not args.keep_cache: | |
| os.removedirs(cache_dir) | |
| else: | |
| ### Infer from dataset | |
| infer(args, cfg, infer_type="from_dataset") | |
| if __name__ == "__main__": | |
| main() | |