Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,388 Bytes
88638a0 f9998f2 a1ebd12 148ad89 f9998f2 a1ebd12 148ad89 88638a0 148ad89 88638a0 7c16cce 88638a0 f9998f2 88638a0 7c16cce 88638a0 8d8cec8 88638a0 7c16cce 88638a0 7c16cce 88638a0 7c16cce a1ebd12 88638a0 148ad89 88638a0 148ad89 88638a0 148ad89 88638a0 148ad89 88638a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import gradio as gr
import numpy as np
import spaces
import torch
from cached_path import cached_path
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
from f5_tts.model import DiT
vocoder = load_vocoder()
# common usage
v1_base_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
v1_small_cfg = dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)
zh_en_vocab_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt"))
alg_vocab_path = str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/vocab.txt"))
tts_lang_model_collections = {
"Mandarin-English": {
"v1-base_zh-en": load_model(
DiT,
v1_base_cfg,
str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")),
vocab_file=zh_en_vocab_path,
),
},
"Algerian": {
"v1-small_alg-64h_400k": load_model(
DiT,
v1_small_cfg,
str(cached_path("hf://chenxie95/F5-TTS_v1_Small_Algerian/64h_model_400000.safetensors")),
vocab_file=alg_vocab_path,
),
},
}
tts_lang_choice = next(iter(tts_lang_model_collections)) # first as default
tts_model_choice = next(iter(tts_lang_model_collections[tts_lang_choice])) # first as default
@spaces.GPU
def infer(
ref_audio_orig,
ref_text,
gen_text,
model,
seed,
show_info=gr.Info,
):
if not ref_audio_orig or not ref_text.strip() or not gen_text.strip():
gr.Warning("Please ensure [Reference Audio] [Reference Text] [Text to Generate] are all provided.")
return gr.update(), ref_text, seed
if seed < 0 or seed > 2**31 - 1:
gr.Warning("Please set a seed in range 0 ~ 2**31 - 1.")
seed = np.random.randint(0, 2**31 - 1)
torch.manual_seed(seed)
used_seed = seed
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
final_wave, final_sample_rate, _ = infer_process(
ref_audio,
ref_text,
gen_text,
tts_lang_model_collections[tts_lang_choice][tts_model_choice],
vocoder,
show_info=show_info,
progress=gr.Progress(),
)
return (final_sample_rate, final_wave), ref_text, used_seed
with gr.Blocks() as app_basic_tts:
with gr.Row():
with gr.Column():
ref_wav_input = gr.Audio(label="Reference Audio", type="filepath")
ref_txt_input = gr.Textbox(label="Reference Text")
gen_txt_input = gr.Textbox(label="Text to Generate")
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed",
info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
value=True,
scale=3,
)
seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
audio_output = gr.Audio(label="Synthesized Audio")
def basic_tts(
ref_wav_input,
ref_txt_input,
gen_txt_input,
randomize_seed,
seed_input,
):
if randomize_seed:
seed_input = np.random.randint(0, 2**31 - 1)
audio_out, ref_text_out, used_seed = infer(
ref_wav_input,
ref_txt_input,
gen_txt_input,
tts_model_choice,
seed_input,
)
return audio_out, ref_text_out, used_seed
generate_btn.click(
basic_tts,
inputs=[
ref_wav_input,
ref_txt_input,
gen_txt_input,
randomize_seed,
seed_input,
],
outputs=[audio_output, ref_txt_input, seed_input],
)
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🗣️ F5-TTS Online Demo for Dev Test
Upload or record a reference voice, give its transcription text, then order the text to generate and have fun!
"""
)
def switch_tts_lang(new_lang_choice):
global tts_lang_choice, tts_model_choice
tts_lang_choice = new_lang_choice
tts_model_choice = next(iter(tts_lang_model_collections[tts_lang_choice])) # first as default
return gr.update(choices=[t for t in tts_lang_model_collections[tts_lang_choice]], value=tts_model_choice)
def switch_tts_model(new_model_choice):
global tts_model_choice
tts_model_choice = new_model_choice
with gr.Row():
choose_tts_lang = gr.Dropdown(
choices=[t for t in tts_lang_model_collections],
label="Choose TTS Language",
value=tts_lang_choice,
)
choose_tts_model = gr.Dropdown(
choices=[t for t in tts_lang_model_collections[tts_lang_choice]],
label="Choose TTS Model",
value=tts_model_choice,
)
choose_tts_lang.change(
switch_tts_lang,
inputs=[choose_tts_lang],
outputs=[choose_tts_model],
)
choose_tts_model.change(
switch_tts_model,
inputs=[choose_tts_model],
)
gr.TabbedInterface(
[app_basic_tts],
["Basic-TTS"],
)
if __name__ == "__main__":
demo.launch()
|