mrfakename commited on
Commit
9941624
·
verified ·
1 Parent(s): f23ce9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -202
app.py CHANGED
@@ -3,11 +3,9 @@
3
 
4
  import gc
5
  import json
6
- import os
7
  import re
8
  import tempfile
9
  from collections import OrderedDict
10
- from functools import lru_cache
11
  from importlib.resources import files
12
 
13
  import click
@@ -19,7 +17,6 @@ import torchaudio
19
  from cached_path import cached_path
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
21
 
22
-
23
  try:
24
  import spaces
25
 
@@ -35,16 +32,15 @@ def gpu_decorator(func):
35
  return func
36
 
37
 
 
38
  from f5_tts.infer.utils_infer import (
39
- infer_process,
40
- load_model,
41
  load_vocoder,
 
42
  preprocess_ref_audio_text,
 
43
  remove_silence_for_generated_wav,
44
  save_spectrogram,
45
- tempfile_kwargs,
46
  )
47
- from f5_tts.model import DiT, UNetT
48
 
49
 
50
  DEFAULT_TTS_MODEL = "F5-TTS_v1"
@@ -82,8 +78,6 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
82
  vocab_path = str(cached_path(vocab_path))
83
  if model_cfg is None:
84
  model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
85
- elif isinstance(model_cfg, str):
86
- model_cfg = json.loads(model_cfg)
87
  return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
88
 
89
 
@@ -128,7 +122,6 @@ def load_text_from_file(file):
128
  return gr.update(value=text)
129
 
130
 
131
- @lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
132
  @gpu_decorator
