Spaces:
Running
Running
import os | |
import sys | |
import json | |
import time | |
from importlib.metadata import version | |
from enum import Enum | |
from huggingface_hub import hf_hub_download | |
import spaces # it's for ZeroGPU | |
import gradio as gr | |
import torch | |
import numpy as np | |
# RAD-TTS code | |
from radtts import RADTTS | |
from data import Data | |
from common import update_params | |
from inference import load_vocoder | |
use_cuda = torch.cuda.is_available() | |
if use_cuda: | |
print('CUDA is available, setting correct inference_device variable.') | |
device = 'cuda' | |
else: | |
device = 'cpu' | |
def download_file_from_repo( | |
repo_id: str, | |
filename: str, | |
local_dir: str = ".", | |
repo_type: str = "model", | |
) -> str: | |
try: | |
os.makedirs( | |
local_dir, exist_ok=True | |
) | |
file_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
local_dir=local_dir, | |
cache_dir=None, | |
force_download=False, | |
repo_type=repo_type, | |
) | |
return file_path | |
except Exception as e: | |
raise Exception(f"An error occurred during download: {e}") from e | |
download_file_from_repo( | |
"Yehor/radtts-uk", | |
"radtts-pp-dap-model/model_dap_84000.pt", | |
"./models/", | |
) | |
download_file_from_repo( | |
"Yehor/radtts-uk", | |
"hifigan/hifigan.pt", | |
"./models/", | |
) | |
# Init the model | |
seed = 1234 | |
config = "configs/radtts-pp-dap-model.json" | |
radtts_path = "models/radtts-pp-dap-model/model_dap_84000.pt" | |
params = [] | |
# Load the config | |
with open(config) as f: | |
data = f.read() | |
config = json.loads(data) | |
update_params(config, params) | |
data_config = config["data_config"] | |
model_config = config["model_config"] | |
vocoder_path = "models/hifigan/hifigan.pt" | |
vocoder_config_path = "configs/hifigan_22khz_config.json" | |
# Seed | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
# Load vocoder | |
vocoder, denoiser = load_vocoder(vocoder_path, vocoder_config_path, use_cuda) | |
# Load RAD-TTS | |
if use_cuda: | |
radtts = RADTTS(**model_config).cuda() | |
else: | |
radtts = RADTTS(**model_config) | |
radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs | |
checkpoint_dict = torch.load(radtts_path, map_location="cpu") # todo: CPU? | |
radtts.load_state_dict(checkpoint_dict["state_dict"], strict=False) | |
radtts.eval() | |
print(f"Loaded checkpoint '{radtts_path}')") | |
ignore_keys = ["training_files", "validation_files"] | |
trainset = Data( | |
data_config["training_files"], | |
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys), | |
) | |
# Config | |
concurrency_limit = 5 | |
title = "RAD-TTS++ Ukrainian" | |
# https://www.tablesgenerator.com/markdown_tables | |
authors_table = """ | |
## Authors | |
Follow them on social networks and **contact** if you need any help or have any questions: | |
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** | | |
|-------------------------------------------------------------------------------------------------| | |
| https://t.me/smlkw in Telegram | | |
| https://x.com/yehor_smoliakov at X | | |
| https://github.com/egorsmkv at GitHub | | |
| https://huggingface.co/Yehor at Hugging Face | | |
| or use [email protected] | | |
""".strip() | |
description_head = f""" | |
# {title} | |
## Overview | |
Type your text in Ukrainian and select a voice to synthesize speech using [the RAD-TTS++ model](https://huggingface.co/Yehor/radtts-uk) and HiFiGAN with 22050 Hz. | |
""".strip() | |
description_foot = f""" | |
{authors_table} | |
""".strip() | |
tech_env = f""" | |
#### Environment | |
- Python: {sys.version} | |
""".strip() | |
tech_libraries = f""" | |
#### Libraries | |
- gradio: {version("gradio")} | |
- torch: {version("torch")} | |
- scipy: {version("scipy")} | |
- numba: {version("numba")} | |
- librosa: {version("librosa")} | |
- unidecode: {version("unidecode")} | |
- inflect: {version("inflect")} | |
""".strip() | |
class VoiceOption(Enum): | |
Tetiana = "Tetiana (female) 👩" | |
Mykyta = "Mykyta (male) 👨" | |
Lada = "Lada (female) 👩" | |
voice_mapping = { | |
VoiceOption.Tetiana.value: "tetiana", | |
VoiceOption.Mykyta.value: "mykyta", | |
VoiceOption.Lada.value: "lada", | |
} | |
examples = [ | |
[ | |
"Прок+инувся ґазд+а вр+анці. Піш+ов, в+ичистив з-під кон+я, в+ичистив з-під бик+а, в+ичистив з-під овеч+ок, в+ибрав молодн+як, відн+іс йог+о н+абік.", | |
VoiceOption.Mykyta.value, | |
], | |
[ | |
"Піш+ов вз+яв с+іна, д+ав кор+ові. Піш+ов вз+яв с+іна, д+ав бик+ові. Ячмен+ю коняц+і нас+ипав. Зайш+ов поч+истив кор+ову, зайш+ов поч+истив бик+а, зайш+ов поч+истив к+оня, за +яйця йог+о мацн+ув.", | |
VoiceOption.Lada.value, | |
], | |
[ | |
"К+інь ного+ю здригну+в, на хазя+їна ласк+авим +оком подиви+вся. Тод+і д+ядько піш+ов відкр+ив кур+ей, гус+ей, кач+ок, повинос+ив їм з+ерна, огірк+ів нарі+заних, нагодув+ав. Кол+и ч+ує – з х+ати друж+ина кл+иче. Зайш+ов. Д+ітки повмив+ані, сид+ять за стол+ом, вс+і чек+ають т+ата. Взяв він л+ожку, перехрест+ив діт+ей, перехрест+ив л+оба, поч+али сн+ідати. Посн+ідали, він діст+ав пр+яників, розд+ав д+ітям. Д+іти зібр+алися, пішл+и в шк+олу. Д+ядько в+ийшов, сів на пр+и+зьбі, взяв с+апку, поч+ав мант+ачити. Мант+ачив-мант+ачив, кол+и – ж+інка вих+одить. Він їй ту с+апку да+є, ласк+аво за ср+аку вщипн+ув, ж+інка до ньог+о л+агідно всміхн+улася, пішл+а на гор+од – сап+ати. Кол+и – йде паст+ух і тов+ар кл+иче в чер+еду. Повідмик+ав д+ядько ов+ечок, кор+овку, бик+а, кон+я, все відпуст+ив. Сів п+опри х+ати, діст+ав таб+аку, відірв+ав шмат газ+ети, нас+ипав, наслин+ив соб+і г+арну так+у циг+арку. Благод+ать б+ожа – і с+онечко вже здійнял+ося над дерев+ами. Д+ядько встром+ив циг+арку в р+ота, діст+ав сірник+и, т+ільки чирк+ати – кол+и р+аптом з х+ати: Д+оброе +утро! Моск+овское вр+емя – ш+есть час+ов +утра! В+итяг д+ядько циг+арку с р+ота, сплюн+ув наб+ік, і сам соб+і к+аже: +Ана м+аєш. Прок+инул+ись, бл+яді!", | |
VoiceOption.Tetiana.value, | |
], | |
] | |
def inference(text, voice): | |
if not text: | |
raise gr.Error("Please paste your text.") | |
gr.Info("Starting...", duration=0.5) | |
speaker = voice_mapping[voice] | |
speaker = speaker_text = speaker_attributes = speaker | |
n_takes = 1 | |
sigma = 0.8 # sampling sigma for decoder | |
sigma_tkndur = 0.666 # sampling sigma for duration | |
sigma_f0 = 1.0 # sampling sigma for f0 | |
sigma_energy = 1.0 # sampling sigma for energy avg | |
token_dur_scaling = 1.0 | |
f0_mean = 0 | |
f0_std = 0 | |
energy_mean = 0 | |
energy_std = 0 | |
denoising_strength = 0 | |
if use_cuda: | |
speaker_id = trainset.get_speaker_id(speaker).cuda() | |
speaker_id_text, speaker_id_attributes = speaker_id, speaker_id | |
if speaker_text is not None: | |
speaker_id_text = trainset.get_speaker_id(speaker_text).cuda() | |
if speaker_attributes is not None: | |
speaker_id_attributes = trainset.get_speaker_id(speaker_attributes).cuda() | |
tensor_text = trainset.get_text(text).cuda()[None] | |
else: | |
speaker_id = trainset.get_speaker_id(speaker) | |
speaker_id_text, speaker_id_attributes = speaker_id, speaker_id | |
if speaker_text is not None: | |
speaker_id_text = trainset.get_speaker_id(speaker_text) | |
if speaker_attributes is not None: | |
speaker_id_attributes = trainset.get_speaker_id(speaker_attributes) | |
tensor_text = trainset.get_text(text)[None] | |
inference_start = time.time() | |
for take in range(n_takes): | |
with torch.autocast(device, enabled=False): | |
with torch.inference_mode(): | |
outputs = radtts.infer( | |
speaker_id, | |
tensor_text, | |
sigma, | |
sigma_tkndur, | |
sigma_f0, | |
sigma_energy, | |
token_dur_scaling, | |
token_duration_max=100, | |
speaker_id_text=speaker_id_text, | |
speaker_id_attributes=speaker_id_attributes, | |
f0_mean=f0_mean, | |
f0_std=f0_std, | |
energy_mean=energy_mean, | |
energy_std=energy_std, | |
use_cuda=use_cuda, | |
) | |
mel = outputs["mel"] | |
gr.Info("Synthesized MEL spectrogram, converting to WAVE.", duration=0.5) | |
audio = vocoder(mel).float()[0] | |
audio_denoised = denoiser(audio, strength=denoising_strength)[0].float() | |
audio = audio[0].cpu().numpy() | |
audio_denoised = audio_denoised[0].cpu().numpy() | |
audio_denoised = audio_denoised / np.max(np.abs(audio_denoised)) | |
audio_data = (22_050, audio_denoised) | |
duration = len(audio) / 22_050 | |
elapsed_time = time.time() - inference_start | |
rtf = elapsed_time / duration | |
speed_ratio = duration / elapsed_time | |
speech_rate = len(text.split(' ')) / duration | |
rtf_value = f"Real-Time Factor: {round(rtf, 4)}, time: {round(elapsed_time, 4)} seconds, audio duration: {round(duration, 4)} seconds. Speed ratio: {round(speed_ratio, 2)}x. Speech rate: {round(speech_rate, 4)} words-per-second." | |
gr.Success("Finished!", duration=0.5) | |
return [gr.Audio(audio_data), rtf_value] | |
demo = gr.Blocks( | |
title=title, | |
analytics_enabled=False, | |
theme=gr.themes.Base(), | |
) | |
with demo: | |
gr.Markdown(description_head) | |
gr.Markdown("## Usage") | |
with gr.Row(): | |
with gr.Column(): | |
audio = gr.Audio(label="Synthesized audio") | |
rtf = gr.Markdown(label="Real-Time Factor", value="Here you will see how fast the model and the speaker is.") | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Text(label="Text", value="Сл+ава Укра+їні! — українське вітання, національне гасло.") | |
voice = gr.Radio( | |
label="Voice", | |
choices=[option.value for option in VoiceOption], | |
value=VoiceOption.Tetiana.value, | |
) | |
gr.Button("Run").click( | |
inference, | |
concurrency_limit=concurrency_limit, | |
inputs=[text, voice], | |
outputs=[audio, rtf], | |
) | |
with gr.Row(): | |
gr.Examples( | |
label="Choose an example", | |
inputs=[text, voice], | |
examples=examples, | |
) | |
gr.Markdown(description_foot) | |
gr.Markdown("### Gradio app uses:") | |
gr.Markdown(tech_env) | |
gr.Markdown(tech_libraries) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |