JotunnBurton commited on
Commit
92259fe
·
verified ·
1 Parent(s): cccafbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -193
app.py CHANGED
@@ -7,28 +7,22 @@ import argparse
7
  import commons
8
  import utils
9
  import gradio as gr
 
 
 
 
10
  from huggingface_hub import hf_hub_download
11
 
12
-
13
  from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
14
  from models import SynthesizerTrn
15
  from text.symbols import symbols
16
  from text import cleaned_text_to_sequence, get_bert
17
  from text.cleaner import clean_text
18
- import numpy as np
19
-
20
- logging.getLogger("numba").setLevel(logging.WARNING)
21
- logging.getLogger("markdown_it").setLevel(logging.WARNING)
22
- logging.getLogger("urllib3").setLevel(logging.WARNING)
23
- logging.getLogger("matplotlib").setLevel(logging.WARNING)
24
 
25
  logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
26
  logger = logging.getLogger(__name__)
27
- limitation = os.getenv("SYSTEM") == "spaces"
28
-
29
 
30
  def get_net_g(model_path: str, version: str, device: str, hps):
31
- # 当前版本模型 net_g
32
  net_g = SynthesizerTrn(
33
  len(symbols),
34
  hps.data.filter_length // 2 + 1,
@@ -42,7 +36,6 @@ def get_net_g(model_path: str, version: str, device: str, hps):
42
 
43
  def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
44
  style_text = None if style_text == "" else style_text
45
- # 在此处实现当前版本的get_text
46
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
47
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
48
  if hps.data.add_blank:
@@ -54,230 +47,113 @@ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7)
54
  word2ph[0] += 1
55
  bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
56
  del word2ph
