File size: 13,504 Bytes
b537b52
6c3d54e
760e479
 
 
 
 
 
 
 
c3d3e4f
92259fe
 
 
 
8279a84
c3d3e4f
bc45e1c
bec57ee
 
b9de5dd
 
c3d3e4f
 
 
 
0fc12b1
 
 
 
 
 
 
 
 
 
 
 
3f3d5c8
 
c3d3e4f
 
 
 
 
 
 
 
 
3f3d5c8
c3d3e4f
92259fe
c3d3e4f
 
 
3f3d5c8
c3d3e4f
92259fe
 
 
760e479
fc7eb6d
 
 
b537b52
 
fc7eb6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b537b52
 
 
 
 
 
 
 
 
 
 
 
 
 
92259fe
 
 
d2afb6a
92259fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2afb6a
92259fe
c3d3e4f
d6827bc
 
92259fe
 
d6827bc
 
 
 
92259fe
 
d6827bc
92259fe
d6827bc
92259fe
 
c3d3e4f
 
 
b9de5dd
bfb8cef
 
7ad160a
 
bfb8cef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92307a4
bfb8cef
 
 
 
7ad160a
 
bfb8cef
 
 
b537b52
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
# βœ… Patched full version of app.py with isolated tts_split per model

import sys
import logging
import os
import json
import torch
import argparse
import commons
import utils
import gradio as gr
import numpy as np
import librosa
import re_matching
from tools.sentence import split_by_language
from huggingface_hub import hf_hub_download, list_repo_files

from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
from models import SynthesizerTrn
from text.symbols import symbols
from text import cleaned_text_to_sequence, get_bert
from text.cleaner import clean_text

logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)

def get_net_g(model_path: str, version: str, device: str, hps):
    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model,
    ).to(device)
    _ = net_g.eval()
    _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
    return net_g

def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
    style_text = None if style_text == "" else style_text
    norm_text, phone, tone, word2ph = clean_text(text, language_str)
    phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
    if hps.data.add_blank:
        phone = commons.intersperse(phone, 0)
        tone = commons.intersperse(tone, 0)
        language = commons.intersperse(language, 0)
        for i in range(len(word2ph)):
            word2ph[i] = word2ph[i] * 2
        word2ph[0] += 1
    bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
    del word2ph
    assert bert.shape[-1] == len(phone)
    phone = torch.LongTensor(phone)
    tone = torch.LongTensor(tone)
    language = torch.LongTensor(language)
    return bert, phone, tone, language

def infer(*args, **kwargs):
    from infer import infer as real_infer
    return real_infer(*args, **kwargs)

def tts_split(
    text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale,
    language, cut_by_sent, interval_between_para, interval_between_sent,
    reference_audio, emotion, style_text, style_weight,
    hps, net_g, device
):
    if style_text == "":
        style_text = None
    if language == "mix":
        return ("'mix' not supported in this simplified split function", None)
    while text.find("\n\n") != -1:
        text = text.replace("\n\n", "\n")
    para_list = re_matching.cut_para(text)
    audio_list = []
    with torch.no_grad():
        if cut_by_sent:
            for pidx, p in enumerate(para_list):
                sent_list = re_matching.cut_sent(p)
                for sidx, s in enumerate(sent_list):
                    skip_start = not (pidx == 0 and sidx == 0)
                    skip_end = not (pidx == len(para_list) - 1 and sidx == len(sent_list) - 1)
                    audio = infer(
                        s,
                        reference_audio=reference_audio,
                        emotion=emotion,
                        sdp_ratio=sdp_ratio,
                        noise_scale=noise_scale,
                        noise_scale_w=noise_scale_w,
                        length_scale=length_scale,
                        sid=speaker,
                        language=language,
                        hps=hps,
                        net_g=net_g,
                        device=device,
                        style_text=style_text,
                        style_weight=style_weight,
                        skip_start=skip_start,
                        skip_end=skip_end,
                    )
                    audio_list.append(audio)
                    audio_list.append(np.zeros((int)(hps.data.sampling_rate * interval_between_sent), dtype=np.int16))
                if (interval_between_para - interval_between_sent) > 0:
                    audio_list.append(np.zeros((int)(hps.data.sampling_rate * (interval_between_para - interval_between_sent)), dtype=np.int16))
        else:
            for idx, p in enumerate(para_list):
                skip_start = idx != 0
                skip_end = idx != len(para_list) - 1
                audio = infer(
                    p,
                    reference_audio=reference_audio,
                    emotion=emotion,
                    sdp_ratio=sdp_ratio,
                    noise_scale=noise_scale,
                    noise_scale_w=noise_scale_w,
                    length_scale=length_scale,
                    sid=speaker,
                    language=language,
                    hps=hps,
                    net_g=net_g,
                    device=device,
                    style_text=style_text,
                    style_weight=style_weight,
                    skip_start=skip_start,
                    skip_end=skip_end,
                )
                audio_list.append(audio)
                audio_list.append(np.zeros((int)(hps.data.sampling_rate * interval_between_para), dtype=np.int16))
    final_audio = np.concatenate(audio_list)
    return "Success", (hps.data.sampling_rate, final_audio)

