Den4ikAI commited on
Commit
923bdc3
·
verified ·
1 Parent(s): 64127bc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +271 -0
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import json
3
+ import tempfile
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import soundfile as sf
8
+ import torch
9
+ import torchaudio
10
+ from cached_path import cached_path
11
+ from ruaccent import RUAccent
12
+ import onnx_asr
13
+
14
+ from f5_tts.infer.utils_infer import (
15
+ infer_process,
16
+ load_model,
17
+ load_vocoder,
18
+ preprocess_ref_audio_text,
19
+ remove_silence_for_generated_wav,
20
+ save_spectrogram,
21
+ tempfile_kwargs,
22
+ )
23
+ from f5_tts.model import DiT
24
+
25
+
26
+ # --- Model configuration ---
27
+ MODEL_CFG = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
28
+
29
+ # Paths for all models (впишите свои пути)
30
+ MODEL_PATHS = {
31
+ "ESpeech-TTS-1 [RL] V2": "stripped_states/espeech_tts_rlv2.pt",
32
+ "ESpeech-TTS-1 [RL] V1": "stripped_states/espeech_tts_rlv1.pt",
33
+ "ESpeech-TTS-1 [SFT] 95K": "stripped_states/espeech_tts_95k.pt",
34
+ "ESpeech-TTS-1 [SFT] 265K": "stripped_states/espeech_tts_256k.pt",
35
+ "ESpeech-TTS-1 PODCASTER [SFT]": "stripped_states/espeech_tts_podcaster.pt"
36
+ }
37
+
38
+ # Shared vocabulary path (одинаковый для всех моделей)
39
+ VOCAB_PATH = "/media/denis/work/f5tts/F5-TTS/base_checkpoint1/vocab.txt"
40
+
41
+ # Load vocoder (shared)
42
+ vocoder = load_vocoder()
43
+
44
+ # Dictionary to store loaded models
45
+ loaded_models = {}
46
+
47
+ # Initialize RUAccent
48
+ print("Loading RUAccent...")
49
+ accentizer = RUAccent()
50
+ accentizer.load(omograph_model_size='turbo3.1', use_dictionary=True, tiny_mode=False)
51
+ print("RUAccent loaded successfully.")
52
+
53
+ # Initialize ASR model
54
+ print("Loading ASR model...")
55
+ asr_model = onnx_asr.load_model("nemo-fastconformer-ru-rnnt")
56
+ print("ASR model loaded successfully.")
57
+
58
+ # Load all models at startup
59
+ print("Loading models...")
60
+ for model_name, model_path in MODEL_PATHS.items():
61
+ print(f"Loading {model_name}...")
62
+ loaded_models[model_name] = load_model(
63
+ DiT,
64
+ MODEL_CFG,
65
+ model_path,
66
+ vocab_file=VOCAB_PATH
67
+ )
68
+ print(f"{model_name} loaded successfully.")
69
+
70
+ print("All models loaded successfully.")
71
+
72
+
73
+ def synthesize(
74
+ model_choice,
75
+ ref_audio,
76
+ ref_text,
77
+ gen_text,
78
+ remove_silence,
79
+ seed,
80
+ cross_fade_duration=0.15,
81
+ nfe_step=32,
82
+ speed=1.0,
83
+ ):
84
+ if not ref_audio:
85
+ gr.Warning("Please provide reference audio.")
86
+ return None, None, ref_text
87
+
88
+ if seed < 0 or seed > 2**31 - 1:
89
+ seed = np.random.randint(0, 2**31 - 1)
90
+ torch.manual_seed(seed)
91
+
92
+ if not gen_text.strip():
93
+ gr.Warning("Please enter text to generate.")
94
+ return None, None, ref_text
95
+
96
+ # If reference text is empty, use ASR to transcribe reference audio
97
+ # If reference text is empty, use ASR to transcribe reference audio
98
+ if not ref_text.strip():
99
+ gr.Info("Reference text is empty. Running ASR to transcribe reference audio...")
100
+ try:
101
+ # Load audio data from Gradio (correct order: waveform first, then sample_rate)
102
+ waveform, sample_rate = torchaudio.load(ref_audio)
103
+
104
+ # Convert tensor to numpy
105
+ waveform = waveform.numpy()
106
+
107
+ # Convert to the format expected by onnx-asr
108
+ if waveform.dtype == np.int16:
109
+ waveform = waveform / 2**15
110
+ elif waveform.dtype == np.int32:
111
+ waveform = waveform / 2**31
112
+ elif waveform.dtype == np.float32 or waveform.dtype == np.float64:
113
+ pass # already in the right range
114
+
115
+ # Convert to mono if stereo
116
+ if waveform.ndim == 2:
117
+ waveform = waveform.mean(axis=0) # average across channels (first dimension)
118
+ elif waveform.ndim == 1:
119
+ pass # already mono
120
+ else:
121
+ waveform = waveform.squeeze()
122
+
123
+ # Run ASR on the audio data directly
124
+ transcribed_text = asr_model.recognize(waveform, sample_rate=sample_rate)
125
+ ref_text = transcribed_text
126
+ gr.Info(f"ASR transcription: {ref_text}")
127
+
128
+ except Exception as e:
129
+ gr.Warning(f"ASR transcription failed: {str(e)}")
130
+ return None, None, ref_text
131
+
132
+ # Apply accent marks to reference text and generation text
133
+ processed_ref_text = accentizer.process_all(ref_text) if ref_text.strip() else ref_text
134
+ processed_gen_text = accentizer.process_all(gen_text)
135
+
136
+ # Select model based on choice
137
+ model = loaded_models[model_choice]
138
+
139
+ # Preprocess reference audio and text
140
+ ref_audio, processed_ref_text = preprocess_ref_audio_text(
141
+ ref_audio,
142
+ processed_ref_text,
143
+ show_info=gr.Info
144
+ )
145
+
146
+ # Generate speech
147
+ final_wave, final_sample_rate, combined_spectrogram = infer_process(
148
+ ref_audio,
149
+ processed_ref_text,
150
+ processed_gen_text,
151
+ model,
152
+ vocoder,
153
+ cross_fade_duration=cross_fade_duration,
154
+ nfe_step=nfe_step,
155
+ speed=speed,
156
+ show_info=gr.Info,
157
+ progress=gr.Progress(),
158
+ )
159
+
160
+ # Remove silence if requested
161
+ if remove_silence:
162
+ with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
163
+ temp_path = f.name
164
+ sf.write(temp_path, final_wave, final_sample_rate)
165
+ remove_silence_for_generated_wav(temp_path)
166
+ final_wave, _ = torchaudio.load(temp_path)
167
+ final_wave = final_wave.squeeze().cpu().numpy()
168
+
169
+ # Save spectrogram
170
+ with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
171
+ spectrogram_path = tmp_spectrogram.name
172
+ save_spectrogram(combined_spectrogram, spectrogram_path)
173
+
174
+ return (final_sample_rate, final_wave), spectrogram_path, processed_ref_text
175
+
176
+
177
+ # --- Gradio interface ---
178
+ with gr.Blocks(title="ESpeech-TTS") as app:
179
+ gr.Markdown("# ESpeech-TTS")
180
+ gr.Markdown("Text-to-Speech synthesis system with multiple model variants")
181
+ gr.Markdown("💡 **Tip:** If you leave the Reference Text empty, it will be automatically transcribed using ASR and then processed with accent marks!")
182
+
183
+ with gr.Row():
184
+ model_choice = gr.Dropdown(
185
+ choices=list(MODEL_PATHS.keys()),
186
+ label="Select Model",
187
+ value="ESpeech-TTS-1 [RL] V2",
188
+ interactive=True
189
+ )
190
+
191
+ with gr.Row():
192
+ with gr.Column():
193
+ ref_audio_input = gr.Audio(
194
+ label="Reference Audio",
195
+ type="filepath"
196
+ )
197
+ ref_text_input = gr.Textbox(
198
+ label="Reference Text",
199
+ lines=2,
200
+ placeholder="Enter the transcription of the reference audio... (leave empty for automatic ASR transcription)"
201
+ )
202
+
203
+ with gr.Column():
204
+ gen_text_input = gr.Textbox(
205
+ label="Text to Generate",
206
+ lines=5,
207
+ max_lines=20,
208
+ placeholder="Enter the text you want to synthesize..."
209
+ )
210
+
211
+ with gr.Row():
212
+ with gr.Column():
213
+ with gr.Accordion("Advanced Settings", open=False):
214
+ seed_input = gr.Number(
215
+ label="Seed (-1 for random)",
216
+ value=-1,
217
+ precision=0
218
+ )
219
+ remove_silence = gr.Checkbox(
220
+ label="Remove Silences",
221
+ value=False
222
+ )
223
+ speed_slider = gr.Slider(
224
+ label="Speed",
225
+ minimum=0.3,
226
+ maximum=2.0,
227
+ value=1.0,
228
+ step=0.1
229
+ )
230
+ nfe_slider = gr.Slider(
231
+ label="NFE Steps (higher = better quality, slower)",
232
+ minimum=4,
233
+ maximum=64,
234
+ value=48,
235
+ step=2
236
+ )
237
+ cross_fade_slider = gr.Slider(
238
+ label="Cross-Fade Duration (s)",
239
+ minimum=0.0,
240
+ maximum=1.0,
241
+ value=0.15,
242
+ step=0.01
243
+ )
244
+
245
+ generate_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg")
246
+
247
+ with gr.Row():
248
+ audio_output = gr.Audio(label="Generated Audio", type="numpy")
249
+ spectrogram_output = gr.Image(label="Spectrogram", type="filepath")
250
+
251
+
252
+ generate_btn.click(
253
+ synthesize,
254
+ inputs=[
255
+ model_choice,
256
+ ref_audio_input,
257
+ ref_text_input,
258
+ gen_text_input,
259
+ remove_silence,
260
+ seed_input,
261
+ cross_fade_slider,
262
+ nfe_slider,
263
+ speed_slider,
264
+ ],
265
+ outputs=[audio_output, spectrogram_output, ref_text_input]
266
+ )
267
+
268
+
269
+ if __name__ == "__main__":
270
+ #app.launch(server_name="0.0.0.0", server_port=7860)
271
+ app.launch()