import time import soundfile as sf import wandb from Architectures.Vocoder.HiFiGAN_Dataset import HiFiGANDataset from Architectures.Vocoder.HiFiGAN_Discriminators import AvocodoHiFiGANJointDiscriminator from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN from Architectures.Vocoder.HiFiGAN_train_loop import train_loop from Utility.path_to_transcript_dicts import * from Utility.storage_config import MODELS_DIR def run(gpu_id, resume_checkpoint, finetune, resume, model_dir, use_wandb, wandb_resume_id, gpu_count): if gpu_id == "cpu": device = torch.device("cpu") else: device = torch.device("cuda") if gpu_count > 1: print("Multi GPU training not supported for HiFiGAN!") import sys sys.exit() print("Preparing") if model_dir is not None: model_save_dir = model_dir else: model_save_dir = os.path.join(MODELS_DIR, "HiFiGAN_clean_data_and_augmentation") os.makedirs(model_save_dir, exist_ok=True) print("Preparing new data...") take_all = False # use only files with a large enough samplerate or just use all of them file_lists_for_this_run_combined = list() fl = list(build_path_to_transcript_dict_mls_italian().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_mls_english().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_mls_french().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_mls_dutch().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_mls_polish().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_mls_spanish().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_mls_portuguese().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_karlsson().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_eva().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_bernd().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_friedrich().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_hokus().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_hui_others().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_elizabeth().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_nancy().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_vctk().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_libritts_all_clean().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_ljspeech().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_css10cmn().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_thorsten_emotional().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_thorsten_2022_10().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_thorsten_neutral().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_css10el().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_css10nl().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_css10fi().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_css10ru().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_css10hu().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_css10es().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_css10fr().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_nvidia_hifitts().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_spanish_blizzard_train().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_aishell3().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_VIVOS_viet().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_ESDS().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl fl = list(build_path_to_transcript_dict_CREMA_D().keys()) wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: file_lists_for_this_run_combined += fl root = "/mount/arbeitsdaten61/studenten3/advanced-ml/2022/dhyanitr/projects/speech_datasets/Datasets/" for a in os.listdir(root): if os.path.isdir(a): for b in os.listdir(os.path.join(root, a)): if os.path.isdir(a): for c in os.listdir(os.path.join(root, a)): if os.path.isdir(a): print("If you can read this, the directories were more nested than I thought.") else: if c.endswith(".wav") or c.endswith(".flac"): _, sr = sf.read(os.path.join(root, a, b, c)) if sr >= 24000 or take_all: file_lists_for_this_run_combined.append(os.path.join(root, a, b, c)) else: if b.endswith(".wav") or b.endswith(".flac"): _, sr = sf.read(os.path.join(root, a, b)) if sr >= 24000 or take_all: file_lists_for_this_run_combined.append(os.path.join(root, a, b)) else: if a.endswith(".wav") or a.endswith(".flac"): _, sr = sf.read(os.path.join(root, a)) if sr >= 24000 or take_all: file_lists_for_this_run_combined.append(os.path.join(root, a)) print("filepaths collected") fisher_yates_shuffle(file_lists_for_this_run_combined) fisher_yates_shuffle(file_lists_for_this_run_combined) fisher_yates_shuffle(file_lists_for_this_run_combined) print("filepaths randomized") selection = file_lists_for_this_run_combined[:250000] # adjust the sample size until it fits into RAM fl = list(build_path_to_transcript_dict_RAVDESS().keys()) # these two datasets are kind of important to represent some out-of-distribution data for what we expect. wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: selection += fl fl = build_file_list_singing_voice_audio_database() # these two datasets are kind of important to represent some out-of-distribution data for what we expect. wav, sr = sf.read(fl[0]) if sr >= 24000 or take_all: selection += fl fisher_yates_shuffle(selection) fisher_yates_shuffle(selection) train_set = HiFiGANDataset(list_of_paths=selection, use_random_corruption=True) generator = HiFiGAN() discriminator = AvocodoHiFiGANJointDiscriminator() print("Training model") if use_wandb: wandb.init( name=f"{__name__.split('.')[-1]}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, id=wandb_resume_id, # this is None if not specified in the command line arguments. resume="must" if wandb_resume_id is not None else None) train_loop(batch_size=64, epochs=180000, generator=generator, discriminator=discriminator, train_dataset=train_set, device=device, epochs_per_save=1, model_save_dir=model_save_dir, path_to_checkpoint=resume_checkpoint, resume=resume, use_wandb=use_wandb, finetune=finetune) if use_wandb: wandb.finish() def fisher_yates_shuffle(lst): for i in range(len(lst) - 1, 0, -1): j = random.randint(0, i) lst[i], lst[j] = lst[j], lst[i]