import sys import logging import os import json import torch import argparse import commons import utils import gradio as gr from huggingface_hub import hf_hub_download from clap_wrapper import get_clap_audio_feature, get_clap_text_feature from models import SynthesizerTrn from text.symbols import symbols from text import cleaned_text_to_sequence, get_bert from text.cleaner import clean_text import numpy as np logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("markdown_it").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.WARNING) logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s") logger = logging.getLogger(__name__) limitation = os.getenv("SYSTEM") == "spaces" def get_net_g(model_path: str, version: str, device: str, hps): # 当前版本模型 net_g net_g = SynthesizerTrn( len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model, ).to(device) _ = net_g.eval() _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True) return net_g def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7): style_text = None if style_text == "" else style_text # 在此处实现当前版本的get_text norm_text, phone, tone, word2ph = clean_text(text, language_str) phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) if hps.data.add_blank: phone = commons.intersperse(phone, 0) tone = commons.intersperse(tone, 0) language = commons.intersperse(language, 0) for i in range(len(word2ph)): word2ph[i] = word2ph[i] * 2 word2ph[0] += 1 bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight) del word2ph assert bert.shape[-1] == len( phone ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" phone = torch.LongTensor(phone) tone = torch.LongTensor(tone) language = torch.LongTensor(language) return bert, phone, tone, language def infer( text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language, hps, net_g, device, emotion, reference_audio=None, skip_start=False, skip_end=False, style_text=None, style_weight=0.7, text_mode="Text", ): # 2.2版本参数位置变了 # 2.1 参数新增 emotion reference_audio skip_start skip_end version = hps.version if hasattr(hps, "version") else latest_version language = "JP" if isinstance(reference_audio, np.ndarray): emo = get_clap_audio_feature(reference_audio, device) else: emo = get_clap_text_feature(emotion, device) emo = torch.squeeze(emo, dim=1) bert, phones, tones, lang_ids = get_text( text, language, hps, device, style_text=style_text, style_weight=style_weight, ) if skip_start: phones = phones[3:] tones = tones[3:] lang_ids = lang_ids[3:] bert = bert[:, 3:] if skip_end: phones = phones[:-2] tones = tones[:-2] lang_ids = lang_ids[:-2] bert = bert[:, :-2] with torch.no_grad(): x_tst = phones.to(device).unsqueeze(0) tones = tones.to(device).unsqueeze(0) lang_ids = lang_ids.to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) emo = emo.to(device).unsqueeze(0) del phones spk2id_dict = {k: v for k, v in hps.data["spk2id"].items()} # ถ้า sid เป็น index (เช่น 0) → แปลงเป็นชื่อ if isinstance(sid, int) or sid.isdigit(): sid_int = int(sid) name_map = {v: k for k, v in spk2id_dict.items()} if sid_int not in name_map: raise ValueError(f"Speaker index {sid_int} not found.") sid = name_map[sid_int] else: sid = str(sid).upper() if sid not in spk2id_dict: raise ValueError(f"Speaker ID '{sid}' not found. Available: {list(spk2id_dict.keys())}") speaker_id = spk2id_dict[sid] speakers = torch.LongTensor([speaker_id]).to(device) print(text) audio = ( net_g.infer( x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, emo, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, )[0][0, 0] .data.cpu() .float() .numpy() ) del ( x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, emo, ) # , emo if torch.cuda.is_available(): torch.cuda.empty_cache() return audio def create_tts_fn(net_g_ms, hps): def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale,language, reference_audio, emotion, prompt_mode, style_text=None, style_weight=0): print(f"{text} | {speaker}") sid = hps.data.spk2id[speaker] text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") if limitation: max_len = 100 if len(text) > max_len: return "Error: Text is too long", None audio = infer( text=text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=sid, language="JP", # หรือตามที่ user เลือก hps=hps, net_g=net_g_ms, device=device, emotion="neutral", # หรือตาม dropdown ที่ผู้ใช้เลือก reference_audio=None, skip_start=False, skip_end=False, style_text=None, style_weight=0.7, text_mode="Text" ) return "Success", (hps.data.sampling_rate, audio) return tts_fn if __name__ == "__main__": device = ( "cuda:0" if torch.cuda.is_available() else ( "mps" if sys.platform == "darwin" and torch.backends.mps.is_available() else "cpu" ) ) parser = argparse.ArgumentParser() parser.add_argument("--share", default=False, help="make link public", action="store_true") parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log") args = parser.parse_args() if args.debug: logger.info("Enable DEBUG-LEVEL log") logging.basicConfig(level=logging.DEBUG) models = [] with open("pretrained_models/info.json", "r", encoding="utf-8") as f: models_info = json.load(f) # ✅ โหลดโมเดลทั้งหมดล่วงหน้า for i, info in models_info.items(): if not info['enable']: continue name = info['name'] title = info['title'] link = info['link'] example = info['example'] print(f"🔄 Loading model: {name} from {link}") config_path = hf_hub_download(repo_id=link, filename="config.json") model_path = hf_hub_download(repo_id=link, filename=f"{name}.pth") hps = utils.get_hparams_from_file(config_path) version = hps.version if hasattr(hps, "version") else latest_version net_g_ms = get_net_g(model_path, version, device, hps) models.append((name, title, example, list(hps.data.spk2id.keys()), net_g_ms, create_tts_fn(net_g_ms, hps))) # ✅ Gradio UI แบบพร้อมใช้กับ Spaces with gr.Blocks(theme='NoCrypt/miku') as app: gr.Markdown("## ✅ All models loaded successfully. Ready to use.") with gr.Tabs(): for (name, title, example, speakers, net_g_ms, tts_fn) in models: with gr.TabItem(name): with gr.Row(): gr.Markdown( '
' f'{title}' f'
' ) with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Text (100 words limitation)" if limitation else "Text", lines=5, value=example) btn = gr.Button(value="Generate", variant="primary") with gr.Row(): sp = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker") with gr.Row(): sdpr = gr.Slider(label="SDP Ratio", minimum=0, maximum=1, step=0.1, value=0.2) ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6) nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.8) ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1) lang = gr.Dropdown(choices=["JP"], value=["JP"], label="Lanaguage") ref_a = gr.Audio(label="Upload your audio", type="filepath") with gr.Column(): o1 = gr.Textbox(label="Output Message") o2 = gr.Audio(label="Output Audio") btn.click(tts_fn, inputs=[input_text, sp, sdpr, ns, nsw, ls, lang,ref_a], outputs=[o1, o2]) app.queue().launch(share=args.share)