Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -40,11 +40,11 @@ def get_net_g(model_path: str, version: str, device: str, hps):
|
|
40 |
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
41 |
return net_g
|
42 |
|
43 |
-
def get_text(text, hps):
|
44 |
-
|
|
|
45 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
46 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
47 |
-
|
48 |
if hps.data.add_blank:
|
49 |
phone = commons.intersperse(phone, 0)
|
50 |
tone = commons.intersperse(tone, 0)
|
@@ -52,18 +52,17 @@ def get_text(text, hps):
|
|
52 |
for i in range(len(word2ph)):
|
53 |
word2ph[i] = word2ph[i] * 2
|
54 |
word2ph[0] += 1
|
55 |
-
|
56 |
-
bert = get_bert(norm_text, word2ph, language_str, device)
|
57 |
del word2ph
|
58 |
-
assert bert.shape[-1] == len(phone), phone
|
59 |
|
60 |
-
|
61 |
-
|
|
|
62 |
|
63 |
phone = torch.LongTensor(phone)
|
64 |
tone = torch.LongTensor(tone)
|
65 |
language = torch.LongTensor(language)
|
66 |
-
return bert,
|
67 |
|
68 |
|
69 |
def infer(
|
|
|
40 |
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
41 |
return net_g
|
42 |
|
43 |
+
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
|
44 |
+
style_text = None if style_text == "" else style_text
|
45 |
+
# 在此处实现当前版本的get_text
|
46 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
47 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
|
|
48 |
if hps.data.add_blank:
|
49 |
phone = commons.intersperse(phone, 0)
|
50 |
tone = commons.intersperse(tone, 0)
|
|
|
52 |
for i in range(len(word2ph)):
|
53 |
word2ph[i] = word2ph[i] * 2
|
54 |
word2ph[0] += 1
|
55 |
+
bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
|
|
|
56 |
del word2ph
|
|
|
57 |
|
58 |
+
assert bert.shape[-1] == len(
|
59 |
+
phone
|
60 |
+
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
|
61 |
|
62 |
phone = torch.LongTensor(phone)
|
63 |
tone = torch.LongTensor(tone)
|
64 |
language = torch.LongTensor(language)
|
65 |
+
return bert, phone, tone, language
|
66 |
|
67 |
|
68 |
def infer(
|