|
|
|
|
|
from gruut import sentences |
|
import re |
|
import numpy as np |
|
import onnxruntime as ort |
|
from pathlib import Path |
|
import json |
|
import string |
|
from IPython.display import Audio |
|
import soundfile as sf |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
lightspeech = ort.InferenceSession("./models/lightspeech_quant.onnx") |
|
mbmelgan = ort.InferenceSession("./models/mbmelgan.onnx") |
|
lightspeech_processor_config = Path("./models/processor.json") |
|
|
|
with open(lightspeech_processor_config, "r") as f: |
|
processor = json.load(f) |
|
tokenizer = processor["symbol_to_id"] |
|
|
|
class TTS: |
|
@staticmethod |
|
def generate(text: str) -> np.ndarray: |
|
sections = TTS.split_text(text) |
|
audio_sections = TTS.generate_speech_for_sections(sections) |
|
concatenated_audio = TTS.concatenate_audio_sections(audio_sections) |
|
return concatenated_audio |
|
|
|
@staticmethod |
|
def split_text(text: str) -> list: |
|
|
|
sentences = re.split(r'(?<=[.!?])\s*', text) |
|
sections = [] |
|
|
|
|
|
sentences = sentences[:3] |
|
for sentence in sentences: |
|
|
|
parts = re.split(r',\s*', sentence) |
|
for i, part in enumerate(parts): |
|
sections.append(part.strip()) |
|
if i < len(parts) - 1: |
|
sections.append('*') |
|
sections.append('**') |
|
|
|
|
|
sections = [section for section in sections if section] |
|
|
|
|
|
if sections[-1] == '**': |
|
sections = sections[:-1] |
|
|
|
logger.info(f"Split text into sections: {sections}") |
|
return sections |
|
|
|
@staticmethod |
|
def generate_speech_for_sections(sections: list) -> list: |
|
audio_sections = [] |
|
for section in sections: |
|
if section == '**': |
|
|
|
pause_duration = 0.4 |
|
sample_rate = 44100 |
|
pause = np.zeros(int(pause_duration * sample_rate)) |
|
audio_sections.append(pause) |
|
elif section == '*': |
|
|
|
pause_duration = 0.2 |
|
sample_rate = 44100 |
|
pause = np.zeros(int(pause_duration * sample_rate)) |
|
audio_sections.append(pause) |
|
else: |
|
mel_output, _ = TTS.text2mel(section) |
|
audio_array = TTS.mel2wav(mel_output) |
|
audio_sections.append(audio_array) |
|
return audio_sections |
|
|
|
@staticmethod |
|
def concatenate_audio_sections(audio_sections: list) -> np.ndarray: |
|
concatenated_audio = np.concatenate(audio_sections) |
|
return concatenated_audio |
|
|
|
|
|
|
|
@staticmethod |
|
def phonemize(word: str) -> str: |
|
ipa = [] |
|
for words in sentences(word, lang="sw"): |
|
for word in words: |
|
if word.is_major_break or word.is_minor_break: |
|
ipa += [word.text] |
|
continue |
|
|
|
phonemes = word.phonemes[:] |
|
NG_GRAPHEME = "ng'" |
|
NG_PRENASALIZED_PHONEME = "ᵑg" |
|
NG_PHONEME = "ŋ" |
|
if NG_GRAPHEME in word.text: |
|
ng_graphemes = re.findall(f"{NG_GRAPHEME}?", word.text) |
|
ng_phonemes_idx = [i for i, p in enumerate(phonemes) if p == NG_PRENASALIZED_PHONEME] |
|
assert len(ng_graphemes) == len(ng_phonemes_idx) |
|
for i, g in zip(ng_phonemes_idx, ng_graphemes): |
|
phonemes[i] = NG_PHONEME if g == NG_GRAPHEME else phonemes[i] |
|
|
|
ipa += phonemes |
|
return ipa |
|
|
|
@staticmethod |
|
def tokenize(phonemes): |
|
input_ids = [] |
|
for phoneme in phonemes: |
|
if all(c in string.punctuation for c in phoneme): |
|
input_ids.append(tokenizer[phoneme]) |
|
else: |
|
input_ids.append(tokenizer[f"@{phoneme}"]) |
|
return input_ids |
|
|
|
@staticmethod |
|
def text2mel(text: str) -> tuple: |
|
phonemes = TTS.phonemize(text) |
|
input_ids = TTS.tokenize(phonemes) |
|
|
|
inputs = { |
|
"input_ids": np.array([input_ids], dtype=np.int32), |
|
"speaker_ids": np.array([0], dtype=np.int32), |
|
"speed_ratios": np.array([1.0], dtype=np.float32), |
|
"f0_ratios": np.array([1.0], dtype=np.float32), |
|
"energy_ratios": np.array([1.0], dtype=np.float32), |
|
} |
|
|
|
mel_output, durations, _ = lightspeech.run(None, inputs) |
|
return mel_output, durations |
|
|
|
@staticmethod |
|
def mel2wav(mel_output: np.ndarray) -> np.ndarray: |
|
|
|
inputs = { |
|
"mels": mel_output, |
|
} |
|
|
|
|
|
outputs = mbmelgan.run(None, inputs) |
|
audio_array = outputs[0][0, :, 0] |
|
|
|
return audio_array |
|
|
|
@staticmethod |
|
def synthesize(text: str) -> np.ndarray: |
|
mel_output, _ = TTS.text2mel(text) |
|
audio_array = TTS.mel2wav(mel_output) |
|
return audio_array |
|
|
|
@staticmethod |
|
def save_audio(audio_array: np.ndarray, path: str): |
|
sf.write(path, audio_array, 44100) |
|
|