barghavani commited on
Commit
678aab1
·
1 Parent(s): 79393d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -97
app.py CHANGED
@@ -1,112 +1,65 @@
1
- import os
2
  import tempfile
 
 
3
  import gradio as gr
4
- from TTS.api import TTS
 
5
  from TTS.utils.synthesizer import Synthesizer
6
- from huggingface_hub import hf_hub_download
7
- import json
8
- os.environ["COQUI_TOS_AGREED"] = "1"
9
 
 
 
 
10
 
11
- # Define constants
12
- MODEL_INFO = [
13
- ["Xtts-Farsi", "best_model.pth", "config.json", "saillab/xtts_v2_fa_revision1","speakers.pth"],
14
- ]
15
 
16
- MODEL_NAMES = [info[0] for info in MODEL_INFO]
17
-
18
- MAX_TXT_LEN = 400
19
- TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
20
-
21
- model_files = {}
22
- config_files = {}
23
- speaker_files = {}
24
-
25
- synthesizers = {}
26
-
27
- def update_config_speakers_file_recursive(config_dict, speakers_path):
28
- if "speakers_file" in config_dict:
29
- config_dict["speakers_file"] = speakers_path
30
- for key, value in config_dict.items():
31
- if isinstance(value, dict):
32
- update_config_speakers_file_recursive(value, speakers_path)
33
-
34
- def update_config_speakers_file(config_path, speakers_path):
35
- with open(config_path, 'r') as f:
36
- config = json.load(f)
37
- update_config_speakers_file_recursive(config, speakers_path)
38
- with open(config_path, 'w') as f:
39
- json.dump(config, f, indent=4)
40
-
41
- for info in MODEL_INFO:
42
- model_name, model_file, config_file, repo_name = info[:4]
43
- speaker_file = info[4] if len(info) == 5 else None
44
- print(f"|> Downloading: {model_name}")
45
- model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
46
- config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
47
-
48
- if speaker_file:
49
- speaker_files[model_name] = hf_hub_download(repo_id=repo_name, filename=speaker_file, use_auth_token=TOKEN)
50
- update_config_speakers_file(config_files[model_name], speaker_files[model_name])
51
- print(speaker_files[model_name])
52
- synthesizer = Synthesizer(
53
- tts_checkpoint=model_files[model_name],
54
- tts_config_path=config_files[model_name],
55
- tts_speakers_file=speaker_files[model_name],
56
- use_cuda=False
57
- )
58
- elif speaker_file is None:
59
- synthesizer = Synthesizer(
60
- tts_checkpoint=model_files[model_name],
61
- tts_config_path=config_files[model_name],
62
- use_cuda=False
63
- )
64
- synthesizers[model_name] = synthesizer
65
-
66
- def synthesize(text: str, model_name: str, speaker_name=None) -> str:
67
  if len(text) > MAX_TXT_LEN:
68
  text = text[:MAX_TXT_LEN]
69
- print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
70
- synthesizer = synthesizers[model_name]
 
 
 
 
 
 
 
71
  if synthesizer is None:
72
- raise NameError("Model not found")
73
-
74
- if not synthesizer.tts_speakers_file:
75
- wavs = synthesizer.tts(text)
76
- elif synthesizer.tts_speakers_file:
77
- if not speaker_name:
78
- wavs = synthesizer.tts(text, speaker_name=None)
79
- else:
80
- wavs = synthesizer.tts(text, speaker_name=speaker_name)
81
-
82
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
83
  synthesizer.save_wav(wavs, fp)
84
  return fp.name
85
 
86
- def update_options(model_name):
87
- synthesizer = synthesizers[model_name]
88
- if model_name is MODEL_NAMES[1]:
89
- speakers = synthesizer.tts_model.speaker_manager.speaker_names
90
- return speakers
91
- else:
92
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- iface = gr.Interface(
95
- fn=synthesize,
96
- inputs=[
97
- gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
98
- gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
99
- gr.Dropdown(label="Select Speaker", choices=update_options(MODEL_NAMES[1]), type="value", default=None)
100
- ],
101
- outputs=gr.Audio(label="Output", type='filepath'),
102
- examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0], ""]],
103
- title='Persian TTS Playground',
104
- description="""
105
- ### Persian text to speech model demo.
106
- #### Pick a speaker for MultiSpeaker models. (for single speaker go for speaker-0)
107
- """,
108
- article="",
109
- live=False
110
- )
111
 
112
- iface.launch()
 
 
1
  import tempfile
2
+ from typing import Optional
3
+ from TTS.config import load_config
4
  import gradio as gr
5
+ import numpy as np
6
+ from TTS.utils.manage import ModelManager
7
  from TTS.utils.synthesizer import Synthesizer
 
 
 
8
 
9
+ MODELS = {}
10
+ SPEAKERS = {}
11
+ MAX_TXT_LEN = 100
12
 
13
+ manager = ModelManager()
14
+ MODEL_NAMES = ["saillab/xtts_v2_fa_revision1"]
 
 
15
 
16
+ def tts(text: str):
17
+ model_name = "saillab/xtts_v2_fa_revision1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  if len(text) > MAX_TXT_LEN:
19
  text = text[:MAX_TXT_LEN]
20
+ print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
21
+ print(text, model_name)
22
+ model_path, config_path, model_item = manager.download_model(model_name)
23
+ vocoder_name: Optional[str] = model_item["default_vocoder"]
24
+ vocoder_path = None
25
+ vocoder_config_path = None
26
+ if vocoder_name is not None:
27
+ vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
28
+ synthesizer = Synthesizer(model_path, config_path, None, None, vocoder_path, vocoder_config_path,)
29
  if synthesizer is None:
30
+ raise NameError("model not found")
31
+ wavs = synthesizer.tts(text, None)
 
 
 
 
 
 
 
 
32
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
33
  synthesizer.save_wav(wavs, fp)
34
  return fp.name
35
 
36
+ title = """<h1 align="center">🐸💬 CoquiTTS Playground </h1>"""
37
+
38
+ with gr.Blocks(analytics_enabled=False) as demo:
39
+ with gr.Row():
40
+ with gr.Column():
41
+ gr.Markdown("GitHub Markdown Details")
42
+ with gr.Column():
43
+ gr.Markdown("GitHub Markdown Details")
44
+
45
+ with gr.Row():
46
+ gr.Markdown("GitHub Markdown Details")
47
+
48
+ with gr.Row():
49
+ with gr.Column():
50
+ input_text = gr.inputs.Textbox(
51
+ label="Input Text",
52
+ default="This sentence has been generated by a speech synthesis system.",
53
+ )
54
+ tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
55
+
56
+ with gr.Column():
57
+ output_audio = gr.outputs.Audio(label="Output", type="filepath")
58
 
59
+ tts_button.click(
60
+ tts,
61
+ inputs=[input_text],
62
+ outputs=[output_audio],
63
+ )
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ demo.queue(concurrency_count=16).launch(debug=True)