57
-
58
- assert bert.shape[-1] == len(
59
- phone
60
- ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
61
-
62
  phone = torch.LongTensor(phone)
63
  tone = torch.LongTensor(tone)
64
  language = torch.LongTensor(language)
65
  return bert, phone, tone, language
66
 
 
 
 
67
 
68
- def infer(
69
- text,
70
- sdp_ratio,
71
- noise_scale,
72
- noise_scale_w,
73
- length_scale,
74
- sid,
75
- language,
76
- hps,
77
- net_g,
78
- device,
79
- emotion,
80
- reference_audio=None,
81
- skip_start=False,
82
- skip_end=False,
83
- style_text=None,
84
- style_weight=0.7,
85
- text_mode="Text",
86
- ):
87
- # 2.2版本参数位置变了
88
- # 2.1 参数新增 emotion reference_audio skip_start skip_end
89
- version = hps.version if hasattr(hps, "version") else latest_version
90
- language = "JP"
91
- if isinstance(reference_audio, np.ndarray):
92
- emo = get_clap_audio_feature(reference_audio, device)
93
- else:
94
- emo = get_clap_text_feature(emotion, device)
95
- emo = torch.squeeze(emo, dim=1)
96
-
97
- bert, phones, tones, lang_ids = get_text(
98
- text,
99
- language,
100
- hps,
101
- device,
102
- style_text=style_text,
103
- style_weight=style_weight,
104
- )
105
- if skip_start:
106
- phones = phones[3:]
107
- tones = tones[3:]
108
- lang_ids = lang_ids[3:]
109
- bert = bert[:, 3:]
110
- if skip_end:
111
- phones = phones[:-2]
112
- tones = tones[:-2]
113
- lang_ids = lang_ids[:-2]
114
- bert = bert[:, :-2]
115
- with torch.no_grad():
116
- x_tst = phones.to(device).unsqueeze(0)
117
- tones = tones.to(device).unsqueeze(0)
118
- lang_ids = lang_ids.to(device).unsqueeze(0)
119
- bert = bert.to(device).unsqueeze(0)
120
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
121
- emo = emo.to(device).unsqueeze(0)
122
- del phones
123
- spk2id_dict = {k: v for k, v in hps.data["spk2id"].items()}
124
 
125
- # ถ้า sid เป็น index (เช่น 0) → แปลงเป็นชื่อ
126
- if isinstance(sid, int) or sid.isdigit():
127
- sid_int = int(sid)
128
- name_map = {v: k for k, v in spk2id_dict.items()}
129
- if sid_int not in name_map:
130
- raise ValueError(f"Speaker index {sid_int} not found.")
131
- sid = name_map[sid_int]
 
 
 
 
 
 
 
 
 
 
 
132
  else:
133
- sid = str(sid).upper()
134
-
135
- if sid not in spk2id_dict:
136
- raise ValueError(f"Speaker ID '{sid}' not found. Available: {list(spk2id_dict.keys())}")
137
-
138
- speaker_id = spk2id_dict[sid]
139
- speakers = torch.LongTensor([speaker_id]).to(device)
140
- print(text)
141
- audio = (
142
- net_g.infer(
143
- x_tst,
144
- x_tst_lengths,
145
- speakers,
146
- tones,
147
- lang_ids,
148
- bert,
149
- emo,
150
- sdp_ratio=sdp_ratio,
151
- noise_scale=noise_scale,
152
- noise_scale_w=noise_scale_w,
153
- length_scale=length_scale,
154
- )[0][0, 0]
155
- .data.cpu()
156
- .float()
157
- .numpy()
158
- )
159
- del (
160
- x_tst,
161
- tones,
162
- lang_ids,
163
- bert,
164
- x_tst_lengths,
165
- speakers,
166
- emo,
167
- ) # , emo
168
- if torch.cuda.is_available():
169
- torch.cuda.empty_cache()
170
- return audio
171
-
172
 
173
- def create_tts_fn(net_g_ms, hps):
174
- def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale,language,
175
- reference_audio,
176
- emotion,
177
- prompt_mode,
178
- style_text=None,
179
- style_weight=0):
180
- print(f"{text} | {speaker}")
181
- sid = hps.data.spk2id[speaker]
182
- text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
183
- if limitation:
184
- max_len = 100
185
- if len(text) > max_len:
186
- return "Error: Text is too long", None
187
  audio = infer(
188
  text=text,
 
 
189
  sdp_ratio=sdp_ratio,
190
  noise_scale=noise_scale,
191
  noise_scale_w=noise_scale_w,
192
  length_scale=length_scale,
193
- sid=sid,
194
- language="JP", # หรือตามที่ user เลือก
195
  hps=hps,
196
- net_g=net_g_ms,
197
  device=device,
198
- emotion="neutral", # หรือตาม dropdown ที่ผู้ใช้เลือก
199
- reference_audio=None,
200
- skip_start=False,
201
- skip_end=False,
202
- style_text=None,
203
- style_weight=0.7,
204
- text_mode="Text"
205
  )
206
  return "Success", (hps.data.sampling_rate, audio)
207
  return tts_fn
208
 
209
-
210
  if __name__ == "__main__":
211
- device = (
212
- "cuda:0"
213
- if torch.cuda.is_available()
214
- else (
215
- "mps"
216
- if sys.platform == "darwin" and torch.backends.mps.is_available()
217
- else "cpu"
218
- )
219
- )
220
-
221
  parser = argparse.ArgumentParser()
222
  parser.add_argument("--share", default=False, help="make link public", action="store_true")
223
  parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
224
  args = parser.parse_args()
225
- if args.debug:
226
- logger.info("Enable DEBUG-LEVEL log")
227
- logging.basicConfig(level=logging.DEBUG)
228
 
229
- models = []
 
230
 
231
  with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
232
  models_info = json.load(f)
233
 
234
- # โหลดโมเดลทั้งหมดล่วงหน้า
235
- for i, info in models_info.items():
 
236
  if not info['enable']:
237
  continue
238
- name = info['name']
239
- title = info['title']
240
- link = info['link']
241
- example = info['example']
242
-
243
- print(f"🔄 Loading model: {name} from {link}")
244
  config_path = hf_hub_download(repo_id=link, filename="config.json")
245
  model_path = hf_hub_download(repo_id=link, filename=f"{name}.pth")
246
  hps = utils.get_hparams_from_file(config_path)
