mosha255's picture
Debug
90ff39a unverified
# tts.py
from gruut import sentences
import re
import numpy as np
import onnxruntime as ort
from pathlib import Path
import json
import string
from IPython.display import Audio
import soundfile as sf
import logging
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load models
lightspeech = ort.InferenceSession("./models/lightspeech_quant.onnx")
mbmelgan = ort.InferenceSession("./models/mbmelgan.onnx")
lightspeech_processor_config = Path("./models/processor.json")
with open(lightspeech_processor_config, "r") as f:
processor = json.load(f)
tokenizer = processor["symbol_to_id"]
class TTS:
@staticmethod
def generate(text: str) -> np.ndarray:
sections = TTS.split_text(text)
audio_sections = TTS.generate_speech_for_sections(sections)
concatenated_audio = TTS.concatenate_audio_sections(audio_sections)
return concatenated_audio
@staticmethod
def split_text(text: str) -> list:
# Split the text into sentences based on punctuation marks
sentences = re.split(r'(?<=[.!?])\s*', text)
sections = []
# for testing get upto first 3 sentences only
sentences = sentences[:3]
for sentence in sentences:
# Split each sentence by commas for short pauses
parts = re.split(r',\s*', sentence)
for i, part in enumerate(parts):
sections.append(part.strip())
if i < len(parts) - 1:
sections.append('*') # Short pause marker
sections.append('**') # Long pause marker after each sentence
# Remove empty sections
sections = [section for section in sections if section]
# Trim last long pause marker
if sections[-1] == '**':
sections = sections[:-1]
logger.info(f"Split text into sections: {sections}")
return sections
@staticmethod
def generate_speech_for_sections(sections: list) -> list:
audio_sections = []
for section in sections:
if section == '**':
# Long pause
pause_duration = 0.4
sample_rate = 44100
pause = np.zeros(int(pause_duration * sample_rate))
audio_sections.append(pause)
elif section == '*':
# Short pause
pause_duration = 0.2
sample_rate = 44100
pause = np.zeros(int(pause_duration * sample_rate))
audio_sections.append(pause)
else:
mel_output, _ = TTS.text2mel(section)
audio_array = TTS.mel2wav(mel_output)
audio_sections.append(audio_array)
return audio_sections
@staticmethod
def concatenate_audio_sections(audio_sections: list) -> np.ndarray:
concatenated_audio = np.concatenate(audio_sections)
return concatenated_audio
@staticmethod
def phonemize(word: str) -> str:
ipa = []
for words in sentences(word, lang="sw"):
for word in words:
if word.is_major_break or word.is_minor_break:
ipa += [word.text]
continue
phonemes = word.phonemes[:]
NG_GRAPHEME = "ng'"
NG_PRENASALIZED_PHONEME = "ᵑg"
NG_PHONEME = "ŋ"
if NG_GRAPHEME in word.text:
ng_graphemes = re.findall(f"{NG_GRAPHEME}?", word.text)
ng_phonemes_idx = [i for i, p in enumerate(phonemes) if p == NG_PRENASALIZED_PHONEME]
assert len(ng_graphemes) == len(ng_phonemes_idx)
for i, g in zip(ng_phonemes_idx, ng_graphemes):
phonemes[i] = NG_PHONEME if g == NG_GRAPHEME else phonemes[i]
ipa += phonemes
return ipa
@staticmethod
def tokenize(phonemes):
input_ids = []
for phoneme in phonemes:
if all(c in string.punctuation for c in phoneme):
input_ids.append(tokenizer[phoneme])
else:
input_ids.append(tokenizer[f"@{phoneme}"])
return input_ids
@staticmethod
def text2mel(text: str) -> tuple:
phonemes = TTS.phonemize(text)
input_ids = TTS.tokenize(phonemes)
inputs = {
"input_ids": np.array([input_ids], dtype=np.int32),
"speaker_ids": np.array([0], dtype=np.int32),
"speed_ratios": np.array([1.0], dtype=np.float32),
"f0_ratios": np.array([1.0], dtype=np.float32),
"energy_ratios": np.array([1.0], dtype=np.float32),
}
mel_output, durations, _ = lightspeech.run(None, inputs)
return mel_output, durations
@staticmethod
def mel2wav(mel_output: np.ndarray) -> np.ndarray:
# Prepare input for vocoder model
inputs = {
"mels": mel_output,
}
# Run inference
outputs = mbmelgan.run(None, inputs)
audio_array = outputs[0][0, :, 0]
return audio_array
@staticmethod
def synthesize(text: str) -> np.ndarray:
mel_output, _ = TTS.text2mel(text)
audio_array = TTS.mel2wav(mel_output)
return audio_array
@staticmethod
def save_audio(audio_array: np.ndarray, path: str):
sf.write(path, audio_array, 44100)