Spaces:
Sleeping
Sleeping
Update text/japanese_bert.py
Browse files- text/japanese_bert.py +42 -57
text/japanese_bert.py
CHANGED
@@ -10,7 +10,6 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
10 |
|
11 |
models = dict()
|
12 |
|
13 |
-
|
14 |
def get_bert_feature(
|
15 |
text,
|
16 |
word2ph,
|
@@ -19,75 +18,61 @@ def get_bert_feature(
|
|
19 |
style_weight=0.7,
|
20 |
):
|
21 |
sep_text, _ = text2sep_kata(text)
|
22 |
-
|
23 |
-
sep_ids = [tokenizer.convert_tokens_to_ids(t) for t in sep_tokens]
|
24 |
-
sep_ids = [2] + [item for sublist in sep_ids for item in sublist] + [3]
|
25 |
|
26 |
-
style_ids = None
|
27 |
if style_text:
|
28 |
-
|
29 |
-
style_tokens = [tokenizer.tokenize(t) for t in sep_style_text]
|
30 |
-
style_ids = [tokenizer.convert_tokens_to_ids(t) for t in style_tokens]
|
31 |
-
style_ids = [2] + [item for sublist in style_ids for item in sublist] + [3]
|
32 |
-
|
33 |
-
return get_bert_feature_with_token(
|
34 |
-
sep_ids, word2ph, device, style_ids, style_weight
|
35 |
-
)
|
36 |
-
|
37 |
|
38 |
-
|
39 |
-
if (
|
40 |
-
sys.platform == "darwin"
|
41 |
-
and torch.backends.mps.is_available()
|
42 |
-
and device == "cpu"
|
43 |
-
):
|
44 |
device = "mps"
|
45 |
if not device:
|
46 |
device = "cuda"
|
47 |
-
if device not in models.keys():
|
48 |
-
models[device] = AutoModelForMaskedLM.from_pretrained(MODEL_ID).to(device)
|
49 |
-
|
50 |
-
def encode(tokens_):
|
51 |
-
inputs = torch.tensor(tokens_).to(device).unsqueeze(0)
|
52 |
-
token_type_ids = torch.zeros_like(inputs).to(device)
|
53 |
-
attention_mask = torch.ones_like(inputs).to(device)
|
54 |
-
inputs = {
|
55 |
-
"input_ids": inputs,
|
56 |
-
"token_type_ids": token_type_ids,
|
57 |
-
"attention_mask": attention_mask,
|
58 |
-
}
|
59 |
-
with torch.no_grad():
|
60 |
-
res = models[device](**inputs, output_hidden_states=True)
|
61 |
-
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
62 |
-
return res, inputs["input_ids"].shape[-1]
|
63 |
-
|
64 |
-
res, main_len = encode(tokens)
|
65 |
|
66 |
-
if
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
phone_level_feature = []
|
77 |
-
for i in range(len(
|
78 |
-
if
|
79 |
blended = (
|
80 |
-
res[i].repeat(
|
81 |
-
+ style_res_mean.repeat(
|
82 |
)
|
83 |
else:
|
84 |
-
blended = res[i].repeat(
|
85 |
phone_level_feature.append(blended)
|
86 |
|
87 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
88 |
-
|
89 |
return phone_level_feature.T
|
90 |
-
|
91 |
-
|
92 |
-
if __name__ == "__main__":
|
93 |
-
print(get_bert_feature("観覧車", [4, 2]))
|
|
|
10 |
|
11 |
models = dict()
|
12 |
|
|
|
13 |
def get_bert_feature(
|
14 |
text,
|
15 |
word2ph,
|
|
|
18 |
style_weight=0.7,
|
19 |
):
|
20 |
sep_text, _ = text2sep_kata(text)
|
21 |
+
text = "".join(sep_text)
|
|
|
|
|
22 |
|
|
|
23 |
if style_text:
|
24 |
+
style_text = "".join(text2sep_kata(style_text)[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
if sys.platform == "darwin" and torch.backends.mps.is_available() and device == "cpu":
|
|
|
|
|
|
|
|
|
|
|
27 |
device = "mps"
|
28 |
if not device:
|
29 |
device = "cuda"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
if device not in models:
|
32 |
+
if config.webui_config.fp16_run:
|
33 |
+
models[device] = AutoModelForMaskedLM.from_pretrained(
|
34 |
+
MODEL_ID, torch_dtype=torch.float16
|
35 |
+
).to(device)
|
36 |
+
else:
|
37 |
+
models[device] = AutoModelForMaskedLM.from_pretrained(MODEL_ID).to(device)
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
# Tokenize text into subwords for correct alignment
|
41 |
+
tokens = [tokenizer.tokenize(t) for t in sep_text]
|
42 |
+
flat_tokens = [item for sublist in tokens for item in sublist]
|
43 |
+
word2ph_token = [len(t) for t in tokens]
|
44 |
+
word2ph_token = [1] + word2ph_token + [1] # Account for [CLS] and [SEP]
|
45 |
+
|
46 |
+
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
|
47 |
+
for k in inputs:
|
48 |
+
inputs[k] = inputs[k].to(device)
|
49 |
+
|
50 |
+
res = models[device](**inputs, output_hidden_states=True)
|
51 |
+
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].float().cpu()
|
52 |
+
|
53 |
+
if style_text:
|
54 |
+
style_inputs = tokenizer(style_text, return_tensors="pt")
|
55 |
+
for k in style_inputs:
|
56 |
+
style_inputs[k] = style_inputs[k].to(device)
|
57 |
+
style_res = models[device](**style_inputs, output_hidden_states=True)
|
58 |
+
style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].float().cpu()
|
59 |
+
style_res_mean = style_res.mean(0)
|
60 |
+
|
61 |
+
if len(word2ph_token) != res.shape[0]:
|
62 |
+
print(f"[ERROR] len(word2ph_token) = {len(word2ph_token)}, but BERT output = {res.shape[0]}")
|
63 |
+
print(f"[DEBUG] input text: {text}")
|
64 |
+
raise ValueError("Mismatch between tokenized word2ph and BERT output length.")
|
65 |
|
66 |
phone_level_feature = []
|
67 |
+
for i in range(len(word2ph_token)):
|
68 |
+
if style_text:
|
69 |
blended = (
|
70 |
+
res[i].repeat(word2ph_token[i], 1) * (1 - style_weight)
|
71 |
+
+ style_res_mean.repeat(word2ph_token[i], 1) * style_weight
|
72 |
)
|
73 |
else:
|
74 |
+
blended = res[i].repeat(word2ph_token[i], 1)
|
75 |
phone_level_feature.append(blended)
|
76 |
|
77 |
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
|
78 |
return phone_level_feature.T
|
|
|
|
|
|
|
|