JotunnBurton commited on
Commit
6ccb2b2
·
verified ·
1 Parent(s): 92259fe

Upload infer.py

Browse files
Files changed (1) hide show
  1. infer.py +261 -0
infer.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 版本管理、兼容推理及模型加载实现。
3
+ 版本说明:
4
+ 1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号
5
+ 2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
6
+ 特殊版本说明:
7
+ 1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
8
+ 2.2:当前版本
9
+ """
10
+ import torch
11
+ import commons
12
+ from text import cleaned_text_to_sequence, get_bert
13
+
14
+ from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
15
+ from text.cleaner import clean_text
16
+ import utils
17
+ import numpy as np
18
+
19
+ from models import SynthesizerTrn
20
+ from text.symbols import symbols
21
+
22
+ # 当前版本信息
23
+ latest_version = "2.4"
24
+
25
+
26
+ # def get_emo_(reference_audio, emotion, sid):
27
+ # emo = (
28
+ # torch.from_numpy(get_emo(reference_audio))
29
+ # if reference_audio and emotion == -1
30
+ # else torch.FloatTensor(
31
+ # np.load(f"emo_clustering/{sid}/cluster_center_{emotion}.npy")
32
+ # )
33
+ # )
34
+ # return emo
35
+
36
+
37
+ def get_net_g(model_path: str, version: str, device: str, hps):
38
+ # 当前版本模型 net_g
39
+ net_g = SynthesizerTrn(
40
+ len(symbols),
41
+ hps.data.filter_length // 2 + 1,
42
+ hps.train.segment_size // hps.data.hop_length,
43
+ n_speakers=hps.data.n_speakers,
44
+ **hps.model,
45
+ ).to(device)
46
+ _ = net_g.eval()
47
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
48
+ return net_g
49
+
50
+
51
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
52
+ style_text = None if style_text == "" else style_text
53
+ # 在此处实现当前版本的get_text
54
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
55
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
56
+ if hps.data.add_blank:
57
+ phone = commons.intersperse(phone, 0)
58
+ tone = commons.intersperse(tone, 0)
59
+ language = commons.intersperse(language, 0)
60
+ for i in range(len(word2ph)):
61
+ word2ph[i] = word2ph[i] * 2
62
+ word2ph[0] += 1
63
+ bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
64
+ del word2ph
65
+
66
+ assert bert.shape[-1] == len(
67
+ phone
68
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
69
+
70
+ phone = torch.LongTensor(phone)
71
+ tone = torch.LongTensor(tone)
72
+ language = torch.LongTensor(language)
73
+ return bert, phone, tone, language
74
+
75
+
76
+ def infer(
77
+ text,
78
+ sdp_ratio,
79
+ noise_scale,
80
+ noise_scale_w,
81
+ length_scale,
82
+ sid,
83
+ language,
84
+ hps,
85
+ net_g,
86
+ device,
87
+ emotion,
88
+ reference_audio=None,
89
+ skip_start=False,
90
+ skip_end=False,
91
+ style_text=None,
92
+ style_weight=0.7,
93
+ text_mode="Text",
94
+ ):
95
+ # 2.2版本参数位置变了
96
+ # 2.1 参数新增 emotion reference_audio skip_start skip_end
97
+ version = hps.version if hasattr(hps, "version") else latest_version
98
+ language = "JP"
99
+ if isinstance(reference_audio, np.ndarray):
100
+ emo = get_clap_audio_feature(reference_audio, device)
101
+ else:
102
+ emo = get_clap_text_feature(emotion, device)
103
+ emo = torch.squeeze(emo, dim=1)
104
+
105
+ bert, phones, tones, lang_ids = get_text(
106
+ text,
107
+ language,
108
+ hps,
109
+ device,
110
+ style_text=style_text,
111
+ style_weight=style_weight,
112
+ )
113
+ if skip_start:
114
+ phones = phones[3:]
115
+ tones = tones[3:]
116
+ lang_ids = lang_ids[3:]
117
+ bert = bert[:, 3:]
118
+ if skip_end:
119
+ phones = phones[:-2]
120
+ tones = tones[:-2]
121
+ lang_ids = lang_ids[:-2]
122
+ bert = bert[:, :-2]
123
+ with torch.no_grad():
124
+ x_tst = phones.to(device).unsqueeze(0)
125
+ tones = tones.to(device).unsqueeze(0)
126
+ lang_ids = lang_ids.to(device).unsqueeze(0)
127
+ bert = bert.to(device).unsqueeze(0)
128
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
129
+ emo = emo.to(device).unsqueeze(0)
130
+ del phones
131
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
132
+ print(text)
133
+ audio = (
134
+ net_g.infer(
135
+ x_tst,
136
+ x_tst_lengths,
137
+ speakers,
138
+ tones,
139
+ lang_ids,
140
+ bert,
141
+ emo,
142
+ sdp_ratio=sdp_ratio,
143
+ noise_scale=noise_scale,
144
+ noise_scale_w=noise_scale_w,
145
+ length_scale=length_scale,
146
+ )[0][0, 0]
147
+ .data.cpu()
148
+ .float()
149
+ .numpy()
150
+ )
151
+ del (
152
+ x_tst,
153
+ tones,
154
+ lang_ids,
155
+ bert,
156
+ x_tst_lengths,
157
+ speakers,
158
+ emo,
159
+ ) # , emo
160
+ if torch.cuda.is_available():
161
+ torch.cuda.empty_cache()
162
+ return audio
163
+
164
+
165
+ def infer_multilang(
166
+ text,
167
+ sdp_ratio,
168
+ noise_scale,
169
+ noise_scale_w,
170
+ length_scale,
171
+ sid,
172
+ language,
173
+ hps,
174
+ net_g,
175
+ device,
176
+ reference_audio=None,
177
+ emotion=None,
178
+ skip_start=False,
179
+ skip_end=False,
180
+ style_text=None,
181
+ style_weight=0.7,
182
+ ):
183
+ bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
184
+ if isinstance(reference_audio, np.ndarray):
185
+ emo = get_clap_audio_feature(reference_audio, device)
186
+ else:
187
+ emo = get_clap_text_feature(emotion, device)
188
+ emo = torch.squeeze(emo, dim=1)
189
+ for idx, (txt, lang) in enumerate(zip(text, language)):
190
+ _skip_start = (idx != 0) or (skip_start and idx == 0)
191
+ _skip_end = (idx != len(language) - 1) or skip_end
192
+ (
193
+ temp_bert,
194
+ temp_phones,
195
+ temp_tones,
196
+ temp_lang_ids,
197
+ ) = get_text(
198
+ txt,
199
+ lang,
200
+ hps,
201
+ device,
202
+ style_text=style_text,
203
+ style_weight=style_weight,
204
+ )
205
+ if _skip_start:
206
+ temp_bert = temp_bert[:, 3:]
207
+ temp_phones = temp_phones[3:]
208
+ temp_tones = temp_tones[3:]
209
+ temp_lang_ids = temp_lang_ids[3:]
210
+ if _skip_end:
211
+ temp_bert = temp_bert[:, :-2]
212
+ temp_phones = temp_phones[:-2]
213
+ temp_tones = temp_tones[:-2]
214
+ temp_lang_ids = temp_lang_ids[:-2]
215
+ bert.append(temp_bert)
216
+ phones.append(temp_phones)
217
+ tones.append(temp_tones)
218
+ lang_ids.append(temp_lang_ids)
219
+ bert = torch.concatenate(bert, dim=1)
220
+ phones = torch.concatenate(phones, dim=0)
221
+ tones = torch.concatenate(tones, dim=0)
222
+ lang_ids = torch.concatenate(lang_ids, dim=0)
223
+ with torch.no_grad():
224
+ x_tst = phones.to(device).unsqueeze(0)
225
+ tones = tones.to(device).unsqueeze(0)
226
+ lang_ids = lang_ids.to(device).unsqueeze(0)
227
+ bert = bert.to(device).unsqueeze(0)
228
+ emo = emo.to(device).unsqueeze(0)
229
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
230
+ del phones
231
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
232
+ audio = (
233
+ net_g.infer(
234
+ x_tst,
235
+ x_tst_lengths,
236
+ speakers,
237
+ tones,
238
+ lang_ids,
239
+ bert,
240
+ emo,
241
+ sdp_ratio=sdp_ratio,
242
+ noise_scale=noise_scale,
243
+ noise_scale_w=noise_scale_w,
244
+ length_scale=length_scale,
245
+ )[0][0, 0]
246
+ .data.cpu()
247
+ .float()
248
+ .numpy()
249
+ )
250
+ del (
251
+ x_tst,
252
+ tones,
253
+ lang_ids,
254
+ bert,
255
+ x_tst_lengths,
256
+ speakers,
257
+ emo,
258
+ ) # , emo
259
+ if torch.cuda.is_available():
260
+ torch.cuda.empty_cache()
261
+ return audio