JotunnBurton commited on
Commit
d5b3961
·
verified ·
1 Parent(s): d3907c0

Update text/japanese_bert.py

Browse files
Files changed (1) hide show
  1. text/japanese_bert.py +42 -57
text/japanese_bert.py CHANGED
@@ -10,7 +10,6 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
10
 
11
  models = dict()
12
 
13
-
14
  def get_bert_feature(
15
  text,
16
  word2ph,
@@ -19,75 +18,61 @@ def get_bert_feature(
19
  style_weight=0.7,
20
  ):
21
  sep_text, _ = text2sep_kata(text)
22
- sep_tokens = [tokenizer.tokenize(t) for t in sep_text]
23
- sep_ids = [tokenizer.convert_tokens_to_ids(t) for t in sep_tokens]
24
- sep_ids = [2] + [item for sublist in sep_ids for item in sublist] + [3]
25
 
26
- style_ids = None
27
  if style_text:
28
- sep_style_text, _ = text2sep_kata(style_text)
29
- style_tokens = [tokenizer.tokenize(t) for t in sep_style_text]
30
- style_ids = [tokenizer.convert_tokens_to_ids(t) for t in style_tokens]
31
- style_ids = [2] + [item for sublist in style_ids for item in sublist] + [3]
32
-
33
- return get_bert_feature_with_token(
34
- sep_ids, word2ph, device, style_ids, style_weight
35
- )
36
-
37
 
38
- def get_bert_feature_with_token(tokens, word2ph, device=None, style_tokens=None, style_weight=0.7):
39
- if (
40
- sys.platform == "darwin"
41
- and torch.backends.mps.is_available()
42
- and device == "cpu"
43
- ):
44
  device = "mps"
45
  if not device:
46
  device = "cuda"
47
- if device not in models.keys():
48
- models[device] = AutoModelForMaskedLM.from_pretrained(MODEL_ID).to(device)
49
-
50
- def encode(tokens_):
51
- inputs = torch.tensor(tokens_).to(device).unsqueeze(0)
52
- token_type_ids = torch.zeros_like(inputs).to(device)
53
- attention_mask = torch.ones_like(inputs).to(device)
54
- inputs = {
55
- "input_ids": inputs,
56
- "token_type_ids": token_type_ids,
57
- "attention_mask": attention_mask,
58
- }
59
- with torch.no_grad():
60
- res = models[device](**inputs, output_hidden_states=True)
61
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
62
- return res, inputs["input_ids"].shape[-1]
63
-
64
- res, main_len = encode(tokens)
65
 
66
- if main_len != len(word2ph):
67
- print(">> DEBUG length mismatch:")
68
- print("token len:", main_len)
69
- print("word2ph len:", len(word2ph))
70
- raise ValueError("Mismatch between token length and word2ph length.")
71
-
72
- if style_tokens:
73
- style_res, _ = encode(style_tokens)
74
- style_res_mean = style_res.mean(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  phone_level_feature = []
77
- for i in range(len(word2ph)):
78
- if style_tokens:
79
  blended = (
80
- res[i].repeat(word2ph[i], 1) * (1 - style_weight)
81
- + style_res_mean.repeat(word2ph[i], 1) * style_weight
82
  )
83
  else:
84
- blended = res[i].repeat(word2ph[i], 1)
85
  phone_level_feature.append(blended)
86
 
87
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
88
-
89
  return phone_level_feature.T
90
-
91
-
92
- if __name__ == "__main__":
93
- print(get_bert_feature("観覧車", [4, 2]))
 
10
 
11
  models = dict()
12
 
 
13
  def get_bert_feature(
14
  text,
15
  word2ph,
 
18
  style_weight=0.7,
19
  ):
20
  sep_text, _ = text2sep_kata(text)
21
+ text = "".join(sep_text)
 
 
22
 
 
23
  if style_text:
24
+ style_text = "".join(text2sep_kata(style_text)[0])
 
 
 
 
 
 
 
 
25
 
26
+ if sys.platform == "darwin" and torch.backends.mps.is_available() and device == "cpu":
 
 
 
 
 
27
  device = "mps"
28
  if not device:
29
  device = "cuda"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ if device not in models:
32
+ if config.webui_config.fp16_run:
33
+ models[device] = AutoModelForMaskedLM.from_pretrained(
34
+ MODEL_ID, torch_dtype=torch.float16
35
+ ).to(device)
36
+ else:
37
+ models[device] = AutoModelForMaskedLM.from_pretrained(MODEL_ID).to(device)
38
+
39
+ with torch.no_grad():
40
+ # Tokenize text into subwords for correct alignment
41
+ tokens = [tokenizer.tokenize(t) for t in sep_text]
42
+ flat_tokens = [item for sublist in tokens for item in sublist]
43
+ word2ph_token = [len(t) for t in tokens]
44
+ word2ph_token = [1] + word2ph_token + [1] # Account for [CLS] and [SEP]
45
+
46
+ inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
47
+ for k in inputs:
48
+ inputs[k] = inputs[k].to(device)
49
+
50
+ res = models[device](**inputs, output_hidden_states=True)
51
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].float().cpu()
52
+
53
+ if style_text:
54
+ style_inputs = tokenizer(style_text, return_tensors="pt")
55
+ for k in style_inputs:
56
+ style_inputs[k] = style_inputs[k].to(device)
57
+ style_res = models[device](**style_inputs, output_hidden_states=True)
58
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].float().cpu()
59
+ style_res_mean = style_res.mean(0)
60
+
61
+ if len(word2ph_token) != res.shape[0]:
62
+ print(f"[ERROR] len(word2ph_token) = {len(word2ph_token)}, but BERT output = {res.shape[0]}")
63
+ print(f"[DEBUG] input text: {text}")
64
+ raise ValueError("Mismatch between tokenized word2ph and BERT output length.")
65
 
66
  phone_level_feature = []
67
+ for i in range(len(word2ph_token)):
68
+ if style_text:
69
  blended = (
70
+ res[i].repeat(word2ph_token[i], 1) * (1 - style_weight)
71
+ + style_res_mean.repeat(word2ph_token[i], 1) * style_weight
72
  )
73
  else:
74
+ blended = res[i].repeat(word2ph_token[i], 1)
75
  phone_level_feature.append(blended)
76
 
77
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
78
  return phone_level_feature.T