yangwang825 commited on
Commit
e5dee27
·
verified ·
1 Parent(s): 140492e

Upload XVectorForSequenceClassification

Browse files
Files changed (13) hide show
  1. angular_loss.py +68 -0
  2. audio_processing.py +411 -0
  3. cnn.py +247 -0
  4. config.json +6 -2
  5. conv_asr.py +189 -0
  6. features.py +560 -0
  7. logging.py +55 -0
  8. model.safetensors +3 -0
  9. modeling_xvector.py +153 -0
  10. module.py +105 -0
  11. normalization.py +99 -0
  12. spectrogram_augment.py +223 -0
  13. tdnn_attention.py +550 -0
angular_loss.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Loss(nn.modules.loss._Loss):
6
+ """Inherit this class to implement custom loss."""
7
+
8
+ def __init__(self, **kwargs):
9
+ super(Loss, self).__init__(**kwargs)
10
+
11
+
12
+ class AdditiveMarginSoftmaxLoss(Loss):
13
+ """Computes Additive Margin Softmax (CosFace) Loss
14
+
15
+ Paper: CosFace: Large Margin Cosine Loss for Deep Face Recognition
16
+
17
+ args:
18
+ scale: scale value for cosine angle
19
+ margin: margin value added to cosine angle
20
+ """
21
+
22
+ def __init__(self, scale=30.0, margin=0.2):
23
+ super().__init__()
24
+
25
+ self.eps = 1e-7
26
+ self.scale = scale
27
+ self.margin = margin
28
+
29
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
30
+ # Extract the logits corresponding to the true class
31
+ logits_target = logits[torch.arange(logits.size(0)), labels] # Faster indexing
32
+ numerator = self.scale * (logits_target - self.margin) # Apply additive margin
33
+ # Exclude the target logits from denominator calculation
34
+ logits.scatter_(1, labels.unsqueeze(1), float('-inf')) # Mask target class
35
+ denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * logits), dim=1)
36
+ # Compute final loss
37
+ loss = -torch.log(torch.exp(numerator) / denominator)
38
+ return loss.mean()
39
+
40
+
41
+ class AdditiveAngularMarginSoftmaxLoss(Loss):
42
+ """Computes Additive Angular Margin Softmax (ArcFace) Loss
43
+
44
+ Paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition
45
+
46
+ Args:
47
+ scale: scale value for cosine angle
48
+ margin: margin value added to cosine angle
49
+ """
50
+
51
+ def __init__(self, scale=20.0, margin=1.35):
52
+ super().__init__()
53
+
54
+ self.eps = 1e-7
55
+ self.scale = scale
56
+ self.margin = margin
57
+
58
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
59
+ numerator = self.scale * torch.cos(
60
+ torch.acos(torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps))
61
+ + self.margin
62
+ )
63
+ excl = torch.cat(
64
+ [torch.cat((logits[i, :y], logits[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0
65
+ )
66
+ denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * excl), dim=1)
67
+ L = numerator - torch.log(denominator)
68
+ return -torch.mean(L)
audio_processing.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from packaging import version
3
+ from dataclasses import dataclass
4
+ from abc import ABC, abstractmethod
5
+
6
+ import torch
7
+
8
+ try:
9
+ import torchaudio
10
+ import torchaudio.functional
11
+ import torchaudio.transforms
12
+
13
+ TORCHAUDIO_VERSION = version.parse(torchaudio.__version__)
14
+ TORCHAUDIO_VERSION_MIN = version.parse('0.5')
15
+
16
+ HAVE_TORCHAUDIO = True
17
+ except ModuleNotFoundError:
18
+ HAVE_TORCHAUDIO = False
19
+
20
+ from .logging import logger
21
+ from .module import NeuralModule
22
+ from .features import FilterbankFeatures, FilterbankFeaturesTA
23
+ from .spectrogram_augment import SpecCutout, SpecAugment
24
+
25
+
26
+ class AudioPreprocessor(NeuralModule, ABC):
27
+ """
28
+ An interface for Neural Modules that performs audio pre-processing,
29
+ transforming the wav files to features.
30
+ """
31
+
32
+ def __init__(self, win_length, hop_length):
33
+ super().__init__()
34
+
35
+ self.win_length = win_length
36
+ self.hop_length = hop_length
37
+
38
+ self.torch_windows = {
39
+ 'hann': torch.hann_window,
40
+ 'hamming': torch.hamming_window,
41
+ 'blackman': torch.blackman_window,
42
+ 'bartlett': torch.bartlett_window,
43
+ 'ones': torch.ones,
44
+ None: torch.ones,
45
+ }
46
+
47
+ # Normally, when you call to(dtype) on a torch.nn.Module, all
48
+ # floating point parameters and buffers will change to that
49
+ # dtype, rather than being float32. The AudioPreprocessor
50
+ # classes, uniquely, don't actually have any parameters or
51
+ # buffers from what I see. In addition, we want the input to
52
+ # the preprocessor to be float32, but need to create the
53
+ # output in appropriate precision. We have this empty tensor
54
+ # here just to detect which dtype tensor this module should
55
+ # output at the end of execution.
56
+ self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False)
57
+
58
+ @torch.no_grad()
59
+ def forward(self, input_signal, length):
60
+ processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length)
61
+ processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype)
62
+ return processed_signal, processed_length
63
+
64
+ @abstractmethod
65
+ def get_features(self, input_signal, length):
66
+ # Called by forward(). Subclasses should implement this.
67
+ pass
68
+
69
+
70
+ class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
71
+ """Featurizer module that converts wavs to mel spectrograms.
72
+
73
+ Args:
74
+ sample_rate (int): Sample rate of the input audio data.
75
+ Defaults to 16000
76
+ window_size (float): Size of window for fft in seconds
77
+ Defaults to 0.02
78
+ window_stride (float): Stride of window for fft in seconds
79
+ Defaults to 0.01
80
+ n_window_size (int): Size of window for fft in samples
81
+ Defaults to None. Use one of window_size or n_window_size.
82
+ n_window_stride (int): Stride of window for fft in samples
83
+ Defaults to None. Use one of window_stride or n_window_stride.
84
+ window (str): Windowing function for fft. can be one of ['hann',
85
+ 'hamming', 'blackman', 'bartlett']
86
+ Defaults to "hann"
87
+ normalize (str): Can be one of ['per_feature', 'all_features']; all
88
+ other options disable feature normalization. 'all_features'
89
+ normalizes the entire spectrogram to be mean 0 with std 1.
90
+ 'pre_features' normalizes per channel / freq instead.
91
+ Defaults to "per_feature"
92
+ n_fft (int): Length of FT window. If None, it uses the smallest power
93
+ of 2 that is larger than n_window_size.
94
+ Defaults to None
95
+ preemph (float): Amount of pre emphasis to add to audio. Can be
96
+ disabled by passing None.
97
+ Defaults to 0.97
98
+ features (int): Number of mel spectrogram freq bins to output.
99
+ Defaults to 64
100
+ lowfreq (int): Lower bound on mel basis in Hz.
101
+ Defaults to 0
102
+ highfreq (int): Lower bound on mel basis in Hz.
103
+ Defaults to None
104
+ log (bool): Log features.
105
+ Defaults to True
106
+ log_zero_guard_type(str): Need to avoid taking the log of zero. There
107
+ are two options: "add" or "clamp".
108
+ Defaults to "add".
109
+ log_zero_guard_value(float, or str): Add or clamp requires the number
110
+ to add with or clamp to. log_zero_guard_value can either be a float
111
+ or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is
112
+ passed.
113
+ Defaults to 2**-24.
114
+ dither (float): Amount of white-noise dithering.
115
+ Defaults to 1e-5
116
+ pad_to (int): Ensures that the output size of the time dimension is
117
+ a multiple of pad_to.
118
+ Defaults to 16
119
+ frame_splicing (int): Defaults to 1
120
+ exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length
121
+ // hop_length. Defaults to False.
122
+ pad_value (float): The value that shorter mels are padded with.
123
+ Defaults to 0
124
+ mag_power (float): The power that the linear spectrogram is raised to
125
+ prior to multiplication with mel basis.
126
+ Defaults to 2 for a power spec
127
+ rng : Random number generator
128
+ nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to
129
+ samples in the batch.
130
+ Defaults to 0.0
131
+ nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
132
+ Defaults to 4000
133
+ use_torchaudio: Whether to use the `torchaudio` implementation.
134
+ mel_norm: Normalization used for mel filterbank weights.
135
+ Defaults to 'slaney' (area normalization)
136
+ stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
137
+ stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ sample_rate=16000,
143
+ window_size=0.02,
144
+ window_stride=0.01,
145
+ n_window_size=None,
146
+ n_window_stride=None,
147
+ window="hann",
148
+ normalize="per_feature",
149
+ n_fft=None,
150
+ preemph=0.97,
151
+ features=64,
152
+ lowfreq=0,
153
+ highfreq=None,
154
+ log=True,
155
+ log_zero_guard_type="add",
156
+ log_zero_guard_value=2**-24,
157
+ dither=1e-5,
158
+ pad_to=16,
159
+ frame_splicing=1,
160
+ exact_pad=False,
161
+ pad_value=0,
162
+ mag_power=2.0,
163
+ rng=None,
164
+ nb_augmentation_prob=0.0,
165
+ nb_max_freq=4000,
166
+ use_torchaudio: bool = False,
167
+ mel_norm="slaney",
168
+ ):
169
+ super().__init__(n_window_size, n_window_stride)
170
+
171
+ self._sample_rate = sample_rate
172
+ if window_size and n_window_size:
173
+ raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
174
+ if window_stride and n_window_stride:
175
+ raise ValueError(
176
+ f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
177
+ )
178
+ if window_size:
179
+ n_window_size = int(window_size * self._sample_rate)
180
+ if window_stride:
181
+ n_window_stride = int(window_stride * self._sample_rate)
182
+
183
+ # Given the long and similar argument list, point to the class and instantiate it by reference
184
+ if not use_torchaudio:
185
+ logger.warning("Current only support FilterbankFeatures with torchaudio.")
186
+ featurizer_class = FilterbankFeaturesTA
187
+ else:
188
+ featurizer_class = FilterbankFeaturesTA
189
+ self.featurizer = featurizer_class(
190
+ sample_rate=self._sample_rate,
191
+ n_window_size=n_window_size,
192
+ n_window_stride=n_window_stride,
193
+ window=window,
194
+ normalize=normalize,
195
+ n_fft=n_fft,
196
+ preemph=preemph,
197
+ nfilt=features,
198
+ lowfreq=lowfreq,
199
+ highfreq=highfreq,
200
+ log=log,
201
+ log_zero_guard_type=log_zero_guard_type,
202
+ log_zero_guard_value=log_zero_guard_value,
203
+ dither=dither,
204
+ pad_to=pad_to,
205
+ frame_splicing=frame_splicing,
206
+ exact_pad=exact_pad,
207
+ pad_value=pad_value,
208
+ mag_power=mag_power,
209
+ rng=rng,
210
+ nb_augmentation_prob=nb_augmentation_prob,
211
+ nb_max_freq=nb_max_freq,
212
+ mel_norm=mel_norm,
213
+ )
214
+
215
+ def get_features(self, input_signal, length):
216
+ return self.featurizer(input_signal, length) # return tensor shape of (B, D, T)
217
+
218
+ @property
219
+ def filter_banks(self):
220
+ return self.featurizer.filter_banks
221
+
222
+
223
+ class AudioToMFCCPreprocessor(AudioPreprocessor):
224
+ """Preprocessor that converts wavs to MFCCs.
225
+ Uses torchaudio.transforms.MFCC.
226
+
227
+ Args:
228
+ sample_rate: The sample rate of the audio.
229
+ Defaults to 16000.
230
+ window_size: Size of window for fft in seconds. Used to calculate the
231
+ win_length arg for mel spectrogram.
232
+ Defaults to 0.02
233
+ window_stride: Stride of window for fft in seconds. Used to caculate
234
+ the hop_length arg for mel spect.
235
+ Defaults to 0.01
236
+ n_window_size: Size of window for fft in samples
237
+ Defaults to None. Use one of window_size or n_window_size.
238
+ n_window_stride: Stride of window for fft in samples
239
+ Defaults to None. Use one of window_stride or n_window_stride.
240
+ window: Windowing function for fft. can be one of ['hann',
241
+ 'hamming', 'blackman', 'bartlett', 'none', 'null'].
242
+ Defaults to 'hann'
243
+ n_fft: Length of FT window. If None, it uses the smallest power of 2
244
+ that is larger than n_window_size.
245
+ Defaults to None
246
+ lowfreq (int): Lower bound on mel basis in Hz.
247
+ Defaults to 0
248
+ highfreq (int): Lower bound on mel basis in Hz.
249
+ Defaults to None
250
+ n_mels: Number of mel filterbanks.
251
+ Defaults to 64
252
+ n_mfcc: Number of coefficients to retain
253
+ Defaults to 64
254
+ dct_type: Type of discrete cosine transform to use
255
+ norm: Type of norm to use
256
+ log: Whether to use log-mel spectrograms instead of db-scaled.
257
+ Defaults to True.
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ sample_rate=16000,
263
+ window_size=0.02,
264
+ window_stride=0.01,
265
+ n_window_size=None,
266
+ n_window_stride=None,
267
+ window='hann',
268
+ n_fft=None,
269
+ lowfreq=0.0,
270
+ highfreq=None,
271
+ n_mels=64,
272
+ n_mfcc=64,
273
+ dct_type=2,
274
+ norm='ortho',
275
+ log=True,
276
+ ):
277
+ self._sample_rate = sample_rate
278
+ if not HAVE_TORCHAUDIO:
279
+ logger.warning('Could not import torchaudio. Some features might not work.')
280
+
281
+ raise ModuleNotFoundError(
282
+ "torchaudio is not installed but is necessary for "
283
+ "AudioToMFCCPreprocessor. We recommend you try "
284
+ "building it from source for the PyTorch version you have."
285
+ )
286
+ if window_size and n_window_size:
287
+ raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
288
+ if window_stride and n_window_stride:
289
+ raise ValueError(
290
+ f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
291
+ )
292
+ # Get win_length (n_window_size) and hop_length (n_window_stride)
293
+ if window_size:
294
+ n_window_size = int(window_size * self._sample_rate)
295
+ if window_stride:
296
+ n_window_stride = int(window_stride * self._sample_rate)
297
+
298
+ super().__init__(n_window_size, n_window_stride)
299
+
300
+ mel_kwargs = {}
301
+
302
+ mel_kwargs['f_min'] = lowfreq
303
+ mel_kwargs['f_max'] = highfreq
304
+ mel_kwargs['n_mels'] = n_mels
305
+
306
+ mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size))
307
+
308
+ mel_kwargs['win_length'] = n_window_size
309
+ mel_kwargs['hop_length'] = n_window_stride
310
+
311
+ # Set window_fn. None defaults to torch.ones.
312
+ window_fn = self.torch_windows.get(window, None)
313
+ if window_fn is None:
314
+ raise ValueError(
315
+ f"Window argument for AudioProcessor is invalid: {window}."
316
+ f"For no window function, use 'ones' or None."
317
+ )
318
+ mel_kwargs['window_fn'] = window_fn
319
+
320
+ # Use torchaudio's implementation of MFCCs as featurizer
321
+ self.featurizer = torchaudio.transforms.MFCC(
322
+ sample_rate=self._sample_rate,
323
+ n_mfcc=n_mfcc,
324
+ dct_type=dct_type,
325
+ norm=norm,
326
+ log_mels=log,
327
+ melkwargs=mel_kwargs,
328
+ )
329
+
330
+ def get_features(self, input_signal, length):
331
+ features = self.featurizer(input_signal)
332
+ seq_len = torch.ceil(length.to(torch.float32) / self.hop_length).to(dtype=torch.long)
333
+ return features, seq_len
334
+
335
+
336
+ class SpectrogramAugmentation(NeuralModule):
337
+ """
338
+ Performs time and freq cuts in one of two ways.
339
+ SpecAugment zeroes out vertical and horizontal sections as described in
340
+ SpecAugment (https://arxiv.org/abs/1904.08779). Arguments for use with
341
+ SpecAugment are `freq_masks`, `time_masks`, `freq_width`, and `time_width`.
342
+ SpecCutout zeroes out rectangulars as described in Cutout
343
+ (https://arxiv.org/abs/1708.04552). Arguments for use with Cutout are
344
+ `rect_masks`, `rect_freq`, and `rect_time`.
345
+
346
+ Args:
347
+ freq_masks (int): how many frequency segments should be cut.
348
+ Defaults to 0.
349
+ time_masks (int): how many time segments should be cut
350
+ Defaults to 0.
351
+ freq_width (int): maximum number of frequencies to be cut in one
352
+ segment.
353
+ Defaults to 10.
354
+ time_width (int): maximum number of time steps to be cut in one
355
+ segment
356
+ Defaults to 10.
357
+ rect_masks (int): how many rectangular masks should be cut
358
+ Defaults to 0.
359
+ rect_freq (int): maximum size of cut rectangles along the frequency
360
+ dimension
361
+ Defaults to 5.
362
+ rect_time (int): maximum size of cut rectangles along the time
363
+ dimension
364
+ Defaults to 25.
365
+ use_numba_spec_augment: use numba code for Spectrogram augmentation
366
+ use_vectorized_spec_augment: use vectorized code for Spectrogram augmentation
367
+
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ freq_masks=0,
373
+ time_masks=0,
374
+ freq_width=10,
375
+ time_width=10,
376
+ rect_masks=0,
377
+ rect_time=5,
378
+ rect_freq=20,
379
+ rng=None,
380
+ mask_value=0.0,
381
+ use_vectorized_spec_augment: bool = True,
382
+ ):
383
+ super().__init__()
384
+
385
+ if rect_masks > 0:
386
+ self.spec_cutout = SpecCutout(
387
+ rect_masks=rect_masks,
388
+ rect_time=rect_time,
389
+ rect_freq=rect_freq,
390
+ rng=rng,
391
+ )
392
+ # self.spec_cutout.to(self._device)
393
+ else:
394
+ self.spec_cutout = lambda input_spec: input_spec
395
+ if freq_masks + time_masks > 0:
396
+ self.spec_augment = SpecAugment(
397
+ freq_masks=freq_masks,
398
+ time_masks=time_masks,
399
+ freq_width=freq_width,
400
+ time_width=time_width,
401
+ rng=rng,
402
+ mask_value=mask_value,
403
+ use_vectorized_code=use_vectorized_spec_augment,
404
+ )
405
+ else:
406
+ self.spec_augment = lambda input_spec, length: input_spec
407
+
408
+ def forward(self, input_spec, length):
409
+ augmented_spec = self.spec_cutout(input_spec=input_spec)
410
+ augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length)
411
+ return augmented_spec # # return tensor shape of (B, D, T)
cnn.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Conv1d(nn.Module):
8
+ """This function implements 1d convolution.
9
+
10
+ Arguments
11
+ ---------
12
+ out_channels : int
13
+ It is the number of output channels.
14
+ kernel_size : int
15
+ Kernel size of the convolutional filters.
16
+ input_shape : tuple
17
+ The shape of the input. Alternatively use ``in_channels``.
18
+ in_channels : int
19
+ The number of input channels. Alternatively use ``input_shape``.
20
+ stride : int
21
+ Stride factor of the convolutional filters. When the stride factor > 1,
22
+ a decimation in time is performed.
23
+ dilation : int
24
+ Dilation factor of the convolutional filters.
25
+ padding : str
26
+ (same, valid, causal). If "valid", no padding is performed.
27
+ If "same" and stride is 1, output shape is the same as the input shape.
28
+ "causal" results in causal (dilated) convolutions.
29
+ groups : int
30
+ Number of blocked connections from input channels to output channels.
31
+ bias : bool
32
+ Whether to add a bias term to convolution operation.
33
+ padding_mode : str
34
+ This flag specifies the type of padding. See torch.nn documentation
35
+ for more information.
36
+ skip_transpose : bool
37
+ If False, uses batch x time x channel convention of speechbrain.
38
+ If True, uses batch x channel x time convention.
39
+ weight_norm : bool
40
+ If True, use weight normalization,
41
+ to be removed with self.remove_weight_norm() at inference
42
+ conv_init : str
43
+ Weight initialization for the convolution network
44
+ default_padding: str or int
45
+ This sets the default padding mode that will be used by the pytorch Conv1d backend.
46
+
47
+ Example
48
+ -------
49
+ >>> inp_tensor = torch.rand([10, 40, 16])
50
+ >>> cnn_1d = Conv1d(
51
+ ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
52
+ ... )
53
+ >>> out_tensor = cnn_1d(inp_tensor)
54
+ >>> out_tensor.shape
55
+ torch.Size([10, 40, 8])
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ out_channels,
61
+ kernel_size,
62
+ input_shape=None,
63
+ in_channels=None,
64
+ stride=1,
65
+ dilation=1,
66
+ padding="same",
67
+ groups=1,
68
+ bias=True,
69
+ padding_mode="reflect",
70
+ skip_transpose=False,
71
+ weight_norm=False,
72
+ conv_init=None,
73
+ default_padding=0,
74
+ ):
75
+ super().__init__()
76
+ self.kernel_size = kernel_size
77
+ self.stride = stride
78
+ self.dilation = dilation
79
+ self.padding = padding
80
+ self.padding_mode = padding_mode
81
+ self.unsqueeze = False
82
+ self.skip_transpose = skip_transpose
83
+
84
+ if input_shape is None and in_channels is None:
85
+ raise ValueError("Must provide one of input_shape or in_channels")
86
+
87
+ if in_channels is None:
88
+ in_channels = self._check_input_shape(input_shape)
89
+
90
+ self.in_channels = in_channels
91
+
92
+ self.conv = nn.Conv1d(
93
+ in_channels,
94
+ out_channels,
95
+ self.kernel_size,
96
+ stride=self.stride,
97
+ dilation=self.dilation,
98
+ padding=default_padding,
99
+ groups=groups,
100
+ bias=bias,
101
+ )
102
+
103
+ if conv_init == "kaiming":
104
+ nn.init.kaiming_normal_(self.conv.weight)
105
+ elif conv_init == "zero":
106
+ nn.init.zeros_(self.conv.weight)
107
+ elif conv_init == "normal":
108
+ nn.init.normal_(self.conv.weight, std=1e-6)
109
+
110
+ if weight_norm:
111
+ self.conv = nn.utils.weight_norm(self.conv)
112
+
113
+ def forward(self, x, *args, **kwargs):
114
+ """Returns the output of the convolution.
115
+
116
+ Arguments
117
+ ---------
118
+ x : torch.Tensor (batch, time, channel)
119
+ input to convolve. 2d or 4d tensors are expected.
120
+
121
+ Returns
122
+ -------
123
+ wx : torch.Tensor
124
+ The convolved outputs.
125
+ """
126
+ if not self.skip_transpose:
127
+ x = x.transpose(1, -1)
128
+
129
+ if self.unsqueeze:
130
+ x = x.unsqueeze(1)
131
+
132
+ if self.padding == "same":
133
+ x = self._manage_padding(
134
+ x, self.kernel_size, self.dilation, self.stride
135
+ )
136
+
137
+ elif self.padding == "causal":
138
+ num_pad = (self.kernel_size - 1) * self.dilation
139
+ x = F.pad(x, (num_pad, 0))
140
+
141
+ elif self.padding == "valid":
142
+ pass
143
+
144
+ else:
145
+ raise ValueError(
146
+ "Padding must be 'same', 'valid' or 'causal'. Got "
147
+ + self.padding
148
+ )
149
+
150
+ wx = self.conv(x)
151
+
152
+ if self.unsqueeze:
153
+ wx = wx.squeeze(1)
154
+
155
+ if not self.skip_transpose:
156
+ wx = wx.transpose(1, -1)
157
+
158
+ return wx
159
+
160
+ def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
161
+ """This function performs zero-padding on the time axis
162
+ such that their lengths is unchanged after the convolution.
163
+
164
+ Arguments
165
+ ---------
166
+ x : torch.Tensor
167
+ Input tensor.
168
+ kernel_size : int
169
+ Size of kernel.
170
+ dilation : int
171
+ Dilation used.
172
+ stride : int
173
+ Stride.
174
+
175
+ Returns
176
+ -------
177
+ x : torch.Tensor
178
+ The padded outputs.
179
+ """
180
+
181
+ # Detecting input shape
182
+ L_in = self.in_channels
183
+
184
+ # Time padding
185
+ padding = get_padding_elem(L_in, stride, kernel_size, dilation)
186
+
187
+ # Applying padding
188
+ x = F.pad(x, padding, mode=self.padding_mode)
189
+
190
+ return x
191
+
192
+ def _check_input_shape(self, shape):
193
+ """Checks the input shape and returns the number of input channels."""
194
+
195
+ if len(shape) == 2:
196
+ self.unsqueeze = True
197
+ in_channels = 1
198
+ elif self.skip_transpose:
199
+ in_channels = shape[1]
200
+ elif len(shape) == 3:
201
+ in_channels = shape[2]
202
+ else:
203
+ raise ValueError(
204
+ "conv1d expects 2d, 3d inputs. Got " + str(len(shape))
205
+ )
206
+
207
+ # Kernel size must be odd
208
+ if not self.padding == "valid" and self.kernel_size % 2 == 0:
209
+ raise ValueError(
210
+ "The field kernel size must be an odd number. Got %s."
211
+ % (self.kernel_size)
212
+ )
213
+
214
+ return in_channels
215
+
216
+ def remove_weight_norm(self):
217
+ """Removes weight normalization at inference if used during training."""
218
+ self.conv = nn.utils.remove_weight_norm(self.conv)
219
+
220
+
221
+ def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
222
+ """This function computes the number of elements to add for zero-padding.
223
+
224
+ Arguments
225
+ ---------
226
+ L_in : int
227
+ stride: int
228
+ kernel_size : int
229
+ dilation : int
230
+
231
+ Returns
232
+ -------
233
+ padding : int
234
+ The size of the padding to be added
235
+ """
236
+ if stride > 1:
237
+ padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
238
+
239
+ else:
240
+ L_out = (
241
+ math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
242
+ )
243
+ padding = [
244
+ math.floor((L_in - L_out) / 2),
245
+ math.floor((L_in - L_out) / 2),
246
+ ]
247
+ return padding
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
- "_attn_implementation_autoset": true,
3
  "angular": false,
 
 
 
