mosha255's picture
Initial commit
fc37b9e unverified
raw
history blame
5.1 kB
# tts.py
from gruut import sentences
import re
import numpy as np
import onnxruntime as ort
from pathlib import Path
import json
import string
from IPython.display import Audio
import soundfile as sf
# Load models
lightspeech = ort.InferenceSession("./models/lightspeech_quant.onnx")
mbmelgan = ort.InferenceSession("./models/mbmelgan.onnx")
lightspeech_processor_config = Path("./models/lightspeech_processor.json")
with open(lightspeech_processor_config, "r") as f:
processor = json.load(f)
tokenizer = processor["symbol_to_id"]
class TTS:
@staticmethod
def generate(text: str) -> np.ndarray:
sections = TTS.split_text(text)
audio_sections = TTS.generate_speech_for_sections(sections)
concatenated_audio = TTS.concatenate_audio_sections(audio_sections)
return concatenated_audio
@staticmethod
def split_text(text: str) -> list:
# Split the text into sentences based on punctuation marks
sentences = re.split(r'(?<=[.!?])\s*', text)
sections = []
for sentence in sentences:
# Split each sentence by commas for short pauses
parts = re.split(r',\s*', sentence)
for i, part in enumerate(parts):
sections.append(part.strip())
if i < len(parts) - 1:
sections.append('*') # Short pause marker
sections.append('**') # Long pause marker after each sentence
# Remove empty sections
sections = [section for section in sections if section]
return sections
@staticmethod
def generate_speech_for_sections(sections: list) -> list:
audio_sections = []
for section in sections:
if section == '**':
# Long pause
pause_duration = 1.0
sample_rate = 44100
pause = np.zeros(int(pause_duration * sample_rate))
audio_sections.append(pause)
elif section == '*':
# Short pause
pause_duration = 0.4
sample_rate = 44100
pause = np.zeros(int(pause_duration * sample_rate))
audio_sections.append(pause)
else:
mel_output, durations = TTS.text2mel(section)
audio_array = TTS.mel2wav(mel_output)
audio_sections.append(audio_array)
return audio_sections
@staticmethod
def concatenate_audio_sections(audio_sections: list) -> np.ndarray:
concatenated_audio = np.concatenate(audio_sections)
return concatenated_audio
@staticmethod
def phonemize(word: str) -> str:
ipa = []
for words in sentences(word, lang="sw"):
for word in words:
if word.is_major_break or word.is_minor_break:
ipa += [word.text]
continue
phonemes = word.phonemes[:]
NG_GRAPHEME = "ng'"
NG_PRENASALIZED_PHONEME = "ᵑg"
NG_PHONEME = "ŋ"
if NG_GRAPHEME in word.text:
ng_graphemes = re.findall(f"{NG_GRAPHEME}?", word.text)
ng_phonemes_idx = [i for i, p in enumerate(phonemes) if p == NG_PRENASALIZED_PHONEME]
assert len(ng_graphemes) == len(ng_phonemes_idx)
for i, g in zip(ng_phonemes_idx, ng_graphemes):
phonemes[i] = NG_PHONEME if g == NG_GRAPHEME else phonemes[i]
ipa += phonemes
return ipa
@staticmethod
def tokenize(phonemes):
input_ids = []
for phoneme in phonemes:
if all(c in string.punctuation for c in phoneme):
input_ids.append(tokenizer[phoneme])
else:
input_ids.append(tokenizer[f"@{phoneme}"])
return input_ids
@staticmethod
def text2mel(text: str) -> tuple:
phonemes = TTS.phonemize(text)
input_ids = TTS.tokenize(phonemes)
inputs = {
"input_ids": np.array([input_ids], dtype=np.int32),
"speaker_ids": np.array([0], dtype=np.int32),
"speed_ratios": np.array([1.0], dtype=np.float32),
"f0_ratios": np.array([1.0], dtype=np.float32),
"energy_ratios": np.array([1.0], dtype=np.float32),
}
mel_output, durations, _ = lightspeech.run(None, inputs)
return mel_output, durations
@staticmethod
def mel2wav(mel_output: np.ndarray) -> np.ndarray:
# Prepare input for vocoder model
inputs = {
"mels": mel_output,
}
# Run inference
outputs = mbmelgan.run(None, inputs)
audio_array = outputs[0][0, :, 0]
return audio_array
@staticmethod
def synthesize(text: str) -> np.ndarray:
mel_output, _ = TTS.text2mel(text)
audio_array = TTS.mel2wav(mel_output)
return audio_array
@staticmethod
def save_audio(audio_array: np.ndarray, path: str):
sf.write(path, audio_array, 44100)