wuwa-bert-vits2 / app.py
JotunnBurton's picture
Update app.py
3f3d5c8 verified
raw
history blame
9.47 kB
import sys
import logging
import os
import json
import torch
import argparse
import commons
import utils
import gradio as gr
from huggingface_hub import hf_hub_download
from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
from models import SynthesizerTrn
from text.symbols import symbols
from text import cleaned_text_to_sequence, get_bert
from text.cleaner import clean_text
import numpy as np
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
limitation = os.getenv("SYSTEM") == "spaces"
def get_net_g(model_path: str, version: str, device: str, hps):
# 当前版本模型 net_g
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
_ = net_g.eval()
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
return net_g
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
style_text = None if style_text == "" else style_text
# 在此处实现当前版本的get_text
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
del word2ph
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, phone, tone, language
def infer(
text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
emotion,
reference_audio=None,
skip_start=False,
skip_end=False,
style_text=None,
style_weight=0.7,
text_mode="Text",
):
# 2.2版本参数位置变了
# 2.1 参数新增 emotion reference_audio skip_start skip_end
version = hps.version if hasattr(hps, "version") else latest_version
language = "JP"
if isinstance(reference_audio, np.ndarray):
emo = get_clap_audio_feature(reference_audio, device)
else:
emo = get_clap_text_feature(emotion, device)
emo = torch.squeeze(emo, dim=1)
bert, phones, tones, lang_ids = get_text(
text,
language,
hps,
device,
style_text=style_text,
style_weight=style_weight,
)
if skip_start:
phones = phones[3:]
tones = tones[3:]
lang_ids = lang_ids[3:]
bert = bert[:, 3:]
if skip_end:
phones = phones[:-2]
tones = tones[:-2]
lang_ids = lang_ids[:-2]
bert = bert[:, :-2]
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
emo = emo.to(device).unsqueeze(0)
del phones
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
print(text)
audio = (
net_g.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
emo,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
del (
x_tst,
tones,
lang_ids,
bert,
x_tst_lengths,
speakers,
emo,
) # , emo
if torch.cuda.is_available():
torch.cuda.empty_cache()
return audio
def create_tts_fn(net_g_ms, hps):
def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale,language,
reference_audio,
emotion,
prompt_mode,
style_text=None,
style_weight=0):
print(f"{text} | {speaker}")
sid = hps.data.spk2id[speaker]
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if limitation:
max_len = 100
if len(text) > max_len:
return "Error: Text is too long", None
audio = infer(
text=text,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=sid,
language="JP", # หรือตามที่ user เลือก
hps=hps,
net_g=net_g_ms,
device=device,
emotion="neutral", # หรือตาม dropdown ที่ผู้ใช้เลือก
reference_audio=None,
skip_start=False,
skip_end=False,
style_text=None,
style_weight=0.7,
text_mode="Text"
)
return "Success", (hps.data.sampling_rate, audio)
return tts_fn
if __name__ == "__main__":
device = (
"cuda:0"
if torch.cuda.is_available()
else (
"mps"
if sys.platform == "darwin" and torch.backends.mps.is_available()
else "cpu"
)
)
parser = argparse.ArgumentParser()
parser.add_argument("--share", default=False, help="make link public", action="store_true")
parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
args = parser.parse_args()
if args.debug:
logger.info("Enable DEBUG-LEVEL log")
logging.basicConfig(level=logging.DEBUG)
models = []
with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
models_info = json.load(f)
# ✅ โหลดโมเดลทั้งหมดล่วงหน้า
for i, info in models_info.items():
if not info['enable']:
continue
name = info['name']
title = info['title']
link = info['link']
example = info['example']
print(f"🔄 Loading model: {name} from {link}")
config_path = hf_hub_download(repo_id=link, filename="config.json")
model_path = hf_hub_download(repo_id=link, filename=f"{name}.pth")
hps = utils.get_hparams_from_file(config_path)
version = hps.version if hasattr(hps, "version") else latest_version
net_g_ms = get_net_g(model_path, version, device, hps)
models.append((name, title, example, list(hps.data.spk2id.keys()), net_g_ms, create_tts_fn(net_g_ms, hps)))
# ✅ Gradio UI แบบพร้อมใช้กับ Spaces
with gr.Blocks(theme='NoCrypt/miku') as app:
gr.Markdown("## ✅ All models loaded successfully. Ready to use.")
with gr.Tabs():
for (name, title, example, speakers, net_g_ms, tts_fn) in models:
with gr.TabItem(name):
with gr.Row():
gr.Markdown(
'<div align="center">'
f'<a><strong>{title}</strong></a>'
f'</div>'
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Text (100 words limitation)" if limitation else "Text", lines=5, value=example)
btn = gr.Button(value="Generate", variant="primary")
with gr.Row():
sp = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
with gr.Row():
sdpr = gr.Slider(label="SDP Ratio", minimum=0, maximum=1, step=0.1, value=0.2)
ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6)
nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.8)
ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1)
lang = gr.Dropdown(choices=["JP"], value=["JP"], label="Lanaguage")
ref_a = gr.Audio(label="Upload your audio", type="filepath")
with gr.Column():
o1 = gr.Textbox(label="Output Message")
o2 = gr.Audio(label="Output Audio")
btn.click(tts_fn, inputs=[input_text, sp, sdpr, ns, nsw, ls, lang,ref_a], outputs=[o1, o2])
app.queue().launch(share=args.share)