4
  "attention_channels": 128,
5
  "auto_map": {
6
- "AutoConfig": "configuration_xvector.XVectorConfig"
 
7
  },
8
  "bos_token_id": 1,
9
  "decoder_config": {
@@ -2603,6 +2606,7 @@
2603
  },
2604
  "time_masks": 5,
2605
  "time_width": 0.03,
 
2606
  "transformers_version": "4.48.3",
2607
  "use_torchaudio": true,
2608
  "use_vectorized_spec_augment": true,
 
1
  {
 
2
  "angular": false,
3
+ "architectures": [
4
+ "XVectorForSequenceClassification"
5
+ ],
6
  "attention_channels": 128,
7
  "auto_map": {
8
+ "AutoConfig": "configuration_xvector.XVectorConfig",
9
+ "AutoModelForAudioClassification": "modeling_xvector.XVectorForSequenceClassification"
10
  },
11
  "bos_token_id": 1,
12
  "decoder_config": {
 
2606
  },
2607
  "time_masks": 5,
2608
  "time_width": 0.03,
2609
+ "torch_dtype": "float32",
2610
  "transformers_version": "4.48.3",
2611
  "use_torchaudio": true,
2612
  "use_vectorized_spec_augment": true,
conv_asr.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .module import NeuralModule
8
+ from .tdnn_attention import StatsPoolLayer, AttentivePoolLayer, init_weights
9
+ from .cnn import Conv1d
10
+ from .normalization import BatchNorm1d
11
+
12
+
13
+ class TDNNLayer(nn.Module):
14
+
15
+ def __init__(self, in_conv_dim, out_conv_dim, kernel_size, dilation):
16
+ super().__init__()
17
+ self.in_conv_dim = in_conv_dim
18
+ self.out_conv_dim = out_conv_dim
19
+ self.kernel_size = kernel_size
20
+ self.dilation = dilation
21
+
22
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
23
+ self.activation = nn.ReLU()
24
+
25
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
26
+ # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
27
+ weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
28
+ hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
29
+ hidden_states = self.activation(hidden_states)
30
+ return hidden_states
31
+
32
+
33
+ class XVectorEncoder(NeuralModule):
34
+ """
35
+ input:
36
+ feat_in: input feature shape (mel spec feature shape)
37
+ filters: list of filter shapes for SE_TDNN modules
38
+ kernel_sizes: list of kernel shapes for SE_TDNN modules
39
+ dilations: list of dilations for group conv se layer
40
+ scale: scale value to group wider conv channels (deafult:8)
41
+
42
+ output:
43
+ outputs : encoded output
44
+ output_length: masked output lengths
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ feat_in: int,
50
+ filters: list,
51
+ kernel_sizes: list,
52
+ dilations: list,
53
+ init_mode: str = 'xavier_uniform',
54
+ ):
55
+ super().__init__()
56
+ self.blocks = nn.ModuleList()
57
+
58
+ # TDNN layers
59
+ in_channels = feat_in
60
+ tdnn_blocks = len(filters)
61
+ for block_index in range(tdnn_blocks):
62
+ out_channels = filters[block_index]
63
+ self.blocks.extend(
64
+ [
65
+ Conv1d(
66
+ in_channels=in_channels,
67
+ out_channels=out_channels,
68
+ kernel_size=kernel_sizes[block_index],
69
+ dilation=dilations[block_index],
70
+ ),
71
+ torch.nn.LeakyReLU(),
72
+ BatchNorm1d(input_size=out_channels),
73
+ ]
74
+ )
75
+ in_channels = filters[block_index]
76
+
77
+ self.apply(lambda x: init_weights(x, mode=init_mode))
78
+
79
+ def forward(self, audio_signal: torch.Tensor, length: torch.Tensor = None):
80
+ """
81
+ audio_signal: tensor shape of (B, D, T)
82
+ output: tensor shape of (B, D, T)
83
+ """
84
+ x = audio_signal.transpose(1, 2)
85
+ for layer in self.blocks:
86
+ x = layer(x)
87
+ output = x.transpose(1, 2)
88
+ return output, length
89
+
90
+
91
+ class SpeakerDecoder(NeuralModule):
92
+ """
93
+ Speaker Decoder creates the final neural layers that maps from the outputs
94
+ of Jasper Encoder to the embedding layer followed by speaker based softmax loss.
95
+
96
+ Args:
97
+ feat_in (int): Number of channels being input to this module
98
+ num_classes (int): Number of unique speakers in dataset
99
+ emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings
100
+ from 1st of this layers). Defaults to [1024,1024]
101
+ pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention'
102
+ Defaults to 'xvector (mean and variance)'
103
+ tap (temporal average pooling: just mean)
104
+ attention (attention based pooling)
105
+ init_mode (str): Describes how neural network parameters are
106
+ initialized. Options are ['xavier_uniform', 'xavier_normal',
107
+ 'kaiming_uniform','kaiming_normal'].
108
+ Defaults to "xavier_uniform".
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ feat_in: int,
114
+ num_classes: int,
115
+ emb_sizes: Optional[Union[int, list]] = 256,
116
+ pool_mode: str = 'xvector',
117
+ angular: bool = False,
118
+ attention_channels: int = 128,
119
+ init_mode: str = "xavier_uniform",
120
+ ):
121
+ super().__init__()
122
+ self.angular = angular
123
+ self.emb_id = 2
124
+ bias = False if self.angular else True
125
+ emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes
126
+
127
+ self._num_classes = num_classes
128
+ self.pool_mode = pool_mode.lower()
129
+ if self.pool_mode == 'xvector' or self.pool_mode == 'tap':
130
+ self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode)
131
+ affine_type = 'linear'
132
+ elif self.pool_mode == 'attention':
133
+ self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels)
134
+ affine_type = 'conv'
135
+
136
+ shapes = [self._pooling.feat_in]
137
+ for size in emb_sizes:
138
+ shapes.append(int(size))
139
+
140
+ emb_layers = []
141
+ for shape_in, shape_out in zip(shapes[:-1], shapes[1:]):
142
+ layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type)
143
+ emb_layers.append(layer)
144
+
145
+ self.emb_layers = nn.ModuleList(emb_layers)
146
+
147
+ self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias)
148
+
149
+ self.apply(lambda x: init_weights(x, mode=init_mode))
150
+
151
+ def affine_layer(
152
+ self,
153
+ inp_shape,
154
+ out_shape,
155
+ learn_mean=True,
156
+ affine_type='conv',
157
+ ):
158
+ if affine_type == 'conv':
159
+ layer = nn.Sequential(
160
+ nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True),
161
+ nn.Conv1d(inp_shape, out_shape, kernel_size=1),
162
+ )
163
+
164
+ else:
165
+ layer = nn.Sequential(
166
+ nn.Linear(inp_shape, out_shape),
167
+ nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True),
168
+ nn.ReLU(),
169
+ )
170
+
171
+ return layer
172
+
173
+ def forward(self, encoder_output, length: torch.Tensor = None):
174
+ pool = self._pooling(encoder_output, length)
175
+ embs = []
176
+
177
+ for layer in self.emb_layers:
178
+ pool, emb = layer(pool), layer[: self.emb_id](pool)
179
+ embs.append(emb)
180
+
181
+ pool = pool.squeeze(-1)
182
+ if self.angular:
183
+ for W in self.final.parameters():
184
+ W = F.normalize(W, p=2, dim=1)
185
+ pool = F.normalize(pool, p=2, dim=1)
186
+
187
+ out = self.final(pool)
188
+
189
+ return out, embs[-1].squeeze(-1)
features.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Optional, Union, Tuple
4
+
5
+ import librosa
6
+ import torchaudio
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import torchaudio
13
+
14
+ HAVE_TORCHAUDIO = True
15
+ except ModuleNotFoundError:
16
+ HAVE_TORCHAUDIO = False
17
+
18
+ CONSTANT = 1e-5
19
+
20
+
21
+ def normalize_batch(x, seq_len, normalize_type):
22
+ x_mean = None
23
+ x_std = None
24
+ if normalize_type == "per_feature":
25
+ batch_size = x.shape[0]
26
+ max_time = x.shape[2]
27
+
28
+ # When doing stream capture to a graph, item() is not allowed
29
+ # becuase it calls cudaStreamSynchronize(). Therefore, we are
30
+ # sacrificing some error checking when running with cuda graphs.
31
+ if (
32
+ torch.cuda.is_available()
33
+ and not torch.cuda.is_current_stream_capturing()
34
+ and torch.any(seq_len == 1).item()
35
+ ):
36
+ raise ValueError(
37
+ "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
38
+ "in torch.std() returning nan. Make sure your audio length has enough samples for a single "
39
+ "feature (ex. at least `hop_length` for Mel Spectrograms)."
40
+ )
41
+ time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time)
42
+ valid_mask = time_steps < seq_len.unsqueeze(1)
43
+ x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2)
44
+ x_mean_denominator = valid_mask.sum(axis=1)
45
+ x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1)
46
+
47
+ # Subtract 1 in the denominator to correct for the bias.
48
+ x_std = torch.sqrt(
49
+ torch.sum(torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2, axis=2)
50
+ / (x_mean_denominator.unsqueeze(1) - 1.0)
51
+ )
52
+ # make sure x_std is not zero
53
+ x_std += CONSTANT
54
+ return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
55
+ elif normalize_type == "all_features":
56
+ x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
57
+ x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
58
+ for i in range(x.shape[0]):
59
+ x_mean[i] = x[i, :, : seq_len[i].item()].mean()
60
+ x_std[i] = x[i, :, : seq_len[i].item()].std()
61
+ # make sure x_std is not zero
62
+ x_std += CONSTANT
63
+ return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
64
+ elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
65
+ x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
66
+ x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
67
+ return (
68
+ (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
69
+ x_mean,
70
+ x_std,
71
+ )
72
+ else:
73
+ return x, x_mean, x_std
74
+
75
+
76
+ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Tensor, fill_value=0.0) -> torch.Tensor:
77
+ """
78
+ Fill spectrogram values outside the length with `fill_value`
79
+
80
+ Args:
81
+ spectrogram: Tensor with shape [B, C, L] containing batched spectrograms
82
+ spectrogram_len: Tensor with shape [B] containing the sequence length of each batch element
83
+ fill_value: value to fill with, 0.0 by default
84
+
85
+ Returns:
86
+ cleaned spectrogram, tensor with shape equal to `spectrogram`
87
+ """
88
+ device = spectrogram.device
89
+ batch_size, _, max_len = spectrogram.shape
90
+ mask = torch.arange(max_len, device=device)[None, :] >= spectrogram_len[:, None]
91
+ mask = mask.unsqueeze(1).expand_as(spectrogram)
92
+ return spectrogram.masked_fill(mask, fill_value)
93
+
94
+
95
+ def splice_frames(x, frame_splicing):
96
+ """Stacks frames together across feature dim
97
+
98
+ input is batch_size, feature_dim, num_frames
99
+ output is batch_size, feature_dim*frame_splicing, num_frames
100
+
101
+ """
102
+ seq = [x]
103
+ for n in range(1, frame_splicing):
104
+ seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
105
+ return torch.cat(seq, dim=1)
106
+
107
+
108
+ @torch.jit.script_if_tracing
109
+ def make_seq_mask_like(
110
+ lengths: torch.Tensor, like: torch.Tensor, time_dim: int = -1, valid_ones: bool = True
111
+ ) -> torch.Tensor:
112
+ """
113
+
114
+ Args:
115
+ lengths: Tensor with shape [B] containing the sequence length of each batch element
116
+ like: The mask will contain the same number of dimensions as this Tensor, and will have the same max
117
+ length in the time dimension of this Tensor.
118
+ time_dim: Time dimension of the `shape_tensor` and the resulting mask. Zero-based.
119
+ valid_ones: If True, valid tokens will contain value `1` and padding will be `0`. Else, invert.
120
+
121
+ Returns:
122
+ A :class:`torch.Tensor` containing 1's and 0's for valid and invalid tokens, respectively, if `valid_ones`, else
123
+ vice-versa. Mask will have the same number of dimensions as `like`. Batch and time dimensions will match
124
+ the `like`. All other dimensions will be singletons. E.g., if `like.shape == [3, 4, 5]` and
125
+ `time_dim == -1', mask will have shape `[3, 1, 5]`.
126
+ """
127
+ # Mask with shape [B, T]
128
+ mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.view(-1, 1))
129
+ # [B, T] -> [B, *, T] where * is any number of singleton dimensions to expand to like tensor
130
+ for _ in range(like.dim() - mask.dim()):
131
+ mask = mask.unsqueeze(1)
132
+ # If needed, transpose time dim
133
+ if time_dim != -1 and time_dim != mask.dim() - 1:
134
+ mask = mask.transpose(-1, time_dim)
135
+ # Maybe invert the padded vs. valid token values
136
+ if not valid_ones:
137
+ mask = ~mask
138
+ return mask
139
+
140
+
141
+ class FilterbankFeatures(nn.Module):
142
+ """Featurizer that converts wavs to Mel Spectrograms.
143
+ See AudioToMelSpectrogramPreprocessor for args.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ sample_rate=16000,
149
+ n_window_size=320,
150
+ n_window_stride=160,
151
+ window="hann",
152
+ normalize="per_feature",
153
+ n_fft=None,
154
+ preemph=0.97,
155
+ nfilt=64,
156
+ lowfreq=0,
157
+ highfreq=None,
158
+ log=True,
159
+ log_zero_guard_type="add",
160
+ log_zero_guard_value=2**-24,
161
+ dither=CONSTANT,
162
+ pad_to=16,
163
+ max_duration=16.7,
164
+ frame_splicing=1,
165
+ exact_pad=False,
166
+ pad_value=0,
167
+ mag_power=2.0,
168
+ use_grads=False,
169
+ rng=None,
170
+ nb_augmentation_prob=0.0,
171
+ nb_max_freq=4000,
172
+ mel_norm="slaney",
173
+ stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
174
+ stft_conv=False, # Deprecated arguments; kept for config compatibility
175
+ ):
176
+ super().__init__()
177
+ if stft_conv or stft_exact_pad:
178
+ print(
179
+ "Using torch_stft is deprecated and has been removed. The values have been forcibly set to False "
180
+ "for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
181
+ "as needed."
182
+ )
183
+ if exact_pad and n_window_stride % 2 == 1:
184
+ raise NotImplementedError(
185
+ f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
186
+ "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
187
+ )
188
+ self.log_zero_guard_value = log_zero_guard_value
189
+ if (
190
+ n_window_size is None
191
+ or n_window_stride is None
192
+ or not isinstance(n_window_size, int)
193
+ or not isinstance(n_window_stride, int)
194
+ or n_window_size <= 0
195
+ or n_window_stride <= 0
196
+ ):
197
+ raise ValueError(
198
+ f"{self} got an invalid value for either n_window_size or "
199
+ f"n_window_stride. Both must be positive ints."
200
+ )
201
+
202
+ self.win_length = n_window_size
203
+ self.hop_length = n_window_stride
204
+ self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
205
+ self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
206
+ self.exact_pad = exact_pad
207
+
208
+ if exact_pad:
209
+ print("STFT using exact pad")
210
+ torch_windows = {
211
+ 'hann': torch.hann_window,
212
+ 'hamming': torch.hamming_window,
213
+ 'blackman': torch.blackman_window,
214
+ 'bartlett': torch.bartlett_window,
215
+ 'none': None,
216
+ }
217
+ window_fn = torch_windows.get(window, None)
218
+ window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
219
+ self.register_buffer("window", window_tensor)
220
+
221
+ self.normalize = normalize
222
+ self.log = log
223
+ self.dither = dither
224
+ self.frame_splicing = frame_splicing
225
+ self.nfilt = nfilt
226
+ self.preemph = preemph
227
+ self.pad_to = pad_to
228
+ highfreq = highfreq or sample_rate / 2
229
+
230
+ filterbanks = torch.tensor(
231
+ librosa.filters.mel(
232
+ sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm
233
+ ),
234
+ dtype=torch.float,
235
+ ).unsqueeze(0)
236
+ self.register_buffer("fb", filterbanks)
237
+
238
+ # Calculate maximum sequence length
239
+ max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
240
+ max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
241
+ self.max_length = max_length + max_pad
242
+ self.pad_value = pad_value
243
+ self.mag_power = mag_power
244
+
245
+ # We want to avoid taking the log of zero
246
+ # There are two options: either adding or clamping to a small value
247
+ if log_zero_guard_type not in ["add", "clamp"]:
248
+ raise ValueError(
249
+ f"{self} received {log_zero_guard_type} for the "
250
+ f"log_zero_guard_type parameter. It must be either 'add' or "
251
+ f"'clamp'."
252
+ )
253
+
254
+ self.use_grads = use_grads
255
+ if not use_grads:
256
+ self.forward = torch.no_grad()(self.forward)
257
+ self._rng = random.Random() if rng is None else rng
258
+ self.nb_augmentation_prob = nb_augmentation_prob
259
+ if self.nb_augmentation_prob > 0.0:
260
+ if nb_max_freq >= sample_rate / 2:
261
+ self.nb_augmentation_prob = 0.0
262
+ else:
263
+ self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
264
+
265
+ # log_zero_guard_value is the the small we want to use, we support
266
+ # an actual number, or "tiny", or "eps"
267
+ self.log_zero_guard_type = log_zero_guard_type
268
+
269
+ def stft(self, x):
270
+ return torch.stft(
271
+ x,
272
+ n_fft=self.n_fft,
273
+ hop_length=self.hop_length,
274
+ win_length=self.win_length,
275
+ center=False if self.exact_pad else True,
276
+ window=self.window.to(dtype=torch.float),
277
+ return_complex=True,
278
+ )
279
+
280
+ def log_zero_guard_value_fn(self, x):
281
+ if isinstance(self.log_zero_guard_value, str):
282
+ if self.log_zero_guard_value == "tiny":
283
+ return torch.finfo(x.dtype).tiny
284
+ elif self.log_zero_guard_value == "eps":
285
+ return torch.finfo(x.dtype).eps
286
+ else:
287
+ raise ValueError(
288
+ f"{self} received {self.log_zero_guard_value} for the "
289
+ f"log_zero_guard_type parameter. It must be either a "
290
+ f"number, 'tiny', or 'eps'"
291
+ )
292
+ else:
293
+ return self.log_zero_guard_value
294
+
295
+ def get_seq_len(self, seq_len):
296
+ # Assuming that center is True is stft_pad_amount = 0
297
+ pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
298
+ seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1
299
+ return seq_len.to(dtype=torch.long)
300
+
301
+ @property
302
+ def filter_banks(self):
303
+ return self.fb
304
+
305
+ def forward(self, x, seq_len, linear_spec=False):
306
+ seq_len = self.get_seq_len(seq_len)
307
+
308
+ if self.stft_pad_amount is not None:
309
+ x = torch.nn.functional.pad(
310
+ x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
311
+ ).squeeze(1)
312
+
313
+ # dither (only in training mode for eval determinism)
314
+ if self.training and self.dither > 0:
315
+ x += self.dither * torch.randn_like(x)
316
+
317
+ # do preemphasis
318
+ if self.preemph is not None:
319
+ x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
320
+
321
+ # disable autocast to get full range of stft values
322
+ with torch.amp.autocast(x.device.type, enabled=False):
323
+ x = self.stft(x)
324
+
325
+ # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
326
+ # guard is needed for sqrt if grads are passed through
327
+ guard = 0 if not self.use_grads else CONSTANT
328
+ x = torch.view_as_real(x)
329
+ x = torch.sqrt(x.pow(2).sum(-1) + guard)
330
+
331
+ if self.training and self.nb_augmentation_prob > 0.0:
332
+ for idx in range(x.shape[0]):
333
+ if self._rng.random() < self.nb_augmentation_prob:
334
+ x[idx, self._nb_max_fft_bin :, :] = 0.0
335
+
336
+ # get power spectrum
337
+ if self.mag_power != 1.0:
338
+ x = x.pow(self.mag_power)
339
+
340
+ # return plain spectrogram if required
341
+ if linear_spec:
342
+ return x, seq_len
343
+
344
+ # dot with filterbank energies
345
+ x = torch.matmul(self.fb.to(x.dtype), x)
346
+ # log features if required
347
+ if self.log:
348
+ if self.log_zero_guard_type == "add":
349
+ x = torch.log(x + self.log_zero_guard_value_fn(x))
350
+ elif self.log_zero_guard_type == "clamp":
351
+ x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
352
+ else:
353
+ raise ValueError("log_zero_guard_type was not understood")
354
+
355
+ # frame splicing if required
356
+ if self.frame_splicing > 1:
357
+ x = splice_frames(x, self.frame_splicing)
358
+
359
+ # normalize if required
360
+ if self.normalize:
361
+ x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)
362
+
363
+ # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
364
+ max_len = x.size(-1)
365
+ mask = torch.arange(max_len, device=x.device)
366
+ mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
367
+ x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
368
+ del mask
369
+ pad_to = self.pad_to
370
+ if pad_to == "max":
371
+ x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
372
+ elif pad_to > 0:
373
+ pad_amt = x.size(-1) % pad_to
374
+ if pad_amt != 0:
375
+ x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
376
+ return x, seq_len
377
+
378
+
379
+ class FilterbankFeaturesTA(nn.Module):
380
+ """
381
+ Exportable, `torchaudio`-based implementation of Mel Spectrogram extraction.
382
+
383
+ See `AudioToMelSpectrogramPreprocessor` for args.
384
+
385
+ """
386
+
387
+ def __init__(
388
+ self,
389
+ sample_rate: int = 16000,
390
+ n_window_size: int = 320,
391
+ n_window_stride: int = 160,
392
+ normalize: Optional[str] = "per_feature",
393
+ nfilt: int = 64,
394
+ n_fft: Optional[int] = None,
395
+ preemph: float = 0.97,
396
+ lowfreq: float = 0,
397
+ highfreq: Optional[float] = None,
398
+ log: bool = True,
399
+ log_zero_guard_type: str = "add",
400
+ log_zero_guard_value: Union[float, str] = 2**-24,
401
+ dither: float = 1e-5,
402
+ window: str = "hann",
403
+ pad_to: int = 0,
404
+ pad_value: float = 0.0,
405
+ mel_norm="slaney",
406
+ # Seems like no one uses these options anymore. Don't convolute the code by supporting thm.
407
+ use_grads: bool = False, # Deprecated arguments; kept for config compatibility
408
+ max_duration: float = 16.7, # Deprecated arguments; kept for config compatibility
409
+ frame_splicing: int = 1, # Deprecated arguments; kept for config compatibility
410
+ exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
411
+ nb_augmentation_prob: float = 0.0, # Deprecated arguments; kept for config compatibility
412
+ nb_max_freq: int = 4000, # Deprecated arguments; kept for config compatibility
413
+ mag_power: float = 2.0, # Deprecated arguments; kept for config compatibility
414
+ rng: Optional[random.Random] = None, # Deprecated arguments; kept for config compatibility
415
+ stft_exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
416
+ stft_conv: bool = False, # Deprecated arguments; kept for config compatibility
417
+ ):
418
+ super().__init__()
419
+ if not HAVE_TORCHAUDIO:
420
+ raise ValueError(f"Need to install torchaudio to instantiate a {self.__class__.__name__}")
421
+
422
+ # Make sure log zero guard is supported, if given as a string
423
+ supported_log_zero_guard_strings = {"eps", "tiny"}
424
+ if isinstance(log_zero_guard_value, str) and log_zero_guard_value not in supported_log_zero_guard_strings:
425
+ raise ValueError(
426
+ f"Log zero guard value must either be a float or a member of {supported_log_zero_guard_strings}"
427
+ )
428
+
429
+ # Copied from `AudioPreprocessor` due to the ad-hoc structuring of the Mel Spec extractor class
430
+ self.torch_windows = {
431
+ 'hann': torch.hann_window,
432
+ 'hamming': torch.hamming_window,
433
+ 'blackman': torch.blackman_window,
434
+ 'bartlett': torch.bartlett_window,
435
+ 'ones': torch.ones,
436
+ None: torch.ones,
437
+ }
438
+
439
+ # Ensure we can look up the window function
440
+ if window not in self.torch_windows:
441
+ raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}")
442
+
443
+ self.win_length = n_window_size
444
+ self.hop_length = n_window_stride
445
+ self._sample_rate = sample_rate
446
+ self._normalize_strategy = normalize
447
+ self._use_log = log
448
+ self._preemphasis_value = preemph
449
+ self.log_zero_guard_type = log_zero_guard_type
450
+ self.log_zero_guard_value: Union[str, float] = log_zero_guard_value
451
+ self.dither = dither
452
+ self.pad_to = pad_to
453
+ self.pad_value = pad_value
454
+ self.n_fft = n_fft
455
+ self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram(
456
+ sample_rate=self._sample_rate,
457
+ win_length=self.win_length,
458
+ hop_length=self.hop_length,
459
+ n_mels=nfilt,
460
+ window_fn=self.torch_windows[window],
461
+ mel_scale="slaney",
462
+ norm=mel_norm,
463
+ n_fft=n_fft,
464
+ f_max=highfreq,
465
+ f_min=lowfreq,
466
+ wkwargs={"periodic": False},
467
+ )
468
+
469
+ @property
470
+ def filter_banks(self):
471
+ """Matches the analogous class"""
472
+ return self._mel_spec_extractor.mel_scale.fb
473
+
474
+ def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
475
+ if isinstance(self.log_zero_guard_value, float):
476
+ return self.log_zero_guard_value
477
+ return getattr(torch.finfo(dtype), self.log_zero_guard_value)
478
+
479
+ def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor:
480
+ if self.training and self.dither > 0.0:
481
+ noise = torch.randn_like(signals) * self.dither
482
+ signals = signals + noise
483
+ return signals
484
+
485
+ def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor:
486
+ if self._preemphasis_value is not None:
487
+ padded = torch.nn.functional.pad(signals, (1, 0))
488
+ signals = signals - self._preemphasis_value * padded[:, :-1]
489
+ return signals
490
+
491
+ def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
492
+ out_lengths = input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
493
+ return out_lengths
494
+
495
+ def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor:
496
+ # Only apply during training; else need to capture dynamic shape for exported models
497
+ if not self.training or self.pad_to == 0 or features.shape[-1] % self.pad_to == 0:
498
+ return features
499
+ pad_length = self.pad_to - (features.shape[-1] % self.pad_to)
500
+ return torch.nn.functional.pad(features, pad=(0, pad_length), value=self.pad_value)
501
+
502
+ def _apply_log(self, features: torch.Tensor) -> torch.Tensor:
503
+ if self._use_log:
504
+ zero_guard = self._resolve_log_zero_guard_value(features.dtype)
505
+ if self.log_zero_guard_type == "add":
506
+ features = features + zero_guard
507
+ elif self.log_zero_guard_type == "clamp":
508
+ features = features.clamp(min=zero_guard)
509
+ else:
510
+ raise ValueError(f"Unsupported log zero guard type: '{self.log_zero_guard_type}'")
511
+ features = features.log()
512
+ return features
513
+
514
+ def _extract_spectrograms(self, signals: torch.Tensor) -> torch.Tensor:
515
+ # Complex FFT needs to be done in single precision
516
+ with torch.amp.autocast('cuda', enabled=False):
517
+ features = self._mel_spec_extractor(waveform=signals)
518
+ return features
519
+
520
+ def _apply_normalization(self, features: torch.Tensor, lengths: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
521
+ # For consistency, this function always does a masked fill even if not normalizing.
522
+ mask: torch.Tensor = make_seq_mask_like(lengths=lengths, like=features, time_dim=-1, valid_ones=False)
523
+ features = features.masked_fill(mask, 0.0)
524
+ # Maybe don't normalize
525
+ if self._normalize_strategy is None:
526
+ return features
527
+ # Use the log zero guard for the sqrt zero guard
528
+ guard_value = self._resolve_log_zero_guard_value(features.dtype)
529
+ if self._normalize_strategy == "per_feature" or self._normalize_strategy == "all_features":
530
+ # 'all_features' reduces over each sample; 'per_feature' reduces over each channel
531
+ reduce_dim = 2
532
+ if self._normalize_strategy == "all_features":
533
+ reduce_dim = [1, 2]
534
+ # [B, D, T] -> [B, D, 1] or [B, 1, 1]
535
+ means = features.sum(dim=reduce_dim, keepdim=True).div(lengths.view(-1, 1, 1))
536
+ stds = (
537
+ features.sub(means)
538
+ .masked_fill(mask, 0.0)
539
+ .pow(2.0)
540
+ .sum(dim=reduce_dim, keepdim=True) # [B, D, T] -> [B, D, 1] or [B, 1, 1]
541
+ .div(lengths.view(-1, 1, 1) - 1) # assume biased estimator
542
+ .clamp(min=guard_value) # avoid sqrt(0)
543
+ .sqrt()
544
+ )
545
+ features = (features - means) / (stds + eps)
546
+ else:
547
+ # Deprecating constant std/mean
548
+ raise ValueError(f"Unsupported norm type: '{self._normalize_strategy}")
549
+ features = features.masked_fill(mask, 0.0)
550
+ return features
551
+
552
+ def forward(self, input_signal: torch.Tensor, length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
553
+ feature_lengths = self._compute_output_lengths(input_lengths=length)
554
+ signals = self._apply_dithering(signals=input_signal)
555
+ signals = self._apply_preemphasis(signals=signals)
556
+ features = self._extract_spectrograms(signals=signals)
557
+ features = self._apply_log(features=features)
558
+ features = self._apply_normalization(features=features, lengths=feature_lengths)
559
+ features = self._apply_pad_to(features=features)
560
+ return features, feature_lengths
logging.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ # Function to convert HEX to ANSI 24-bit escape code
5
+ def hex_to_ansi(hex_color, is_background=False):
6
+ """Convert a hex color code to an ANSI escape sequence."""
7
+ hex_color = hex_color.lstrip("#") # Remove '#' if present
8
+ if len(hex_color) != 6:
9
+ raise ValueError("Invalid hex color format. Use #RRGGBB.")
10
+
11
+ r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
12
+ return f"\033[{48 if is_background else 38};2;{r};{g};{b}m"
13
+
14
+ # Custom log formatter with level-specific colors
15
+ class ColoredFormatter(logging.Formatter):
16
+ # Define hex colors per log level
17
+ COLORS = {
18
+ "DEBUG": {"HEADER": "#1E3A8A", "TIMESTAMP": "#2563EB"}, # Dark Blue / Blue
19
+ "INFO": {"HEADER": "#166534", "TIMESTAMP": "#22C55E"}, # Dark Green / Green
20
+ "WARNING": {"HEADER": "#92400E", "TIMESTAMP": "#FACC15"}, # Dark Yellow / Yellow
21
+ "ERROR": {"HEADER": "#7F1D1D", "TIMESTAMP": "#EF4444"}, # Dark Red / Red
22
+ "CRITICAL": {"HEADER": "#581C87", "TIMESTAMP": "#C084FC"}, # Dark Purple / Purple
23
+ }
24
+
25
+ def format(self, record):
26
+ # Extract filename and line number
27
+ filename = record.pathname.split("/")[-1]
28
+ line_no = record.lineno
29
+ level_name = record.levelname
30
+
31
+ # Choose colors based on log level
32
+ level_colors = self.COLORS.get(level_name, self.COLORS["INFO"])
33
+ header_color = hex_to_ansi(level_colors["HEADER"])
34
+ timestamp_color = hex_to_ansi(level_colors["TIMESTAMP"])
35
+ reset_color = "\033[0m" # Reset to default terminal color
36
+
37
+ # Format header as "[LEVEL|file.py:line]"
38
+ header = f"{header_color}[{level_name}|{filename}:{line_no}]{reset_color}"
39
+
40
+ # Format timestamp
41
+ timestamp = f"{timestamp_color}{self.formatTime(record, self.datefmt)}{reset_color}"
42
+
43
+ # Format message
44
+ message = f"\033[37m{record.getMessage()}{reset_color}" # White message
45
+
46
+ return f"{header} {timestamp} >> {message}"
47
+
48
+
49
+ # Set up logger
50
+ logger = logging.getLogger(__name__)
51
+ logger.setLevel(logging.INFO)
52
+ console_handler = logging.StreamHandler()
53
+ formatter = ColoredFormatter(datefmt="%Y-%m-%d %H:%M:%S")
54
+ console_handler.setFormatter(formatter)
55
+ logger.addHandler(console_handler)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:553b3b3f58b772c5a184701cc761f96c36457df105da5fdc4336e9ad73f0209d
3
+ size 28476124
modeling_xvector.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Union, Tuple
3
+
4
+ from rich import print
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from transformers import PreTrainedModel
10
+ from transformers.utils import ModelOutput
11
+
12
+ from .configuration_xvector import XVectorConfig
13
+ from .audio_processing import AudioToMelSpectrogramPreprocessor
14
+ from .audio_processing import SpectrogramAugmentation
15
+ from .conv_asr import XVectorEncoder, SpeakerDecoder
16
+ from .angular_loss import AdditiveMarginSoftmaxLoss, AdditiveAngularMarginSoftmaxLoss
17
+
18
+
19
+ @dataclass
20
+ class XVectorBaseModelOutput(ModelOutput):
21
+
22
+ encoder_outputs: torch.FloatTensor = None
23
+ extract_features: torch.FloatTensor = None
24
+ output_lengths: torch.FloatTensor = None
25
+
26
+
27
+ @dataclass
28
+ class XVectorSequenceClassifierOutput(ModelOutput):
29
+
30
+ loss: torch.FloatTensor = None
31
+ logits: torch.FloatTensor = None
32
+ embeddings: torch.FloatTensor = None
33
+
34
+
35
+ class XVectorPreTrainedModel(PreTrainedModel):
36
+
37
+ config_class = XVectorConfig
38
+ base_model_prefix = "xvector"
39
+ main_input_name = "input_values"
40
+
41
+ def _init_weights(self, module):
42
+ """Initialize the weights"""
43
+ config: XVectorConfig = self.config
44
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
45
+ module.weight.data.normal_(mean=0.0, std=config.initializer_range)
46
+ if module.bias is not None:
47
+ module.bias.data.zero_()
48
+ elif isinstance(module, nn.Conv2d):
49
+ module.weight.data.normal_(mean=0.0, std=config.initializer_range)
50
+ if module.bias is not None:
51
+ module.bias.data.zero_()
52
+ elif isinstance(module, nn.LayerNorm):
53
+ module.bias.data.zero_()
54
+ module.weight.data.fill_(1.0)
55
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
56
+ nn.init.constant_(module.weight, 1)
57
+ nn.init.constant_(module.bias, 0)
58
+
59
+ @property
60
+ def num_weights(self):
61
+ """
62
+ Utility property that returns the total number of parameters of NeuralModule.
63
+ """
64
+ return self._num_weights()
65
+
66
+ @torch.jit.ignore
67
+ def _num_weights(self):
68
+ num: int = 0
69
+ for p in self.parameters():
70
+ if p.requires_grad:
71
+ num += p.numel()
72
+ return num
73
+
74
+
75
+ class XVectorModel(XVectorPreTrainedModel):
76
+
77
+ def __init__(self, config: XVectorConfig):
78
+ super().__init__(config)
79
+ self.config = config
80
+
81
+ self.preprocessor = AudioToMelSpectrogramPreprocessor(**config.mel_spectrogram_config)
82
+ self.spec_augment = SpectrogramAugmentation(**config.spectrogram_augmentation_config)
83
+ self.encoder = XVectorEncoder(**config.encoder_config)
84
+
85
+ # Initialize weights and apply final processing
86
+ self.post_init()
87
+
88
+ def forward(
89
+ self,
90
+ input_values: Optional[torch.Tensor],
91
+ attention_mask: Optional[torch.Tensor] = None,
92
+ ) -> Union[Tuple, XVectorBaseModelOutput]:
93
+ if attention_mask is None:
94
+ attention_mask = torch.ones_like(input_values).to(input_values)
95
+ lengths = attention_mask.sum(dim=1).long()
96
+ extract_features, output_lengths = self.preprocessor(input_values, lengths)
97
+ if self.training:
98
+ extract_features = self.spec_augment(extract_features, output_lengths)
99
+ encoder_outputs, output_lengths = self.encoder(extract_features, output_lengths)
100
+
101
+ return XVectorBaseModelOutput(
102
+ encoder_outputs=encoder_outputs,
103
+ extract_features=extract_features,
104
+ output_lengths=output_lengths,
105
+ )
106
+
107
+
108
+ class XVectorForSequenceClassification(XVectorPreTrainedModel):
109
+
110
+ def __init__(self, config: XVectorConfig):
111
+ super().__init__(config)
112
+
113
+ self.xvector = XVectorModel(config)
114
+ self.classifier = SpeakerDecoder(**config.decoder_config)
115
+
116
+ if config.objective == 'additive_angular_margin':
117
+ self.loss_fct = AdditiveAngularMarginSoftmaxLoss(**config.objective_config)
118
+ elif config.objective == 'additive_margin':
119
+ self.loss_fct = AdditiveMarginSoftmaxLoss(**config.objective_config)
120
+ elif config.objective == 'cross_entropy':
121
+ self.loss_fct = nn.CrossEntropyLoss(**config.objective_config)
122
+
123
+ self.init_weights()
124
+
125
+ def freeze_base_model(self):
126
+ for param in self.xvector.parameters():
127
+ param.requires_grad = False
128
+
129
+ def forward(
130
+ self,
131
+ input_values: Optional[torch.Tensor],
132
+ attention_mask: Optional[torch.Tensor] = None,
133
+ labels: Optional[torch.Tensor] = None,
134
+ ) -> Union[Tuple, XVectorSequenceClassifierOutput]:
135
+ xvector_outputs = self.xvector(
136
+ input_values,
137
+ attention_mask,
138
+ )
139
+ logits, output_embeddings = self.classifier(
140
+ xvector_outputs.encoder_outputs,
141
+ xvector_outputs.output_lengths
142
+ )
143
+ logits = logits.view(-1, self.config.num_labels)
144
+
145
+ loss = None
146
+ if labels is not None:
147
+ loss = self.loss_fct(logits, labels.view(-1))
148
+
149
+ return XVectorSequenceClassifierOutput(
150
+ loss=loss,
151
+ logits=logits,
152
+ embeddings=output_embeddings,
153
+ )
module.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class NeuralModule(nn.Module):
6
+
7
+ @property
8
+ def num_weights(self):
9
+ """
10
+ Utility property that returns the total number of parameters of NeuralModule.
11
+ """
12
+ return self._num_weights()
13
+
14
+ @torch.jit.ignore
15
+ def _num_weights(self):
16
+ num: int = 0
17
+ for p in self.parameters():
18
+ if p.requires_grad:
19
+ num += p.numel()
20
+ return num
21
+
22
+ def freeze(self) -> None:
23
+ r"""
24
+ Freeze all params for inference.
25
+
26
+ This method sets `requires_grad` to False for all parameters of the module.
27
+ It also stores the original `requires_grad` state of each parameter in a dictionary,
28
+ so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`.
29
+ """
30
+ grad_map = {}
31
+
32
+ for pname, param in self.named_parameters():
33
+ # Store the original grad state
34
+ grad_map[pname] = param.requires_grad
35
+ # Freeze the parameter
36
+ param.requires_grad = False
37
+
38
+ # Store the frozen grad map
39
+ if not hasattr(self, '_frozen_grad_map'):
40
+ self._frozen_grad_map = grad_map
41
+ else:
42
+ self._frozen_grad_map.update(grad_map)
43
+
44
+ self.eval()
45
+
46
+ def unfreeze(self, partial: bool = False) -> None:
47
+ """
48
+ Unfreeze all parameters for training.
49
+
50
+ Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
51
+ The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
52
+ previously unfrozen prior `freeze()`.
53
+
54
+ Example:
55
+ Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always.
56
+
57
+ ```python
58
+ model.encoder.freeze() # Freezes all parameters in the encoder explicitly
59
+ ```
60
+
61
+ During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method.
62
+ This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called,
63
+ we should keep the encoder parameters frozen.
64
+
65
+ ```python
66
+ model.freeze() # Freezes all parameters in the model; encoder remains frozen
67
+ ```
68
+
69
+ Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling
70
+ `unfreeze(partial=True)`.
71
+
72
+ ```python
73
+ model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen
74
+ ```
75
+
76
+ Args:
77
+ partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen
78
+ when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`.
79
+ """
80
+ if partial and not hasattr(self, '_frozen_grad_map'):
81
+ raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")
82
+
83
+ for pname, param in self.named_parameters():
84
+ if not partial:
85
+ # Unfreeze all parameters
86
+ param.requires_grad = True
87
+ else:
88
+ # Unfreeze only parameters that were previously frozen
89
+
90
+ # Check if the parameter was frozen
91
+ if pname in self._frozen_grad_map:
92
+ param.requires_grad = self._frozen_grad_map[pname]
93
+ else:
94
+ # Log a warning if the parameter was not found in the frozen grad map
95
+ print(
96
+ f"Parameter {pname} not found in list of previously frozen parameters. "
97
+ f"Unfreezing this parameter."
98
+ )
99
+ param.requires_grad = True
100
+
101
+ # Clean up the frozen grad map
102
+ if hasattr(self, '_frozen_grad_map'):
103
+ delattr(self, '_frozen_grad_map')
104
+
105
+ self.train()
normalization.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class BatchNorm1d(nn.Module):
5
+ """Applies 1d batch normalization to the input tensor.
6
+
7
+ Arguments
8
+ ---------
9
+ input_shape : tuple
10
+ The expected shape of the input. Alternatively, use ``input_size``.
11
+ input_size : int
12
+ The expected size of the input. Alternatively, use ``input_shape``.
13
+ eps : float
14
+ This value is added to std deviation estimation to improve the numerical
15
+ stability.
16
+ momentum : float
17
+ It is a value used for the running_mean and running_var computation.
18
+ affine : bool
19
+ When set to True, the affine parameters are learned.
20
+ track_running_stats : bool
21
+ When set to True, this module tracks the running mean and variance,
22
+ and when set to False, this module does not track such statistics.
23
+ combine_batch_time : bool
24
+ When true, it combines batch an time axis.
25
+ skip_transpose : bool
26
+ Whether to skip the transposition.
27
+
28
+
29
+ Example
30
+ -------
31
+ >>> input = torch.randn(100, 10)
32
+ >>> norm = BatchNorm1d(input_shape=input.shape)
33
+ >>> output = norm(input)
34
+ >>> output.shape
35
+ torch.Size([100, 10])
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ input_shape=None,
41
+ input_size=None,
42
+ eps=1e-05,
43
+ momentum=0.1,
44
+ affine=True,
45
+ track_running_stats=True,
46
+ combine_batch_time=False,
47
+ skip_transpose=False,
48
+ ):
49
+ super().__init__()
50
+ self.combine_batch_time = combine_batch_time
51
+ self.skip_transpose = skip_transpose
52
+
53
+ if input_size is None and skip_transpose:
54
+ input_size = input_shape[1]
55
+ elif input_size is None:
56
+ input_size = input_shape[-1]
57
+
58
+ self.norm = nn.BatchNorm1d(
59
+ input_size,
60
+ eps=eps,
61
+ momentum=momentum,
62
+ affine=affine,
63
+ track_running_stats=track_running_stats,
64
+ )
65
+
66
+ def forward(self, x, *args, **kwargs):
67
+ """Returns the normalized input tensor.
68
+
69
+ Arguments
70
+ ---------
71
+ x : torch.Tensor (batch, time, [channels])
72
+ input to normalize. 2d or 3d tensors are expected in input
73
+ 4d tensors can be used when combine_dims=True.
74
+
75
+ Returns
76
+ -------
77
+ x_n : torch.Tensor
78
+ The normalized outputs.
79
+ """
80
+ shape_or = x.shape
81
+ if self.combine_batch_time:
82
+ if x.ndim == 3:
83
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
84
+ else:
85
+ x = x.reshape(
86
+ shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
87
+ )
88
+
89
+ elif not self.skip_transpose:
90
+ x = x.transpose(-1, 1)
91
+
92
+ x_n = self.norm(x)
93
+
94
+ if self.combine_batch_time:
95
+ x_n = x_n.reshape(shape_or)
96
+ elif not self.skip_transpose:
97
+ x_n = x_n.transpose(1, -1)
98
+
99
+ return x_n
spectrogram_augment.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class SpecAugment(nn.Module):
11
+ """
12
+ Zeroes out(cuts) random continuous horisontal or
13
+ vertical segments of the spectrogram as described in
14
+ SpecAugment (https://arxiv.org/abs/1904.08779).
15
+
16
+ params:
17
+ freq_masks - how many frequency segments should be cut
18
+ time_masks - how many time segments should be cut
19
+ freq_width - maximum number of frequencies to be cut in one segment
20
+ time_width - maximum number of time steps to be cut in one segment.
21
+ Can be a positive integer or a float value in the range [0, 1].
22
+ If positive integer value, defines maximum number of time steps
23
+ to be cut in one segment.
24
+ If a float value, defines maximum percentage of timesteps that
25
+ are cut adaptively.
26
+ use_vectorized_code - GPU-based implementation with batched masking and GPU rng,
27
+ setting it to False reverts to the legacy implementation.
28
+ Fast implementation is inspired by torchaudio:
29
+ https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/functional/functional.py#L816
30
+ """
31
+
32
+ FREQ_AXIS = 1 # Frequency axis in the spectrogram tensor
33
+ TIME_AXIS = 2 # Time axis in the spectrogram tensor
34
+
35
+ def __init__(
36
+ self,
37
+ freq_masks: int = 0,
38
+ time_masks: int = 0,
39
+ freq_width: int = 10,
40
+ time_width: Union[int, float] = 10,
41
+ rng: random.Random = None,
42
+ mask_value: float = 0.0,
43
+ use_vectorized_code: bool = True,
44
+ ):
45
+ super().__init__()
46
+
47
+ self._rng = random.Random() if rng is None else rng
48
+
49
+ self.freq_masks = freq_masks
50
+ self.time_masks = time_masks
51
+
52
+ self.freq_width = freq_width
53
+ self.time_width = time_width
54
+
55
+ self.mask_value = mask_value
56
+ self.use_vectorized_code = use_vectorized_code
57
+
58
+ if isinstance(time_width, int):
59
+ self.adaptive_temporal_width = False
60
+ else:
61
+ if time_width > 1.0 or time_width < 0.0:
62
+ raise ValueError("If `time_width` is a float value, must be in range [0, 1]")
63
+
64
+ self.adaptive_temporal_width = True
65
+
66
+ @torch.no_grad()
67
+ def forward(self, input_spec, length):
68
+ if self.use_vectorized_code:
69
+ return self._forward_vectorized(input_spec, length)
70
+ else:
71
+ return self._forward_legacy(input_spec, length)
72
+
73
+ def _forward_legacy(self, input_spec, length):
74
+ batch_size, num_freq_bins, _ = input_spec.shape
75
+ # Move lengths to CPU before repeated indexing
76
+ lengths_cpu = length.cpu().numpy()
77
+ # Generate a numpy boolean mask. `True` elements represent where the input spec will be augmented.
78
+ fill_mask: np.array = np.full(shape=input_spec.shape, fill_value=False)
79
+ freq_start_upper_bound = num_freq_bins - self.freq_width
80
+ # Choose different mask ranges for each element of the batch
81
+ for idx in range(batch_size):
82
+ # Set freq masking
83
+ for _ in range(self.freq_masks):
84
+ start = self._rng.randint(0, freq_start_upper_bound)
85
+ width = self._rng.randint(0, self.freq_width)
86
+ fill_mask[idx, start : start + width, :] = True
87
+
88
+ # Derive time width, sometimes based percentage of input length.
89
+ if self.adaptive_temporal_width:
90
+ time_max_width = max(1, int(lengths_cpu[idx] * self.time_width))
91
+ else:
92
+ time_max_width = self.time_width
93
+ time_start_upper_bound = max(1, lengths_cpu[idx] - time_max_width)
94
+
95
+ # Set time masking
96
+ for _ in range(self.time_masks):
97
+ start = self._rng.randint(0, time_start_upper_bound)
98
+ width = self._rng.randint(0, time_max_width)
99
+ fill_mask[idx, :, start : start + width] = True
100
+ # Bring the mask to device and fill spec
101
+ fill_mask = torch.from_numpy(fill_mask).to(input_spec.device)
102
+ masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value)
103
+ return masked_spec
104
+
105
+ def _forward_vectorized(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
106
+ # time masks
107
+ input_spec = self._apply_masks(
108
+ input_spec=input_spec,
109
+ num_masks=self.time_masks,
110
+ length=length,
111
+ width=self.time_width,
112
+ axis=self.TIME_AXIS,
113
+ mask_value=self.mask_value,
114
+ )
115
+ # freq masks
116
+ input_spec = self._apply_masks(
117
+ input_spec=input_spec,
118
+ num_masks=self.freq_masks,
119
+ length=length,
120
+ width=self.freq_width,
121
+ axis=self.FREQ_AXIS,
122
+ mask_value=self.mask_value,
123
+ )
124
+ return input_spec
125
+
126
+ def _apply_masks(
127
+ self,
128
+ input_spec: torch.Tensor,
129
+ num_masks: int,
130
+ length: torch.Tensor,
131
+ width: Union[int, float],
132
+ mask_value: float,
133
+ axis: int,
134
+ ) -> torch.Tensor:
135
+
136
+ assert axis in (
137
+ self.FREQ_AXIS,
138
+ self.TIME_AXIS,
139
+ ), f"Axis can be only be equal to frequency \
140
+ ({self.FREQ_AXIS}) or time ({self.TIME_AXIS}). Received: {axis=}"
141
+ assert not (
142
+ isinstance(width, float) and axis == self.FREQ_AXIS
143
+ ), "Float width supported \
144
+ only with time axis."
145
+
146
+ batch_size = input_spec.shape[0]
147
+ axis_length = input_spec.shape[axis]
148
+
149
+ # If width is float then it is transformed into a tensor
150
+ if axis == self.TIME_AXIS and isinstance(width, float):
151
+ width = torch.clamp(width * length, max=axis_length).unsqueeze(1)
152
+
153
+ # Generate [0-1) random numbers and then scale the tensors.
154
+ # Use float32 dtype for begin/end mask markers before they are quantized to long.
155
+ mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * width
156
+ mask_width = mask_width.long()
157
+ mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32)
158
+
159
+ if axis == self.TIME_AXIS:
160
+ # length can only be used for the time axis
161
+ mask_start = mask_start * (length.unsqueeze(1) - mask_width)
162
+ else:
163
+ mask_start = mask_start * (axis_length - mask_width)
164
+
165
+ mask_start = mask_start.long()
166
+ mask_end = mask_start + mask_width
167
+
168
+ # Create mask values using vectorized indexing
169
+ indices = torch.arange(axis_length, device=input_spec.device)
170
+ # Create a mask_tensor with all the indices.
171
+ # The mask_tensor shape is (batch_size, num_masks, axis_length).
172
+ mask_tensor = (indices >= mask_start.unsqueeze(-1)) & (indices < mask_end.unsqueeze(-1))
173
+
174
+ # Reduce masks to one mask
175
+ mask_tensor = mask_tensor.any(dim=1)
176
+
177
+ # Create a final mask that aligns with the full tensor
178
+ mask = torch.zeros_like(input_spec, dtype=torch.bool)
179
+ if axis == self.TIME_AXIS:
180
+ mask_ranges = mask_tensor[:, None, :]
181
+ else: # axis == self.FREQ_AXIS
182
+ mask_ranges = mask_tensor[:, :, None]
183
+ mask[:, :, :] = mask_ranges
184
+
185
+ # Apply the mask value
186
+ return input_spec.masked_fill(mask=mask, value=mask_value)
187
+
188
+
189
+ class SpecCutout(nn.Module):
190
+ """
191
+ Zeroes out(cuts) random rectangles in the spectrogram
192
+ as described in (https://arxiv.org/abs/1708.04552).
193
+
194
+ params:
195
+ rect_masks - how many rectangular masks should be cut
196
+ rect_freq - maximum size of cut rectangles along the frequency dimension
197
+ rect_time - maximum size of cut rectangles along the time dimension
198
+ """
199
+
200
+ def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None):
201
+ super(SpecCutout, self).__init__()
202
+
203
+ self._rng = random.Random() if rng is None else rng
204
+
205
+ self.rect_masks = rect_masks
206
+ self.rect_time = rect_time
207
+ self.rect_freq = rect_freq
208
+
209
+ @torch.no_grad()
210
+ def forward(self, input_spec):
211
+ sh = input_spec.shape
212
+
213
+ for idx in range(sh[0]):
214
+ for i in range(self.rect_masks):
215
+ rect_x = self._rng.randint(0, sh[1] - self.rect_freq)
216
+ rect_y = self._rng.randint(0, sh[2] - self.rect_time)
217
+
218
+ w_x = self._rng.randint(0, self.rect_freq)
219
+ w_y = self._rng.randint(0, self.rect_time)
220
+
221
+ input_spec[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0
222
+
223
+ return input_spec
tdnn_attention.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ from numpy import inf
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.init import _calculate_correct_fan
10
+
11
+
12
+ class StatsPoolLayer(nn.Module):
13
+ """Statistics and time average pooling (TAP) layer
14
+
15
+ This computes mean and, optionally, standard deviation statistics across the time dimension.
16
+
17
+ Args:
18
+ feat_in: Input features with shape [B, D, T]
19
+ pool_mode: Type of pool mode. Supported modes are 'xvector' (mean and standard deviation) and 'tap' (time
20
+ average pooling, i.e., mean)
21
+ eps: Epsilon, minimum value before taking the square root, when using 'xvector' mode.
22
+ unbiased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default
23
+ for torch.Tensor.std() is True.
24
+
25
+ Returns:
26
+ Pooled statistics with shape [B, D].
27
+
28
+ Raises:
29
+ ValueError if an unsupported pooling mode is specified.
30
+ """
31
+
32
+ def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, unbiased: bool = True):
33
+ super().__init__()
34
+ supported_modes = {"xvector", "tap"}
35
+ if pool_mode not in supported_modes:
36
+ raise ValueError(f"Pool mode must be one of {supported_modes}; got '{pool_mode}'")
37
+ self.pool_mode = pool_mode
38
+ self.feat_in = feat_in
39
+ self.eps = eps
40
+ self.unbiased = unbiased
41
+ if self.pool_mode == 'xvector':
42
+ # Mean + std
43
+ self.feat_in *= 2
44
+
45
+ def forward(self, encoder_output, length=None):
46
+ if length is None:
47
+ mean = encoder_output.mean(dim=-1) # Time Axis
48
+ if self.pool_mode == 'xvector':
49
+ correction = 1 if self.unbiased else 0
50
+ std = encoder_output.std(dim=-1, correction=correction).clamp(min=self.eps)
51
+ pooled = torch.cat([mean, std], dim=-1)
52
+ else:
53
+ pooled = mean
54
+ else:
55
+ mask = make_seq_mask_like(like=encoder_output, lengths=length, valid_ones=False)
56
+ encoder_output = encoder_output.masked_fill(mask, 0.0)
57
+ # [B, D, T] -> [B, D]
58
+ means = encoder_output.mean(dim=-1)
59
+ # Re-scale to get padded means
60
+ means = means * (encoder_output.shape[-1] / length).unsqueeze(-1)
61
+ if self.pool_mode == "xvector":
62
+ correction = 1 if self.unbiased else 0
63
+ stds = (
64
+ encoder_output.sub(means.unsqueeze(-1))
65
+ .masked_fill(mask, 0.0)
66
+ .pow(2.0)
67
+ .sum(-1) # [B, D, T] -> [B, D]
68
+ .div(length.view(-1, 1).sub(correction))
69
+ .clamp(min=self.eps)
70
+ .sqrt()
71
+ )
72
+ pooled = torch.cat((means, stds), dim=-1)
73
+ else:
74
+ pooled = means
75
+ return pooled
76
+
77
+
78
+ class AttentivePoolLayer(nn.Module):
79
+ """
80
+ Attention pooling layer for pooling speaker embeddings
81
+ Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
82
+ inputs:
83
+ inp_filters: input feature channel length from encoder
84
+ attention_channels: intermediate attention channel size
85
+ kernel_size: kernel_size for TDNN and attention conv1d layers (default: 1)
86
+ dilation: dilation size for TDNN and attention conv1d layers (default: 1)
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ inp_filters: int,
92
+ attention_channels: int = 128,
93
+ kernel_size: int = 1,
94
+ dilation: int = 1,
95
+ eps: float = 1e-10,
96
+ ):
97
+ super().__init__()
98
+
99
+ self.feat_in = 2 * inp_filters
100
+
101
+ self.attention_layer = nn.Sequential(
102
+ TDNNModule(inp_filters * 3, attention_channels, kernel_size=kernel_size, dilation=dilation),
103
+ nn.Tanh(),
104
+ nn.Conv1d(
105
+ in_channels=attention_channels,
106
+ out_channels=inp_filters,
107
+ kernel_size=kernel_size,
108
+ dilation=dilation,
109
+ ),
110
+ )
111
+ self.eps = eps
112
+
113
+ def forward(self, x, length=None):
114
+ max_len = x.size(2)
115
+
116
+ if length is None:
117
+ length = torch.ones(x.shape[0], device=x.device)
118
+
119
+ mask, num_values = lens_to_mask(length, max_len=max_len, device=x.device)
120
+
121
+ # encoder statistics
122
+ mean, std = get_statistics_with_mask(x, mask / num_values)
123
+ mean = mean.unsqueeze(2).repeat(1, 1, max_len)
124
+ std = std.unsqueeze(2).repeat(1, 1, max_len)
125
+ attn = torch.cat([x, mean, std], dim=1)
126
+
127
+ # attention statistics
128
+ attn = self.attention_layer(attn) # attention pass
129
+ attn = attn.masked_fill(mask == 0, -inf)
130
+ alpha = F.softmax(attn, dim=2) # attention values, α
131
+ mu, sg = get_statistics_with_mask(x, alpha) # µ and ∑
132
+
133
+ # gather
134
+ return torch.cat((mu, sg), dim=1).unsqueeze(2)
135
+
136
+
137
+ class TDNNModule(nn.Module):
138
+ """
139
+ Time Delayed Neural Module (TDNN) - 1D
140
+ input:
141
+ inp_filters: input filter channels for conv layer
142
+ out_filters: output filter channels for conv layer
143
+ kernel_size: kernel weight size for conv layer
144
+ dilation: dilation for conv layer
145
+ stride: stride for conv layer
146
+ padding: padding for conv layer (default None: chooses padding value such that input and output feature shape matches)
147
+ output:
148
+ tdnn layer output
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ inp_filters: int,
154
+ out_filters: int,
155
+ kernel_size: int = 1,
156
+ dilation: int = 1,
157
+ stride: int = 1,
158
+ padding: int = None,
159
+ ):
160
+ super().__init__()
161
+ if padding is None:
162
+ padding = get_same_padding(kernel_size, stride=stride, dilation=dilation)
163
+
164
+ self.conv_layer = nn.Conv1d(
165
+ in_channels=inp_filters,
166
+ out_channels=out_filters,
167
+ kernel_size=kernel_size,
168
+ dilation=dilation,
169
+ padding=padding,
170
+ )
171
+
172
+ self.activation = nn.ReLU()
173
+ self.bn = nn.BatchNorm1d(out_filters)
174
+
175
+ def forward(self, x, length=None):
176
+ x = self.conv_layer(x)
177
+ x = self.activation(x)
178
+ return self.bn(x)
179
+
180
+
181
+ class MaskedSEModule(nn.Module):
182
+ """
183
+ Squeeze and Excite module implementation with conv1d layers
184
+ input:
185
+ inp_filters: input filter channel size
186
+ se_filters: intermediate squeeze and excite channel output and input size
187
+ out_filters: output filter channel size
188
+ kernel_size: kernel_size for both conv1d layers
189
+ dilation: dilation size for both conv1d layers
190
+
191
+ output:
192
+ squeeze and excite layer output
193
+ """
194
+
195
+ def __init__(self, inp_filters: int, se_filters: int, out_filters: int, kernel_size: int = 1, dilation: int = 1):
196
+ super().__init__()
197
+ self.se_layer = nn.Sequential(
198
+ nn.Conv1d(
199
+ inp_filters,
200
+ se_filters,
201
+ kernel_size=kernel_size,
202
+ dilation=dilation,
203
+ ),
204
+ nn.ReLU(),
205
+ nn.BatchNorm1d(se_filters),
206
+ nn.Conv1d(
207
+ se_filters,
208
+ out_filters,
209
+ kernel_size=kernel_size,
210
+ dilation=dilation,
211
+ ),
212
+ nn.Sigmoid(),
213
+ )
214
+
215
+ def forward(self, input, length=None):
216
+ if length is None:
217
+ x = torch.mean(input, dim=2, keep_dim=True)
218
+ else:
219
+ max_len = input.size(2)
220
+ mask, num_values = lens_to_mask(length, max_len=max_len, device=input.device)
221
+ x = torch.sum((input * mask), dim=2, keepdim=True) / (num_values)
222
+
223
+ out = self.se_layer(x)
224
+ return out * input
225
+
226
+
227
+ class TDNNSEModule(nn.Module):
228
+ """
229
+ Modified building SE_TDNN group module block from ECAPA implementation for faster training and inference
230
+ Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
231
+ inputs:
232
+ inp_filters: input filter channel size
233
+ out_filters: output filter channel size
234
+ group_scale: scale value to group wider conv channels (deafult:8)
235
+ se_channels: squeeze and excite output channel size (deafult: 1024/8= 128)
236
+ kernel_size: kernel_size for group conv1d layers (default: 1)
237
+ dilation: dilation size for group conv1d layers (default: 1)
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ inp_filters: int,
243
+ out_filters: int,
244
+ group_scale: int = 8,
245
+ se_channels: int = 128,
246
+ kernel_size: int = 1,
247
+ dilation: int = 1,
248
+ init_mode: str = 'xavier_uniform',
249
+ ):
250
+ super().__init__()
251
+ self.out_filters = out_filters
252
+ padding_val = get_same_padding(kernel_size=kernel_size, dilation=dilation, stride=1)
253
+
254
+ group_conv = nn.Conv1d(
255
+ out_filters,
256
+ out_filters,
257
+ kernel_size=kernel_size,
258
+ dilation=dilation,
259
+ padding=padding_val,
260
+ groups=group_scale,
261
+ )
262
+ self.group_tdnn_block = nn.Sequential(
263
+ TDNNModule(inp_filters, out_filters, kernel_size=1, dilation=1),
264
+ group_conv,
265
+ nn.ReLU(),
266
+ nn.BatchNorm1d(out_filters),
267
+ TDNNModule(out_filters, out_filters, kernel_size=1, dilation=1),
268
+ )
269
+
270
+ self.se_layer = MaskedSEModule(out_filters, se_channels, out_filters)
271
+
272
+ self.apply(lambda x: init_weights(x, mode=init_mode))
273
+
274
+ def forward(self, input, length=None):
275
+ x = self.group_tdnn_block(input)
276
+ x = self.se_layer(x, length)
277
+ return x + input
278
+
279
+
280
+ class MaskedConv1d(nn.Module):
281
+ __constants__ = ["use_conv_mask", "real_out_channels", "heads"]
282
+
283
+ def __init__(
284
+ self,
285
+ in_channels,
286
+ out_channels,
287
+ kernel_size,
288
+ stride=1,
289
+ padding=0,
290
+ dilation=1,
291
+ groups=1,
292
+ heads=-1,
293
+ bias=False,
294
+ use_mask=True,
295
+ quantize=False,
296
+ ):
297
+ super(MaskedConv1d, self).__init__()
298
+
299
+ if not (heads == -1 or groups == in_channels):
300
+ raise ValueError("Only use heads for depthwise convolutions")
301
+
302
+ self.real_out_channels = out_channels
303
+ if heads != -1:
304
+ in_channels = heads
305
+ out_channels = heads
306
+ groups = heads
307
+
308
+ # preserve original padding
309
+ self._padding = padding
310
+
311
+ # if padding is a tuple/list, it is considered as asymmetric padding
312
+ if type(padding) in (tuple, list):
313
+ self.pad_layer = nn.ConstantPad1d(padding, value=0.0)
314
+ # reset padding for conv since pad_layer will handle this
315
+ padding = 0
316
+ else:
317
+ self.pad_layer = None
318
+
319
+ self.conv = nn.Conv1d(
320
+ in_channels,
321
+ out_channels,
322
+ kernel_size,
323
+ stride=stride,
324
+ padding=padding,
325
+ dilation=dilation,
326
+ groups=groups,
327
+ bias=bias,
328
+ )
329
+ self.use_mask = use_mask
330
+ self.heads = heads
331
+
332
+ # Calculations for "same" padding cache
333
+ self.same_padding = (self.conv.stride[0] == 1) and (
334
+ 2 * self.conv.padding[0] == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1)
335
+ )
336
+ if self.pad_layer is None:
337
+ self.same_padding_asymmetric = False
338
+ else:
339
+ self.same_padding_asymmetric = (self.conv.stride[0] == 1) and (
340
+ sum(self._padding) == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1)
341
+ )
342
+
343
+ # `self.lens` caches consecutive integers from 0 to `self.max_len` that are used to compute the mask for a
344
+ # batch. Recomputed to bigger size as needed. Stored on a device of the latest batch lens.
345
+ if self.use_mask:
346
+ self.max_len = torch.tensor(0)
347
+ self.lens = torch.tensor(0)
348
+
349
+ def get_seq_len(self, lens):
350
+ if self.same_padding or self.same_padding_asymmetric:
351
+ return lens
352
+
353
+ if self.pad_layer is None:
354
+ return (
355
+ torch.div(
356
+ lens + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1,
357
+ self.conv.stride[0],
358
+ rounding_mode='trunc',
359
+ )
360
+ + 1
361
+ )
362
+ else:
363
+ return (
364
+ torch.div(
365
+ lens + sum(self._padding) - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1,
366
+ self.conv.stride[0],
367
+ rounding_mode='trunc',
368
+ )
369
+ + 1
370
+ )
371
+
372
+ def forward(self, x, lens):
373
+ if self.use_mask:
374
+ # Generally will be called by ConvASREncoder, but kept as single gpu backup.
375
+ if x.size(2) > self.max_len:
376
+ self.update_masked_length(x.size(2), device=lens.device)
377
+ x = self.mask_input(x, lens)
378
+
379
+ # Update lengths
380
+ lens = self.get_seq_len(lens)
381
+
382
+ # asymmtric pad if necessary
383
+ if self.pad_layer is not None:
384
+ x = self.pad_layer(x)
385
+
386
+ sh = x.shape
387
+ if self.heads != -1:
388
+ x = x.view(-1, self.heads, sh[-1])
389
+
390
+ out = self.conv(x)
391
+
392
+ if self.heads != -1:
393
+ out = out.view(sh[0], self.real_out_channels, -1)
394
+
395
+ return out, lens
396
+
397
+ def update_masked_length(self, max_len, seq_range=None, device=None):
398
+ if seq_range is None:
399
+ self.lens, self.max_len = _masked_conv_init_lens(self.lens, max_len, self.max_len)
400
+ self.lens = self.lens.to(device)
401
+ else:
402
+ self.lens = seq_range
403
+ self.max_len = torch.tensor(max_len)
404
+
405
+ def mask_input(self, x, lens):
406
+ max_len = x.size(2)
407
+ mask = self.lens[:max_len].unsqueeze(0).to(lens.device) < lens.unsqueeze(1)
408
+ x = x * mask.unsqueeze(1).to(device=x.device)
409
+ return x
410
+
411
+
412
+ @torch.jit.script
413
+ def _masked_conv_init_lens(lens: torch.Tensor, current_maxlen: int, original_maxlen: torch.Tensor):
414
+ if current_maxlen > original_maxlen:
415
+ new_lens = torch.arange(current_maxlen)
416
+ new_max_lens = torch.tensor(current_maxlen)
417
+ else:
418
+ new_lens = lens
419
+ new_max_lens = original_maxlen
420
+ return new_lens, new_max_lens
421
+
422
+
423
+ def get_same_padding(kernel_size, stride, dilation) -> int:
424
+ if stride > 1 and dilation > 1:
425
+ raise ValueError("Only stride OR dilation may be greater than 1")
426
+ return (dilation * (kernel_size - 1)) // 2
427
+
428
+
429
+ def lens_to_mask(lens: List[int], max_len: int, device: str = None):
430
+ """
431
+ outputs masking labels for list of lengths of audio features, with max length of any
432
+ mask as max_len
433
+ input:
434
+ lens: list of lens
435
+ max_len: max length of any audio feature
436
+ output:
437
+ mask: masked labels
438
+ num_values: sum of mask values for each feature (useful for computing statistics later)
439
+ """
440
+ lens_mat = torch.arange(max_len).to(device)
441
+ mask = lens_mat[:max_len].unsqueeze(0) < lens.unsqueeze(1)
442
+ mask = mask.unsqueeze(1)
443
+ num_values = torch.sum(mask, dim=2, keepdim=True)
444
+ return mask, num_values
445
+
446
+
447
+ def get_statistics_with_mask(x: torch.Tensor, m: torch.Tensor, dim: int = 2, eps: float = 1e-10):
448
+ """
449
+ compute mean and standard deviation of input(x) provided with its masking labels (m)
450
+ input:
451
+ x: feature input
452
+ m: averaged mask labels
453
+ output:
454
+ mean: mean of input features
455
+ std: stadard deviation of input features
456
+ """
457
+ mean = torch.sum((m * x), dim=dim)
458
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
459
+ return mean, std
460
+
461
+
462
+ @torch.jit.script_if_tracing
463
+ def make_seq_mask_like(
464
+ like: torch.Tensor, lengths: torch.Tensor, valid_ones: bool = True, time_dim: int = -1
465
+ ) -> torch.Tensor:
466
+ mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.unsqueeze(-1))
467
+ # Match number of dims in `like` tensor
468
+ for _ in range(like.dim() - mask.dim()):
469
+ mask = mask.unsqueeze(1)
470
+ # If time dim != -1, transpose to proper dim.
471
+ if time_dim != -1:
472
+ mask = mask.transpose(time_dim, -1)
473
+ if not valid_ones:
474
+ mask = ~mask
475
+ return mask
476
+
477
+
478
+ def init_weights(m, mode: Optional[str] = 'xavier_uniform'):
479
+ if isinstance(m, MaskedConv1d):
480
+ init_weights(m.conv, mode)
481
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
482
+ if mode is not None:
483
+ if mode == 'xavier_uniform':
484
+ nn.init.xavier_uniform_(m.weight, gain=1.0)
485
+ elif mode == 'xavier_normal':
486
+ nn.init.xavier_normal_(m.weight, gain=1.0)
487
+ elif mode == 'kaiming_uniform':
488
+ nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
489
+ elif mode == 'kaiming_normal':
490
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
491
+ elif mode == 'tds_uniform':
492
+ tds_uniform_(m.weight)
493
+ elif mode == 'tds_normal':
494
+ tds_normal_(m.weight)
495
+ else:
496
+ raise ValueError("Unknown Initialization mode: {0}".format(mode))
497
+ elif isinstance(m, nn.BatchNorm1d):
498
+ if m.track_running_stats:
499
+ m.running_mean.zero_()
500
+ m.running_var.fill_(1)
501
+ m.num_batches_tracked.zero_()
502
+ if m.affine:
503
+ nn.init.ones_(m.weight)
504
+ nn.init.zeros_(m.bias)
505
+
506
+
507
+ def tds_uniform_(tensor, mode='fan_in'):
508
+ """
509
+ Uniform Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
510
+ Normalized to -
511
+
512
+ .. math::
513
+ \\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}}
514
+
515
+ Args:
516
+ tensor: an n-dimensional `torch.Tensor`
517
+ mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
518
+ preserves the magnitude of the variance of the weights in the
519
+ forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
520
+ backwards pass.
521
+ """
522
+ fan = _calculate_correct_fan(tensor, mode)
523
+ gain = 2.0 # sqrt(4.0) = 2
524
+ std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
525
+ bound = std # Calculate uniform bounds from standard deviation
526
+ with torch.no_grad():
527
+ return tensor.uniform_(-bound, bound)
528
+
529
+
530
+ def tds_normal_(tensor, mode='fan_in'):
531
+ """
532
+ Normal Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
533
+ Normalized to -
534
+
535
+ .. math::
536
+ \\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}}
537
+
538
+ Args:
539
+ tensor: an n-dimensional `torch.Tensor`
540
+ mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
541
+ preserves the magnitude of the variance of the weights in the
542
+ forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
543
+ backwards pass.
544
+ """
545
+ fan = _calculate_correct_fan(tensor, mode)
546
+ gain = 2.0
547
+ std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
548
+ bound = std # Calculate uniform bounds from standard deviation
549
+ with torch.no_grad():
550
+ return tensor.normal_(0.0, bound)