247
- version = hps.version if hasattr(hps, "version") else latest_version
248
- net_g_ms = get_net_g(model_path, version, device, hps)
249
- models.append((name, title, example, list(hps.data.spk2id.keys()), net_g_ms, create_tts_fn(net_g_ms, hps)))
 
250
 
251
- # ✅ Gradio UI แบบพร้อมใช้กับ Spaces
252
  with gr.Blocks(theme='NoCrypt/miku') as app:
253
  gr.Markdown("## ✅ All models loaded successfully. Ready to use.")
254
-
255
  with gr.Tabs():
256
- for (name, title, example, speakers, net_g_ms, tts_fn) in models:
257
- with gr.TabItem(name):
258
- with gr.Row():
259
- gr.Markdown(
260
- '<div align="center">'
261
- f'<a><strong>{title}</strong></a>'
262
- f'</div>'
263
- )
264
  with gr.Row():
265
  with gr.Column():
266
- input_text = gr.Textbox(label="Text (100 words limitation)" if limitation else "Text", lines=5, value=example)
267
- btn = gr.Button(value="Generate", variant="primary")
268
- with gr.Row():
269
- sp = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
270
- with gr.Row():
271
- sdpr = gr.Slider(label="SDP Ratio", minimum=0, maximum=1, step=0.1, value=0.2)
272
- ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6)
273
- nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.8)
274
- ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1)
275
- lang = gr.Dropdown(choices=["JP"], value=["JP"], label="Lanaguage")
276
- ref_a = gr.Audio(label="Upload your audio", type="filepath")
277
-
 
 
278
  with gr.Column():
279
- o1 = gr.Textbox(label="Output Message")
280
- o2 = gr.Audio(label="Output Audio")
281
- btn.click(tts_fn, inputs=[input_text, sp, sdpr, ns, nsw, ls, lang,ref_a], outputs=[o1, o2])
 
 
 
282
 
283
  app.queue().launch(share=args.share)
 
7
  import commons
8
  import utils
9
  import gradio as gr
10
+ import numpy as np
11
+ import librosa
12
+ import re_matching
13
+ from tools.sentence import split_by_language
14
  from huggingface_hub import hf_hub_download
15
 
 
16
  from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
17
  from models import SynthesizerTrn
18
  from text.symbols import symbols
19
  from text import cleaned_text_to_sequence, get_bert
20
  from text.cleaner import clean_text
 
 
 
 
 
 
21
 
22
  logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
23
  logger = logging.getLogger(__name__)
 
 
24
 
25
  def get_net_g(model_path: str, version: str, device: str, hps):
 