def create_split_fn(hps, net_g, device):
    def split_fn(
        text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale,
        language, cut_by_sent, interval_between_para, interval_between_sent,
        reference_audio, emotion, style_text, style_weight
    ):
        return tts_split(
            text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale,
            language, cut_by_sent, interval_between_para, interval_between_sent,
            reference_audio, emotion, style_text, style_weight,
            hps=hps, net_g=net_g, device=device
        )
    return split_fn

def load_audio(path):
    audio, sr = librosa.load(path, 48000)
    return sr, audio

def gr_util(item):
    if item == "Text prompt":
        return {"visible": True, "__type__": "update"}, {"visible": False, "__type__": "update"}
    else:
        return {"visible": False, "__type__": "update"}, {"visible": True, "__type__": "update"}

def create_tts_fn(hps, net_g, device):
    def tts_fn(
        text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language,
        reference_audio, emotion, prompt_mode, style_text, style_weight
    ):
        if style_text == "":
            style_text = None
        if prompt_mode == "Audio prompt":
            if reference_audio is None:
                return ("Invalid audio prompt", None)
            else:
                reference_audio = load_audio(reference_audio)[1]
        else:
            reference_audio = None

        audio = infer(
            text=text,
            reference_audio=reference_audio,
            emotion=emotion,
            sdp_ratio=sdp_ratio,
            noise_scale=noise_scale,
            noise_scale_w=noise_scale_w,
            length_scale=length_scale,
            sid=speaker,
            language=language,
            hps=hps,
            net_g=net_g,
            device=device,
            style_text=style_text,
            style_weight=style_weight,
        )
        return "Success", (hps.data.sampling_rate, audio)
    return tts_fn


