TWB-Voice-TTS / app.py
CLEARGlobal's picture
Upload app.py
e1bb1bf verified
raw
history blame
10.6 kB
import gradio as gr
import torch
from TTS.api import TTS
import numpy as np
import tempfile
import os
# Model configurations
MODELS = {
"Hausa": {
"model_repo": "CLEAR-Global/TWB-Voice-Hausa-TTS-1.0",
"model_name": "best_model_498283.pth",
"config_name": "config.json",
"speakers": {
"spk_f_1": "Female",
"spk_m_1": "Male 1",
"spk_m_2": "Male 2"
},
"examples": [
"Lokacin damuna shuka kan koriya shar.",
"Lafiyarku tafi kuΙ—inku muhimmanci.",
"A kiyayi inda ake samun labarun magani ko kariya da cututtuka."
]
}
}
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
loaded_models = {}
def load_model(language):
"""Load TTS model for the specified language"""
if language not in loaded_models:
model_repo = MODELS[language]["model_repo"]
model_name = MODELS[language]["model_name"]
config_name = MODELS[language]["config_name"]
try:
from huggingface_hub import hf_hub_download
import json
# First download and read the config to get the required filenames
config_path = hf_hub_download(repo_id=model_repo, filename=config_name)
with open(config_path, 'r') as f:
config = json.load(f)
# Extract filenames from config (get just the filename, not the full path)
speakers_filename = os.path.basename(config.get("speakers_file", "speakers.pth"))
language_ids_filename = os.path.basename(config.get("language_ids_file", "language_ids.json"))
d_vector_filename = os.path.basename(config.get("d_vector_file", ["d_vector.pth"])[0])
config_se_filename = os.path.basename(config.get("model_args", {}).get("speaker_encoder_config_path", "config_se.json"))
model_se_filename = os.path.basename(config.get("model_args", {}).get("speaker_encoder_model_path", "model_se.pth"))
# Download specific model and config files from HuggingFace repo
model_path = hf_hub_download(repo_id=model_repo, filename=model_name)
speakers_file = hf_hub_download(repo_id=model_repo, filename=speakers_filename)
language_ids_file = hf_hub_download(repo_id=model_repo, filename=language_ids_filename)
d_vector_file = hf_hub_download(repo_id=model_repo, filename=d_vector_filename)
config_se_file = hf_hub_download(repo_id=model_repo, filename=config_se_filename)
model_se_file = hf_hub_download(repo_id=model_repo, filename=model_se_filename)
# Update the config paths to point to the downloaded files
config["speakers_file"] = speakers_file
config["language_ids_file"] = language_ids_file
config["d_vector_file"] = [d_vector_file]
config["model_args"]["speakers_file"] = speakers_file
config["model_args"]["language_ids_file"] = language_ids_file
config["model_args"]["d_vector_file"] = [d_vector_file]
config["model_args"]["speaker_encoder_config_path"] = config_se_file
config["model_args"]["speaker_encoder_model_path"] = model_se_file
# Save the updated config to a temporary file
import tempfile
temp_config = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False)
json.dump(config, temp_config, indent=2)
temp_config.close()
print(f"Loading {language} model with config:")
print(f"- language_ids_file: {config.get('language_ids_file')}")
print(f"- use_speaker_embedding: {config.get('use_speaker_embedding')}")
print(f"- speakers_file: {config.get('speakers_file')}")
print(f"- d_vector_file: {config.get('d_vector_file')}")
# Load TTS model with specific model and config paths
loaded_models[language] = TTS(model_path=model_path,
config_path=temp_config.name,
gpu=torch.cuda.is_available())
except Exception as e:
print(f"Error loading {language} model: {e}")
import traceback
traceback.print_exc()
return None
return loaded_models[language]
def update_speakers(language):
"""Update speaker dropdown based on selected language"""
if language in MODELS:
speakers = MODELS[language]["speakers"]
choices = [(f"{speaker_id}: {description}", speaker_id)
for speaker_id, description in speakers.items()]
return gr.Dropdown(choices=choices, value=choices[0][1], interactive=True)
return gr.Dropdown(choices=[], interactive=False)
def get_example_text(language, example_idx):
"""Get example text for the selected language"""
if language in MODELS and 0 <= example_idx < len(MODELS[language]["examples"]):
return MODELS[language]["examples"][example_idx]
return ""
def synthesize_speech(text, language, speaker):
"""Synthesize speech from text"""
if not text.strip():
return None, "Please enter some text to synthesize."
# Load the model
tts_model = load_model(language)
if tts_model is None:
return None, f"Failed to load {language} model."
try:
text = text.lower().strip()
print(f"DEBUG: Processing text: '{text}'")
print(f"DEBUG: Speaker name: '{speaker}'")
synthesizer = tts_model.synthesizer
try:
wav = synthesizer.tts(text=text, speaker_name=speaker)
except TypeError:
wav = synthesizer.tts(text=text)
print(f"DEBUG: synthesizer.tts() completed successfully")
# Convert to numpy array and save to temporary file
wav_array = np.array(wav, dtype=np.float32)
# Create temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
# Save audio using the synthesizer's sample rate
import scipy.io.wavfile as wavfile
wavfile.write(temp_file.name, synthesizer.output_sample_rate, wav_array)
print("Speech synthesized successfully!")
return temp_file.name, "Speech synthesized successfully!"
except Exception as e:
return None, f"Error during synthesis: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="TWB Voice TTS Demo") as demo:
gr.Markdown("""
# TWB Voice Text-to-Speech Demo Space
This demo showcases neural Text-to-Speech models developed within the TWB Voice project by CLEAR Global.
Currently it supports **Hausa** and **Kanuri** languages, developed as part of the first phase of the project.
### Features:
- **Hausa**: 3 speakers (1 female, 2 male)
- **Kanuri**: 1 female speaker
- High-quality 24kHz audio output
- Based on YourTTS architecture
### Links:
- πŸ€— [Hausa Model](https://huggingface.co/CLEAR-Global/TWB-Voice-Hausa-TTS-1.0)
- πŸ€— [Kanuri Model](https://huggingface.co/CLEAR-Global/TWB-Voice-Kanuri-TTS-1.0)
- πŸ“Š [Hausa Dataset](https://huggingface.co/datasets/CLEAR-Global/TWB-voice-TTS-Hausa-1.0-sampleset)
- πŸ“Š [Kanuri Dataset](https://huggingface.co/datasets/CLEAR-Global/TWB-voice-TTS-Kanuri-1.0-sampleset)
- 🌐 [TWB Voice Project](https://twbvoice.org/)
---
""")
with gr.Row():
with gr.Column():
# Language selection
language_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value="Hausa",
label="Language",
info="Select the language for synthesis"
)
# Speaker selection
speaker_dropdown = gr.Dropdown(
choices=list(MODELS["Hausa"]["speakers"].keys()),
value="spk_f_1",
label="Speaker",
info="Select the voice speaker"
)
# Text input
text_input = gr.Textbox(
label="Text to synthesize",
placeholder="Enter text in the selected language (will be converted to lowercase)",
lines=3,
info="Note: Text will be automatically converted to lowercase as required by the models"
)
# Example buttons
gr.Markdown("**Press to load a sentence in selected language:**")
with gr.Row():
example_btn_1 = gr.Button("Example 1", size="sm")
example_btn_2 = gr.Button("Example 2", size="sm")
example_btn_3 = gr.Button("Example 3", size="sm")
# Synthesize button
synthesize_btn = gr.Button("🎀 Synthesize Speech", variant="primary")
with gr.Column():
# Audio output
audio_output = gr.Audio(
label="Generated Speech",
type="filepath"
)
# Status message
status_output = gr.Textbox(
label="Status",
interactive=False
)
# Event handlers
language_dropdown.change(
fn=update_speakers,
inputs=[language_dropdown],
outputs=[speaker_dropdown]
)
example_btn_1.click(
fn=lambda lang: get_example_text(lang, 0),
inputs=[language_dropdown],
outputs=[text_input]
)
example_btn_2.click(
fn=lambda lang: get_example_text(lang, 1),
inputs=[language_dropdown],
outputs=[text_input]
)
example_btn_3.click(
fn=lambda lang: get_example_text(lang, 2),
inputs=[language_dropdown],
outputs=[text_input]
)
synthesize_btn.click(
fn=synthesize_speech,
inputs=[text_input, language_dropdown, speaker_dropdown],
outputs=[audio_output, status_output]
)
gr.Markdown("""
---
### Notes:
- Models work with **lowercase input text** (automatically converted)
- Audio output is generated at 24kHz sample rate
### License:
This app and the models are released under **CC-BY-NC-4.0** license (Non-Commercial use only).
**Created by:** CLEAR Global with support from the Patrick J. McGovern Foundation
""")
if __name__ == "__main__":
demo.launch()