JotunnBurton commited on
Commit
3f3d5c8
·
verified ·
1 Parent(s): 36863d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -40,11 +40,11 @@ def get_net_g(model_path: str, version: str, device: str, hps):
40
  _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
41
  return net_g
42
 
43
- def get_text(text, hps):
44
- language_str = "JP"
 
45
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
46
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
47
-
48
  if hps.data.add_blank:
49
  phone = commons.intersperse(phone, 0)
50
  tone = commons.intersperse(tone, 0)
@@ -52,18 +52,17 @@ def get_text(text, hps):
52
  for i in range(len(word2ph)):
53
  word2ph[i] = word2ph[i] * 2
54
  word2ph[0] += 1
55
-
56
- bert = get_bert(norm_text, word2ph, language_str, device)
57
  del word2ph
58
- assert bert.shape[-1] == len(phone), phone
59
 
60
- ja_bert = bert
61
- bert = torch.zeros(1024, len(phone))
 
62
 
63
  phone = torch.LongTensor(phone)
64
  tone = torch.LongTensor(tone)
65
  language = torch.LongTensor(language)
66
- return bert, ja_bert, phone, tone, language
67
 
68
 
69
  def infer(
 
40
  _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
41
  return net_g
42
 
43
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
44
+ style_text = None if style_text == "" else style_text
45
+ # 在此处实现当前版本的get_text
46
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
47
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
 
48
  if hps.data.add_blank:
49
  phone = commons.intersperse(phone, 0)
50
  tone = commons.intersperse(tone, 0)
 
52
  for i in range(len(word2ph)):
53
  word2ph[i] = word2ph[i] * 2
54
  word2ph[0] += 1
55
+ bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
 
56
  del word2ph
 
57
 
58
+ assert bert.shape[-1] == len(
59
+ phone
60
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
61
 
62
  phone = torch.LongTensor(phone)
63
  tone = torch.LongTensor(tone)
64
  language = torch.LongTensor(language)
65
+ return bert, phone, tone, language
66
 
67
 
68
  def infer(