import itertools import os import warnings from typing import cast import librosa import matplotlib.pyplot as plt from matplotlib import font_manager as fm, rcParams import pyloudnorm import sounddevice import soundfile import torch with warnings.catch_warnings(): warnings.simplefilter("ignore") from audioseal.builder import create_generator from omegaconf import DictConfig from omegaconf import OmegaConf from speechbrain.pretrained import EncoderClassifier from torchaudio.transforms import Resample from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN from Preprocessing.AudioPreprocessor import AudioPreprocessor from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend from Preprocessing.TextFrontend import get_language_id from Utility.storage_config import MODELS_DIR from Utility.utils import cumsum_durations from Utility.utils import float2pcm class ToucanTTSInterface(torch.nn.Module): def __init__( self, device="cpu", # device that everything computes on. If a cuda device is available, this can speed things up by an order of magnitude. tts_model_path=os.path.join( MODELS_DIR, f"ToucanTTS_Shan", "best.pt" ), # path to the ToucanTTS checkpoint or just a shorthand if run standalone vocoder_model_path=os.path.join( MODELS_DIR, f"Vocoder", "best.pt" ), # path to the Vocoder checkpoint language="eng", # initial language of the model, can be changed later with the setter methods enhance=None, # legacy argument ): super().__init__() self.device = device if not tts_model_path.endswith(".pt"): # default to shorthand system tts_model_path = os.path.join( MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt" ) if "USER" not in os.environ: os.environ["USER"] = ( "" # that's the case under Windows, but omegaconf needs this ) with warnings.catch_warnings(): warnings.simplefilter("ignore") watermark_conf = cast( DictConfig, OmegaConf.load("InferenceInterfaces/audioseal_wm_16bits.yaml"), ) self.watermark = create_generator(watermark_conf) self.watermark.load_state_dict( torch.load("Models/audioseal/generator.pth", map_location="cpu")[ "model" ] ) # downloaded from https://dl.fbaipublicfiles.com/audioseal/6edcf62f/generator.pth originally ################################ # build text to phone # ################################ self.text2phone = ArticulatoryCombinedTextFrontend( language=language, add_silence_to_end=True ) ##################################### # load phone to features model # ##################################### checkpoint = torch.load(tts_model_path, map_location="cpu") self.phone2mel = ToucanTTS( weights=checkpoint["model"], config=checkpoint["config"] ) with torch.no_grad(): self.phone2mel.store_inverse_all() # this also removes weight norm self.phone2mel = self.phone2mel.to(torch.device(device)) ###################################### # load features to style models # ###################################### self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device": str(device)}, savedir=os.path.join( MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa" ), ) ################################ # load mel to wave model # ################################ vocoder_checkpoint = torch.load(vocoder_model_path, map_location="cpu") self.vocoder = HiFiGAN() self.vocoder.load_state_dict(vocoder_checkpoint) self.vocoder = self.vocoder.to(device).eval() self.vocoder.remove_weight_norm() self.meter = pyloudnorm.Meter(24000) ################################ # set defaults # ################################ self.default_utterance_embedding = checkpoint["default_emb"].to(self.device) self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, device=device) self.phone2mel.eval() self.vocoder.eval() self.lang_id = get_language_id(language) self.to(torch.device(device)) self.eval() def set_utterance_embedding(self, path_to_reference_audio="", embedding=None): if embedding is not None: self.default_utterance_embedding = embedding.squeeze().to(self.device) return if type(path_to_reference_audio) != list: path_to_reference_audio = [path_to_reference_audio] if len(path_to_reference_audio) > 0: for path in path_to_reference_audio: assert os.path.exists(path) speaker_embs = list() for path in path_to_reference_audio: wave, sr = soundfile.read(path) if len(wave.shape) > 1: # oh no, we found a stereo audio! if ( len(wave[0]) == 2 ): # let's figure out whether we need to switch the axes wave = wave.transpose() # if yes, we switch the axes. wave = librosa.to_mono(wave) wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)( torch.tensor(wave, device=self.device, dtype=torch.float32) ) speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch( wavs=wave.to(self.device).squeeze().unsqueeze(0) ).squeeze() speaker_embs.append(speaker_embedding) self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs) def set_language(self, lang_id): """ The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs """ self.set_phonemizer_language(lang_id=lang_id) self.set_accent_language(lang_id=lang_id) def set_phonemizer_language(self, lang_id): self.text2phone = ArticulatoryCombinedTextFrontend( language=lang_id, add_silence_to_end=True ) def set_accent_language(self, lang_id): if lang_id in [ "ajp", "ajt", "lak", "lno", "nul", "pii", "plj", "slq", "smd", "snb", "tpw", "wya", "zua", "en-us", "en-sc", "fr-be", "fr-sw", "pt-br", "spa-lat", "vi-ctr", "vi-so", ]: if lang_id == "vi-so" or lang_id == "vi-ctr": lang_id = "vie" elif lang_id == "spa-lat": lang_id = "spa" elif lang_id == "pt-br": lang_id = "por" elif lang_id == "fr-sw" or lang_id == "fr-be": lang_id = "fra" elif lang_id == "en-sc" or lang_id == "en-us": lang_id = "eng" else: # no clue where these others are even coming from, they are not in ISO 639-2 lang_id = "eng" self.lang_id = get_language_id(lang_id).to(self.device) def forward( self, text, view=False, duration_scaling_factor=1.0, pitch_variance_scale=1.0, energy_variance_scale=1.0, pause_duration_scaling_factor=1.0, durations=None, pitch=None, energy=None, input_is_phones=False, return_plot_as_filepath=False, loudness_in_db=-24.0, glow_sampling_temperature=0.2, ): """ duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. 1.0 means no scaling happens, higher values increase durations for the whole utterance, lower values decrease durations for the whole utterance. pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the pitch curve, lower values decrease variance of the pitch curve. energy_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the energy curve, lower values decrease variance of the energy curve. """ with torch.inference_mode(): phones = self.text2phone.string_to_tensor( text, input_phonemes=input_is_phones ).to(torch.device(self.device)) mel, durations, pitch, energy = self.phone2mel( phones, return_duration_pitch_energy=True, utterance_embedding=self.default_utterance_embedding, durations=durations, pitch=pitch, energy=energy, lang_id=self.lang_id, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, pause_duration_scaling_factor=pause_duration_scaling_factor, glow_sampling_temperature=glow_sampling_temperature, ) wave, _, _ = self.vocoder(mel.unsqueeze(0)) wave = wave.squeeze().cpu() wave = wave.numpy() sr = 24000 try: loudness = self.meter.integrated_loudness(wave) wave = pyloudnorm.normalize.loudness(wave, loudness, loudness_in_db) except ValueError: # if the audio is too short, a value error will arise pass with torch.inference_mode(): wave = ( ( torch.tensor(wave) + 0.1 * self.watermark.get_watermark( torch.tensor(wave).to(self.device).unsqueeze(0).unsqueeze(0) ) .squeeze() .detach() .cpu() ) .detach() .numpy() ) if view or return_plot_as_filepath: fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5)) # fpath = "./src/fonts/Shan.ttf" fpath = os.path.join(os.path.dirname(__file__), "src/fonts/Shan.ttf") prop = fm.FontProperties(fname=fpath) ax.imshow(mel.cpu().numpy(), origin="lower", cmap="GnBu") ax.yaxis.set_visible(False) duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) ax.xaxis.grid(True, which="minor") ax.set_xticks(label_positions, minor=False) if input_is_phones: phones = text.replace(" ", "|") else: phones = self.text2phone.get_phone_string(text, for_plot_labels=True) ax.set_xticklabels(phones) word_boundaries = list() for label_index, phone in enumerate(phones): if phone == "|": word_boundaries.append(label_positions[label_index]) try: prev_word_boundary = 0 word_label_positions = list() for word_boundary in word_boundaries: word_label_positions.append( (word_boundary + prev_word_boundary) / 2 ) prev_word_boundary = word_boundary word_label_positions.append( (duration_splits[-1] + prev_word_boundary) / 2 ) secondary_ax = ax.secondary_xaxis("bottom") secondary_ax.tick_params(axis="x", direction="out", pad=24) secondary_ax.set_xticks(word_label_positions, minor=False) secondary_ax.set_xticklabels(text.split(), fontproperties=prop) secondary_ax.tick_params(axis="x", colors="orange") secondary_ax.xaxis.label.set_color("orange") except ValueError: ax.set_title(text) except IndexError: ax.set_title(text) ax.vlines( x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5, ) ax.vlines( x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0, ) plt.subplots_adjust( left=0.02, bottom=0.2, right=0.98, top=0.9, wspace=0.0, hspace=0.0 ) ax.set_aspect("auto") if return_plot_as_filepath: plt.savefig("tmp.png") return wave, sr, "tmp.png" return wave, sr def read_to_file( self, text_list, file_location, duration_scaling_factor=1.0, pitch_variance_scale=1.0, energy_variance_scale=1.0, pause_duration_scaling_factor=1.0, silent=False, dur_list=None, pitch_list=None, energy_list=None, glow_sampling_temperature=0.2, ): """ Args: silent: Whether to be verbose about the process text_list: A list of strings to be read file_location: The path and name of the file it should be saved to energy_list: list of energy tensors to be used for the texts pitch_list: list of pitch tensors to be used for the texts dur_list: list of duration tensors to be used for the texts duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. 1.0 means no scaling happens, higher values increase durations for the whole utterance, lower values decrease durations for the whole utterance. pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the pitch curve, lower values decrease variance of the pitch curve. energy_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the energy curve, lower values decrease variance of the energy curve. """ if not dur_list: dur_list = [] if not pitch_list: pitch_list = [] if not energy_list: energy_list = [] silence = torch.zeros([14300]) wav = silence.clone() for text, durations, pitch, energy in itertools.zip_longest( text_list, dur_list, pitch_list, energy_list ): if text.strip() != "": if not silent: print("Now synthesizing: {}".format(text)) spoken_sentence, sr = self( text, durations=( durations.to(self.device) if durations is not None else None ), pitch=pitch.to(self.device) if pitch is not None else None, energy=energy.to(self.device) if energy is not None else None, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, pause_duration_scaling_factor=pause_duration_scaling_factor, glow_sampling_temperature=glow_sampling_temperature, ) spoken_sentence = torch.tensor(spoken_sentence).cpu() wav = torch.cat((wav, spoken_sentence, silence), 0) soundfile.write( file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16" ) def read_aloud( self, text, view=False, duration_scaling_factor=1.0, pitch_variance_scale=1.0, energy_variance_scale=1.0, blocking=False, glow_sampling_temperature=0.2, ): if text.strip() == "": return wav, sr = self( text, view, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, glow_sampling_temperature=glow_sampling_temperature, ) silence = torch.zeros([sr // 2]) wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy() sounddevice.play(float2pcm(wav), samplerate=sr) if view: plt.show() if blocking: sounddevice.wait()