# tts.py from gruut import sentences import re import numpy as np import onnxruntime as ort from pathlib import Path import json import string from IPython.display import Audio import soundfile as sf import logging # Configure logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load models lightspeech = ort.InferenceSession("./models/lightspeech_quant.onnx") mbmelgan = ort.InferenceSession("./models/mbmelgan.onnx") lightspeech_processor_config = Path("./models/processor.json") with open(lightspeech_processor_config, "r") as f: processor = json.load(f) tokenizer = processor["symbol_to_id"] class TTS: @staticmethod def generate(text: str) -> np.ndarray: sections = TTS.split_text(text) audio_sections = TTS.generate_speech_for_sections(sections) concatenated_audio = TTS.concatenate_audio_sections(audio_sections) return concatenated_audio @staticmethod def split_text(text: str) -> list: # Split the text into sentences based on punctuation marks sentences = re.split(r'(?<=[.!?])\s*', text) sections = [] # for testing get upto first 3 sentences only sentences = sentences[:3] for sentence in sentences: # Split each sentence by commas for short pauses parts = re.split(r',\s*', sentence) for i, part in enumerate(parts): sections.append(part.strip()) if i < len(parts) - 1: sections.append('*') # Short pause marker sections.append('**') # Long pause marker after each sentence # Remove empty sections sections = [section for section in sections if section] # Trim last long pause marker if sections[-1] == '**': sections = sections[:-1] logger.info(f"Split text into sections: {sections}") return sections @staticmethod def generate_speech_for_sections(sections: list) -> list: audio_sections = [] for section in sections: if section == '**': # Long pause pause_duration = 0.4 sample_rate = 44100 pause = np.zeros(int(pause_duration * sample_rate)) audio_sections.append(pause) elif section == '*': # Short pause pause_duration = 0.2 sample_rate = 44100 pause = np.zeros(int(pause_duration * sample_rate)) audio_sections.append(pause) else: mel_output, _ = TTS.text2mel(section) audio_array = TTS.mel2wav(mel_output) audio_sections.append(audio_array) return audio_sections @staticmethod def concatenate_audio_sections(audio_sections: list) -> np.ndarray: concatenated_audio = np.concatenate(audio_sections) return concatenated_audio @staticmethod def phonemize(word: str) -> str: ipa = [] for words in sentences(word, lang="sw"): for word in words: if word.is_major_break or word.is_minor_break: ipa += [word.text] continue phonemes = word.phonemes[:] NG_GRAPHEME = "ng'" NG_PRENASALIZED_PHONEME = "ᵑg" NG_PHONEME = "ŋ" if NG_GRAPHEME in word.text: ng_graphemes = re.findall(f"{NG_GRAPHEME}?", word.text) ng_phonemes_idx = [i for i, p in enumerate(phonemes) if p == NG_PRENASALIZED_PHONEME] assert len(ng_graphemes) == len(ng_phonemes_idx) for i, g in zip(ng_phonemes_idx, ng_graphemes): phonemes[i] = NG_PHONEME if g == NG_GRAPHEME else phonemes[i] ipa += phonemes return ipa @staticmethod def tokenize(phonemes): input_ids = [] for phoneme in phonemes: if all(c in string.punctuation for c in phoneme): input_ids.append(tokenizer[phoneme]) else: input_ids.append(tokenizer[f"@{phoneme}"]) return input_ids @staticmethod def text2mel(text: str) -> tuple: phonemes = TTS.phonemize(text) input_ids = TTS.tokenize(phonemes) inputs = { "input_ids": np.array([input_ids], dtype=np.int32), "speaker_ids": np.array([0], dtype=np.int32), "speed_ratios": np.array([1.0], dtype=np.float32), "f0_ratios": np.array([1.0], dtype=np.float32), "energy_ratios": np.array([1.0], dtype=np.float32), } mel_output, durations, _ = lightspeech.run(None, inputs) return mel_output, durations @staticmethod def mel2wav(mel_output: np.ndarray) -> np.ndarray: # Prepare input for vocoder model inputs = { "mels": mel_output, } # Run inference outputs = mbmelgan.run(None, inputs) audio_array = outputs[0][0, :, 0] return audio_array @staticmethod def synthesize(text: str) -> np.ndarray: mel_output, _ = TTS.text2mel(text) audio_array = TTS.mel2wav(mel_output) return audio_array @staticmethod def save_audio(audio_array: np.ndarray, path: str): sf.write(path, audio_array, 44100)