133
  def infer(
134
  ref_audio_orig,
@@ -147,11 +140,7 @@ def infer(
147
  return gr.update(), gr.update(), ref_text
148
 
149
  # Set inference seed
150
- if seed < 0 or seed > 2**31 - 1:
151
- gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
152
- seed = np.random.randint(0, 2**31 - 1)
153
  torch.manual_seed(seed)
154
- used_seed = seed
155
 
156
  if not gen_text.strip():
157
  gr.Warning("Please enter text to generate or upload a text file.")
@@ -167,7 +156,7 @@ def infer(
167
  show_info("Loading E2-TTS model...")
168
  E2TTS_ema_model = load_e2tts()
169
  ema_model = E2TTS_ema_model
170
- elif isinstance(model, tuple) and model[0] == "Custom":
171
  assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
172
  global custom_ema_model, pre_custom_path
173
  if pre_custom_path != model[1]:
@@ -191,24 +180,28 @@ def infer(
191
 
192
  # Remove silence
193
  if remove_silence:
194
- with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
195
- temp_path = f.name
196
- try:
197
- sf.write(temp_path, final_wave, final_sample_rate)
198
  remove_silence_for_generated_wav(f.name)
199
  final_wave, _ = torchaudio.load(f.name)
200
- finally:
201
- os.unlink(temp_path)
202
  final_wave = final_wave.squeeze().cpu().numpy()
203
 
204
  # Save the spectrogram
205
- with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
206
  spectrogram_path = tmp_spectrogram.name
207
- save_spectrogram(combined_spectrogram, spectrogram_path)
208
 
209
- return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
210
 
211
 
 
 
 
 
 
 
 
 
212
  with gr.Blocks() as app_tts:
213
  gr.Markdown("# Batched TTS")
214
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
@@ -229,7 +222,9 @@ with gr.Blocks() as app_tts:
229
  lines=2,
230
  scale=4,
231
  )
232
- ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
 
 
233
  with gr.Row():
234
  randomize_seed = gr.Checkbox(
235
  label="Randomize Seed",
@@ -284,21 +279,27 @@ with gr.Blocks() as app_tts:
284
  nfe_slider,
285
  speed_slider,
286
  ):
 
287
  if randomize_seed:
288
- seed_input = np.random.randint(0, 2**31 - 1)
 
 
 
 
 
289
 
290
- audio_out, spectrogram_path, ref_text_out, used_seed = infer(
291
  ref_audio_input,
292
  ref_text_input,
293
  gen_text_input,
294
  tts_model_choice,
295
  remove_silence,
296
- seed=seed_input,
297
  cross_fade_duration=cross_fade_duration_slider,
298
  nfe_step=nfe_slider,
299
  speed=speed_slider,
300
  )
301
- return audio_out, spectrogram_path, ref_text_out, used_seed
302
 
303
  gen_text_file.upload(
304
  load_text_from_file,
@@ -312,12 +313,6 @@ with gr.Blocks() as app_tts:
312
  outputs=[ref_text_input],
313
  )
314
 
315
- ref_audio_input.clear(
316
- lambda: [None, None],
317
- None,
318
- [ref_text_input, ref_text_file],
319
- )
320
-
321
  generate_btn.click(
322
  basic_tts,
323
  inputs=[
@@ -336,35 +331,26 @@ with gr.Blocks() as app_tts:
336
 
337
 
338
  def parse_speechtypes_text(gen_text):
339
- # Pattern to find {str} or {"name": str, "seed": int, "speed": float}
340
- pattern = r"(\{.*?\})"
341
 
342
  # Split the text by the pattern
343
  tokens = re.split(pattern, gen_text)
344
 
345
  segments = []
346
 
347
- current_type_dict = {
348
- "name": "Regular",
349
- "seed": -1,
350
- "speed": 1.0,
351
- }
352
 
353
  for i in range(len(tokens)):
354
  if i % 2 == 0:
355
  # This is text
356
  text = tokens[i].strip()
357
  if text:
358
- current_type_dict["text"] = text
359
- segments.append(current_type_dict)
360
  else:
361
- # This is type
362
- type_str = tokens[i].strip()
363
- try: # if type dict
364
- current_type_dict = json.loads(type_str)
365
- except json.decoder.JSONDecodeError:
366
- type_str = type_str[1:-1] # remove brace {}
367
- current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
368
 
369
  return segments
370
 
@@ -382,48 +368,40 @@ with gr.Blocks() as app_multistyle:
382
  with gr.Row():
383
  gr.Markdown(
384
  """
385
- **Example Input:** <br>
386
- {Regular} Hello, I'd like to order a sandwich please. <br>
387
- {Surprised} What do you mean you're out of bread? <br>
388
- {Sad} I really wanted a sandwich though... <br>
389
- {Angry} You know what, darn you and your little shop! <br>
390
- {Whisper} I'll just go back home and cry now. <br>
391
  {Shouting} Why me?!
392
  """
393
  )
394
 
395
  gr.Markdown(
396
  """
397
- **Example Input 2:** <br>
398
- {"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please. <br>
399
- {"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread. <br>
400
- {"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though... <br>
401
- {"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
402
  """
403
  )
404
 
405
  gr.Markdown(
406
- 'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
407
  )
408
 
409
  # Regular speech type (mandatory)
410
- with gr.Row(variant="compact") as regular_row:
411
  with gr.Column(scale=1, min_width=160):
412
  regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
413
  regular_insert = gr.Button("Insert Label", variant="secondary")
414
  with gr.Column(scale=3):
415
  regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
416
- with gr.Column(scale=3):
417
- regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
418
- with gr.Row():
419
- regular_seed_slider = gr.Slider(
420
- show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
421
- )
422
- regular_speed_slider = gr.Slider(
423
- show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
424
- )
425
- with gr.Column(scale=1, min_width=160):
426
- regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
427
 
428
  # Regular speech type (max 100)
429
  max_speech_types = 100
@@ -432,54 +410,31 @@ with gr.Blocks() as app_multistyle:
432
  speech_type_audios = [regular_audio]
433
  speech_type_ref_texts = [regular_ref_text]
434
  speech_type_ref_text_files = [regular_ref_text_file]
435
- speech_type_seeds = [regular_seed_slider]
436
- speech_type_speeds = [regular_speed_slider]
437
  speech_type_delete_btns = [None]
438
  speech_type_insert_btns = [regular_insert]
439
 
440
  # Additional speech types (99 more)
441
  for i in range(max_speech_types - 1):
442
- with gr.Row(variant="compact", visible=False) as row:
443
  with gr.Column(scale=1, min_width=160):
444
  name_input = gr.Textbox(label="Speech Type Name")
 
445
  insert_btn = gr.Button("Insert Label", variant="secondary")
446
- delete_btn = gr.Button("Delete Type", variant="stop")
447
  with gr.Column(scale=3):
448
  audio_input = gr.Audio(label="Reference Audio", type="filepath")
449
- with gr.Column(scale=3):
450
- ref_text_input = gr.Textbox(label="Reference Text", lines=4)
451
- with gr.Row():
452
- seed_input = gr.Slider(
453
- show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
454
- )
455
- speed_input = gr.Slider(
456
- show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
457
- )
458
- with gr.Column(scale=1, min_width=160):
459
- ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
460
  speech_type_rows.append(row)
461
  speech_type_names.append(name_input)
462
  speech_type_audios.append(audio_input)
463
  speech_type_ref_texts.append(ref_text_input)
464
  speech_type_ref_text_files.append(ref_text_file_input)
465
- speech_type_seeds.append(seed_input)
466
- speech_type_speeds.append(speed_input)
467
  speech_type_delete_btns.append(delete_btn)
468
  speech_type_insert_btns.append(insert_btn)
469
 
470
- # Global logic for all speech types
471
- for i in range(max_speech_types):
472
- speech_type_audios[i].clear(
473
- lambda: [None, None],
474
- None,
475
- [speech_type_ref_texts[i], speech_type_ref_text_files[i]],
476
- )
477
- speech_type_ref_text_files[i].upload(
478
- load_text_from_file,
479
- inputs=[speech_type_ref_text_files[i]],
480
- outputs=[speech_type_ref_texts[i]],
481
- )
482
-
483
  # Button to add speech type
484
  add_speech_type_btn = gr.Button("Add Speech Type")
485
 
@@ -515,6 +470,18 @@ with gr.Blocks() as app_multistyle:
515
  speech_type_ref_text_files[i],
516
  ],
517
  )
 
 
 
 
 
 
 
 
 
 
 
 
518
 
519
  # Text input for the prompt
520
  with gr.Row():
@@ -528,17 +495,10 @@ with gr.Blocks() as app_multistyle:
528
  gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
529
 
530
  def make_insert_speech_type_fn(index):
531
- def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
532
  current_text = current_text or ""
533
- if not speech_type_name:
534
- gr.Warning("Please enter speech type name before insert.")
535
- return current_text
536
- speech_type_dict = {
537
- "name": speech_type_name,
538
- "seed": speech_type_seed,
539
- "speed": speech_type_speed,
540
- }
541
- updated_text = current_text + json.dumps(speech_type_dict) + " "
542
  return updated_text
543
 
544
  return insert_speech_type_fn
@@ -547,24 +507,16 @@ with gr.Blocks() as app_multistyle:
547
  insert_fn = make_insert_speech_type_fn(i)
548
  insert_btn.click(
549
  insert_fn,
550
- inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
551
  outputs=gen_text_input_multistyle,
552
  )
553
 
554
- with gr.Accordion("Advanced Settings", open=True):
555
- with gr.Row():
556
- with gr.Column():
557
- show_cherrypick_multistyle = gr.Checkbox(
558
- label="Show Cherry-pick Interface",
559
- info="Turn on to show interface, picking seeds from previous generations.",
560
- value=False,
561
- )
562
- with gr.Column():
563
- remove_silence_multistyle = gr.Checkbox(
564
- label="Remove Silences",
565
- info="Turn on to automatically detect and crop long silences.",
566
- value=True,
567
- )
568
 
569
  # Generate button
570
  generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
@@ -572,24 +524,6 @@ with gr.Blocks() as app_multistyle:
572
  # Output audio
573
  audio_output_multistyle = gr.Audio(label="Synthesized Audio")
574
 
575
- # Used seed gallery
576
- cherrypick_interface_multistyle = gr.Textbox(
577
- label="Cherry-pick Interface",
578
- lines=10,
579
- max_lines=40,
580
- show_copy_button=True,
581
- interactive=False,
582
- visible=False,
583
- )
584
-
585
- # Logic control to show/hide the cherrypick interface
586
- show_cherrypick_multistyle.change(
587
- lambda is_visible: gr.update(visible=is_visible),
588
- show_cherrypick_multistyle,
589
- cherrypick_interface_multistyle,
590
- )
591
-
592
- # Function to load text to generate from file
593
  gen_text_file_multistyle.upload(
594
  load_text_from_file,
595
  inputs=[gen_text_file_multistyle],
@@ -623,60 +557,44 @@ with gr.Blocks() as app_multistyle:
623
 
624
  # For each segment, generate speech
625
  generated_audio_segments = []
626
- current_type_name = "Regular"
627
- inference_meta_data = ""
628
 
629
  for segment in segments:
630
- name = segment["name"]
631
- seed_input = segment["seed"]
632
- speed = segment["speed"]
633
  text = segment["text"]
634
 
635
- if name in speech_types:
636
- current_type_name = name
637
  else:
638
- gr.Warning(f"Type {name} is not available, will use Regular as default.")
639
- current_type_name = "Regular"
640
 
641
  try:
642
- ref_audio = speech_types[current_type_name]["audio"]
643
  except KeyError:
644
- gr.Warning(f"Please provide reference audio for type {current_type_name}.")
645
- return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
646
- ref_text = speech_types[current_type_name].get("ref_text", "")
647
 
648
- if seed_input == -1:
649
- seed_input = np.random.randint(0, 2**31 - 1)
650
 
651
- # Generate or retrieve speech for this segment
652
- audio_out, _, ref_text_out, used_seed = infer(
653
- ref_audio,
654
- ref_text,
655
- text,
656
- tts_model_choice,
657
- remove_silence,
658
- seed=seed_input,
659
- cross_fade_duration=0,
660
- speed=speed,
661
- show_info=print, # no pull to top when generating
662
- )
663
  sr, audio_data = audio_out
664
 
665
  generated_audio_segments.append(audio_data)
666
- speech_types[current_type_name]["ref_text"] = ref_text_out
667
- inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
668
 
669
  # Concatenate all audio segments
670
  if generated_audio_segments:
671
  final_audio_data = np.concatenate(generated_audio_segments)
672
- return (
673
- [(sr, final_audio_data)]
674
- + [speech_types[name]["ref_text"] for name in speech_types]
675
- + [inference_meta_data]
676
- )
677
  else:
678
  gr.Warning("No audio generated.")
679
- return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
680
 
681
  generate_multistyle_btn.click(
682
  generate_multistyle_speech,
@@ -689,7 +607,7 @@ with gr.Blocks() as app_multistyle:
689
  + [
690
  remove_silence_multistyle,
691
  ],
692
- outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
693
  )
694
 
695
  # Validation function to disable Generate button if speech types are missing
@@ -706,7 +624,7 @@ with gr.Blocks() as app_multistyle:
706
 
707
  # Parse the gen_text to get the speech types used
708
  segments = parse_speechtypes_text(gen_text)
709
- speech_types_in_text = set(segment["name"] for segment in segments)
710
 
711
  # Check if all speech types in text are available
712
  missing_speech_types = speech_types_in_text - speech_types_available
@@ -870,21 +788,27 @@ Have a conversation with an AI using your reference voice!
870
  if not last_ai_response or conv_state[-1]["role"] != "assistant":
871
  return None, ref_text, seed_input
872
 
 
873
  if randomize_seed:
874
- seed_input = np.random.randint(0, 2**31 - 1)
 
 
 
 
 
875
 
876
- audio_result, _, ref_text_out, used_seed = infer(
877
  ref_audio,
878
  ref_text,
879
  last_ai_response,
880
  tts_model_choice,
881
  remove_silence,
882
- seed=seed_input,
883
  cross_fade_duration=0.15,
884
  speed=1.0,
885
  show_info=print, # show_info=print no pull to top when generating
886
  )
887
- return audio_result, ref_text_out, used_seed
888
 
889
  def clear_conversation():
890
  """Reset the conversation"""
@@ -930,16 +854,6 @@ Have a conversation with an AI using your reference voice!
930
  )
931
 
932
 
933
- with gr.Blocks() as app_credits:
934
- gr.Markdown("""
935
- # Credits
936
-
937
- * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
938
- * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
939
- * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
940
- """)
941
-
942
-
943
  with gr.Blocks() as app:
944
  gr.Markdown(
945
  f"""
@@ -975,7 +889,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
975
  global tts_model_choice
976
  if new_choice == "Custom": # override in case webpage is refreshed
977
  custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
978
- tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
979
  return (
980
  gr.update(visible=True, value=custom_ckpt_path),
981
  gr.update(visible=True, value=custom_vocab_path),
@@ -987,7 +901,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
987
 
988
  def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
989
  global tts_model_choice
990
- tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
991
  with open(last_used_custom, "w", encoding="utf-8") as f:
992
  f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
993
 
@@ -1118,4 +1032,4 @@ if __name__ == "__main__":
1118
  if not USING_SPACES:
1119
  main()
1120
  else:
1121
- app.queue().launch()
 
3
 
4
  import gc
5
  import json
 
6
  import re
7
  import tempfile
8
  from collections import OrderedDict
 
9
  from importlib.resources import files
10
 
11
  import click
 
17
  from cached_path import cached_path
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
19
 
 
20
  try:
21
  import spaces
22
 
 
32
  return func
33
 
34
 
35
+ from f5_tts.model import DiT, UNetT
36
  from f5_tts.infer.utils_infer import (
 
 
37
  load_vocoder,
38
+ load_model,
39
  preprocess_ref_audio_text,
40
+ infer_process,
41
  remove_silence_for_generated_wav,
42
  save_spectrogram,
 
43
  )
 
44
 
45
 
46
  DEFAULT_TTS_MODEL = "F5-TTS_v1"
 
78
  vocab_path = str(cached_path(vocab_path))
79
  if model_cfg is None:
80
  model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
 
 
81
  return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
82
 
83
 
 
122
  return gr.update(value=text)
123
 
124
 
 
125
  @gpu_decorator
126
  def infer(
127
  ref_audio_orig,
 
140
  return gr.update(), gr.update(), ref_text
141
 
142
  # Set inference seed
 
 
 
143
  torch.manual_seed(seed)
 
144
 
145
  if not gen_text.strip():
146
  gr.Warning("Please enter text to generate or upload a text file.")
 
156
  show_info("Loading E2-TTS model...")
157
  E2TTS_ema_model = load_e2tts()
158
  ema_model = E2TTS_ema_model
159
+ elif isinstance(model, list) and model[0] == "Custom":
160
  assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
161
  global custom_ema_model, pre_custom_path
162
  if pre_custom_path != model[1]:
 
180
 
181
  # Remove silence
182
  if remove_silence:
183
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
184
+ sf.write(f.name, final_wave, final_sample_rate)
 
 
185
  remove_silence_for_generated_wav(f.name)
186
  final_wave, _ = torchaudio.load(f.name)
 
 
187
  final_wave = final_wave.squeeze().cpu().numpy()
188
 
189
  # Save the spectrogram
190
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
191
  spectrogram_path = tmp_spectrogram.name
192
+ save_spectrogram(combined_spectrogram, spectrogram_path)
193
 
194
+ return (final_sample_rate, final_wave), spectrogram_path, ref_text
195
 
196
 
197
+ with gr.Blocks() as app_credits:
198
+ gr.Markdown("""
199
+ # Credits
200
+
201
+ * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
202
+ * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
203
+ * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
204
+ """)
205
  with gr.Blocks() as app_tts:
206
  gr.Markdown("# Batched TTS")
207
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
 
222
  lines=2,
223
  scale=4,
224
  )
225
+ ref_text_file = gr.File(
226
+ label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1, height=1
227
+ )
228
  with gr.Row():
229
  randomize_seed = gr.Checkbox(
230
  label="Randomize Seed",
 
279
  nfe_slider,
280
  speed_slider,
281
  ):
282
+ # Determine the seed to use
283
  if randomize_seed:
284
+ seed = np.random.randint(0, 2**31 - 1)
285
+ else:
286
+ seed = seed_input
287
+ if seed < 0 or seed > 2**31 - 1:
288
+ gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
289
+ seed = np.random.randint(0, 2**31 - 1)
290
 
291
+ audio_out, spectrogram_path, ref_text_out = infer(
292
  ref_audio_input,
293
  ref_text_input,
294
  gen_text_input,
295
  tts_model_choice,
296
  remove_silence,
297
+ seed=seed,
298
  cross_fade_duration=cross_fade_duration_slider,
299
  nfe_step=nfe_slider,
300
  speed=speed_slider,
301
  )
302
+ return audio_out, spectrogram_path, ref_text_out, seed
303
 
304
  gen_text_file.upload(
305
  load_text_from_file,
 
313
  outputs=[ref_text_input],
314
  )
315
 
 
 
 
 
 
 
316
  generate_btn.click(
317
  basic_tts,
318
  inputs=[
 
331
 
332
 
333
  def parse_speechtypes_text(gen_text):
334
+ # Pattern to find {speechtype}
335
+ pattern = r"\{(.*?)\}"
336
 
337
  # Split the text by the pattern
338
  tokens = re.split(pattern, gen_text)
339
 
340
  segments = []
341
 
342
+ current_style = "Regular"
 
 
 
 
343
 
344
  for i in range(len(tokens)):
345
  if i % 2 == 0:
346
  # This is text
347
  text = tokens[i].strip()
348
  if text:
349
+ segments.append({"style": current_style, "text": text})
 
350
  else:
351
+ # This is style
352
+ style = tokens[i].strip()
353
+ current_style = style
 
 
 
 
354
 
355
  return segments
356
 
 
368
  with gr.Row():
369
  gr.Markdown(
370
  """
371
+ **Example Input:**
372
+ {Regular} Hello, I'd like to order a sandwich please.
373
+ {Surprised} What do you mean you're out of bread?
374
+ {Sad} I really wanted a sandwich though...
375
+ {Angry} You know what, darn you and your little shop!
376
+ {Whisper} I'll just go back home and cry now.
377
  {Shouting} Why me?!
378
  """
379
  )
380
 
381
  gr.Markdown(
382
  """
383
+ **Example Input 2:**
384
+ {Speaker1_Happy} Hello, I'd like to order a sandwich please.
385
+ {Speaker2_Regular} Sorry, we're out of bread.
386
+ {Speaker1_Sad} I really wanted a sandwich though...
387
+ {Speaker2_Whisper} I'll give you the last one I was hiding.
388
  """
389
  )
390
 
391
  gr.Markdown(
392
+ "Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
393
  )
394
 
395
  # Regular speech type (mandatory)
396
+ with gr.Row() as regular_row:
397
  with gr.Column(scale=1, min_width=160):
398
  regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
399
  regular_insert = gr.Button("Insert Label", variant="secondary")
400
  with gr.Column(scale=3):
401
  regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
402
+ with gr.Column(scale=4):
403
+ regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=8, scale=3)
404
+ regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
 
 
 
 
 
 
 
 
405
 
406
  # Regular speech type (max 100)
407
  max_speech_types = 100
 
410
  speech_type_audios = [regular_audio]
411
  speech_type_ref_texts = [regular_ref_text]
412
  speech_type_ref_text_files = [regular_ref_text_file]
 
 
413
  speech_type_delete_btns = [None]
414
  speech_type_insert_btns = [regular_insert]
415
 
416
  # Additional speech types (99 more)
417
  for i in range(max_speech_types - 1):
418
+ with gr.Row(visible=False) as row:
419
  with gr.Column(scale=1, min_width=160):
420
  name_input = gr.Textbox(label="Speech Type Name")
421
+ delete_btn = gr.Button("Delete Type", variant="secondary")
422
  insert_btn = gr.Button("Insert Label", variant="secondary")
 
423
  with gr.Column(scale=3):
424
  audio_input = gr.Audio(label="Reference Audio", type="filepath")
425
+ with gr.Column(scale=4):
426
+ ref_text_input = gr.Textbox(label="Reference Text", lines=8, scale=3)
427
+ ref_text_file_input = gr.File(
428
+ label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
429
+ )
 
 
 
 
 
 
430
  speech_type_rows.append(row)
431
  speech_type_names.append(name_input)
432
  speech_type_audios.append(audio_input)
433
  speech_type_ref_texts.append(ref_text_input)
434
  speech_type_ref_text_files.append(ref_text_file_input)
 
 
435
  speech_type_delete_btns.append(delete_btn)
436
  speech_type_insert_btns.append(insert_btn)
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  # Button to add speech type
439
  add_speech_type_btn = gr.Button("Add Speech Type")
440
 
 
470
  speech_type_ref_text_files[i],
471
  ],
472
  )
473
+ speech_type_ref_text_files[i].upload(
474
+ load_text_from_file,
475
+ inputs=[speech_type_ref_text_files[i]],
476
+ outputs=[speech_type_ref_texts[i]],
477
+ )
478
+
479
+ # Update regular speech type ref text file
480
+ regular_ref_text_file.upload(
481
+ load_text_from_file,
482
+ inputs=[regular_ref_text_file],
483
+ outputs=[regular_ref_text],
484
+ )
485
 
486
  # Text input for the prompt
487
  with gr.Row():
 
495
  gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
496
 
497
  def make_insert_speech_type_fn(index):
498
+ def insert_speech_type_fn(current_text, speech_type_name):
499
  current_text = current_text or ""
500
+ speech_type_name = speech_type_name or "None"
501
+ updated_text = current_text + f"{{{speech_type_name}}} "
 
 
 
 
 
 
 
502
  return updated_text
503
 
504
  return insert_speech_type_fn
 
507
  insert_fn = make_insert_speech_type_fn(i)
508
  insert_btn.click(
509
  insert_fn,
510
+ inputs=[gen_text_input_multistyle, speech_type_names[i]],
511
  outputs=gen_text_input_multistyle,
512
  )
513
 
514
+ with gr.Accordion("Advanced Settings", open=False):
515
+ remove_silence_multistyle = gr.Checkbox(
516
+ label="Remove Silences",
517
+ info="Turn on to automatically detect and crop long silences.",
518
+ value=True,
519
+ )
 
 
 
 
 
 
 
 
520
 
521
  # Generate button
522
  generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
 
524
  # Output audio
525
  audio_output_multistyle = gr.Audio(label="Synthesized Audio")
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  gen_text_file_multistyle.upload(
528
  load_text_from_file,
529
  inputs=[gen_text_file_multistyle],
 
557
 
558
  # For each segment, generate speech
559
  generated_audio_segments = []
560
+ current_style = "Regular"
 
561
 
562
  for segment in segments:
563
+ style = segment["style"]
 
 
564
  text = segment["text"]
565
 
566
+ if style in speech_types:
567
+ current_style = style
568
  else:
569
+ gr.Warning(f"Type {style} is not available, will use Regular as default.")
570
+ current_style = "Regular"
571
 
572
  try:
573
+ ref_audio = speech_types[current_style]["audio"]
574
  except KeyError:
575
+ gr.Warning(f"Please provide reference audio for type {current_style}.")
576
+ return [None] + [speech_types[style]["ref_text"] for style in speech_types]
577
+ ref_text = speech_types[current_style].get("ref_text", "")
578
 
579
+ # TODO. Attribute each type a unique seed (maybe also speed, pseudo-feature for #730 #813)
580
+ seed = np.random.randint(0, 2**31 - 1)
581
 
582
+ # Generate speech for this segment
583
+ audio_out, _, ref_text_out = infer(
584
+ ref_audio, ref_text, text, tts_model_choice, remove_silence, seed, 0, show_info=print
585
+ ) # show_info=print no pull to top when generating
 
 
 
 
 
 
 
 
586
  sr, audio_data = audio_out
587
 
588
  generated_audio_segments.append(audio_data)
589
+ speech_types[current_style]["ref_text"] = ref_text_out
 
590
 
591
  # Concatenate all audio segments
592
  if generated_audio_segments:
593
  final_audio_data = np.concatenate(generated_audio_segments)
594
+ return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
 
 
 
 
595
  else:
596
  gr.Warning("No audio generated.")
597
+ return [None] + [speech_types[style]["ref_text"] for style in speech_types]
598
 
599
  generate_multistyle_btn.click(
600
  generate_multistyle_speech,
 
607
  + [
608
  remove_silence_multistyle,
609
  ],
610
+ outputs=[audio_output_multistyle] + speech_type_ref_texts,
611
  )
612
 
613
  # Validation function to disable Generate button if speech types are missing
 
624
 
625
  # Parse the gen_text to get the speech types used
626
  segments = parse_speechtypes_text(gen_text)
627
+ speech_types_in_text = set(segment["style"] for segment in segments)
628
 
629
  # Check if all speech types in text are available
630
  missing_speech_types = speech_types_in_text - speech_types_available
 
788
  if not last_ai_response or conv_state[-1]["role"] != "assistant":
789
  return None, ref_text, seed_input
790
 
791
+ # Determine the seed to use
792
  if randomize_seed:
793
+ seed = np.random.randint(0, 2**31 - 1)
794
+ else:
795
+ seed = seed_input
796
+ if seed < 0 or seed > 2**31 - 1:
797
+ gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
798
+ seed = np.random.randint(0, 2**31 - 1)
799
 
800
+ audio_result, _, ref_text_out = infer(
801
  ref_audio,
802
  ref_text,
803
  last_ai_response,
804
  tts_model_choice,
805
  remove_silence,
806
+ seed=seed,
807
  cross_fade_duration=0.15,
808
  speed=1.0,
809
  show_info=print, # show_info=print no pull to top when generating
810
  )
811
+ return audio_result, ref_text_out, seed
812
 
813
  def clear_conversation():
814
  """Reset the conversation"""
 
854
  )
855
 
856
 
 
 
 
 
 
 
 
 
 
 
857
  with gr.Blocks() as app:
858
  gr.Markdown(
859
  f"""
 
889
  global tts_model_choice
890
  if new_choice == "Custom": # override in case webpage is refreshed
891
  custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
892
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
893
  return (
894
  gr.update(visible=True, value=custom_ckpt_path),
895
  gr.update(visible=True, value=custom_vocab_path),
 
901
 
902
  def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
903
  global tts_model_choice
904
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
905
  with open(last_used_custom, "w", encoding="utf-8") as f:
906
  f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
907
 
 
1032
  if not USING_SPACES:
1033
  main()
1034
  else:
1035
+ app.queue().launch()