F5-TTS_Space / app.py
chenxie95's picture
Update app.py
7c16cce verified
raw
history blame
4 kB
import gradio as gr
import numpy as np
import spaces
import torch
from cached_path import cached_path
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
from f5_tts.model import DiT
vocoder = load_vocoder()
tts_model_choice = "v1-base_zh-en" # default
tts_model_collections = {
"v1-base_zh-en": load_model(
DiT,
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")),
vocab_file=str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt")),
),
}
@spaces.GPU
def infer(
ref_audio_orig,
ref_text,
gen_text,
model,
seed,
show_info=gr.Info,
):
if not ref_audio_orig or not ref_text.strip() or not gen_text.strip():
gr.Warning("Please ensure [Reference Audio] [Reference Text] [Text to Generate] are all provided.")
return gr.update(), gr.update(), ref_text
if seed < 0 or seed > 2**31 - 1:
gr.Warning("Please set a seed in range 0 ~ 2**31 - 1.")
seed = np.random.randint(0, 2**31 - 1)
torch.manual_seed(seed)
used_seed = seed
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
final_wave, final_sample_rate, _ = infer_process(
ref_audio,
ref_text,
gen_text,
tts_model_collections[tts_model_choice],
vocoder,
show_info=show_info,
progress=gr.Progress(),
)
return (final_sample_rate, final_wave), ref_text, used_seed
with gr.Blocks() as app_basic_tts:
gr.Markdown("# Batched TTS")
with gr.Row():
with gr.Column():
ref_wav_input = gr.Audio(label="Reference Audio", type="filepath")
ref_txt_input = gr.Textbox(label="Reference Text")
gen_txt_input = gr.Textbox(label="Text to Generate")
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed",
info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
value=True,
scale=3,
)
seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
audio_output = gr.Audio(label="Synthesized Audio")
def basic_tts(
ref_wav_input,
ref_txt_input,
gen_txt_input,
randomize_seed,
seed_input,
):
if randomize_seed:
seed_input = np.random.randint(0, 2**31 - 1)
audio_out, ref_text_out, used_seed = infer(
ref_wav_input,
ref_txt_input,
gen_txt_input,
tts_model_choice,
seed_input,
)
return audio_out, ref_text_out, used_seed
ref_wav_input.clear(
lambda: [None],
None,
[ref_txt_input],
)
generate_btn.click(
basic_tts,
inputs=[
ref_wav_input,
ref_txt_input,
gen_txt_input,
randomize_seed,
seed_input,
],
outputs=[audio_output, ref_txt_input, seed_input],
)
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🗣️ F5-TTS Online Demo for Dev Test
Upload/record a reference voice, give reference and generation text, and enjoy playing!
"""
)
def switch_tts_model(new_choice):
global tts_model_choice
tts_model_choice = new_choice
with gr.Row():
choose_tts_model = gr.Radio(choices=["v1-base_zh-en"], label="Choose TTS Model", value="v1-base_zh-en")
choose_tts_model.change(
switch_tts_model,
inputs=[choose_tts_model],
)
gr.TabbedInterface(
[app_basic_tts],
["Basic-TTS"],
)
if __name__ == "__main__":
demo.launch()