JotunnBurton commited on
Commit
8cc38b0
·
verified ·
1 Parent(s): f96eb4f

Upload models.py

Browse files
Files changed (1) hide show
  1. models.py +1105 -0
models.py ADDED
@@ -0,0 +1,1105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions
9
+ import monotonic_align
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+
14
+ from commons import init_weights, get_padding
15
+ from text import symbols, num_tones, num_languages
16
+
17
+ from vector_quantize_pytorch import VectorQuantize
18
+
19
+
20
+ class DurationDiscriminator(nn.Module): # vits2
21
+ def __init__(
22
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
23
+ ):
24
+ super().__init__()
25
+
26
+ self.in_channels = in_channels
27
+ self.filter_channels = filter_channels
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.gin_channels = gin_channels
31
+
32
+ self.drop = nn.Dropout(p_dropout)
33
+ self.conv_1 = nn.Conv1d(
34
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
35
+ )
36
+ self.norm_1 = modules.LayerNorm(filter_channels)
37
+ self.conv_2 = nn.Conv1d(
38
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
39
+ )
40
+ self.norm_2 = modules.LayerNorm(filter_channels)
41
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
42
+
43
+ self.LSTM = nn.LSTM(
44
+ 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
45
+ )
46
+
47
+ if gin_channels != 0:
48
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
49
+
50
+ self.output_layer = nn.Sequential(
51
+ nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
52
+ )
53
+
54
+ def forward_probability(self, x, dur):
55
+ dur = self.dur_proj(dur)
56
+ x = torch.cat([x, dur], dim=1)
57
+ x = x.transpose(1, 2)
58
+ x, _ = self.LSTM(x)
59
+ output_prob = self.output_layer(x)
60
+ return output_prob
61
+
62
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
63
+ x = torch.detach(x)
64
+ if g is not None:
65
+ g = torch.detach(g)
66
+ x = x + self.cond(g)
67
+ x = self.conv_1(x * x_mask)
68
+ x = torch.relu(x)
69
+ x = self.norm_1(x)
70
+ x = self.drop(x)
71
+ x = self.conv_2(x * x_mask)
72
+ x = torch.relu(x)
73
+ x = self.norm_2(x)
74
+ x = self.drop(x)
75
+
76
+ output_probs = []
77
+ for dur in [dur_r, dur_hat]:
78
+ output_prob = self.forward_probability(x, dur)
79
+ output_probs.append(output_prob)
80
+
81
+ return output_probs
82
+
83
+
84
+ class TransformerCouplingBlock(nn.Module):
85
+ def __init__(
86
+ self,
87
+ channels,
88
+ hidden_channels,
89
+ filter_channels,
90
+ n_heads,
91
+ n_layers,
92
+ kernel_size,
93
+ p_dropout,
94
+ n_flows=4,
95
+ gin_channels=0,
96
+ share_parameter=False,
97
+ ):
98
+ super().__init__()
99
+ self.channels = channels
100
+ self.hidden_channels = hidden_channels
101
+ self.kernel_size = kernel_size
102
+ self.n_layers = n_layers
103
+ self.n_flows = n_flows
104
+ self.gin_channels = gin_channels
105
+
106
+ self.flows = nn.ModuleList()
107
+
108
+ self.wn = (
109
+ attentions.FFT(
110
+ hidden_channels,
111
+ filter_channels,
112
+ n_heads,
113
+ n_layers,
114
+ kernel_size,
115
+ p_dropout,
116
+ isflow=True,
117
+ gin_channels=self.gin_channels,
118
+ )
119
+ if share_parameter
120
+ else None
121
+ )
122
+
123
+ for i in range(n_flows):
124
+ self.flows.append(
125
+ modules.TransformerCouplingLayer(
126
+ channels,
127
+ hidden_channels,
128
+ kernel_size,
129
+ n_layers,
130
+ n_heads,
131
+ p_dropout,
132
+ filter_channels,
133
+ mean_only=True,
134
+ wn_sharing_parameter=self.wn,
135
+ gin_channels=self.gin_channels,
136
+ )
137
+ )
138
+ self.flows.append(modules.Flip())
139
+
140
+ def forward(self, x, x_mask, g=None, reverse=False):
141
+ if not reverse:
142
+ for flow in self.flows:
143
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
144
+ else:
145
+ for flow in reversed(self.flows):
146
+ x = flow(x, x_mask, g=g, reverse=reverse)
147
+ return x
148
+
149
+
150
+ class StochasticDurationPredictor(nn.Module):
151
+ def __init__(
152
+ self,
153
+ in_channels,
154
+ filter_channels,
155
+ kernel_size,
156
+ p_dropout,
157
+ n_flows=4,
158
+ gin_channels=0,
159
+ ):
160
+ super().__init__()
161
+ filter_channels = in_channels # it needs to be removed from future version.
162
+ self.in_channels = in_channels
163
+ self.filter_channels = filter_channels
164
+ self.kernel_size = kernel_size
165
+ self.p_dropout = p_dropout
166
+ self.n_flows = n_flows
167
+ self.gin_channels = gin_channels
168
+
169
+ self.log_flow = modules.Log()
170
+ self.flows = nn.ModuleList()
171
+ self.flows.append(modules.ElementwiseAffine(2))
172
+ for i in range(n_flows):
173
+ self.flows.append(
174
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
175
+ )
176
+ self.flows.append(modules.Flip())
177
+
178
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
179
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
180
+ self.post_convs = modules.DDSConv(
181
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
182
+ )
183
+ self.post_flows = nn.ModuleList()
184
+ self.post_flows.append(modules.ElementwiseAffine(2))
185
+ for i in range(4):
186
+ self.post_flows.append(
187
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
188
+ )
189
+ self.post_flows.append(modules.Flip())
190
+
191
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
192
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
193
+ self.convs = modules.DDSConv(
194
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
195
+ )
196
+ if gin_channels != 0:
197
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
198
+
199
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
200
+ x = torch.detach(x)
201
+ x = self.pre(x)
202
+ if g is not None:
203
+ g = torch.detach(g)
204
+ x = x + self.cond(g)
205
+ x = self.convs(x, x_mask)
206
+ x = self.proj(x) * x_mask
207
+
208
+ if not reverse:
209
+ flows = self.flows
210
+ assert w is not None
211
+
212
+ logdet_tot_q = 0
213
+ h_w = self.post_pre(w)
214
+ h_w = self.post_convs(h_w, x_mask)
215
+ h_w = self.post_proj(h_w) * x_mask
216
+ e_q = (
217
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
218
+ * x_mask
219
+ )
220
+ z_q = e_q
221
+ for flow in self.post_flows:
222
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
223
+ logdet_tot_q += logdet_q
224
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
225
+ u = torch.sigmoid(z_u) * x_mask
226
+ z0 = (w - u) * x_mask
227
+ logdet_tot_q += torch.sum(
228
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
229
+ )
230
+ logq = (
231
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
232
+ - logdet_tot_q
233
+ )
234
+
235
+ logdet_tot = 0
236
+ z0, logdet = self.log_flow(z0, x_mask)
237
+ logdet_tot += logdet
238
+ z = torch.cat([z0, z1], 1)
239
+ for flow in flows:
240
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
241
+ logdet_tot = logdet_tot + logdet
242
+ nll = (
243
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
244
+ - logdet_tot
245
+ )
246
+ return nll + logq # [b]
247
+ else:
248
+ flows = list(reversed(self.flows))
249
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
250
+ z = (
251
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
252
+ * noise_scale
253
+ )
254
+ for flow in flows:
255
+ z = flow(z, x_mask, g=x, reverse=reverse)
256
+ z0, z1 = torch.split(z, [1, 1], 1)
257
+ logw = z0
258
+ return logw
259
+
260
+
261
+ class DurationPredictor(nn.Module):
262
+ def __init__(
263
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
264
+ ):
265
+ super().__init__()
266
+
267
+ self.in_channels = in_channels
268
+ self.filter_channels = filter_channels
269
+ self.kernel_size = kernel_size
270
+ self.p_dropout = p_dropout
271
+ self.gin_channels = gin_channels
272
+
273
+ self.drop = nn.Dropout(p_dropout)
274
+ self.conv_1 = nn.Conv1d(
275
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
276
+ )
277
+ self.norm_1 = modules.LayerNorm(filter_channels)
278
+ self.conv_2 = nn.Conv1d(
279
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
280
+ )
281
+ self.norm_2 = modules.LayerNorm(filter_channels)
282
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
283
+
284
+ if gin_channels != 0:
285
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
286
+
287
+ def forward(self, x, x_mask, g=None):
288
+ x = torch.detach(x)
289
+ if g is not None:
290
+ g = torch.detach(g)
291
+ x = x + self.cond(g)
292
+ x = self.conv_1(x * x_mask)
293
+ x = torch.relu(x)
294
+ x = self.norm_1(x)
295
+ x = self.drop(x)
296
+ x = self.conv_2(x * x_mask)
297
+ x = torch.relu(x)
298
+ x = self.norm_2(x)
299
+ x = self.drop(x)
300
+ x = self.proj(x * x_mask)
301
+ return x * x_mask
302
+
303
+
304
+ class Bottleneck(nn.Sequential):
305
+ def __init__(self, in_dim, hidden_dim):
306
+ c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
307
+ c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
308
+ super().__init__(*[c_fc1, c_fc2])
309
+
310
+
311
+ class Block(nn.Module):
312
+ def __init__(self, in_dim, hidden_dim) -> None:
313
+ super().__init__()
314
+ self.norm = nn.LayerNorm(in_dim)
315
+ self.mlp = MLP(in_dim, hidden_dim)
316
+
317
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
318
+ x = x + self.mlp(self.norm(x))
319
+ return x
320
+
321
+
322
+ class MLP(nn.Module):
323
+ def __init__(self, in_dim, hidden_dim):
324
+ super().__init__()
325
+ self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
326
+ self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
327
+ self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
328
+
329
+ def forward(self, x: torch.Tensor):
330
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
331
+ x = self.c_proj(x)
332
+ return x
333
+
334
+
335
+ class TextEncoder(nn.Module):
336
+ def __init__(
337
+ self,
338
+ n_vocab,
339
+ out_channels,
340
+ hidden_channels,
341
+ filter_channels,
342
+ n_heads,
343
+ n_layers,
344
+ kernel_size,
345
+ p_dropout,
346
+ gin_channels=0,
347
+ ):
348
+ super().__init__()
349
+ self.n_vocab = n_vocab
350
+ self.out_channels = out_channels
351
+ self.hidden_channels = hidden_channels
352
+ self.filter_channels = filter_channels
353
+ self.n_heads = n_heads
354
+ self.n_layers = n_layers
355
+ self.kernel_size = kernel_size
356
+ self.p_dropout = p_dropout
357
+ self.gin_channels = gin_channels
358
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
359
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
360
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
361
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
362
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
363
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
364
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
365
+ #self.bert_pre_proj = nn.Conv1d(2048, 1024, 1)
366
+ # self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
367
+ self.in_feature_net = nn.Sequential(
368
+ # input is assumed to an already normalized embedding
369
+ nn.Linear(512, 1028, bias=False),
370
+ nn.GELU(),
371
+ nn.LayerNorm(1028),
372
+ *[Block(1028, 512) for _ in range(1)],
373
+ nn.Linear(1028, 512, bias=False),
374
+ # normalize before passing to VQ?
375
+ # nn.GELU(),
376
+ # nn.LayerNorm(512),
377
+ )
378
+ self.emo_vq = VectorQuantize(
379
+ dim=512,
380
+ # codebook_size=128,
381
+ codebook_size=256,
382
+ codebook_dim=16,
383
+ # codebook_dim=32,
384
+ commitment_weight=0.1,
385
+ decay=0.99,
386
+ heads=32,
387
+ kmeans_iters=20,
388
+ separate_codebook_per_head=True,
389
+ stochastic_sample_codes=True,
390
+ threshold_ema_dead_code=2,
391
+ use_cosine_sim = True,
392
+ )
393
+ self.out_feature_net = nn.Linear(512, hidden_channels)
394
+
395
+ self.encoder = attentions.Encoder(
396
+ hidden_channels,
397
+ filter_channels,
398
+ n_heads,
399
+ n_layers,
400
+ kernel_size,
401
+ p_dropout,
402
+ gin_channels=self.gin_channels,
403
+ )
404
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
405
+
406
+ def forward(self, x, x_lengths, tone, language, bert, emo, g=None):
407
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
408
+ # en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
409
+ emo_emb = self.in_feature_net(emo)
410
+ emo_emb, _, loss_commit = self.emo_vq(emo_emb.unsqueeze(1))
411
+ loss_commit = loss_commit.mean()
412
+ emo_emb = self.out_feature_net(emo_emb)
413
+ x = (
414
+ self.emb(x)
415
+ + self.tone_emb(tone)
416
+ + self.language_emb(language)
417
+ + bert_emb
418
+ # + en_bert_emb
419
+ + emo_emb
420
+ ) * math.sqrt(
421
+ self.hidden_channels
422
+ ) # [b, t, h]
423
+ x = torch.transpose(x, 1, -1) # [b, h, t]
424
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
425
+ x.dtype
426
+ )
427
+
428
+ x = self.encoder(x * x_mask, x_mask, g=g)
429
+ stats = self.proj(x) * x_mask
430
+
431
+ m, logs = torch.split(stats, self.out_channels, dim=1)
432
+ return x, m, logs, x_mask, loss_commit
433
+
434
+
435
+ class ResidualCouplingBlock(nn.Module):
436
+ def __init__(
437
+ self,
438
+ channels,
439
+ hidden_channels,
440
+ kernel_size,
441
+ dilation_rate,
442
+ n_layers,
443
+ n_flows=4,
444
+ gin_channels=0,
445
+ ):
446
+ super().__init__()
447
+ self.channels = channels
448
+ self.hidden_channels = hidden_channels
449
+ self.kernel_size = kernel_size
450
+ self.dilation_rate = dilation_rate
451
+ self.n_layers = n_layers
452
+ self.n_flows = n_flows
453
+ self.gin_channels = gin_channels
454
+
455
+ self.flows = nn.ModuleList()
456
+ for i in range(n_flows):
457
+ self.flows.append(
458
+ modules.ResidualCouplingLayer(
459
+ channels,
460
+ hidden_channels,
461
+ kernel_size,
462
+ dilation_rate,
463
+ n_layers,
464
+ gin_channels=gin_channels,
465
+ mean_only=True,
466
+ )
467
+ )
468
+ self.flows.append(modules.Flip())
469
+
470
+ def forward(self, x, x_mask, g=None, reverse=False):
471
+ if not reverse:
472
+ for flow in self.flows:
473
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
474
+ else:
475
+ for flow in reversed(self.flows):
476
+ x = flow(x, x_mask, g=g, reverse=reverse)
477
+ return x
478
+
479
+
480
+ class PosteriorEncoder(nn.Module):
481
+ def __init__(
482
+ self,
483
+ in_channels,
484
+ out_channels,
485
+ hidden_channels,
486
+ kernel_size,
487
+ dilation_rate,
488
+ n_layers,
489
+ gin_channels=0,
490
+ ):
491
+ super().__init__()
492
+ self.in_channels = in_channels
493
+ self.out_channels = out_channels
494
+ self.hidden_channels = hidden_channels
495
+ self.kernel_size = kernel_size
496
+ self.dilation_rate = dilation_rate
497
+ self.n_layers = n_layers
498
+ self.gin_channels = gin_channels
499
+
500
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
501
+ self.enc = modules.WN(
502
+ hidden_channels,
503
+ kernel_size,
504
+ dilation_rate,
505
+ n_layers,
506
+ gin_channels=gin_channels,
507
+ )
508
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
509
+
510
+ def forward(self, x, x_lengths, g=None):
511
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
512
+ x.dtype
513
+ )
514
+ x = self.pre(x) * x_mask
515
+ x = self.enc(x, x_mask, g=g)
516
+ stats = self.proj(x) * x_mask
517
+ m, logs = torch.split(stats, self.out_channels, dim=1)
518
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
519
+ return z, m, logs, x_mask
520
+
521
+
522
+ class Generator(torch.nn.Module):
523
+ def __init__(
524
+ self,
525
+ initial_channel,
526
+ resblock,
527
+ resblock_kernel_sizes,
528
+ resblock_dilation_sizes,
529
+ upsample_rates,
530
+ upsample_initial_channel,
531
+ upsample_kernel_sizes,
532
+ gin_channels=0,
533
+ ):
534
+ super(Generator, self).__init__()
535
+ self.num_kernels = len(resblock_kernel_sizes)
536
+ self.num_upsamples = len(upsample_rates)
537
+ self.conv_pre = Conv1d(
538
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
539
+ )
540
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
541
+
542
+ self.ups = nn.ModuleList()
543
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
544
+ self.ups.append(
545
+ weight_norm(
546
+ ConvTranspose1d(
547
+ upsample_initial_channel // (2**i),
548
+ upsample_initial_channel // (2 ** (i + 1)),
549
+ k,
550
+ u,
551
+ padding=(k - u) // 2,
552
+ )
553
+ )
554
+ )
555
+
556
+ self.resblocks = nn.ModuleList()
557
+ for i in range(len(self.ups)):
558
+ ch = upsample_initial_channel // (2 ** (i + 1))
559
+ for j, (k, d) in enumerate(
560
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
561
+ ):
562
+ self.resblocks.append(resblock(ch, k, d))
563
+
564
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
565
+ self.ups.apply(init_weights)
566
+
567
+ if gin_channels != 0:
568
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
569
+
570
+ def forward(self, x, g=None):
571
+ x = self.conv_pre(x)
572
+ if g is not None:
573
+ x = x + self.cond(g)
574
+
575
+ for i in range(self.num_upsamples):
576
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
577
+ x = self.ups[i](x)
578
+ xs = None
579
+ for j in range(self.num_kernels):
580
+ if xs is None:
581
+ xs = self.resblocks[i * self.num_kernels + j](x)
582
+ else:
583
+ xs += self.resblocks[i * self.num_kernels + j](x)
584
+ x = xs / self.num_kernels
585
+ x = F.leaky_relu(x)
586
+ x = self.conv_post(x)
587
+ x = torch.tanh(x)
588
+
589
+ return x
590
+
591
+ def remove_weight_norm(self):
592
+ print("Removing weight norm...")
593
+ for layer in self.ups:
594
+ remove_weight_norm(layer)
595
+ for layer in self.resblocks:
596
+ layer.remove_weight_norm()
597
+
598
+
599
+ class DiscriminatorP(torch.nn.Module):
600
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
601
+ super(DiscriminatorP, self).__init__()
602
+ self.period = period
603
+ self.use_spectral_norm = use_spectral_norm
604
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
605
+ self.convs = nn.ModuleList(
606
+ [
607
+ norm_f(
608
+ Conv2d(
609
+ 1,
610
+ 32,
611
+ (kernel_size, 1),
612
+ (stride, 1),
613
+ padding=(get_padding(kernel_size, 1), 0),
614
+ )
615
+ ),
616
+ norm_f(
617
+ Conv2d(
618
+ 32,
619
+ 128,
620
+ (kernel_size, 1),
621
+ (stride, 1),
622
+ padding=(get_padding(kernel_size, 1), 0),
623
+ )
624
+ ),
625
+ norm_f(
626
+ Conv2d(
627
+ 128,
628
+ 512,
629
+ (kernel_size, 1),
630
+ (stride, 1),
631
+ padding=(get_padding(kernel_size, 1), 0),
632
+ )
633
+ ),
634
+ norm_f(
635
+ Conv2d(
636
+ 512,
637
+ 1024,
638
+ (kernel_size, 1),
639
+ (stride, 1),
640
+ padding=(get_padding(kernel_size, 1), 0),
641
+ )
642
+ ),
643
+ norm_f(
644
+ Conv2d(
645
+ 1024,
646
+ 1024,
647
+ (kernel_size, 1),
648
+ 1,
649
+ padding=(get_padding(kernel_size, 1), 0),
650
+ )
651
+ ),
652
+ ]
653
+ )
654
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
655
+
656
+ def forward(self, x):
657
+ fmap = []
658
+
659
+ # 1d to 2d
660
+ b, c, t = x.shape
661
+ if t % self.period != 0: # pad first
662
+ n_pad = self.period - (t % self.period)
663
+ x = F.pad(x, (0, n_pad), "reflect")
664
+ t = t + n_pad
665
+ x = x.view(b, c, t // self.period, self.period)
666
+
667
+ for layer in self.convs:
668
+ x = layer(x)
669
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
670
+ fmap.append(x)
671
+ x = self.conv_post(x)
672
+ fmap.append(x)
673
+ x = torch.flatten(x, 1, -1)
674
+
675
+ return x, fmap
676
+
677
+
678
+ class DiscriminatorS(torch.nn.Module):
679
+ def __init__(self, use_spectral_norm=False):
680
+ super(DiscriminatorS, self).__init__()
681
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
682
+ self.convs = nn.ModuleList(
683
+ [
684
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
685
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
686
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
687
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
688
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
689
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
690
+ ]
691
+ )
692
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
693
+
694
+ def forward(self, x):
695
+ fmap = []
696
+
697
+ for layer in self.convs:
698
+ x = layer(x)
699
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
700
+ fmap.append(x)
701
+ x = self.conv_post(x)
702
+ fmap.append(x)
703
+ x = torch.flatten(x, 1, -1)
704
+
705
+ return x, fmap
706
+
707
+
708
+ class MultiPeriodDiscriminator(torch.nn.Module):
709
+ def __init__(self, use_spectral_norm=False):
710
+ super(MultiPeriodDiscriminator, self).__init__()
711
+ periods = [2, 3, 5, 7, 11]
712
+
713
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
714
+ discs = discs + [
715
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
716
+ ]
717
+ self.discriminators = nn.ModuleList(discs)
718
+
719
+ def forward(self, y, y_hat):
720
+ y_d_rs = []
721
+ y_d_gs = []
722
+ fmap_rs = []
723
+ fmap_gs = []
724
+ for i, d in enumerate(self.discriminators):
725
+ y_d_r, fmap_r = d(y)
726
+ y_d_g, fmap_g = d(y_hat)
727
+ y_d_rs.append(y_d_r)
728
+ y_d_gs.append(y_d_g)
729
+ fmap_rs.append(fmap_r)
730
+ fmap_gs.append(fmap_g)
731
+
732
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
733
+
734
+
735
+ class WavLMDiscriminator(nn.Module):
736
+ """docstring for Discriminator."""
737
+
738
+ def __init__(
739
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
740
+ ):
741
+ super(WavLMDiscriminator, self).__init__()
742
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
743
+ self.pre = norm_f(
744
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
745
+ )
746
+
747
+ self.convs = nn.ModuleList(
748
+ [
749
+ norm_f(
750
+ nn.Conv1d(
751
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
752
+ )
753
+ ),
754
+ norm_f(
755
+ nn.Conv1d(
756
+ initial_channel * 2,
757
+ initial_channel * 4,
758
+ kernel_size=5,
759
+ padding=2,
760
+ )
761
+ ),
762
+ norm_f(
763
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
764
+ ),
765
+ ]
766
+ )
767
+
768
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
769
+
770
+ def forward(self, x):
771
+ x = self.pre(x)
772
+
773
+ fmap = []
774
+ for l in self.convs:
775
+ x = l(x)
776
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
777
+ fmap.append(x)
778
+ x = self.conv_post(x)
779
+ x = torch.flatten(x, 1, -1)
780
+
781
+ return x
782
+
783
+
784
+ class ReferenceEncoder(nn.Module):
785
+ """
786
+ inputs --- [N, Ty/r, n_mels*r] mels
787
+ outputs --- [N, ref_enc_gru_size]
788
+ """
789
+
790
+ def __init__(self, spec_channels, gin_channels=0):
791
+ super().__init__()
792
+ self.spec_channels = spec_channels
793
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
794
+ K = len(ref_enc_filters)
795
+ filters = [1] + ref_enc_filters
796
+ convs = [
797
+ weight_norm(
798
+ nn.Conv2d(
799
+ in_channels=filters[i],
800
+ out_channels=filters[i + 1],
801
+ kernel_size=(3, 3),
802
+ stride=(2, 2),
803
+ padding=(1, 1),
804
+ )
805
+ )
806
+ for i in range(K)
807
+ ]
808
+ self.convs = nn.ModuleList(convs)
809
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
810
+
811
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
812
+ self.gru = nn.GRU(
813
+ input_size=ref_enc_filters[-1] * out_channels,
814
+ hidden_size=256 // 2,
815
+ batch_first=True,
816
+ )
817
+ self.proj = nn.Linear(128, gin_channels)
818
+
819
+ def forward(self, inputs, mask=None):
820
+ N = inputs.size(0)
821
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
822
+ for conv in self.convs:
823
+ out = conv(out)
824
+ # out = wn(out)
825
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
826
+
827
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
828
+ T = out.size(1)
829
+ N = out.size(0)
830
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
831
+
832
+ self.gru.flatten_parameters()
833
+ memory, out = self.gru(out) # out --- [1, N, 128]
834
+
835
+ return self.proj(out.squeeze(0))
836
+
837
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
838
+ for i in range(n_convs):
839
+ L = (L - kernel_size + 2 * pad) // stride + 1
840
+ return L
841
+
842
+
843
+ class SynthesizerTrn(nn.Module):
844
+ """
845
+ Synthesizer for Training
846
+ """
847
+
848
+ def __init__(
849
+ self,
850
+ n_vocab,
851
+ spec_channels,
852
+ segment_size,
853
+ inter_channels,
854
+ hidden_channels,
855
+ filter_channels,
856
+ n_heads,
857
+ n_layers,
858
+ kernel_size,
859
+ p_dropout,
860
+ resblock,
861
+ resblock_kernel_sizes,
862
+ resblock_dilation_sizes,
863
+ upsample_rates,
864
+ upsample_initial_channel,
865
+ upsample_kernel_sizes,
866
+ n_speakers=256,
867
+ gin_channels=256,
868
+ use_sdp=True,
869
+ n_flow_layer=4,
870
+ n_layers_trans_flow=6,
871
+ flow_share_parameter=False,
872
+ use_transformer_flow=True,
873
+ **kwargs
874
+ ):
875
+ super().__init__()
876
+ self.n_vocab = n_vocab
877
+ self.spec_channels = spec_channels
878
+ self.inter_channels = inter_channels
879
+ self.hidden_channels = hidden_channels
880
+ self.filter_channels = filter_channels
881
+ self.n_heads = n_heads
882
+ self.n_layers = n_layers
883
+ self.kernel_size = kernel_size
884
+ self.p_dropout = p_dropout
885
+ self.resblock = resblock
886
+ self.resblock_kernel_sizes = resblock_kernel_sizes
887
+ self.resblock_dilation_sizes = resblock_dilation_sizes
888
+ self.upsample_rates = upsample_rates
889
+ self.upsample_initial_channel = upsample_initial_channel
890
+ self.upsample_kernel_sizes = upsample_kernel_sizes
891
+ self.segment_size = segment_size
892
+ self.n_speakers = n_speakers
893
+ self.gin_channels = gin_channels
894
+ self.n_layers_trans_flow = n_layers_trans_flow
895
+ self.use_spk_conditioned_encoder = kwargs.get(
896
+ "use_spk_conditioned_encoder", True
897
+ )
898
+ self.use_sdp = use_sdp
899
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
900
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
901
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
902
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
903
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
904
+ self.enc_gin_channels = gin_channels
905
+ self.enc_p = TextEncoder(
906
+ n_vocab,
907
+ inter_channels,
908
+ hidden_channels,
909
+ filter_channels,
910
+ n_heads,
911
+ n_layers,
912
+ kernel_size,
913
+ p_dropout,
914
+ gin_channels=self.enc_gin_channels,
915
+ )
916
+ self.dec = Generator(
917
+ inter_channels,
918
+ resblock,
919
+ resblock_kernel_sizes,
920
+ resblock_dilation_sizes,
921
+ upsample_rates,
922
+ upsample_initial_channel,
923
+ upsample_kernel_sizes,
924
+ gin_channels=gin_channels,
925
+ )
926
+ self.enc_q = PosteriorEncoder(
927
+ spec_channels,
928
+ inter_channels,
929
+ hidden_channels,
930
+ 5,
931
+ 1,
932
+ 16,
933
+ gin_channels=gin_channels,
934
+ )
935
+ if use_transformer_flow:
936
+ self.flow = TransformerCouplingBlock(
937
+ inter_channels,
938
+ hidden_channels,
939
+ filter_channels,
940
+ n_heads,
941
+ n_layers_trans_flow,
942
+ 5,
943
+ p_dropout,
944
+ n_flow_layer,
945
+ gin_channels=gin_channels,
946
+ share_parameter=flow_share_parameter,
947
+ )
948
+ else:
949
+ self.flow = ResidualCouplingBlock(
950
+ inter_channels,
951
+ hidden_channels,
952
+ 5,
953
+ 1,
954
+ n_flow_layer,
955
+ gin_channels=gin_channels,
956
+ )
957
+ self.sdp = StochasticDurationPredictor(
958
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
959
+ )
960
+ self.dp = DurationPredictor(
961
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
962
+ )
963
+
964
+ if n_speakers >= 1:
965
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
966
+ else:
967
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
968
+
969
+ def forward(
970
+ self,
971
+ x,
972
+ x_lengths,
973
+ y,
974
+ y_lengths,
975
+ sid,
976
+ tone,
977
+ language,
978
+ bert,
979
+ emo,
980
+ ):
981
+ if self.n_speakers > 0:
982
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
983
+ else:
984
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
985
+ x, m_p, logs_p, x_mask, loss_commit = self.enc_p(
986
+ x, x_lengths, tone, language, bert, emo, g=g
987
+ )
988
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
989
+ z_p = self.flow(z, y_mask, g=g)
990
+
991
+ with torch.no_grad():
992
+ # negative cross-entropy
993
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
994
+ neg_cent1 = torch.sum(
995
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
996
+ ) # [b, 1, t_s]
997
+ neg_cent2 = torch.matmul(
998
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
999
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1000
+ neg_cent3 = torch.matmul(
1001
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
1002
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1003
+ neg_cent4 = torch.sum(
1004
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
1005
+ ) # [b, 1, t_s]
1006
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
1007
+ if self.use_noise_scaled_mas:
1008
+ epsilon = (
1009
+ torch.std(neg_cent)
1010
+ * torch.randn_like(neg_cent)
1011
+ * self.current_mas_noise_scale
1012
+ )
1013
+ neg_cent = neg_cent + epsilon
1014
+
1015
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1016
+ attn = (
1017
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
1018
+ .unsqueeze(1)
1019
+ .detach()
1020
+ )
1021
+
1022
+ w = attn.sum(2)
1023
+
1024
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
1025
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
1026
+
1027
+ logw_ = torch.log(w + 1e-6) * x_mask
1028
+ logw = self.dp(x, x_mask, g=g)
1029
+ # logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
1030
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1031
+ x_mask
1032
+ ) # for averaging
1033
+ # l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1034
+
1035
+ l_length = l_length_dp + l_length_sdp
1036
+
1037
+ # expand prior
1038
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1039
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1040
+
1041
+ z_slice, ids_slice = commons.rand_slice_segments(
1042
+ z, y_lengths, self.segment_size
1043
+ )
1044
+ o = self.dec(z_slice, g=g)
1045
+ return (
1046
+ o,
1047
+ l_length,
1048
+ attn,
1049
+ ids_slice,
1050
+ x_mask,
1051
+ y_mask,
1052
+ (z, z_p, m_p, logs_p, m_q, logs_q),
1053
+ (x, logw, logw_), # , logw_sdp),
1054
+ g,
1055
+ loss_commit,
1056
+ )
1057
+
1058
+ def infer(
1059
+ self,
1060
+ x,
1061
+ x_lengths,
1062
+ sid,
1063
+ tone,
1064
+ language,
1065
+ bert,
1066
+ emo,
1067
+ noise_scale=0.667,
1068
+ length_scale=1,
1069
+ noise_scale_w=0.8,
1070
+ max_len=None,
1071
+ sdp_ratio=0,
1072
+ y=None,
1073
+ ):
1074
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
1075
+ # g = self.gst(y)
1076
+ if self.n_speakers > 0:
1077
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1078
+ else:
1079
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1080
+ x, m_p, logs_p, x_mask, _ = self.enc_p(
1081
+ x, x_lengths, tone, language, bert, emo, g=g
1082
+ )
1083
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1084
+ sdp_ratio
1085
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1086
+ w = torch.exp(logw) * x_mask * length_scale
1087
+ w_ceil = torch.ceil(w)
1088
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1089
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1090
+ x_mask.dtype
1091
+ )
1092
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1093
+ attn = commons.generate_path(w_ceil, attn_mask)
1094
+
1095
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1096
+ 1, 2
1097
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1098
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1099
+ 1, 2
1100
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1101
+
1102
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1103
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1104
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1105
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)