File size: 4,851 Bytes
09dc90f
 
 
f6a407b
09dc90f
 
 
f6a407b
09dc90f
 
 
 
 
 
 
 
f6a407b
09dc90f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874c079
c311c29
09dc90f
 
 
874c079
fd98fad
874c079
 
 
 
 
09dc90f
fd98fad
 
09dc90f
 
 
 
 
 
fd98fad
09dc90f
874c079
e8f2c20
874c079
09dc90f
 
 
 
 
 
 
 
 
 
 
 
 
874c079
09dc90f
 
 
 
fd98fad
 
09dc90f
 
874c079
 
 
 
09dc90f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874c079
fd98fad
874c079
 
09dc90f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import spaces
import os
from huggingface_hub import login
import gradio as gr
from cached_path import cached_path
import tempfile
from vinorm import TTSnorm

from f5_tts.model import DiT
from f5_tts.infer.utils_infer import (
    preprocess_ref_audio_text,
    load_vocoder,
    load_model,
    infer_process,
    save_spectrogram,
)

# Retrieve token from secrets
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")

# Log in to Hugging Face
if hf_token:
    login(token=hf_token)

def post_process(text):
    text = " " + text + " "
    text = text.replace(" . . ", " . ")
    text = " " + text + " "
    text = text.replace(" .. ", " . ")
    text = " " + text + " "
    text = text.replace(" , , ", " , ")
    text = " " + text + " "
    text = text.replace(" ,, ", " , ")
    text = " " + text + " "
    text = text.replace('"', "")
    return " ".join(text.split())

# Load models
vocoder = load_vocoder()
model = load_model(
    DiT,
    dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
    ckpt_path=str(cached_path("hf://jackkie99/f5-tts-vnese/model_reduce_109000.pt")),
    vocab_file=str(cached_path("hf://hynt/F5-TTS-Vietnamese-100h/vocab.txt")),
)

@spaces.GPU
def infer_tts(
        gen_text: str, speed: float = 1.0, 
        nfe_steps: float = 64.0, target_rms: float = 0.1,
        cross_fade_duration: float = 0,
        sway_sampling_coef: float = -1,
        request: gr.Request = None
    ):

    # if not ref_audio_orig:
    #     raise gr.Error("Please upload a sample audio file.")
    if not gen_text.strip():
        raise gr.Error("Please enter the text content to generate voice.")
    if len(gen_text.split()) > 1000:
        raise gr.Error("Please enter text content with less than 1000 words.")
    
    try:
        ref_audio, ref_text = preprocess_ref_audio_text(cached_path("hf://jackkie99/f5-tts-vnese/segment_59.wav"), "")
        final_wave, final_sample_rate, spectrogram = infer_process(
            ref_audio, ref_text.lower(), post_process(TTSnorm(gen_text)).lower(), model, vocoder, speed=speed,
            nfe_step=nfe_steps, target_rms=target_rms, cross_fade_duration=cross_fade_duration,
            sway_sampling_coef=sway_sampling_coef
        )
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
            spectrogram_path = tmp_spectrogram.name
            save_spectrogram(spectrogram, spectrogram_path)

        return (final_sample_rate, final_wave), spectrogram_path
    except Exception as e:
        raise gr.Error(f"Error generating voice: {e}")

# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🎤 F5-TTS: Vietnamese Text-to-Speech Synthesis.
    # The model was trained for 500.000 steps with approximately 200 hours of data on an RTX 3090 GPU. 
    Enter text and upload a sample voice to generate natural speech.
    """)
    
    with gr.Row():
        # ref_audio = gr.Audio(label="🔊 Sample Voice", type="filepath")
        gen_text = gr.Textbox(label="📝 Text", placeholder="Nhập văn bản để tổng hợp giọng", lines=3)
    
    speed = gr.Slider(0.3, 2.0, value=1.0, step=0.1, label="⚡ Speed")
    nfe_steps = gr.Slider(16, 64, value=64, step=16, label="NFE Steps")
    target_rms = gr.Slider(0, 1, value=0.1, step=0.1, label="Target RMS")
    cross_fade_duration = gr.Slider(0, 1, value=0, step=0.05, label="Cross Fade Duration")
    sway_sampling_coef = gr.Slider(-1, 3, value=-1, step=0.5, label="Sway Sampling Coef")
    btn_synthesize = gr.Button("🔥 Generate Voice")
    
    with gr.Row():
        output_audio = gr.Audio(label="🎧 Generated Audio", type="numpy")
        output_spectrogram = gr.Image(label="📊 Spectrogram")
    
    model_limitations = gr.Textbox(
        value="""1. This model may not perform well with numerical characters, dates, special characters, etc. => A text normalization module is needed.
2. The rhythm of some generated audios may be inconsistent or choppy => It is recommended to select clearly pronounced sample audios with minimal pauses for better synthesis quality.
3. Default, reference audio text uses the whisper-large-v3-turbo model, which may not always accurately recognize Vietnamese, resulting in poor voice synthesis quality.
4. Checkpoint is stopped at step 500.000, trained with 150 hours of public data => Voice cloning for non-native voices may not be perfectly accurate.
5. Inference with overly long paragraphs may produce poor results.""", 
        label="❗ Model Limitations",
        lines=5,
        interactive=False
    )

    btn_synthesize.click(infer_tts,
        inputs=[gen_text, speed, nfe_steps, target_rms, cross_fade_duration, sway_sampling_coef],
        outputs=[output_audio, output_spectrogram]
    )

# Run Gradio with share=True to get a gradio.live link
demo.queue().launch()