JotunnBurton commited on
Commit
cccafbc
·
verified ·
1 Parent(s): d5b3961

Update text/japanese_bert.py

Browse files
Files changed (1) hide show
  1. text/japanese_bert.py +15 -24
text/japanese_bert.py CHANGED
@@ -1,15 +1,15 @@
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForMaskedLM
3
- import sys
4
- import os
5
- from text.japanese import text2sep_kata
6
  from config import config
 
7
 
8
  MODEL_ID = "ku-nlp/deberta-v2-large-japanese-char-wwm"
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
10
 
11
  models = dict()
12
 
 
13
  def get_bert_feature(
14
  text,
15
  word2ph,
@@ -17,9 +17,7 @@ def get_bert_feature(
17
  style_text=None,
18
  style_weight=0.7,
19
  ):
20
- sep_text, _ = text2sep_kata(text)
21
- text = "".join(sep_text)
22
-
23
  if style_text:
24
  style_text = "".join(text2sep_kata(style_text)[0])
25
 
@@ -37,42 +35,35 @@ def get_bert_feature(
37
  models[device] = AutoModelForMaskedLM.from_pretrained(MODEL_ID).to(device)
38
 
39
  with torch.no_grad():
40
- # Tokenize text into subwords for correct alignment
41
- tokens = [tokenizer.tokenize(t) for t in sep_text]
42
- flat_tokens = [item for sublist in tokens for item in sublist]
43
- word2ph_token = [len(t) for t in tokens]
44
- word2ph_token = [1] + word2ph_token + [1] # Account for [CLS] and [SEP]
45
-
46
  inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
47
  for k in inputs:
48
  inputs[k] = inputs[k].to(device)
49
-
50
  res = models[device](**inputs, output_hidden_states=True)
51
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].float().cpu()
52
 
53
  if style_text:
54
- style_inputs = tokenizer(style_text, return_tensors="pt")
55
  for k in style_inputs:
56
  style_inputs[k] = style_inputs[k].to(device)
57
  style_res = models[device](**style_inputs, output_hidden_states=True)
58
  style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].float().cpu()
59
  style_res_mean = style_res.mean(0)
60
 
61
- if len(word2ph_token) != res.shape[0]:
62
- print(f"[ERROR] len(word2ph_token) = {len(word2ph_token)}, but BERT output = {res.shape[0]}")
63
- print(f"[DEBUG] input text: {text}")
64
- raise ValueError("Mismatch between tokenized word2ph and BERT output length.")
65
 
66
  phone_level_feature = []
67
- for i in range(len(word2ph_token)):
68
  if style_text:
69
- blended = (
70
- res[i].repeat(word2ph_token[i], 1) * (1 - style_weight)
71
- + style_res_mean.repeat(word2ph_token[i], 1) * style_weight
72
  )
73
  else:
74
- blended = res[i].repeat(word2ph_token[i], 1)
75
- phone_level_feature.append(blended)
76
 
77
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
78
  return phone_level_feature.T
 
1
+ import sys
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForMaskedLM
 
 
 
4
  from config import config
5
+ from text.japanese import text2sep_kata
6
 
7
  MODEL_ID = "ku-nlp/deberta-v2-large-japanese-char-wwm"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
9
 
10
  models = dict()
11
 
12
+
13
  def get_bert_feature(
14
  text,
15
  word2ph,
 
17
  style_text=None,
18
  style_weight=0.7,
19
  ):
20
+ text = "".join(text2sep_kata(text)[0])
 
 
21
  if style_text:
22
  style_text = "".join(text2sep_kata(style_text)[0])
23
 
 
35
  models[device] = AutoModelForMaskedLM.from_pretrained(MODEL_ID).to(device)
36
 
37
  with torch.no_grad():
 
 
 
 
 
 
38
  inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
39
  for k in inputs:
40
  inputs[k] = inputs[k].to(device)
 
41
  res = models[device](**inputs, output_hidden_states=True)
42
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].float().cpu()
43
 
44
  if style_text:
45
+ style_inputs = tokenizer(style_text, return_tensors="pt", add_special_tokens=True)
46
  for k in style_inputs:
47
  style_inputs[k] = style_inputs[k].to(device)
48
  style_res = models[device](**style_inputs, output_hidden_states=True)
49
  style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].float().cpu()
50
  style_res_mean = style_res.mean(0)
51
 
52
+ # Force truncate ให้ความยาวตรงกับ word2ph
53
+ min_len = min(len(word2ph), res.shape[0])
54
+ word2phone = word2ph[:min_len]
55
+ res = res[:min_len]
56
 
57
  phone_level_feature = []
58
+ for i in range(len(word2phone)):
59
  if style_text:
60
+ repeat_feature = (
61
+ res[i].repeat(word2phone[i], 1) * (1 - style_weight)
62
+ + style_res_mean.repeat(word2phone[i], 1) * style_weight
63
  )
64
  else:
65
+ repeat_feature = res[i].repeat(word2phone[i], 1)
66
+ phone_level_feature.append(repeat_feature)
67
 
68
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
69
  return phone_level_feature.T