Respair commited on
Commit
bdb9479
·
verified ·
1 Parent(s): e3b8436

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +289 -0
utils.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import sys
4
+ import time
5
+ from collections import defaultdict
6
+
7
+ import matplotlib
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+ from torch import nn
12
+ import jiwer
13
+
14
+ import matplotlib.pylab as plt
15
+ import functools
16
+ import os
17
+ import random
18
+ import traceback
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Optional, Tuple
21
+
22
+ import librosa
23
+ import numpy as np
24
+ import torch
25
+ from einops import rearrange
26
+ from scipy import ndimage
27
+ from torch.special import gammaln
28
+
29
+
30
+ def calc_wer(target, pred, ignore_indexes=[0]):
31
+ target_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(target)))))
32
+ pred_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(pred)))))
33
+ target_str = ' '.join(target_chars)
34
+ pred_str = ' '.join(pred_chars)
35
+ error = jiwer.wer(target_str, pred_str)
36
+ return error
37
+
38
+ def drop_duplicated(chars):
39
+ ret_chars = [chars[0]]
40
+ for prev, curr in zip(chars[:-1], chars[1:]):
41
+ if prev != curr:
42
+ ret_chars.append(curr)
43
+ return ret_chars
44
+
45
+ # def build_criterion(critic_params={}):
46
+
47
+ # criterion = {
48
+ # "ce": nn.CrossEntropyLoss(ignore_index=-1),
49
+ # "ctc": torch.nn.CTCLoss(**critic_params.get('ctc', {})),
50
+ # "hinge": nn.HingeEmbeddingLoss(margin=critic_params.get('hinge', {}).get("margin", 1.0))
51
+ # }
52
+ # return criterion
53
+
54
+ def build_criterion(critic_params={}):
55
+ criterion = {
56
+ "ce": nn.CrossEntropyLoss(ignore_index=-1),
57
+ "ctc": torch.nn.CTCLoss(**critic_params.get('ctc', {})),
58
+ }
59
+ return criterion
60
+
61
+
62
+
63
+ def get_data_path_list(train_path=None, val_path=None):
64
+ if train_path is None:
65
+ train_path = "Data/train_list.txt"
66
+ if val_path is None:
67
+ val_path = "Data/val_list.txt"
68
+
69
+ with open(train_path, 'r') as f:
70
+ train_list = f.readlines()
71
+ with open(val_path, 'r') as f:
72
+ val_list = f.readlines()
73
+
74
+ return train_list, val_list
75
+
76
+
77
+ def plot_image(image):
78
+ fig, ax = plt.subplots(figsize=(10, 2))
79
+ im = ax.imshow(image, aspect="auto", origin="lower",
80
+ interpolation='none')
81
+
82
+ fig.canvas.draw()
83
+ plt.close()
84
+
85
+ return fig
86
+
87
+
88
+
89
+ class PartialConv1d(torch.nn.Conv1d):
90
+ """
91
+ Zero padding creates a unique identifier for where the edge of the data is, such that the model can almost always identify
92
+ exactly where it is relative to either edge given a sufficient receptive field. Partial padding goes to some lengths to remove
93
+ this affect.
94
+ """
95
+
96
+ __constants__ = ['slide_winsize']
97
+ slide_winsize: float
98
+
99
+ def __init__(self, *args, **kwargs):
100
+ super(PartialConv1d, self).__init__(*args, **kwargs)
101
+ weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0])
102
+ self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False)
103
+ self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]
104
+
105
+ def forward(self, input, mask_in):
106
+ if mask_in is None:
107
+ mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device)
108
+ else:
109
+ mask = mask_in
110
+ input = torch.mul(input, mask)
111
+ with torch.no_grad():
112
+ update_mask = F.conv1d(
113
+ mask,
114
+ self.weight_maskUpdater,
115
+ bias=None,
116
+ stride=self.stride,
117
+ padding=self.padding,
118
+ dilation=self.dilation,
119
+ groups=1,
120
+ )
121
+ update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize)
122
+ mask_ratio = self.slide_winsize / update_mask_filled
123
+ update_mask = torch.clamp(update_mask, 0, 1)
124
+ mask_ratio = torch.mul(mask_ratio, update_mask)
125
+
126
+ raw_out = self._conv_forward(input, self.weight, self.bias)
127
+
128
+ if self.bias is not None:
129
+ bias_view = self.bias.view(1, self.out_channels, 1)
130
+ output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view
131
+ output = torch.mul(output, update_mask)
132
+ else:
133
+ output = torch.mul(raw_out, mask_ratio)
134
+
135
+ return output
136
+
137
+
138
+ class LinearNorm(torch.nn.Module):
139
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
140
+ super().__init__()
141
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
142
+
143
+ torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
144
+
145
+ def forward(self, x):
146
+ return self.linear_layer(x)
147
+
148
+
149
+ class ConvNorm(torch.nn.Module):
150
+ __constants__ = ['use_partial_padding']
151
+ use_partial_padding: bool
152
+
153
+ def __init__(
154
+ self,
155
+ in_channels,
156
+ out_channels,
157
+ kernel_size=1,
158
+ stride=1,
159
+ padding=None,
160
+ dilation=1,
161
+ bias=True,
162
+ w_init_gain='linear',
163
+ use_partial_padding=False,
164
+ use_weight_norm=False,
165
+ norm_fn=None,
166
+ ):
167
+ super(ConvNorm, self).__init__()
168
+ if padding is None:
169
+ assert kernel_size % 2 == 1
170
+ padding = int(dilation * (kernel_size - 1) / 2)
171
+ self.use_partial_padding = use_partial_padding
172
+ conv_fn = torch.nn.Conv1d
173
+ if use_partial_padding:
174
+ conv_fn = PartialConv1d
175
+ self.conv = conv_fn(
176
+ in_channels,
177
+ out_channels,
178
+ kernel_size=kernel_size,
179
+ stride=stride,
180
+ padding=padding,
181
+ dilation=dilation,
182
+ bias=bias,
183
+ )
184
+ torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
185
+ if use_weight_norm:
186
+ self.conv = torch.nn.utils.weight_norm(self.conv)
187
+ if norm_fn is not None:
188
+ self.norm = norm_fn(out_channels, affine=True)
189
+ else:
190
+ self.norm = None
191
+
192
+ def forward(self, signal, mask=None):
193
+ if self.use_partial_padding:
194
+ ret = self.conv(signal, mask)
195
+ if self.norm is not None:
196
+ ret = self.norm(ret, mask)
197
+ else:
198
+ if mask is not None:
199
+ signal = signal.mul(mask)
200
+ ret = self.conv(signal)
201
+ if self.norm is not None:
202
+ ret = self.norm(ret)
203
+
204
+ # if self.is_adapter_available():
205
+ # ret = self.forward_enabled_adapters(ret.transpose(1, 2)).transpose(1, 2)
206
+
207
+ return ret
208
+
209
+
210
+
211
+ class BetaBinomialInterpolator:
212
+ """
213
+ This module calculates alignment prior matrices (based on beta-binomial distribution) using cached popular sizes and image interpolation.
214
+ The implementation is taken from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py
215
+ """
216
+
217
+ def __init__(self, round_mel_len_to=50, round_text_len_to=10, cache_size=500, scaling_factor: float = 1.0):
218
+ self.round_mel_len_to = round_mel_len_to
219
+ self.round_text_len_to = round_text_len_to
220
+ cached_func = lambda x, y: beta_binomial_prior_distribution(x, y, scaling_factor=scaling_factor)
221
+ self.bank = functools.lru_cache(maxsize=cache_size)(cached_func)
222
+
223
+ @staticmethod
224
+ def round(val, to):
225
+ return max(1, int(np.round((val + 1) / to))) * to
226
+
227
+ def __call__(self, w, h):
228
+ bw = BetaBinomialInterpolator.round(w, to=self.round_mel_len_to)
229
+ bh = BetaBinomialInterpolator.round(h, to=self.round_text_len_to)
230
+ ret = ndimage.zoom(self.bank(bw, bh).T, zoom=(w / bw, h / bh), order=1)
231
+ assert ret.shape[0] == w, ret.shape
232
+ assert ret.shape[1] == h, ret.shape
233
+ return ret
234
+
235
+
236
+ def general_padding(item, item_len, max_len, pad_value=0):
237
+ if item_len < max_len:
238
+ item = torch.nn.functional.pad(item, (0, max_len - item_len), value=pad_value)
239
+ return item
240
+
241
+
242
+ def stack_tensors(tensors: List[torch.Tensor], max_lens: List[int], pad_value: float = 0.0) -> torch.Tensor:
243
+ """
244
+ Create batch by stacking input tensor list along the time axes.
245
+
246
+ Args:
247
+ tensors: List of tensors to pad and stack
248
+ max_lens: List of lengths to pad each axis to, starting with the last axis
249
+ pad_value: Value for padding
250
+
251
+ Returns:
252
+ Padded and stacked tensor.
253
+ """
254
+ padded_tensors = []
255
+ for tensor in tensors:
256
+ padding = []
257
+ for i, max_len in enumerate(max_lens, 1):
258
+ padding += [0, max_len - tensor.shape[-i]]
259
+
260
+ padded_tensor = torch.nn.functional.pad(tensor, pad=padding, value=pad_value)
261
+ padded_tensors.append(padded_tensor)
262
+
263
+ stacked_tensor = torch.stack(padded_tensors)
264
+ return stacked_tensor
265
+
266
+
267
+ def logbeta(x, y):
268
+ return gammaln(x) + gammaln(y) - gammaln(x + y)
269
+
270
+
271
+ def logcombinations(n, k):
272
+ return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1)
273
+
274
+
275
+ def logbetabinom(n, a, b, x):
276
+ return logcombinations(n, x) + logbeta(x + a, n - x + b) - logbeta(a, b)
277
+
278
+
279
+ def beta_binomial_prior_distribution(phoneme_count: int, mel_count: int, scaling_factor: float = 1.0) -> np.array:
280
+ x = rearrange(torch.arange(0, phoneme_count), "b -> 1 b")
281
+ y = rearrange(torch.arange(1, mel_count + 1), "b -> b 1")
282
+ a = scaling_factor * y
283
+ b = scaling_factor * (mel_count + 1 - y)
284
+ n = torch.FloatTensor([phoneme_count - 1])
285
+
286
+ return logbetabinom(n, a, b, x).exp().numpy()
287
+
288
+
289
+ # example : attn_prior = (torch.from_numpy(beta_binomial_interpolator(spect_len.item(), text_len.item())).unsqueeze(0).to(text.device))