TWB-Voice-TTS / app.py
alp's picture
minor edits on UI
2744224
raw
history blame
11 kB
import gradio as gr
import torch
from TTS.api import TTS
import numpy as np
import tempfile
import os
# Model configurations
MODELS = {
"Hausa": {
"model_repo": "CLEAR-Global/TWB-Voice-Hausa-TTS-1.0",
"model_name": "best_model_498283.pth",
"config_name": "config.json",
"speakers_pth_name": "speakers.pth",
"speakers": {
"spk_f_1": "Female",
"spk_m_1": "Male 1",
"spk_m_2": "Male 2"
},
"examples": [
"Lokacin damuna shuka kan koriya shar.",
"Lafiyarku tafi kuɗinku muhimmanci.",
"A kiyayi inda ake samun labarun magani ko kariya da cututtuka."
]
},
"Kanuri": {
"model_repo": "CLEAR-Global/TWB-Voice-Kanuri-TTS-1.0",
"model_name": "best_model_264313.pth",
"config_name": "config.json",
"speakers": {
"spk1": "Female"
},
"examples": [
"Loktu nǝngriyi ye lan, nǝyama kulo ye dǝ so shawwa ro wurazen.",
"Nǝlewa nǝm dǝ, kunguna nǝm wa faidan kozǝna.",
"Na done hawar kattu ye so kǝla kurun nǝlewa ye tarzeyen so dǝa wane."
]
}
}
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
loaded_models = {}
def load_model(language):
"""Load TTS model for the specified language"""
if language not in loaded_models:
model_repo = MODELS[language]["model_repo"]
model_name = MODELS[language]["model_name"]
config_name = MODELS[language]["config_name"]
try:
from huggingface_hub import hf_hub_download
import json
# First download and read the config to get the required filenames
config_path = hf_hub_download(repo_id=model_repo, filename=config_name)
with open(config_path, 'r') as f:
config = json.load(f)
# Extract filenames from config (get just the filename, not the full path)
speakers_filename = os.path.basename(config.get("speakers_file", "speakers.pth"))
language_ids_filename = os.path.basename(config.get("language_ids_file", "language_ids.json"))
d_vector_filename = os.path.basename(config.get("d_vector_file", ["d_vector.pth"])[0])
config_se_filename = os.path.basename(config.get("model_args", {}).get("speaker_encoder_config_path", "config_se.json"))
model_se_filename = os.path.basename(config.get("model_args", {}).get("speaker_encoder_model_path", "model_se.pth"))
# Download specific model and config files from HuggingFace repo
model_path = hf_hub_download(repo_id=model_repo, filename=model_name)
speakers_file = hf_hub_download(repo_id=model_repo, filename=speakers_filename)
language_ids_file = hf_hub_download(repo_id=model_repo, filename=language_ids_filename)
d_vector_file = hf_hub_download(repo_id=model_repo, filename=d_vector_filename)
config_se_file = hf_hub_download(repo_id=model_repo, filename=config_se_filename)
model_se_file = hf_hub_download(repo_id=model_repo, filename=model_se_filename)
# Update the config paths to point to the downloaded files
config["speakers_file"] = speakers_file
config["language_ids_file"] = language_ids_file
config["d_vector_file"] = [d_vector_file]
config["model_args"]["speakers_file"] = speakers_file
config["model_args"]["language_ids_file"] = language_ids_file
config["model_args"]["d_vector_file"] = [d_vector_file]
config["model_args"]["speaker_encoder_config_path"] = config_se_file
config["model_args"]["speaker_encoder_model_path"] = model_se_file
# Save the updated config to a temporary file
import tempfile
temp_config = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False)
json.dump(config, temp_config, indent=2)
temp_config.close()
print(f"Loading {language} model with config:")
print(f"- language_ids_file: {config.get('language_ids_file')}")
print(f"- use_speaker_embedding: {config.get('use_speaker_embedding')}")
print(f"- speakers_file: {config.get('speakers_file')}")
print(f"- d_vector_file: {config.get('d_vector_file')}")
# Load TTS model with specific model and config paths
loaded_models[language] = TTS(model_path=model_path,
config_path=temp_config.name,
gpu=torch.cuda.is_available())
except Exception as e:
print(f"Error loading {language} model: {e}")
import traceback
traceback.print_exc()
return None
return loaded_models[language]
def update_speakers(language):
"""Update speaker dropdown based on selected language"""
if language in MODELS:
speakers = MODELS[language]["speakers"]
choices = [(f"{speaker_id}: {description}", speaker_id)
for speaker_id, description in speakers.items()]
return gr.Dropdown(choices=choices, value=choices[0][1], interactive=True)
return gr.Dropdown(choices=[], interactive=False)
def get_example_text(language, example_idx):
"""Get example text for the selected language"""
if language in MODELS and 0 <= example_idx < len(MODELS[language]["examples"]):
return MODELS[language]["examples"][example_idx]
return ""
def synthesize_speech(text, language, speaker):
"""Synthesize speech from text"""
if not text.strip():
return None, "Please enter some text to synthesize."
# Load the model
tts_model = load_model(language)
if tts_model is None:
return None, f"Failed to load {language} model."
try:
# Convert text to lowercase as required by the models
text = text.lower().strip()
# Generate speech using synthesizer directly (following your inference script)
synthesizer = tts_model.synthesizer
wav = synthesizer.tts(text=text, speaker_name=speaker)
# Convert to numpy array and save to temporary file
wav_array = np.array(wav, dtype=np.float32)
# Create temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
# Save audio using the synthesizer's sample rate
import scipy.io.wavfile as wavfile
wavfile.write(temp_file.name, synthesizer.output_sample_rate, wav_array)
return temp_file.name, "Speech synthesized successfully!"
except Exception as e:
return None, f"Error during synthesis: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="TWB Voice TTS Demo") as demo:
gr.Markdown("""
# TWB Voice Text-to-Speech Demo Space
This demo showcases neural Text-to-Speech models developed within the TWB Voice project by CLEAR Global.
Currently it supports **Hausa** and **Kanuri** languages, developed as part of the first phase of the project.
### Features:
- **Hausa**: 3 speakers (1 female, 2 male)
- **Kanuri**: 1 female speaker
- High-quality 24kHz audio output
- Based on YourTTS architecture
### Links:
- 🤗 [Hausa Model](https://huggingface.co/CLEAR-Global/TWB-Voice-Hausa-TTS-1.0)
- 🤗 [Kanuri Model](https://huggingface.co/CLEAR-Global/TWB-Voice-Kanuri-TTS-1.0)
- 📊 [Hausa Dataset](https://huggingface.co/datasets/CLEAR-Global/TWB-voice-TTS-Hausa-1.0-sampleset)
- 📊 [Kanuri Dataset](https://huggingface.co/datasets/CLEAR-Global/TWB-voice-TTS-Kanuri-1.0-sampleset)
- 🌐 [TWB Voice Project](https://twbvoice.org/)
---
""")
with gr.Row():
with gr.Column():
# Language selection
language_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="Hausa",
label="Language",
info="Select the language for synthesis"
)
# Speaker selection
speaker_dropdown = gr.Dropdown(
choices=list(MODELS["Hausa"]["speakers"].keys()),
value="spk_f_1",
label="Speaker",
info="Select the voice speaker"
)
# Text input
text_input = gr.Textbox(
label="Text to synthesize",
placeholder="Enter text in the selected language (will be converted to lowercase)",
lines=3,
info="Note: Text will be automatically converted to lowercase as required by the models"
)
# Example buttons
gr.Markdown("**Quick examples (press to load):**")
with gr.Row():
example_btn_1 = gr.Button("Example 1", size="sm")
example_btn_2 = gr.Button("Example 2", size="sm")
example_btn_3 = gr.Button("Example 3", size="sm")
# Synthesize button
synthesize_btn = gr.Button("🎤 Synthesize Speech", variant="primary")
with gr.Column():
# Audio output
audio_output = gr.Audio(
label="Generated Speech",
type="filepath"
)
# Status message
status_output = gr.Textbox(
label="Status",
interactive=False
)
# Event handlers
language_dropdown.change(
fn=update_speakers,
inputs=[language_dropdown],
outputs=[speaker_dropdown]
)
example_btn_1.click(
fn=lambda lang: get_example_text(lang, 0),
inputs=[language_dropdown],
outputs=[text_input]
)
example_btn_2.click(
fn=lambda lang: get_example_text(lang, 1),
inputs=[language_dropdown],
outputs=[text_input]
)
example_btn_3.click(
fn=lambda lang: get_example_text(lang, 2),
inputs=[language_dropdown],
outputs=[text_input]
)
synthesize_btn.click(
fn=synthesize_speech,
inputs=[text_input, language_dropdown, speaker_dropdown],
outputs=[audio_output, status_output]
)
gr.Markdown("""
---
### Notes:
- Models work with **lowercase input text** (automatically converted)
- Audio output is generated at 24kHz sample rate
### License:
This app and the models are released under **CC-BY-NC-4.0** license (Non-Commercial use only).
**Created by:** CLEAR Global with support from the Patrick J. McGovern Foundation
""")
if __name__ == "__main__":
demo.launch()