26
  net_g = SynthesizerTrn(
27
  len(symbols),
28
  hps.data.filter_length // 2 + 1,
 
36
 
37
  def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
38
  style_text = None if style_text == "" else style_text
 
39
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
40
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
41
  if hps.data.add_blank:
 
47
  word2ph[0] += 1
48
  bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
49
  del word2ph
50
+ assert bert.shape[-1] == len(phone)
 
 
 
 
51
  phone = torch.LongTensor(phone)
52
  tone = torch.LongTensor(tone)
53
  language = torch.LongTensor(language)
54
  return bert, phone, tone, language
55
 
56
+ def infer(*args, **kwargs):
57
+ from infer import infer as real_infer
58
+ return real_infer(*args, **kwargs)
59
 
60
+ def load_audio(path):
61
+ audio, sr = librosa.load(path, 48000)
62
+ return sr, audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ def gr_util(item):
65
+ if item == "Text prompt":
66
+ return {"visible": True, "__type__": "update"}, {"visible": False, "__type__": "update"}
67
+ else:
68
+ return {"visible": False, "__type__": "update"}, {"visible": True, "__type__": "update"}
69
+
70
+ def create_tts_fn(hps, net_g, device):
71
+ def tts_fn(
72
+ text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language,
73
+ reference_audio, emotion, prompt_mode, style_text, style_weight
74
+ ):
75
+ if style_text == "":
76
+ style_text = None
77
+ if prompt_mode == "Audio prompt":
78
+ if reference_audio is None:
79
+ return ("Invalid audio prompt", None)
80
+ else:
81
+ reference_audio = load_audio(reference_audio)[1]
82
  else:
83
+ reference_audio = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  audio = infer(
86
  text=text,
87
+ reference_audio=reference_audio,
88
+ emotion=emotion,
89
  sdp_ratio=sdp_ratio,
90
  noise_scale=noise_scale,
91
  noise_scale_w=noise_scale_w,
92
  length_scale=length_scale,
93
+ sid=speaker,
94
+ language=language,
95
  hps=hps,
96
+ net_g=net_g,
97
  device=device,
98
+ style_text=style_text,
99
+ style_weight=style_weight,
 
 
 
 
 
100
  )
101
  return "Success", (hps.data.sampling_rate, audio)
102
  return tts_fn
103
 
 
104
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
105
  parser = argparse.ArgumentParser()
106
  parser.add_argument("--share", default=False, help="make link public", action="store_true")
107
  parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
108
  args = parser.parse_args()
 
 
 
109
 
110
+ if args.debug:
111
+ logger.setLevel(logging.DEBUG)
112
 
113
  with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
114
  models_info = json.load(f)
115
 
116
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
117
+ models = []
118
+ for _, info in models_info.items():
119
  if not info['enable']:
120
  continue
121
+ name, title, link, example = info['name'], info['title'], info['link'], info['example']
 
 
 
 
 
122
  config_path = hf_hub_download(repo_id=link, filename="config.json")
123
  model_path = hf_hub_download(repo_id=link, filename=f"{name}.pth")
124
  hps = utils.get_hparams_from_file(config_path)
125
+ version = hps.version if hasattr(hps, "version") else "v2"
126
+ net_g = get_net_g(model_path, version, device, hps)
127
+ fn = create_tts_fn(hps, net_g, device)
128
+ models.append((title, example, list(hps.data.spk2id.keys()), fn))
129
 
 
130
  with gr.Blocks(theme='NoCrypt/miku') as app:
131
  gr.Markdown("## ✅ All models loaded successfully. Ready to use.")
 
132
  with gr.Tabs():
133
+ for (title, example, speakers, tts_fn) in models:
134
+ with gr.TabItem(title):
 
 
 
 
 
 
135
  with gr.Row():
136
  with gr.Column():
137
+ input_text = gr.Textbox(label="Input text", lines=5, value=example)
138
+ speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
139
+ prompt_mode = gr.Radio(["Text prompt", "Audio prompt"], label="Prompt Mode", value="Text prompt")
140
+ text_prompt = gr.Textbox(label="Text prompt", value="Happy", visible=True)
141
+ audio_prompt = gr.Audio(label="Audio prompt", type="filepath", visible=False)
142
+ sdp_ratio = gr.Slider(0, 1, 0.2, 0.1, label="SDP Ratio")
143
+ noise_scale = gr.Slider(0.1, 2.0, 0.6, 0.1, label="Noise")
144
+ noise_scale_w = gr.Slider(0.1, 2.0, 0.8, 0.1, label="Noise_W")
145
+ length_scale = gr.Slider(0.1, 2.0, 1.0, 0.1, label="Length")
146
+ language = gr.Dropdown(choices=["JP", "ZH", "EN", "mix", "auto"], value="JP", label="Language")
147
+ style_text = gr.Textbox(label="Style Text", placeholder="辅助文本 (留空为无)")
148
+ style_weight = gr.Slider(0, 1, 0.7, 0.1, label="Style Weight")
149
+ btn = gr.Button("Generate Audio", variant="primary")
150
+
151
  with gr.Column():
152
+ output_msg = gr.Textbox(label="Output Message")
153
+ output_audio = gr.Audio(label="Output Audio")
154
+
155
+ prompt_mode.change(lambda x: gr_util(x), inputs=[prompt_mode], outputs=[text_prompt, audio_prompt])
156
+ audio_prompt.upload(lambda x: load_audio(x), inputs=[audio_prompt], outputs=[audio_prompt])
157
+ btn.click(tts_fn, inputs=[input_text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language, audio_prompt, text_prompt, prompt_mode, style_text, style_weight], outputs=[output_msg, output_audio])
158
 
159
  app.queue().launch(share=args.share)