JotunnBurton commited on
Commit
e5d8a74
·
verified ·
1 Parent(s): 82f5f52

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +264 -261
infer.py CHANGED
@@ -1,261 +1,264 @@
1
- """
2
- 版本管理、兼容推理及模型加载实现。
3
- 版本说明:
4
- 1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号
5
- 2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
6
- 特殊版本说明:
7
- 1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
8
- 2.2:当前版本
9
- """
10
- import torch
11
- import commons
12
- from text import cleaned_text_to_sequence, get_bert
13
-
14
- from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
15
- from text.cleaner import clean_text
16
- import utils
17
- import numpy as np
18
-
19
- from models import SynthesizerTrn
20
- from text.symbols import symbols
21
-
22
- # 当前版本信息
23
- latest_version = "2.4"
24
-
25
-
26
- # def get_emo_(reference_audio, emotion, sid):
27
- # emo = (
28
- # torch.from_numpy(get_emo(reference_audio))
29
- # if reference_audio and emotion == -1
30
- # else torch.FloatTensor(
31
- # np.load(f"emo_clustering/{sid}/cluster_center_{emotion}.npy")
32
- # )
33
- # )
34
- # return emo
35
-
36
-
37
- def get_net_g(model_path: str, version: str, device: str, hps):
38
- # 当前版本模型 net_g
39
- net_g = SynthesizerTrn(
40
- len(symbols),
41
- hps.data.filter_length // 2 + 1,
42
- hps.train.segment_size // hps.data.hop_length,
43
- n_speakers=hps.data.n_speakers,
44
- **hps.model,
45
- ).to(device)
46
- _ = net_g.eval()
47
- _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
48
- return net_g
49
-
50
-
51
- def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
52
- style_text = None if style_text == "" else style_text
53
- # 在此处实现当前版本的get_text
54
- norm_text, phone, tone, word2ph = clean_text(text, language_str)
55
- phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
56
- if hps.data.add_blank:
57
- phone = commons.intersperse(phone, 0)
58
- tone = commons.intersperse(tone, 0)
59
- language = commons.intersperse(language, 0)
60
- for i in range(len(word2ph)):
61
- word2ph[i] = word2ph[i] * 2
62
- word2ph[0] += 1
63
- bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
64
- del word2ph
65
-
66
- assert bert.shape[-1] == len(
67
- phone
68
- ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
69
-
70
- phone = torch.LongTensor(phone)
71
- tone = torch.LongTensor(tone)
72
- language = torch.LongTensor(language)
73
- return bert, phone, tone, language
74
-
75
-
76
- def infer(
77
- text,
78
- sdp_ratio,
79
- noise_scale,
80
- noise_scale_w,
81
- length_scale,
82
- sid,
83
- language,
84
- hps,
85
- net_g,
86
- device,
87
- emotion,
88
- reference_audio=None,
89
- skip_start=False,
90
- skip_end=False,
91
- style_text=None,
92
- style_weight=0.7,
93
- text_mode="Text",
94
- ):
95
- # 2.2版本参数位置变了
96
- # 2.1 参数新增 emotion reference_audio skip_start skip_end
97
- version = hps.version if hasattr(hps, "version") else latest_version
98
- language = "JP"
99
- if isinstance(reference_audio, np.ndarray):
100
- emo = get_clap_audio_feature(reference_audio, device)
101
- else:
102
- emo = get_clap_text_feature(emotion, device)
103
- emo = torch.squeeze(emo, dim=1)
104
-
105
- bert, phones, tones, lang_ids = get_text(
106
- text,
107
- language,
108
- hps,
109
- device,
110
- style_text=style_text,
111
- style_weight=style_weight,
112
- )
113
- if skip_start:
114
- phones = phones[3:]
115
- tones = tones[3:]
116
- lang_ids = lang_ids[3:]
117
- bert = bert[:, 3:]
118
- if skip_end:
119
- phones = phones[:-2]
120
- tones = tones[:-2]
121
- lang_ids = lang_ids[:-2]
122
- bert = bert[:, :-2]
123
- with torch.no_grad():
124
- x_tst = phones.to(device).unsqueeze(0)
125
- tones = tones.to(device).unsqueeze(0)
126
- lang_ids = lang_ids.to(device).unsqueeze(0)
127
- bert = bert.to(device).unsqueeze(0)
128
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
129
- emo = emo.to(device).unsqueeze(0)
130
- del phones
131
- speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
132
- print(text)
133
- audio = (
134
- net_g.infer(
135
- x_tst,
136
- x_tst_lengths,
137
- speakers,
138
- tones,
139
- lang_ids,
140
- bert,
141
- emo,
142
- sdp_ratio=sdp_ratio,
143
- noise_scale=noise_scale,
144
- noise_scale_w=noise_scale_w,
145
- length_scale=length_scale,
146
- )[0][0, 0]
147
- .data.cpu()
148
- .float()
149
- .numpy()
150
- )
151
- del (
152
- x_tst,
153
- tones,
154
- lang_ids,
155
- bert,
156
- x_tst_lengths,
157
- speakers,
158
- emo,
159
- ) # , emo
160
- if torch.cuda.is_available():
161
- torch.cuda.empty_cache()
162
- return audio
163
-
164
-
165
- def infer_multilang(
166
- text,
167
- sdp_ratio,
168
- noise_scale,
169
- noise_scale_w,
170
- length_scale,
171
- sid,
172
- language,
173
- hps,
174
- net_g,
175
- device,
176
- reference_audio=None,
177
- emotion=None,
178
- skip_start=False,
179
- skip_end=False,
180
- style_text=None,
181
- style_weight=0.7,
182
- ):
183
- bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
184
- if isinstance(reference_audio, np.ndarray):
185
- emo = get_clap_audio_feature(reference_audio, device)
186
- else:
187
- emo = get_clap_text_feature(emotion, device)
188
- emo = torch.squeeze(emo, dim=1)
189
- for idx, (txt, lang) in enumerate(zip(text, language)):
190
- _skip_start = (idx != 0) or (skip_start and idx == 0)
191
- _skip_end = (idx != len(language) - 1) or skip_end
192
- (
193
- temp_bert,
194
- temp_phones,
195
- temp_tones,
196
- temp_lang_ids,
197
- ) = get_text(
198
- txt,
199
- lang,
200
- hps,
201
- device,
202
- style_text=style_text,
203
- style_weight=style_weight,
204
- )
205
- if _skip_start:
206
- temp_bert = temp_bert[:, 3:]
207
- temp_phones = temp_phones[3:]
208
- temp_tones = temp_tones[3:]
209
- temp_lang_ids = temp_lang_ids[3:]
210
- if _skip_end:
211
- temp_bert = temp_bert[:, :-2]
212
- temp_phones = temp_phones[:-2]
213
- temp_tones = temp_tones[:-2]
214
- temp_lang_ids = temp_lang_ids[:-2]
215
- bert.append(temp_bert)
216
- phones.append(temp_phones)
217
- tones.append(temp_tones)
218
- lang_ids.append(temp_lang_ids)
219
- bert = torch.concatenate(bert, dim=1)
220
- phones = torch.concatenate(phones, dim=0)
221
- tones = torch.concatenate(tones, dim=0)
222
- lang_ids = torch.concatenate(lang_ids, dim=0)
223
- with torch.no_grad():
224
- x_tst = phones.to(device).unsqueeze(0)
225
- tones = tones.to(device).unsqueeze(0)
226
- lang_ids = lang_ids.to(device).unsqueeze(0)
227
- bert = bert.to(device).unsqueeze(0)
228
- emo = emo.to(device).unsqueeze(0)
229
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
230
- del phones
231
- speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
232
- audio = (
233
- net_g.infer(
234
- x_tst,
235
- x_tst_lengths,
236
- speakers,
237
- tones,
238
- lang_ids,
239
- bert,
240
- emo,
241
- sdp_ratio=sdp_ratio,
242
- noise_scale=noise_scale,
243
- noise_scale_w=noise_scale_w,
244
- length_scale=length_scale,
245
- )[0][0, 0]
246
- .data.cpu()
247
- .float()
248
- .numpy()
249
- )
250
- del (
251
- x_tst,
252
- tones,
253
- lang_ids,
254
- bert,
255
- x_tst_lengths,
256
- speakers,
257
- emo,
258
- ) # , emo
259
- if torch.cuda.is_available():
260
- torch.cuda.empty_cache()
261
- return audio
 
 
 
 
1
+ """
2
+ 版本管理、兼容推理及模型加载实现。
3
+ 版本说明:
4
+ 1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号
5
+ 2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
6
+ 特殊版本说明:
7
+ 1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
8
+ 2.2:当前版本
9
+ """
10
+ import torch
11
+ import commons
12
+ from text import cleaned_text_to_sequence, get_bert
13
+
14
+ from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
15
+ from text.cleaner import clean_text
16
+ import utils
17
+ import numpy as np
18
+
19
+ from models import SynthesizerTrn
20
+ from text.symbols import symbols
21
+
22
+ # 当前版本信息
23
+ latest_version = "2.4"
24
+
25
+
26
+ # def get_emo_(reference_audio, emotion, sid):
27
+ # emo = (
28
+ # torch.from_numpy(get_emo(reference_audio))
29
+ # if reference_audio and emotion == -1
30
+ # else torch.FloatTensor(
31
+ # np.load(f"emo_clustering/{sid}/cluster_center_{emotion}.npy")
32
+ # )
33
+ # )
34
+ # return emo
35
+
36
+
37
+ def get_net_g(model_path: str, version: str, device: str, hps):
38
+ # 当前版本模型 net_g
39
+ net_g = SynthesizerTrn(
40
+ len(symbols),
41
+ hps.data.filter_length // 2 + 1,
42
+ hps.train.segment_size // hps.data.hop_length,
43
+ n_speakers=hps.data.n_speakers,
44
+ **hps.model,
45
+ ).to(device)
46
+ _ = net_g.eval()
47
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
48
+ return net_g
49
+
50
+
51
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
52
+ style_text = None if style_text == "" else style_text
53
+ # 在此处实现当前版本的get_text
54
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
55
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
56
+ if hps.data.add_blank:
57
+ phone = commons.intersperse(phone, 0)
58
+ tone = commons.intersperse(tone, 0)
59
+ language = commons.intersperse(language, 0)
60
+ for i in range(len(word2ph)):
61
+ word2ph[i] = word2ph[i] * 2
62
+ word2ph[0] += 1
63
+ bert = get_bert(norm_text, word2ph, language_str, device, style_text, style_weight)
64
+ del word2ph
65
+
66
+ assert bert.shape[-1] == len(
67
+ phone
68
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
69
+
70
+ phone = torch.LongTensor(phone)
71
+ tone = torch.LongTensor(tone)
72
+ language = torch.LongTensor(language)
73
+ return bert, phone, tone, language
74
+
75
+
76
+ def infer(
77
+ text,
78
+ sdp_ratio,
79
+ noise_scale,
80
+ noise_scale_w,
81
+ length_scale,
82
+ sid,
83
+ language,
84
+ hps,
85
+ net_g,
86
+ device,
87
+ emotion,
88
+ reference_audio=None,
89
+ skip_start=False,
90
+ skip_end=False,
91
+ style_text=None,
92
+ style_weight=0.7,
93
+ text_mode="Text",
94
+ ):
95
+ # 2.2版本参数位置变了
96
+ # 2.1 参数新增 emotion reference_audio skip_start skip_end
97
+ version = hps.version if hasattr(hps, "version") else latest_version
98
+ language = "JP"
99
+ if isinstance(reference_audio, np.ndarray):
100
+ emo = get_clap_audio_feature(reference_audio, device)
101
+ else:
102
+ emo = get_clap_text_feature(emotion, device)
103
+ emo = torch.squeeze(emo, dim=1)
104
+
105
+ bert, phones, tones, lang_ids = get_text(
106
+ text,
107
+ language,
108
+ hps,
109
+ device,
110
+ style_text=style_text,
111
+ style_weight=style_weight,
112
+ )
113
+ if skip_start:
114
+ phones = phones[3:]
115
+ tones = tones[3:]
116
+ lang_ids = lang_ids[3:]
117
+ bert = bert[:, 3:]
118
+ if skip_end:
119
+ phones = phones[:-2]
120
+ tones = tones[:-2]
121
+ lang_ids = lang_ids[:-2]
122
+ bert = bert[:, :-2]
123
+ with torch.no_grad():
124
+ x_tst = phones.to(device).unsqueeze(0)
125
+ tones = tones.to(device).unsqueeze(0)
126
+ lang_ids = lang_ids.to(device).unsqueeze(0)
127
+ bert = bert.to(device).unsqueeze(0)
128
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
129
+ emo = emo.to(device).unsqueeze(0)
130
+ del phones
131
+
132
+ print([hps.data.spk2id[sid]]);
133
+
134
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
135
+ print(text)
136
+ audio = (
137
+ net_g.infer(
138
+ x_tst,
139
+ x_tst_lengths,
140
+ speakers,
141
+ tones,
142
+ lang_ids,
143
+ bert,
144
+ emo,
145
+ sdp_ratio=sdp_ratio,
146
+ noise_scale=noise_scale,
147
+ noise_scale_w=noise_scale_w,
148
+ length_scale=length_scale,
149
+ )[0][0, 0]
150
+ .data.cpu()
151
+ .float()
152
+ .numpy()
153
+ )
154
+ del (
155
+ x_tst,
156
+ tones,
157
+ lang_ids,
158
+ bert,
159
+ x_tst_lengths,
160
+ speakers,
161
+ emo,
162
+ ) # , emo
163
+ if torch.cuda.is_available():
164
+ torch.cuda.empty_cache()
165
+ return audio
166
+
167
+
168
+ def infer_multilang(
169
+ text,
170
+ sdp_ratio,
171
+ noise_scale,
172
+ noise_scale_w,
173
+ length_scale,
174
+ sid,
175
+ language,
176
+ hps,
177
+ net_g,
178
+ device,
179
+ reference_audio=None,
180
+ emotion=None,
181
+ skip_start=False,
182
+ skip_end=False,
183
+ style_text=None,
184
+ style_weight=0.7,
185
+ ):
186
+ bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
187
+ if isinstance(reference_audio, np.ndarray):
188
+ emo = get_clap_audio_feature(reference_audio, device)
189
+ else:
190
+ emo = get_clap_text_feature(emotion, device)
191
+ emo = torch.squeeze(emo, dim=1)
192
+ for idx, (txt, lang) in enumerate(zip(text, language)):
193
+ _skip_start = (idx != 0) or (skip_start and idx == 0)
194
+ _skip_end = (idx != len(language) - 1) or skip_end
195
+ (
196
+ temp_bert,
197
+ temp_phones,
198
+ temp_tones,
199
+ temp_lang_ids,
200
+ ) = get_text(
201
+ txt,
202
+ lang,
203
+ hps,
204
+ device,
205
+ style_text=style_text,
206
+ style_weight=style_weight,
207
+ )
208
+ if _skip_start:
209
+ temp_bert = temp_bert[:, 3:]
210
+ temp_phones = temp_phones[3:]
211
+ temp_tones = temp_tones[3:]
212
+ temp_lang_ids = temp_lang_ids[3:]
213
+ if _skip_end:
214
+ temp_bert = temp_bert[:, :-2]
215
+ temp_phones = temp_phones[:-2]
216
+ temp_tones = temp_tones[:-2]
217
+ temp_lang_ids = temp_lang_ids[:-2]
218
+ bert.append(temp_bert)
219
+ phones.append(temp_phones)
220
+ tones.append(temp_tones)
221
+ lang_ids.append(temp_lang_ids)
222
+ bert = torch.concatenate(bert, dim=1)
223
+ phones = torch.concatenate(phones, dim=0)
224
+ tones = torch.concatenate(tones, dim=0)
225
+ lang_ids = torch.concatenate(lang_ids, dim=0)
226
+ with torch.no_grad():
227
+ x_tst = phones.to(device).unsqueeze(0)
228
+ tones = tones.to(device).unsqueeze(0)
229
+ lang_ids = lang_ids.to(device).unsqueeze(0)
230
+ bert = bert.to(device).unsqueeze(0)
231
+ emo = emo.to(device).unsqueeze(0)
232
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
233
+ del phones
234
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
235
+ audio = (
236
+ net_g.infer(
237
+ x_tst,
238
+ x_tst_lengths,
239
+ speakers,
240
+ tones,
241
+ lang_ids,
242
+ bert,
243
+ emo,
244
+ sdp_ratio=sdp_ratio,
245
+ noise_scale=noise_scale,
246
+ noise_scale_w=noise_scale_w,
247
+ length_scale=length_scale,
248
+ )[0][0, 0]
249
+ .data.cpu()
250
+ .float()
251
+ .numpy()
252
+ )
253
+ del (
254
+ x_tst,
255
+ tones,
256
+ lang_ids,
257
+ bert,
258
+ x_tst_lengths,
259
+ speakers,
260
+ emo,
261
+ ) # , emo
262
+ if torch.cuda.is_available():
263
+ torch.cuda.empty_cache()
264
+ return audio