# Function to build a single tab per model
def create_tab(name,title, example, speakers, tts_fn, split_fn, repid):
    with gr.TabItem(name):
        gr.Markdown(
            '<div align="center">'
            f'<a><strong>{repid}</strong></a>'
            f'<br>'
            f'<a><strong>{title}</strong></a>'
            f'</div>'
        )
        with gr.Row():
            with gr.Column():
                input_text = gr.Textbox(label="Input text", lines=5, value=example)
                speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
                prompt_mode = gr.Radio(["Text prompt", "Audio prompt"], label="Prompt Mode", value="Text prompt")
                text_prompt = gr.Textbox(label="Text prompt", value="Happy", visible=True)
                audio_prompt = gr.Audio(label="Audio prompt", type="filepath", visible=False)
                sdp_ratio = gr.Slider(0, 1, 0.2, 0.1, label="SDP Ratio")
                noise_scale = gr.Slider(0.1, 2.0, 0.6, 0.1, label="Noise")
                noise_scale_w = gr.Slider(0.1, 2.0, 0.8, 0.1, label="Noise_W")
                length_scale = gr.Slider(0.1, 2.0, 1.0, 0.1, label="Length")
                language = gr.Dropdown(choices=["JP", "ZH", "EN", "mix", "auto"], value="JP", label="Language")
                btn = gr.Button("Generate Audio", variant="primary")

            with gr.Column():
                with gr.Accordion("Semantic Fusion", open=False):
                    gr.Markdown(
                        value="Use auxiliary text semantics to assist speech generation (language remains same as main text)\n\n"
                              "**Note**: Avoid using *command-style text* (e.g., 'Happy'). Use *emotionally rich text* (e.g., 'I'm so happy!!!')\n\n"
                              "Leave it blank to disable. \n\n"
                              "**If mispronunciations occur, try replacing characters and inputting the original here with weight set to 1.0 for semantic retention.**"
                    )
                    style_text = gr.Textbox(label="Auxiliary Text")
                    style_weight = gr.Slider(0, 1, 0.7, 0.1, label="Weight", info="Ratio between main and auxiliary BERT embeddings")

                with gr.Row():
                    with gr.Column():
                        interval_between_sent = gr.Slider(0, 5, 0.2, 0.1, label="Pause between sentences (sec)")
                        interval_between_para = gr.Slider(0, 10, 1, 0.1, label="Pause between paragraphs (sec)")
                        opt_cut_by_sent = gr.Checkbox(label="Split by sentence")
                        slicer = gr.Button("Split and Generate", variant="primary")

            with gr.Column():
                output_msg = gr.Textbox(label="Output Message")
                output_audio = gr.Audio(label="Output Audio")

        prompt_mode.change(lambda x: gr_util(x), inputs=[prompt_mode], outputs=[text_prompt, audio_prompt])
        audio_prompt.upload(lambda x: load_audio(x), inputs=[audio_prompt], outputs=[audio_prompt])
        btn.click(
            tts_fn,
            inputs=[
                input_text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language,
                audio_prompt, text_prompt, prompt_mode, style_text, style_weight
            ],
            outputs=[output_msg, output_audio],
        )
        slicer.click(
            split_fn,
            inputs=[
                input_text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language,
                opt_cut_by_sent, interval_between_para, interval_between_sent,
                audio_prompt, text_prompt, style_text, style_weight
            ],
            outputs=[output_msg, output_audio],
        )

# --- Main entry point ---
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", default=False, help="make link public", action="store_true")
    parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
    args = parser.parse_args()

    if args.debug:
        logger.setLevel(logging.DEBUG)

    with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
        models_info = json.load(f)

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    models = []
    for _, info in models_info.items():
        if not info['enable']:
            continue
        name, title, repid, example, filename = info['name'], info['title'], info['repid'], info['example'], info['filename']

        files = list_repo_files(repo_id=repid)
        model_subfolder = None
        for f in files:
            if f.endswith(filename):
                parts = f.split("/")
                if len(parts) > 1:
                    model_subfolder = "/".join(parts[:-1])
                break

        if model_subfolder:
            model_path = hf_hub_download(repo_id=repid, filename=filename, subfolder=model_subfolder)
            config_path = hf_hub_download(repo_id=repid, filename="config.json", subfolder=model_subfolder)
        else:
            model_path = hf_hub_download(repo_id=repid, filename=filename)
            config_path = hf_hub_download(repo_id=repid, filename="config.json")

        hps = utils.get_hparams_from_file(config_path)
        version = hps.version if hasattr(hps, "version") else "v2"
        net_g = get_net_g(model_path, version, device, hps)
        tts_fn = create_tts_fn(hps, net_g, device)
        split_fn = create_split_fn(hps, net_g, device)
        models.append((name,title, example, list(hps.data.spk2id.keys()), tts_fn, split_fn, repid))

    with gr.Blocks(theme='NoCrypt/miku') as app:
        gr.Markdown("## βœ… All models loaded successfully. Ready to use.")
        with gr.Tabs():
            for (name,title, example, speakers, tts_fn, split_fn, repid) in models:
                create_tab(name,title, example, speakers, tts_fn, split_fn, repid)

    app.queue().launch(share=args.share)

# Then patch create_tab to accept split_fn and use it in slicer.click
# And in the model loop, generate both tts_fn and split_fn then pass both into create_tab
# (Same as your current setup but now split_fn is isolated per model just like tts_fn)