JotunnBurton commited on
Commit
7709d54
·
verified ·
1 Parent(s): c069afe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -9
app.py CHANGED
@@ -50,26 +50,72 @@ def get_text(text, hps):
50
  return bert, ja_bert, phone, tone, language
51
 
52
 
53
- def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, net_g_ms, hps):
54
- bert, ja_bert, phones, tones, lang_ids = get_text(text, hps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  with torch.no_grad():
56
  x_tst = phones.to(device).unsqueeze(0)
57
  tones = tones.to(device).unsqueeze(0)
58
  lang_ids = lang_ids.to(device).unsqueeze(0)
59
  bert = bert.to(device).unsqueeze(0)
60
- ja_bert = ja_bert.to(device).unsqueeze(0)
61
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
 
62
  del phones
63
- sid = torch.LongTensor([sid]).to(device)
 
64
  audio = (
65
- net_g_ms.infer(
66
  x_tst,
67
  x_tst_lengths,
68
- sid,
69
  tones,
70
  lang_ids,
71
  bert,
72
- ja_bert,
73
  sdp_ratio=sdp_ratio,
74
  noise_scale=noise_scale,
75
  noise_scale_w=noise_scale_w,
@@ -79,8 +125,17 @@ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, net_g_
79
  .float()
80
  .numpy()
81
  )
82
- del x_tst, tones, lang_ids, bert, x_tst_lengths, sid
83
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
84
  return audio
85
 
86
 
 
50
  return bert, ja_bert, phone, tone, language
51
 
52
 
53
+ def infer(
54
+ text,
55
+ sdp_ratio,
56
+ noise_scale,
57
+ noise_scale_w,
58
+ length_scale,
59
+ sid,
60
+ language,
61
+ hps,
62
+ net_g,
63
+ device,
64
+ emotion,
65
+ reference_audio=None,
66
+ skip_start=False,
67
+ skip_end=False,
68
+ style_text=None,
69
+ style_weight=0.7,
70
+ text_mode="Text",
71
+ ):
72
+ # 2.2版本参数位置变了
73
+ # 2.1 参数新增 emotion reference_audio skip_start skip_end
74
+ version = hps.version if hasattr(hps, "version") else latest_version
75
+ language = "JP"
76
+ if isinstance(reference_audio, np.ndarray):
77
+ emo = get_clap_audio_feature(reference_audio, device)
78
+ else:
79
+ emo = get_clap_text_feature(emotion, device)
80
+ emo = torch.squeeze(emo, dim=1)
81
+
82
+ bert, phones, tones, lang_ids = get_text(
83
+ text,
84
+ language,
85
+ hps,
86
+ device,
87
+ style_text=style_text,
88
+ style_weight=style_weight,
89
+ )
90
+ if skip_start:
91
+ phones = phones[3:]
92
+ tones = tones[3:]
93
+ lang_ids = lang_ids[3:]
94
+ bert = bert[:, 3:]
95
+ if skip_end:
96
+ phones = phones[:-2]
97
+ tones = tones[:-2]
98
+ lang_ids = lang_ids[:-2]
99
+ bert = bert[:, :-2]
100
  with torch.no_grad():
101
  x_tst = phones.to(device).unsqueeze(0)
102
  tones = tones.to(device).unsqueeze(0)
103
  lang_ids = lang_ids.to(device).unsqueeze(0)
104
  bert = bert.to(device).unsqueeze(0)
 
105
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
106
+ emo = emo.to(device).unsqueeze(0)
107
  del phones
108
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
109
+ print(text)
110
  audio = (
111
+ net_g.infer(
112
  x_tst,
113
  x_tst_lengths,
114
+ speakers,
115
  tones,
116
  lang_ids,
117
  bert,
118
+ emo,
119
  sdp_ratio=sdp_ratio,
120
  noise_scale=noise_scale,
121
  noise_scale_w=noise_scale_w,
 
125
  .float()
126
  .numpy()
127
  )
128
+ del (
129
+ x_tst,
130
+ tones,
131
+ lang_ids,
132
+ bert,
133
+ x_tst_lengths,
134
+ speakers,
135
+ emo,
136
+ ) # , emo
137
+ if torch.cuda.is_available():
138
+ torch.cuda.empty_cache()
139
  return audio
140
 
141