# tts.py from gruut import sentences import re import numpy as np import onnxruntime as ort from pathlib import Path import json import string from IPython.display import Audio import soundfile as sf # Load models lightspeech = ort.InferenceSession("./models/lightspeech_quant.onnx") mbmelgan = ort.InferenceSession("./models/mbmelgan.onnx") lightspeech_processor_config = Path("./models/lightspeech_processor.json") with open(lightspeech_processor_config, "r") as f: processor = json.load(f) tokenizer = processor["symbol_to_id"] class TTS: @staticmethod def generate(text: str) -> np.ndarray: sections = TTS.split_text(text) audio_sections = TTS.generate_speech_for_sections(sections) concatenated_audio = TTS.concatenate_audio_sections(audio_sections) return concatenated_audio @staticmethod def split_text(text: str) -> list: # Split the text into sentences based on punctuation marks sentences = re.split(r'(?<=[.!?])\s*', text) sections = [] for sentence in sentences: # Split each sentence by commas for short pauses parts = re.split(r',\s*', sentence) for i, part in enumerate(parts): sections.append(part.strip()) if i < len(parts) - 1: sections.append('*') # Short pause marker sections.append('**') # Long pause marker after each sentence # Remove empty sections sections = [section for section in sections if section] return sections @staticmethod def generate_speech_for_sections(sections: list) -> list: audio_sections = [] for section in sections: if section == '**': # Long pause pause_duration = 1.0 sample_rate = 44100 pause = np.zeros(int(pause_duration * sample_rate)) audio_sections.append(pause) elif section == '*': # Short pause pause_duration = 0.4 sample_rate = 44100 pause = np.zeros(int(pause_duration * sample_rate)) audio_sections.append(pause) else: mel_output, durations = TTS.text2mel(section) audio_array = TTS.mel2wav(mel_output) audio_sections.append(audio_array) return audio_sections @staticmethod def concatenate_audio_sections(audio_sections: list) -> np.ndarray: concatenated_audio = np.concatenate(audio_sections) return concatenated_audio @staticmethod def phonemize(word: str) -> str: ipa = [] for words in sentences(word, lang="sw"): for word in words: if word.is_major_break or word.is_minor_break: ipa += [word.text] continue phonemes = word.phonemes[:] NG_GRAPHEME = "ng'" NG_PRENASALIZED_PHONEME = "ᵑg" NG_PHONEME = "ŋ" if NG_GRAPHEME in word.text: ng_graphemes = re.findall(f"{NG_GRAPHEME}?", word.text) ng_phonemes_idx = [i for i, p in enumerate(phonemes) if p == NG_PRENASALIZED_PHONEME] assert len(ng_graphemes) == len(ng_phonemes_idx) for i, g in zip(ng_phonemes_idx, ng_graphemes): phonemes[i] = NG_PHONEME if g == NG_GRAPHEME else phonemes[i] ipa += phonemes return ipa @staticmethod def tokenize(phonemes): input_ids = [] for phoneme in phonemes: if all(c in string.punctuation for c in phoneme): input_ids.append(tokenizer[phoneme]) else: input_ids.append(tokenizer[f"@{phoneme}"]) return input_ids @staticmethod def text2mel(text: str) -> tuple: phonemes = TTS.phonemize(text) input_ids = TTS.tokenize(phonemes) inputs = { "input_ids": np.array([input_ids], dtype=np.int32), "speaker_ids": np.array([0], dtype=np.int32), "speed_ratios": np.array([1.0], dtype=np.float32), "f0_ratios": np.array([1.0], dtype=np.float32), "energy_ratios": np.array([1.0], dtype=np.float32), } mel_output, durations, _ = lightspeech.run(None, inputs) return mel_output, durations @staticmethod def mel2wav(mel_output: np.ndarray) -> np.ndarray: # Prepare input for vocoder model inputs = { "mels": mel_output, } # Run inference outputs = mbmelgan.run(None, inputs) audio_array = outputs[0][0, :, 0] return audio_array @staticmethod def synthesize(text: str) -> np.ndarray: mel_output, _ = TTS.text2mel(text) audio_array = TTS.mel2wav(mel_output) return audio_array @staticmethod def save_audio(audio_array: np.ndarray, path: str): sf.write(path, audio_array, 44100)