TaiYa1 commited on
Commit
cfe1020
·
verified ·
1 Parent(s): 6bd360f

Update tools/webui.py

Browse files
Files changed (1) hide show
  1. tools/webui.py +546 -621
tools/webui.py CHANGED
@@ -1,621 +1,546 @@
1
- import gc
2
- import html
3
- import io
4
- import os
5
- import queue
6
- import wave
7
- from argparse import ArgumentParser
8
- from functools import partial
9
- from pathlib import Path
10
- from tools.api import decode_vq_tokens, encode_reference
11
- import gradio as gr
12
- import librosa
13
- import numpy as np
14
- import pyrootutils
15
- import torch
16
- from loguru import logger
17
- from transformers import AutoTokenizer
18
- import spaces
19
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
-
21
-
22
- from fish_speech.i18n import i18n
23
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
24
- from fish_speech.utils import autocast_exclude_mps
25
- from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
26
- from tools.llama.generate import (
27
- GenerateRequest,
28
- GenerateResponse,
29
- WrappedGenerateResponse,
30
- launch_thread_safe_queue,
31
- )
32
- from tools.vqgan.inference import load_model as load_decoder_model
33
- import torchaudio
34
-
35
-
36
- torchaudio.set_audio_backend("soundfile")
37
-
38
- # Make einx happy
39
- os.environ["EINX_FILTER_TRACEBACK"] = "false"
40
-
41
-
42
- HEADER_MD = f"""# Fish Speech
43
-
44
- {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
45
-
46
- {i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
47
-
48
- {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
49
-
50
- {i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
51
- """
52
-
53
- TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
54
- SPACE_IMPORTED = False
55
-
56
- def build_html_error_message(error):
57
- return f"""
58
- <div style="color: red;
59
- font-weight: bold;">
60
- {html.escape(str(error))}
61
- </div>
62
- """
63
-
64
-
65
- @torch.inference_mode()
66
- def inference(
67
- text,
68
- enable_reference_audio=False,
69
- reference_audio=None,
70
- reference_text="",
71
- max_new_tokens=0,
72
- chunk_length=100,
73
- top_p=0.7,
74
- repetition_penalty=1.2,
75
- temperature=0.7,
76
- streaming=False,
77
- ):
78
- if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
79
- return (
80
- None,
81
- None,
82
- i18n("Text is too long, please keep it under {} characters.").format(
83
- args.max_gradio_length
84
- ),
85
- )
86
-
87
- # Parse reference audio aka prompt
88
- prompt_tokens = encode_reference(
89
- decoder_model=decoder_model,
90
- reference_audio=reference_audio,
91
- enable_reference_audio=enable_reference_audio,
92
- )
93
-
94
- # LLAMA Inference
95
- request = dict(
96
- device=decoder_model.device,
97
- max_new_tokens=max_new_tokens,
98
- text=text,
99
- top_p=top_p,
100
- repetition_penalty=repetition_penalty,
101
- temperature=temperature,
102
- compile=args.compile,
103
- iterative_prompt=chunk_length > 0,
104
- chunk_length=chunk_length,
105
- max_length=2048,
106
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
107
- prompt_text=reference_text if enable_reference_audio else None,
108
- )
109
-
110
- response_queue = queue.Queue()
111
- llama_queue.put(
112
- GenerateRequest(
113
- request=request,
114
- response_queue=response_queue,
115
- )
116
- )
117
-
118
- if streaming:
119
- yield wav_chunk_header(), None, None
120
-
121
- segments = []
122
-
123
- while True:
124
- result: WrappedGenerateResponse = response_queue.get()
125
- if result.status == "error":
126
- yield None, None, build_html_error_message(result.response)
127
- break
128
-
129
- result: GenerateResponse = result.response
130
- if result.action == "next":
131
- break
132
-
133
- with autocast_exclude_mps(
134
- device_type=decoder_model.device.type, dtype=args.precision
135
- ):
136
- fake_audios = decode_vq_tokens(
137
- decoder_model=decoder_model,
138
- codes=result.codes,
139
- )
140
-
141
- fake_audios = fake_audios.float().cpu().numpy()
142
- segments.append(fake_audios)
143
-
144
- if streaming:
145
- yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
146
-
147
- if len(segments) == 0:
148
- return (
149
- None,
150
- None,
151
- build_html_error_message(
152
- i18n("No audio generated, please check the input text.")
153
- ),
154
- )
155
-
156
- # No matter streaming or not, we need to return the final audio
157
- audio = np.concatenate(segments, axis=0)
158
- yield None, (decoder_model.spec_transform.sample_rate, audio), None
159
-
160
- if torch.cuda.is_available():
161
- torch.cuda.empty_cache()
162
- gc.collect()
163
-
164
-
165
- def inference_with_auto_rerank(
166
- text,
167
- enable_reference_audio,
168
- reference_audio,
169
- reference_text,
170
- max_new_tokens,
171
- chunk_length,
172
- top_p,
173
- repetition_penalty,
174
- temperature,
175
- use_auto_rerank,
176
- streaming=False,
177
- ):
178
-
179
- max_attempts = 2 if use_auto_rerank else 1
180
- best_wer = float("inf")
181
- best_audio = None
182
- best_sample_rate = None
183
-
184
- for attempt in range(max_attempts):
185
- audio_generator = inference(
186
- text,
187
- enable_reference_audio,
188
- reference_audio,
189
- reference_text,
190
- max_new_tokens,
191
- chunk_length,
192
- top_p,
193
- repetition_penalty,
194
- temperature,
195
- streaming=False,
196
- )
197
-
198
- # 获取音频数据
199
- for _ in audio_generator:
200
- pass
201
- _, (sample_rate, audio), message = _
202
-
203
- if audio is None:
204
- return None, None, message
205
-
206
- if not use_auto_rerank:
207
- return None, (sample_rate, audio), None
208
-
209
- asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
210
- wer = calculate_wer(text, asr_result["text"])
211
- if wer <= 0.3 and not asr_result["huge_gap"]:
212
- return None, (sample_rate, audio), None
213
-
214
- if wer < best_wer:
215
- best_wer = wer
216
- best_audio = audio
217
- best_sample_rate = sample_rate
218
-
219
- if attempt == max_attempts - 1:
220
- break
221
-
222
- return None, (best_sample_rate, best_audio), None
223
-
224
-
225
- inference_stream = partial(inference, streaming=True)
226
-
227
- n_audios = 4
228
-
229
- global_audio_list = []
230
- global_error_list = []
231
-
232
-
233
- def inference_wrapper(
234
- text,
235
- enable_reference_audio,
236
- reference_audio,
237
- reference_text,
238
- max_new_tokens,
239
- chunk_length,
240
- top_p,
241
- repetition_penalty,
242
- temperature,
243
- batch_infer_num,
244
- if_load_asr_model,
245
- ):
246
- audios = []
247
- errors = []
248
-
249
- for _ in range(batch_infer_num):
250
- result = inference_with_auto_rerank(
251
- text,
252
- enable_reference_audio,
253
- reference_audio,
254
- reference_text,
255
- max_new_tokens,
256
- chunk_length,
257
- top_p,
258
- repetition_penalty,
259
- temperature,
260
- if_load_asr_model,
261
- )
262
-
263
- _, audio_data, error_message = result
264
-
265
- audios.append(
266
- gr.Audio(value=audio_data if audio_data else None, visible=True),
267
- )
268
- errors.append(
269
- gr.HTML(value=error_message if error_message else None, visible=True),
270
- )
271
-
272
- for _ in range(batch_infer_num, n_audios):
273
- audios.append(
274
- gr.Audio(value=None, visible=False),
275
- )
276
- errors.append(
277
- gr.HTML(value=None, visible=False),
278
- )
279
-
280
- return None, *audios, *errors
281
-
282
-
283
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
284
- buffer = io.BytesIO()
285
-
286
- with wave.open(buffer, "wb") as wav_file:
287
- wav_file.setnchannels(channels)
288
- wav_file.setsampwidth(bit_depth // 8)
289
- wav_file.setframerate(sample_rate)
290
-
291
- wav_header_bytes = buffer.getvalue()
292
- buffer.close()
293
- return wav_header_bytes
294
-
295
-
296
- def normalize_text(user_input, use_normalization):
297
- if use_normalization:
298
- return ChnNormedText(raw_text=user_input).normalize()
299
- else:
300
- return user_input
301
-
302
-
303
- asr_model = None
304
-
305
-
306
- def change_if_load_asr_model(if_load):
307
- global asr_model
308
-
309
- if if_load:
310
- gr.Warning("Loading faster whisper model...")
311
- if asr_model is None:
312
- asr_model = load_model()
313
- return gr.Checkbox(label="Unload faster whisper model", value=if_load)
314
-
315
- if if_load is False:
316
- gr.Warning("Unloading faster whisper model...")
317
- del asr_model
318
- asr_model = None
319
- if torch.cuda.is_available():
320
- torch.cuda.empty_cache()
321
- gc.collect()
322
- return gr.Checkbox(label="Load faster whisper model", value=if_load)
323
-
324
-
325
- def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
326
- if if_load and asr_model is not None:
327
- if (
328
- if_auto_label
329
- and enable_ref
330
- and ref_audio is not None
331
- and ref_text.strip() == ""
332
- ):
333
- data, sample_rate = librosa.load(ref_audio)
334
- res = batch_asr(asr_model, [data], sample_rate)[0]
335
- ref_text = res["text"]
336
- else:
337
- gr.Warning("Whisper model not loaded!")
338
-
339
- return gr.Textbox(value=ref_text)
340
-
341
-
342
- def build_app():
343
- with gr.Blocks(theme=gr.themes.Base()) as app:
344
- gr.Markdown(HEADER_MD)
345
-
346
- # Use light theme by default
347
- app.load(
348
- None,
349
- None,
350
- js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
351
- % args.theme,
352
- )
353
-
354
- # Inference
355
- with gr.Row():
356
- with gr.Column(scale=3):
357
- text = gr.Textbox(
358
- label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
359
- )
360
- refined_text = gr.Textbox(
361
- label=i18n("Realtime Transform Text"),
362
- placeholder=i18n(
363
- "Normalization Result Preview (Currently Only Chinese)"
364
- ),
365
- lines=5,
366
- interactive=False,
367
- )
368
-
369
- with gr.Row():
370
- if_refine_text = gr.Checkbox(
371
- label=i18n("Text Normalization"),
372
- value=False,
373
- scale=1,
374
- )
375
-
376
- if_load_asr_model = gr.Checkbox(
377
- label=i18n("Load / Unload ASR model for auto-reranking"),
378
- value=False,
379
- scale=3,
380
- )
381
-
382
- with gr.Row():
383
- with gr.Tab(label=i18n("Advanced Config")):
384
- chunk_length = gr.Slider(
385
- label=i18n("Iterative Prompt Length, 0 means off"),
386
- minimum=50,
387
- maximum=300,
388
- value=200,
389
- step=8,
390
- )
391
-
392
- max_new_tokens = gr.Slider(
393
- label=i18n("Maximum tokens per batch, 0 means no limit"),
394
- minimum=0,
395
- maximum=2048,
396
- value=1024, # 0 means no limit
397
- step=8,
398
- )
399
-
400
- top_p = gr.Slider(
401
- label="Top-P",
402
- minimum=0.6,
403
- maximum=0.9,
404
- value=0.7,
405
- step=0.01,
406
- )
407
-
408
- repetition_penalty = gr.Slider(
409
- label=i18n("Repetition Penalty"),
410
- minimum=1,
411
- maximum=1.5,
412
- value=1.2,
413
- step=0.01,
414
- )
415
-
416
- temperature = gr.Slider(
417
- label="Temperature",
418
- minimum=0.6,
419
- maximum=0.9,
420
- value=0.7,
421
- step=0.01,
422
- )
423
-
424
- with gr.Tab(label=i18n("Reference Audio")):
425
- gr.Markdown(
426
- i18n(
427
- "5 to 10 seconds of reference audio, useful for specifying speaker."
428
- )
429
- )
430
-
431
- enable_reference_audio = gr.Checkbox(
432
- label=i18n("Enable Reference Audio"),
433
- )
434
- reference_audio = gr.Audio(
435
- label=i18n("Reference Audio"),
436
- type="filepath",
437
- )
438
- with gr.Row():
439
- if_auto_label = gr.Checkbox(
440
- label=i18n("Auto Labeling"),
441
- min_width=100,
442
- scale=0,
443
- value=False,
444
- )
445
- reference_text = gr.Textbox(
446
- label=i18n("Reference Text"),
447
- lines=1,
448
- placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
449
- value="",
450
- )
451
- with gr.Tab(label=i18n("Batch Inference")):
452
- batch_infer_num = gr.Slider(
453
- label="Batch infer nums",
454
- minimum=1,
455
- maximum=n_audios,
456
- step=1,
457
- value=1,
458
- )
459
-
460
- with gr.Column(scale=3):
461
- for _ in range(n_audios):
462
- with gr.Row():
463
- error = gr.HTML(
464
- label=i18n("Error Message"),
465
- visible=True if _ == 0 else False,
466
- )
467
- global_error_list.append(error)
468
- with gr.Row():
469
- audio = gr.Audio(
470
- label=i18n("Generated Audio"),
471
- type="numpy",
472
- interactive=False,
473
- visible=True if _ == 0 else False,
474
- )
475
- global_audio_list.append(audio)
476
-
477
- with gr.Row():
478
- stream_audio = gr.Audio(
479
- label=i18n("Streaming Audio"),
480
- streaming=True,
481
- autoplay=True,
482
- interactive=False,
483
- show_download_button=True,
484
- )
485
- with gr.Row():
486
- with gr.Column(scale=3):
487
- generate = gr.Button(
488
- value="\U0001F3A7 " + i18n("Generate"), variant="primary"
489
- )
490
- generate_stream = gr.Button(
491
- value="\U0001F3A7 " + i18n("Streaming Generate"),
492
- variant="primary",
493
- )
494
-
495
- text.input(
496
- fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
497
- )
498
-
499
- if_load_asr_model.change(
500
- fn=change_if_load_asr_model,
501
- inputs=[if_load_asr_model],
502
- outputs=[if_load_asr_model],
503
- )
504
-
505
- if_auto_label.change(
506
- fn=lambda: gr.Textbox(value=""),
507
- inputs=[],
508
- outputs=[reference_text],
509
- ).then(
510
- fn=change_if_auto_label,
511
- inputs=[
512
- if_load_asr_model,
513
- if_auto_label,
514
- enable_reference_audio,
515
- reference_audio,
516
- reference_text,
517
- ],
518
- outputs=[reference_text],
519
- )
520
-
521
- # # Submit
522
- generate.click(
523
- inference_wrapper,
524
- [
525
- refined_text,
526
- enable_reference_audio,
527
- reference_audio,
528
- reference_text,
529
- max_new_tokens,
530
- chunk_length,
531
- top_p,
532
- repetition_penalty,
533
- temperature,
534
- batch_infer_num,
535
- if_load_asr_model,
536
- ],
537
- [stream_audio, *global_audio_list, *global_error_list],
538
- concurrency_limit=1,
539
- )
540
-
541
- generate_stream.click(
542
- inference_stream,
543
- [
544
- refined_text,
545
- enable_reference_audio,
546
- reference_audio,
547
- reference_text,
548
- max_new_tokens,
549
- chunk_length,
550
- top_p,
551
- repetition_penalty,
552
- temperature,
553
- ],
554
- [stream_audio, global_audio_list[0], global_error_list[0]],
555
- concurrency_limit=10,
556
- )
557
- return app
558
-
559
-
560
- def parse_args():
561
- parser = ArgumentParser()
562
- parser.add_argument(
563
- "--llama-checkpoint-path",
564
- type=Path,
565
- default="checkpoints/fish-speech-1.4",
566
- )
567
- parser.add_argument(
568
- "--decoder-checkpoint-path",
569
- type=Path,
570
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
571
- )
572
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
573
- parser.add_argument("--device", type=str, default="cuda")
574
- parser.add_argument("--half", action="store_true")
575
- parser.add_argument("--compile", action="store_true")
576
- parser.add_argument("--max-gradio-length", type=int, default=0)
577
- parser.add_argument("--theme", type=str, default="light")
578
-
579
- return parser.parse_args()
580
-
581
-
582
- if __name__ == "__main__":
583
- args = parse_args()
584
- args.precision = torch.half if args.half else torch.bfloat16
585
-
586
- logger.info("Loading Llama model...")
587
- llama_queue = launch_thread_safe_queue(
588
- checkpoint_path=args.llama_checkpoint_path,
589
- device=args.device,
590
- precision=args.precision,
591
- compile=args.compile,
592
- )
593
- logger.info("Llama model loaded, loading VQ-GAN model...")
594
-
595
- decoder_model = load_decoder_model(
596
- config_name=args.decoder_config_name,
597
- checkpoint_path=args.decoder_checkpoint_path,
598
- device=args.device,
599
- )
600
-
601
- logger.info("Decoder model loaded, warming up...")
602
-
603
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
604
- list(
605
- inference(
606
- text="",
607
- enable_reference_audio=False,
608
- reference_audio=None,
609
- reference_text="",
610
- max_new_tokens=0,
611
- chunk_length=100,
612
- top_p=0.7,
613
- repetition_penalty=1.2,
614
- temperature=0.7,
615
- )
616
- )
617
-
618
- logger.info("Warming up done, launching the web UI...")
619
-
620
- app = build_app()
621
- app.launch(show_api=True)
 
1
+ import gc
2
+ import html
3
+ import io
4
+ import os
5
+ import queue
6
+ import wave
7
+ from argparse import ArgumentParser
8
+ from functools import partial
9
+ from pathlib import Path
10
+
11
+ import gradio as gr
12
+ import librosa
13
+ import numpy as np
14
+ import pyrootutils
15
+ import torch
16
+ from loguru import logger
17
+ from transformers import AutoTokenizer
18
+
19
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
+
21
+
22
+ from fish_speech.i18n import i18n
23
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
24
+ from fish_speech.utils import autocast_exclude_mps, set_seed
25
+ from tools.api import decode_vq_tokens, encode_reference
26
+ from tools.file import AUDIO_EXTENSIONS, list_files
27
+ from tools.llama.generate import (
28
+ GenerateRequest,
29
+ GenerateResponse,
30
+ WrappedGenerateResponse,
31
+ launch_thread_safe_queue,
32
+ )
33
+ from tools.vqgan.inference import load_model as load_decoder_model
34
+
35
+ # Make einx happy
36
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
37
+
38
+
39
+ HEADER_MD = f"""# Fish Speech
40
+
41
+ {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
42
+
43
+ {i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
44
+
45
+ {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
46
+
47
+ {i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
48
+ """
49
+
50
+ TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
51
+ SPACE_IMPORTED = False
52
+
53
+
54
+ def build_html_error_message(error):
55
+ return f"""
56
+ <div style="color: red;
57
+ font-weight: bold;">
58
+ {html.escape(str(error))}
59
+ </div>
60
+ """
61
+
62
+
63
+ @torch.inference_mode()
64
+ def inference(
65
+ text,
66
+ enable_reference_audio,
67
+ reference_audio,
68
+ reference_text,
69
+ max_new_tokens,
70
+ chunk_length,
71
+ top_p,
72
+ repetition_penalty,
73
+ temperature,
74
+ seed="0",
75
+ streaming=False,
76
+ ):
77
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
78
+ return (
79
+ None,
80
+ None,
81
+ i18n("Text is too long, please keep it under {} characters.").format(
82
+ args.max_gradio_length
83
+ ),
84
+ )
85
+
86
+ seed = int(seed)
87
+ if seed != 0:
88
+ set_seed(seed)
89
+ logger.warning(f"set seed: {seed}")
90
+
91
+ # Parse reference audio aka prompt
92
+ prompt_tokens = encode_reference(
93
+ decoder_model=decoder_model,
94
+ reference_audio=reference_audio,
95
+ enable_reference_audio=enable_reference_audio,
96
+ )
97
+
98
+ # LLAMA Inference
99
+ request = dict(
100
+ device=decoder_model.device,
101
+ max_new_tokens=max_new_tokens,
102
+ text=text,
103
+ top_p=top_p,
104
+ repetition_penalty=repetition_penalty,
105
+ temperature=temperature,
106
+ compile=args.compile,
107
+ iterative_prompt=chunk_length > 0,
108
+ chunk_length=chunk_length,
109
+ max_length=2048,
110
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
111
+ prompt_text=reference_text if enable_reference_audio else None,
112
+ )
113
+
114
+ response_queue = queue.Queue()
115
+ llama_queue.put(
116
+ GenerateRequest(
117
+ request=request,
118
+ response_queue=response_queue,
119
+ )
120
+ )
121
+
122
+ if streaming:
123
+ yield wav_chunk_header(), None, None
124
+
125
+ segments = []
126
+
127
+ while True:
128
+ result: WrappedGenerateResponse = response_queue.get()
129
+ if result.status == "error":
130
+ yield None, None, build_html_error_message(result.response)
131
+ break
132
+
133
+ result: GenerateResponse = result.response
134
+ if result.action == "next":
135
+ break
136
+
137
+ with autocast_exclude_mps(
138
+ device_type=decoder_model.device.type, dtype=args.precision
139
+ ):
140
+ fake_audios = decode_vq_tokens(
141
+ decoder_model=decoder_model,
142
+ codes=result.codes,
143
+ )
144
+
145
+ fake_audios = fake_audios.float().cpu().numpy()
146
+ segments.append(fake_audios)
147
+
148
+ if streaming:
149
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
150
+
151
+ if len(segments) == 0:
152
+ return (
153
+ None,
154
+ None,
155
+ build_html_error_message(
156
+ i18n("No audio generated, please check the input text.")
157
+ ),
158
+ )
159
+
160
+ # No matter streaming or not, we need to return the final audio
161
+ audio = np.concatenate(segments, axis=0)
162
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
163
+
164
+ if torch.cuda.is_available():
165
+ torch.cuda.empty_cache()
166
+ gc.collect()
167
+
168
+
169
+ inference_stream = partial(inference, streaming=True)
170
+
171
+ n_audios = 4
172
+
173
+ global_audio_list = []
174
+ global_error_list = []
175
+
176
+
177
+ def inference_wrapper(
178
+ text,
179
+ enable_reference_audio,
180
+ reference_audio,
181
+ reference_text,
182
+ max_new_tokens,
183
+ chunk_length,
184
+ top_p,
185
+ repetition_penalty,
186
+ temperature,
187
+ seed,
188
+ batch_infer_num,
189
+ ):
190
+ audios = []
191
+ errors = []
192
+
193
+ for _ in range(batch_infer_num):
194
+ result = inference(
195
+ text,
196
+ enable_reference_audio,
197
+ reference_audio,
198
+ reference_text,
199
+ max_new_tokens,
200
+ chunk_length,
201
+ top_p,
202
+ repetition_penalty,
203
+ temperature,
204
+ seed,
205
+ )
206
+
207
+ _, audio_data, error_message = next(result)
208
+
209
+ audios.append(
210
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
211
+ )
212
+ errors.append(
213
+ gr.HTML(value=error_message if error_message else None, visible=True),
214
+ )
215
+
216
+ for _ in range(batch_infer_num, n_audios):
217
+ audios.append(
218
+ gr.Audio(value=None, visible=False),
219
+ )
220
+ errors.append(
221
+ gr.HTML(value=None, visible=False),
222
+ )
223
+
224
+ return None, *audios, *errors
225
+
226
+
227
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
228
+ buffer = io.BytesIO()
229
+
230
+ with wave.open(buffer, "wb") as wav_file:
231
+ wav_file.setnchannels(channels)
232
+ wav_file.setsampwidth(bit_depth // 8)
233
+ wav_file.setframerate(sample_rate)
234
+
235
+ wav_header_bytes = buffer.getvalue()
236
+ buffer.close()
237
+ return wav_header_bytes
238
+
239
+
240
+ def normalize_text(user_input, use_normalization):
241
+ if use_normalization:
242
+ return ChnNormedText(raw_text=user_input).normalize()
243
+ else:
244
+ return user_input
245
+
246
+
247
+ def update_examples():
248
+ examples_dir = Path("references")
249
+ examples_dir.mkdir(parents=True, exist_ok=True)
250
+ example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
251
+ return gr.Dropdown(choices=example_audios + [""])
252
+
253
+
254
+ def build_app():
255
+ with gr.Blocks(theme=gr.themes.Base()) as app:
256
+ gr.Markdown(HEADER_MD)
257
+
258
+ # Use light theme by default
259
+ app.load(
260
+ None,
261
+ None,
262
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
263
+ % args.theme,
264
+ )
265
+
266
+ # Inference
267
+ with gr.Row():
268
+ with gr.Column(scale=3):
269
+ text = gr.Textbox(
270
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
271
+ )
272
+ refined_text = gr.Textbox(
273
+ label=i18n("Realtime Transform Text"),
274
+ placeholder=i18n(
275
+ "Normalization Result Preview (Currently Only Chinese)"
276
+ ),
277
+ lines=5,
278
+ interactive=False,
279
+ )
280
+
281
+ with gr.Row():
282
+ if_refine_text = gr.Checkbox(
283
+ label=i18n("Text Normalization"),
284
+ value=False,
285
+ scale=1,
286
+ )
287
+
288
+ with gr.Row():
289
+ with gr.Column():
290
+ with gr.Tab(label=i18n("Advanced Config")):
291
+ with gr.Row():
292
+ chunk_length = gr.Slider(
293
+ label=i18n("Iterative Prompt Length, 0 means off"),
294
+ minimum=50,
295
+ maximum=300,
296
+ value=200,
297
+ step=8,
298
+ )
299
+
300
+ max_new_tokens = gr.Slider(
301
+ label=i18n(
302
+ "Maximum tokens per batch, 0 means no limit"
303
+ ),
304
+ minimum=0,
305
+ maximum=2048,
306
+ value=0, # 0 means no limit
307
+ step=8,
308
+ )
309
+
310
+ with gr.Row():
311
+ top_p = gr.Slider(
312
+ label="Top-P",
313
+ minimum=0.6,
314
+ maximum=0.9,
315
+ value=0.7,
316
+ step=0.01,
317
+ )
318
+
319
+ repetition_penalty = gr.Slider(
320
+ label=i18n("Repetition Penalty"),
321
+ minimum=1,
322
+ maximum=1.5,
323
+ value=1.2,
324
+ step=0.01,
325
+ )
326
+
327
+ with gr.Row():
328
+ temperature = gr.Slider(
329
+ label="Temperature",
330
+ minimum=0.6,
331
+ maximum=0.9,
332
+ value=0.7,
333
+ step=0.01,
334
+ )
335
+ seed = gr.Textbox(
336
+ label="Seed",
337
+ info="0 means randomized inference, otherwise deterministic",
338
+ placeholder="any 32-bit-integer",
339
+ value="0",
340
+ )
341
+
342
+ with gr.Tab(label=i18n("Reference Audio")):
343
+ with gr.Row():
344
+ gr.Markdown(
345
+ i18n(
346
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
347
+ )
348
+ )
349
+ with gr.Row():
350
+ enable_reference_audio = gr.Checkbox(
351
+ label=i18n("Enable Reference Audio"),
352
+ )
353
+
354
+ with gr.Row():
355
+ example_audio_dropdown = gr.Dropdown(
356
+ label=i18n("Select Example Audio"),
357
+ choices=[""],
358
+ value="",
359
+ interactive=True,
360
+ allow_custom_value=True,
361
+ )
362
+ with gr.Row():
363
+ reference_audio = gr.Audio(
364
+ label=i18n("Reference Audio"),
365
+ type="filepath",
366
+ )
367
+ with gr.Row():
368
+ reference_text = gr.Textbox(
369
+ label=i18n("Reference Text"),
370
+ lines=1,
371
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
372
+ value="",
373
+ )
374
+ with gr.Tab(label=i18n("Batch Inference")):
375
+ with gr.Row():
376
+ batch_infer_num = gr.Slider(
377
+ label="Batch infer nums",
378
+ minimum=1,
379
+ maximum=n_audios,
380
+ step=1,
381
+ value=1,
382
+ )
383
+
384
+ with gr.Column(scale=3):
385
+ for _ in range(n_audios):
386
+ with gr.Row():
387
+ error = gr.HTML(
388
+ label=i18n("Error Message"),
389
+ visible=True if _ == 0 else False,
390
+ )
391
+ global_error_list.append(error)
392
+ with gr.Row():
393
+ audio = gr.Audio(
394
+ label=i18n("Generated Audio"),
395
+ type="numpy",
396
+ interactive=False,
397
+ visible=True if _ == 0 else False,
398
+ )
399
+ global_audio_list.append(audio)
400
+
401
+ with gr.Row():
402
+ stream_audio = gr.Audio(
403
+ label=i18n("Streaming Audio"),
404
+ streaming=True,
405
+ autoplay=True,
406
+ interactive=False,
407
+ show_download_button=True,
408
+ )
409
+ with gr.Row():
410
+ with gr.Column(scale=3):
411
+ generate = gr.Button(
412
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
413
+ )
414
+ generate_stream = gr.Button(
415
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
416
+ variant="primary",
417
+ )
418
+
419
+ text.input(
420
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
421
+ )
422
+
423
+ def select_example_audio(audio_path):
424
+ audio_path = Path(audio_path)
425
+ if audio_path.is_file():
426
+ lab_file = Path(audio_path.with_suffix(".lab"))
427
+
428
+ if lab_file.exists():
429
+ lab_content = lab_file.read_text(encoding="utf-8").strip()
430
+ else:
431
+ lab_content = ""
432
+
433
+ return str(audio_path), lab_content, True
434
+ return None, "", False
435
+
436
+ # Connect the dropdown to update reference audio and text
437
+ example_audio_dropdown.change(
438
+ fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
439
+ ).then(
440
+ fn=select_example_audio,
441
+ inputs=[example_audio_dropdown],
442
+ outputs=[reference_audio, reference_text, enable_reference_audio],
443
+ )
444
+
445
+ # # Submit
446
+ generate.click(
447
+ inference_wrapper,
448
+ [
449
+ refined_text,
450
+ enable_reference_audio,
451
+ reference_audio,
452
+ reference_text,
453
+ max_new_tokens,
454
+ chunk_length,
455
+ top_p,
456
+ repetition_penalty,
457
+ temperature,
458
+ seed,
459
+ batch_infer_num,
460
+ ],
461
+ [stream_audio, *global_audio_list, *global_error_list],
462
+ concurrency_limit=1,
463
+ )
464
+
465
+ generate_stream.click(
466
+ inference_stream,
467
+ [
468
+ refined_text,
469
+ enable_reference_audio,
470
+ reference_audio,
471
+ reference_text,
472
+ max_new_tokens,
473
+ chunk_length,
474
+ top_p,
475
+ repetition_penalty,
476
+ temperature,
477
+ seed,
478
+ ],
479
+ [stream_audio, global_audio_list[0], global_error_list[0]],
480
+ concurrency_limit=1,
481
+ )
482
+ return app
483
+
484
+
485
+ def parse_args():
486
+ parser = ArgumentParser()
487
+ parser.add_argument(
488
+ "--llama-checkpoint-path",
489
+ type=Path,
490
+ default="checkpoints/fish-speech-1.4",
491
+ )
492
+ parser.add_argument(
493
+ "--decoder-checkpoint-path",
494
+ type=Path,
495
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
496
+ )
497
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
498
+ parser.add_argument("--device", type=str, default="cuda")
499
+ parser.add_argument("--half", action="store_true")
500
+ parser.add_argument("--compile", action="store_true")
501
+ parser.add_argument("--max-gradio-length", type=int, default=0)
502
+ parser.add_argument("--theme", type=str, default="light")
503
+
504
+ return parser.parse_args()
505
+
506
+
507
+ if __name__ == "__main__":
508
+ args = parse_args()
509
+ args.precision = torch.half if args.half else torch.bfloat16
510
+
511
+ logger.info("Loading Llama model...")
512
+ llama_queue = launch_thread_safe_queue(
513
+ checkpoint_path=args.llama_checkpoint_path,
514
+ device=args.device,
515
+ precision=args.precision,
516
+ compile=args.compile,
517
+ )
518
+ logger.info("Llama model loaded, loading VQ-GAN model...")
519
+
520
+ decoder_model = load_decoder_model(
521
+ config_name=args.decoder_config_name,
522
+ checkpoint_path=args.decoder_checkpoint_path,
523
+ device=args.device,
524
+ )
525
+
526
+ logger.info("Decoder model loaded, warming up...")
527
+
528
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
529
+ list(
530
+ inference(
531
+ text="Hello, world!",
532
+ enable_reference_audio=False,
533
+ reference_audio=None,
534
+ reference_text="",
535
+ max_new_tokens=0,
536
+ chunk_length=200,
537
+ top_p=0.7,
538
+ repetition_penalty=1.2,
539
+ temperature=0.7,
540
+ )
541
+ )
542
+
543
+ logger.info("Warming up done, launching the web UI...")
544
+
545
+ app = build_app()
546
+ app.launch(show_api=True)