import os import torch import numpy as np import torchaudio import matplotlib.pyplot as plt CACHE = { "get_vits_phoneme_ids": { "PAD_LENGTH": 310, "_pad": "_", "_punctuation": ';:,.!?¡¿—…"«»“” ', "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", "_special": "♪☎☒☝⚠", } } CACHE["get_vits_phoneme_ids"]["symbols"] = ( [CACHE["get_vits_phoneme_ids"]["_pad"]] + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) + list(CACHE["get_vits_phoneme_ids"]["_letters"]) + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) + list(CACHE["get_vits_phoneme_ids"]["_special"]) ) CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = { s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"]) } def get_vits_phoneme_ids(config, dl_output, metadata): pad_token_id = 0 pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] assert ( "phonemes" in metadata.keys() ), "You must provide vits phonemes on using addon get_vits_phoneme_ids" clean_text = metadata["phonemes"] sequence = [] for symbol in clean_text: symbol_id = _symbol_to_id[symbol] sequence += [symbol_id] inserted_zero_sequence = [0] * (len(sequence) * 2) inserted_zero_sequence[1::2] = sequence inserted_zero_sequence = inserted_zero_sequence + [0] def _pad_phonemes(phonemes_list): return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))} def get_vits_phoneme_ids_no_padding(config, dl_output, metadata): pad_token_id = 0 pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] assert ( "phonemes" in metadata.keys() ), "You must provide vits phonemes on using addon get_vits_phoneme_ids" clean_text = metadata["phonemes"] + "⚠" sequence = [] for symbol in clean_text: if symbol not in _symbol_to_id.keys(): print("%s is not in the vocabulary. %s" % (symbol, clean_text)) symbol = "_" symbol_id = _symbol_to_id[symbol] sequence += [symbol_id] def _pad_phonemes(phonemes_list): return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list)) sequence = sequence[:pad_length] return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))} def calculate_relative_bandwidth(config, dl_output, metadata): assert "stft" in dl_output.keys() # The last dimension of the stft feature is the frequency dimension freq_dimensions = dl_output["stft"].size(-1) freq_energy_dist = torch.sum(dl_output["stft"], dim=0) freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) total_energy = freq_energy_dist[-1] percentile_5th = total_energy * 0.05 percentile_95th = total_energy * 0.95 lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) lower_idx = int((lower_idx / freq_dimensions) * 1000) higher_idx = int((higher_idx / freq_dimensions) * 1000) return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])} def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata): assert "stft" in dl_output.keys() linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10)) # The last dimension of the stft feature is the frequency dimension freq_dimensions = linear_mel_spec.size(-1) freq_energy_dist = torch.sum(linear_mel_spec, dim=0) freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0) total_energy = freq_energy_dist[-1] percentile_5th = total_energy * 0.05 percentile_95th = total_energy * 0.95 lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist)) higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist)) latent_t_size = config["model"]["params"]["latent_t_size"] latent_f_size = config["model"]["params"]["latent_f_size"] lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions))) higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions))) bandwidth_condition = torch.zeros((latent_t_size, latent_f_size)) bandwidth_condition[:, lower_idx:higher_idx] += 1.0 return { "mel_spec_bandwidth_cond_extra_channel": bandwidth_condition, "freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]), } def waveform_rs_48k(config, dl_output, metadata): waveform = dl_output["waveform"] # [1, samples] sampling_rate = dl_output["sampling_rate"] if sampling_rate != 48000: waveform_48k = torchaudio.functional.resample( waveform, orig_freq=sampling_rate, new_freq=48000 ) else: waveform_48k = waveform return {"waveform_48k": waveform_48k} def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata): assert ( "phoneme" not in metadata.keys() ), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json" if "phonemes" in metadata.keys(): new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata) new_item["text"] = "" # We assume TTS data does not have text description else: fake_metadata = {"phonemes": ""} # Add empty phoneme sequence new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata) return new_item def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata): if "phoneme" in metadata.keys(): new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata) new_item["text"] = "" else: fake_metadata = {"phoneme": []} new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata) return new_item def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata): PAD_LENGTH = 135 phonemes_lookup_dict = { "K": 0, "IH2": 1, "NG": 2, "OW2": 3, "AH2": 4, "F": 5, "AE0": 6, "IY0": 7, "SH": 8, "G": 9, "W": 10, "UW1": 11, "AO2": 12, "AW2": 13, "UW0": 14, "EY2": 15, "UW2": 16, "AE2": 17, "IH0": 18, "P": 19, "D": 20, "ER1": 21, "AA1": 22, "EH0": 23, "UH1": 24, "N": 25, "V": 26, "AY1": 27, "EY1": 28, "UH2": 29, "EH1": 30, "L": 31, "AA2": 32, "R": 33, "OY1": 34, "Y": 35, "ER2": 36, "S": 37, "AE1": 38, "AH1": 39, "JH": 40, "ER0": 41, "EH2": 42, "IY2": 43, "OY2": 44, "AW1": 45, "IH1": 46, "IY1": 47, "OW0": 48, "AO0": 49, "AY0": 50, "EY0": 51, "AY2": 52, "UH0": 53, "M": 54, "TH": 55, "T": 56, "OY0": 57, "AW0": 58, "DH": 59, "Z": 60, "spn": 61, "AH0": 62, "sp": 63, "AO1": 64, "OW1": 65, "ZH": 66, "B": 67, "AA0": 68, "CH": 69, "HH": 70, } pad_token_id = len(phonemes_lookup_dict.keys()) assert ( "phoneme" in metadata.keys() ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" phonemes = [ phonemes_lookup_dict[x] for x in metadata["phoneme"] if (x in phonemes_lookup_dict.keys()) ] if (len(phonemes) / PAD_LENGTH) > 5: print( "Warning: Phonemes length is too long and is truncated too much! %s" % metadata ) phonemes = phonemes[:PAD_LENGTH] def _pad_phonemes(phonemes_list): return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} def extract_phoneme_g2p_en_feature(config, dl_output, metadata): PAD_LENGTH = 250 phonemes_lookup_dict = { " ": 0, "AA": 1, "AE": 2, "AH": 3, "AO": 4, "AW": 5, "AY": 6, "B": 7, "CH": 8, "D": 9, "DH": 10, "EH": 11, "ER": 12, "EY": 13, "F": 14, "G": 15, "HH": 16, "IH": 17, "IY": 18, "JH": 19, "K": 20, "L": 21, "M": 22, "N": 23, "NG": 24, "OW": 25, "OY": 26, "P": 27, "R": 28, "S": 29, "SH": 30, "T": 31, "TH": 32, "UH": 33, "UW": 34, "V": 35, "W": 36, "Y": 37, "Z": 38, "ZH": 39, } pad_token_id = len(phonemes_lookup_dict.keys()) assert ( "phoneme" in metadata.keys() ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset" phonemes = [ phonemes_lookup_dict[x] for x in metadata["phoneme"] if (x in phonemes_lookup_dict.keys()) ] if (len(phonemes) / PAD_LENGTH) > 5: print( "Warning: Phonemes length is too long and is truncated too much! %s" % metadata ) phonemes = phonemes[:PAD_LENGTH] def _pad_phonemes(phonemes_list): return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list)) return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))} def extract_kaldi_fbank_feature(config, dl_output, metadata): norm_mean = -4.2677393 norm_std = 4.5689974 waveform = dl_output["waveform"] # [1, samples] sampling_rate = dl_output["sampling_rate"] log_mel_spec_hifigan = dl_output["log_mel_spec"] if sampling_rate != 16000: waveform_16k = torchaudio.functional.resample( waveform, orig_freq=sampling_rate, new_freq=16000 ) else: waveform_16k = waveform waveform_16k = waveform_16k - waveform_16k.mean() fbank = torchaudio.compliance.kaldi.fbank( waveform_16k, htk_compat=True, sample_frequency=16000, use_energy=False, window_type="hanning", num_mel_bins=128, dither=0.0, frame_shift=10, ) TARGET_LEN = log_mel_spec_hifigan.size(0) # cut and pad n_frames = fbank.shape[0] p = TARGET_LEN - n_frames if p > 0: m = torch.nn.ZeroPad2d((0, 0, 0, p)) fbank = m(fbank) elif p < 0: fbank = fbank[:TARGET_LEN, :] fbank = (fbank - norm_mean) / (norm_std * 2) return {"ta_kaldi_fbank": fbank} # [1024, 128] def extract_kaldi_fbank_feature_32k(config, dl_output, metadata): norm_mean = -4.2677393 norm_std = 4.5689974 waveform = dl_output["waveform"] # [1, samples] sampling_rate = dl_output["sampling_rate"] log_mel_spec_hifigan = dl_output["log_mel_spec"] if sampling_rate != 32000: waveform_32k = torchaudio.functional.resample( waveform, orig_freq=sampling_rate, new_freq=32000 ) else: waveform_32k = waveform waveform_32k = waveform_32k - waveform_32k.mean() fbank = torchaudio.compliance.kaldi.fbank( waveform_32k, htk_compat=True, sample_frequency=32000, use_energy=False, window_type="hanning", num_mel_bins=128, dither=0.0, frame_shift=10, ) TARGET_LEN = log_mel_spec_hifigan.size(0) # cut and pad n_frames = fbank.shape[0] p = TARGET_LEN - n_frames if p > 0: m = torch.nn.ZeroPad2d((0, 0, 0, p)) fbank = m(fbank) elif p < 0: fbank = fbank[:TARGET_LEN, :] fbank = (fbank - norm_mean) / (norm_std * 2) return {"ta_kaldi_fbank": fbank} # [1024, 128] # Use the beat and downbeat information as music conditions def extract_drum_beat(config, dl_output, metadata): def visualization(conditional_signal, mel_spectrogram, filename): import soundfile as sf sf.write( os.path.basename(dl_output["fname"]), np.array(dl_output["waveform"])[0], dl_output["sampling_rate"], ) plt.figure(figsize=(10, 10)) plt.subplot(211) plt.imshow(np.array(conditional_signal).T, aspect="auto") plt.title("Conditional Signal") plt.subplot(212) plt.imshow(np.array(mel_spectrogram).T, aspect="auto") plt.title("Mel Spectrogram") plt.savefig(filename) plt.close() assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata sampling_rate = metadata["sample_rate"] duration = dl_output["duration"] # The dataloader segment length before performing torch resampling original_segment_length_before_resample = int(sampling_rate * duration) random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"]) # The sample idx for beat and downbeat, relatively to the segmented audio beat = [ x - random_start_sample for x in metadata["beat"] if ( x - random_start_sample >= 0 and x - random_start_sample <= original_segment_length_before_resample ) ] downbeat = [ x - random_start_sample for x in metadata["downbeat"] if ( x - random_start_sample >= 0 and x - random_start_sample <= original_segment_length_before_resample ) ] latent_shape = ( config["model"]["params"]["latent_t_size"], config["model"]["params"]["latent_f_size"], ) conditional_signal = torch.zeros(latent_shape) # beat: -0.5 # downbeat: +1.0 # 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat for each in beat: beat_index = int( (each / original_segment_length_before_resample) * latent_shape[0] ) beat_index = min(beat_index, conditional_signal.size(0) - 1) conditional_signal[beat_index, :] -= 0.5 for each in downbeat: beat_index = int( (each / original_segment_length_before_resample) * latent_shape[0] ) beat_index = min(beat_index, conditional_signal.size(0) - 1) conditional_signal[beat_index, :] += 1.0 # visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png") return {"cond_beat_downbeat": conditional_signal}