yangwang825 commited on
Commit
3957fd5
·
verified ·
1 Parent(s): eccfa9c

Upload ResNetForSequenceClassification

Browse files
Files changed (10) hide show
  1. angular_loss.py +68 -0
  2. audio_processing.py +413 -0
  3. config.json +6 -2
  4. conv_asr.py +528 -0
  5. features.py +560 -0
  6. model.safetensors +3 -0
  7. modeling_resnet.py +150 -0
  8. module.py +105 -0
  9. spectrogram_augment.py +223 -0
  10. tdnn_attention.py +749 -0
angular_loss.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Loss(nn.modules.loss._Loss):
6
+ """Inherit this class to implement custom loss."""
7
+
8
+ def __init__(self, **kwargs):
9
+ super(Loss, self).__init__(**kwargs)
10
+
11
+
12
+ class AdditiveMarginSoftmaxLoss(Loss):
13
+ """Computes Additive Margin Softmax (CosFace) Loss
14
+
15
+ Paper: CosFace: Large Margin Cosine Loss for Deep Face Recognition
16
+
17
+ args:
18
+ scale: scale value for cosine angle
19
+ margin: margin value added to cosine angle
20
+ """
21
+
22
+ def __init__(self, scale=30.0, margin=0.2):
23
+ super().__init__()
24
+
25
+ self.eps = 1e-7
26
+ self.scale = scale
27
+ self.margin = margin
28
+
29
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
30
+ # Extract the logits corresponding to the true class
31
+ logits_target = logits[torch.arange(logits.size(0)), labels] # Faster indexing
32
+ numerator = self.scale * (logits_target - self.margin) # Apply additive margin
33
+ # Exclude the target logits from denominator calculation
34
+ logits.scatter_(1, labels.unsqueeze(1), float('-inf')) # Mask target class
35
+ denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * logits), dim=1)
36
+ # Compute final loss
37
+ loss = -torch.log(torch.exp(numerator) / denominator)
38
+ return loss.mean()
39
+
40
+
41
+ class AdditiveAngularMarginSoftmaxLoss(Loss):
42
+ """Computes Additive Angular Margin Softmax (ArcFace) Loss
43
+
44
+ Paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition
45
+
46
+ Args:
47
+ scale: scale value for cosine angle
48
+ margin: margin value added to cosine angle
49
+ """
50
+
51
+ def __init__(self, scale=20.0, margin=1.35):
52
+ super().__init__()
53
+
54
+ self.eps = 1e-7
55
+ self.scale = scale
56
+ self.margin = margin
57
+
58
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
59
+ numerator = self.scale * torch.cos(
60
+ torch.acos(torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps))
61
+ + self.margin
62
+ )
63
+ excl = torch.cat(
64
+ [torch.cat((logits[i, :y], logits[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0
65
+ )
66
+ denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * excl), dim=1)
67
+ L = numerator - torch.log(denominator)
68
+ return -torch.mean(L)
audio_processing.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from packaging import version
3
+ from dataclasses import dataclass
4
+ from abc import ABC, abstractmethod
5
+
6
+ import torch
7
+
8
+ try:
9
+ import torchaudio
10
+ import torchaudio.functional
11
+ import torchaudio.transforms
12
+
13
+ TORCHAUDIO_VERSION = version.parse(torchaudio.__version__)
14
+ TORCHAUDIO_VERSION_MIN = version.parse('0.5')
15
+
16
+ HAVE_TORCHAUDIO = True
17
+ except ModuleNotFoundError:
18
+ HAVE_TORCHAUDIO = False
19
+
20
+ from .module import NeuralModule
21
+ from .features import FilterbankFeatures, FilterbankFeaturesTA
22
+ from .spectrogram_augment import SpecCutout, SpecAugment
23
+
24
+
25
+ class AudioPreprocessor(NeuralModule, ABC):
26
+ """
27
+ An interface for Neural Modules that performs audio pre-processing,
28
+ transforming the wav files to features.
29
+ """
30
+
31
+ def __init__(self, win_length, hop_length):
32
+ super().__init__()
33
+
34
+ self.win_length = win_length
35
+ self.hop_length = hop_length
36
+
37
+ self.torch_windows = {
38
+ 'hann': torch.hann_window,
39
+ 'hamming': torch.hamming_window,
40
+ 'blackman': torch.blackman_window,
41
+ 'bartlett': torch.bartlett_window,
42
+ 'ones': torch.ones,
43
+ None: torch.ones,
44
+ }
45
+
46
+ # Normally, when you call to(dtype) on a torch.nn.Module, all
47
+ # floating point parameters and buffers will change to that
48
+ # dtype, rather than being float32. The AudioPreprocessor
49
+ # classes, uniquely, don't actually have any parameters or
50
+ # buffers from what I see. In addition, we want the input to
51
+ # the preprocessor to be float32, but need to create the
52
+ # output in appropriate precision. We have this empty tensor
53
+ # here just to detect which dtype tensor this module should
54
+ # output at the end of execution.
55
+ self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False)
56
+
57
+ @torch.no_grad()
58
+ def forward(self, input_signal, length):
59
+ processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length)
60
+ processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype)
61
+ return processed_signal, processed_length
62
+
63
+ @abstractmethod
64
+ def get_features(self, input_signal, length):
65
+ # Called by forward(). Subclasses should implement this.
66
+ pass
67
+
68
+
69
+ class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
70
+ """Featurizer module that converts wavs to mel spectrograms.
71
+
72
+ Args:
73
+ sample_rate (int): Sample rate of the input audio data.
74
+ Defaults to 16000
75
+ window_size (float): Size of window for fft in seconds
76
+ Defaults to 0.02
77
+ window_stride (float): Stride of window for fft in seconds
78
+ Defaults to 0.01
79
+ n_window_size (int): Size of window for fft in samples
80
+ Defaults to None. Use one of window_size or n_window_size.
81
+ n_window_stride (int): Stride of window for fft in samples
82
+ Defaults to None. Use one of window_stride or n_window_stride.
83
+ window (str): Windowing function for fft. can be one of ['hann',
84
+ 'hamming', 'blackman', 'bartlett']
85
+ Defaults to "hann"
86
+ normalize (str): Can be one of ['per_feature', 'all_features']; all
87
+ other options disable feature normalization. 'all_features'
88
+ normalizes the entire spectrogram to be mean 0 with std 1.
89
+ 'pre_features' normalizes per channel / freq instead.
90
+ Defaults to "per_feature"
91
+ n_fft (int): Length of FT window. If None, it uses the smallest power
92
+ of 2 that is larger than n_window_size.
93
+ Defaults to None
94
+ preemph (float): Amount of pre emphasis to add to audio. Can be
95
+ disabled by passing None.
96
+ Defaults to 0.97
97
+ features (int): Number of mel spectrogram freq bins to output.
98
+ Defaults to 64
99
+ lowfreq (int): Lower bound on mel basis in Hz.
100
+ Defaults to 0
101
+ highfreq (int): Lower bound on mel basis in Hz.
102
+ Defaults to None
103
+ log (bool): Log features.
104
+ Defaults to True
105
+ log_zero_guard_type(str): Need to avoid taking the log of zero. There
106
+ are two options: "add" or "clamp".
107
+ Defaults to "add".
108
+ log_zero_guard_value(float, or str): Add or clamp requires the number
109
+ to add with or clamp to. log_zero_guard_value can either be a float
110
+ or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is
111
+ passed.
112
+ Defaults to 2**-24.
113
+ dither (float): Amount of white-noise dithering.
114
+ Defaults to 1e-5
115
+ pad_to (int): Ensures that the output size of the time dimension is
116
+ a multiple of pad_to.
117
+ Defaults to 16
118
+ frame_splicing (int): Defaults to 1
119
+ exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length
120
+ // hop_length. Defaults to False.
121
+ pad_value (float): The value that shorter mels are padded with.
122
+ Defaults to 0
123
+ mag_power (float): The power that the linear spectrogram is raised to
124
+ prior to multiplication with mel basis.
125
+ Defaults to 2 for a power spec
126
+ rng : Random number generator
127
+ nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to
128
+ samples in the batch.
129
+ Defaults to 0.0
130
+ nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
131
+ Defaults to 4000
132
+ use_torchaudio: Whether to use the `torchaudio` implementation.
133
+ mel_norm: Normalization used for mel filterbank weights.
134
+ Defaults to 'slaney' (area normalization)
135
+ stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
136
+ stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ sample_rate=16000,
142
+ window_size=0.02,
143
+ window_stride=0.01,
144
+ n_window_size=None,
145
+ n_window_stride=None,
146
+ window="hann",
147
+ normalize="per_feature",
148
+ n_fft=None,
149
+ preemph=0.97,
150
+ features=64,
151
+ lowfreq=0,
152
+ highfreq=None,
153
+ log=True,
154
+ log_zero_guard_type="add",
155
+ log_zero_guard_value=2**-24,
156
+ dither=1e-5,
157
+ pad_to=16,
158
+ frame_splicing=1,
159
+ exact_pad=False,
160
+ pad_value=0,
161
+ mag_power=2.0,
162
+ rng=None,
163
+ nb_augmentation_prob=0.0,
164
+ nb_max_freq=4000,
165
+ use_torchaudio: bool = False,
166
+ mel_norm="slaney",
167
+ stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
168
+ stft_conv=False, # Deprecated arguments; kept for config compatibility
169
+ ):
170
+ super().__init__(n_window_size, n_window_stride)
171
+
172
+ self._sample_rate = sample_rate
173
+ if window_size and n_window_size:
174
+ raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
175
+ if window_stride and n_window_stride:
176
+ raise ValueError(
177
+ f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
178
+ )
179
+ if window_size:
180
+ n_window_size = int(window_size * self._sample_rate)
181
+ if window_stride:
182
+ n_window_stride = int(window_stride * self._sample_rate)
183
+
184
+ # Given the long and similar argument list, point to the class and instantiate it by reference
185
+ if not use_torchaudio:
186
+ featurizer_class = FilterbankFeatures
187
+ else:
188
+ featurizer_class = FilterbankFeaturesTA
189
+ self.featurizer = featurizer_class(
190
+ sample_rate=self._sample_rate,
191
+ n_window_size=n_window_size,
192
+ n_window_stride=n_window_stride,
193
+ window=window,
194
+ normalize=normalize,
195
+ n_fft=n_fft,
196
+ preemph=preemph,
197
+ nfilt=features,
198
+ lowfreq=lowfreq,
199
+ highfreq=highfreq,
200
+ log=log,
201
+ log_zero_guard_type=log_zero_guard_type,
202
+ log_zero_guard_value=log_zero_guard_value,
203
+ dither=dither,
204
+ pad_to=pad_to,
205
+ frame_splicing=frame_splicing,
206
+ exact_pad=exact_pad,
207
+ pad_value=pad_value,
208
+ mag_power=mag_power,
209
+ rng=rng,
210
+ nb_augmentation_prob=nb_augmentation_prob,
211
+ nb_max_freq=nb_max_freq,
212
+ mel_norm=mel_norm,
213
+ stft_exact_pad=stft_exact_pad, # Deprecated arguments; kept for config compatibility
214
+ stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility
215
+ )
216
+
217
+ def get_features(self, input_signal, length):
218
+ return self.featurizer(input_signal, length)
219
+
220
+ @property
221
+ def filter_banks(self):
222
+ return self.featurizer.filter_banks
223
+
224
+
225
+ class AudioToMFCCPreprocessor(AudioPreprocessor):
226
+ """Preprocessor that converts wavs to MFCCs.
227
+ Uses torchaudio.transforms.MFCC.
228
+
229
+ Args:
230
+ sample_rate: The sample rate of the audio.
231
+ Defaults to 16000.
232
+ window_size: Size of window for fft in seconds. Used to calculate the
233
+ win_length arg for mel spectrogram.
234
+ Defaults to 0.02
235
+ window_stride: Stride of window for fft in seconds. Used to caculate
236
+ the hop_length arg for mel spect.
237
+ Defaults to 0.01
238
+ n_window_size: Size of window for fft in samples
239
+ Defaults to None. Use one of window_size or n_window_size.
240
+ n_window_stride: Stride of window for fft in samples
241
+ Defaults to None. Use one of window_stride or n_window_stride.
242
+ window: Windowing function for fft. can be one of ['hann',
243
+ 'hamming', 'blackman', 'bartlett', 'none', 'null'].
244
+ Defaults to 'hann'
245
+ n_fft: Length of FT window. If None, it uses the smallest power of 2
246
+ that is larger than n_window_size.
247
+ Defaults to None
248
+ lowfreq (int): Lower bound on mel basis in Hz.
249
+ Defaults to 0
250
+ highfreq (int): Lower bound on mel basis in Hz.
251
+ Defaults to None
252
+ n_mels: Number of mel filterbanks.
253
+ Defaults to 64
254
+ n_mfcc: Number of coefficients to retain
255
+ Defaults to 64
256
+ dct_type: Type of discrete cosine transform to use
257
+ norm: Type of norm to use
258
+ log: Whether to use log-mel spectrograms instead of db-scaled.
259
+ Defaults to True.
260
+ """
261
+
262
+ def __init__(
263
+ self,
264
+ sample_rate=16000,
265
+ window_size=0.02,
266
+ window_stride=0.01,
267
+ n_window_size=None,
268
+ n_window_stride=None,
269
+ window='hann',
270
+ n_fft=None,
271
+ lowfreq=0.0,
272
+ highfreq=None,
273
+ n_mels=64,
274
+ n_mfcc=64,
275
+ dct_type=2,
276
+ norm='ortho',
277
+ log=True,
278
+ ):
279
+ self._sample_rate = sample_rate
280
+ if not HAVE_TORCHAUDIO:
281
+ print('Could not import torchaudio. Some features might not work.')
282
+
283
+ raise ModuleNotFoundError(
284
+ "torchaudio is not installed but is necessary for "
285
+ "AudioToMFCCPreprocessor. We recommend you try "
286
+ "building it from source for the PyTorch version you have."
287
+ )
288
+ if window_size and n_window_size:
289
+ raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
290
+ if window_stride and n_window_stride:
291
+ raise ValueError(
292
+ f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
293
+ )
294
+ # Get win_length (n_window_size) and hop_length (n_window_stride)
295
+ if window_size:
296
+ n_window_size = int(window_size * self._sample_rate)
297
+ if window_stride:
298
+ n_window_stride = int(window_stride * self._sample_rate)
299
+
300
+ super().__init__(n_window_size, n_window_stride)
301
+
302
+ mel_kwargs = {}
303
+
304
+ mel_kwargs['f_min'] = lowfreq
305
+ mel_kwargs['f_max'] = highfreq
306
+ mel_kwargs['n_mels'] = n_mels
307
+
308
+ mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size))
309
+
310
+ mel_kwargs['win_length'] = n_window_size
311
+ mel_kwargs['hop_length'] = n_window_stride
312
+
313
+ # Set window_fn. None defaults to torch.ones.
314
+ window_fn = self.torch_windows.get(window, None)
315
+ if window_fn is None:
316
+ raise ValueError(
317
+ f"Window argument for AudioProcessor is invalid: {window}."
318
+ f"For no window function, use 'ones' or None."
319
+ )
320
+ mel_kwargs['window_fn'] = window_fn
321
+
322
+ # Use torchaudio's implementation of MFCCs as featurizer
323
+ self.featurizer = torchaudio.transforms.MFCC(
324
+ sample_rate=self._sample_rate,
325
+ n_mfcc=n_mfcc,
326
+ dct_type=dct_type,
327
+ norm=norm,
328
+ log_mels=log,
329
+ melkwargs=mel_kwargs,
330
+ )
331
+
332
+ def get_features(self, input_signal, length):
333
+ features = self.featurizer(input_signal)
334
+ seq_len = torch.ceil(length.to(torch.float32) / self.hop_length).to(dtype=torch.long)
335
+ return features, seq_len
336
+
337
+
338
+ class SpectrogramAugmentation(NeuralModule):
339
+ """
340
+ Performs time and freq cuts in one of two ways.
341
+ SpecAugment zeroes out vertical and horizontal sections as described in
342
+ SpecAugment (https://arxiv.org/abs/1904.08779). Arguments for use with
343
+ SpecAugment are `freq_masks`, `time_masks`, `freq_width`, and `time_width`.
344
+ SpecCutout zeroes out rectangulars as described in Cutout
345
+ (https://arxiv.org/abs/1708.04552). Arguments for use with Cutout are
346
+ `rect_masks`, `rect_freq`, and `rect_time`.
347
+
348
+ Args:
349
+ freq_masks (int): how many frequency segments should be cut.
350
+ Defaults to 0.
351
+ time_masks (int): how many time segments should be cut
352
+ Defaults to 0.
353
+ freq_width (int): maximum number of frequencies to be cut in one
354
+ segment.
355
+ Defaults to 10.
356
+ time_width (int): maximum number of time steps to be cut in one
357
+ segment
358
+ Defaults to 10.
359
+ rect_masks (int): how many rectangular masks should be cut
360
+ Defaults to 0.
361
+ rect_freq (int): maximum size of cut rectangles along the frequency
362
+ dimension
363
+ Defaults to 5.
364
+ rect_time (int): maximum size of cut rectangles along the time
365
+ dimension
366
+ Defaults to 25.
367
+ use_numba_spec_augment: use numba code for Spectrogram augmentation
368
+ use_vectorized_spec_augment: use vectorized code for Spectrogram augmentation
369
+
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ freq_masks=0,
375
+ time_masks=0,
376
+ freq_width=10,
377
+ time_width=10,
378
+ rect_masks=0,
379
+ rect_time=5,
380
+ rect_freq=20,
381
+ rng=None,
382
+ mask_value=0.0,
383
+ use_vectorized_spec_augment: bool = True,
384
+ ):
385
+ super().__init__()
386
+
387
+ if rect_masks > 0:
388
+ self.spec_cutout = SpecCutout(
389
+ rect_masks=rect_masks,
390
+ rect_time=rect_time,
391
+ rect_freq=rect_freq,
392
+ rng=rng,
393
+ )
394
+ # self.spec_cutout.to(self._device)
395
+ else:
396
+ self.spec_cutout = lambda input_spec: input_spec
397
+ if freq_masks + time_masks > 0:
398
+ self.spec_augment = SpecAugment(
399
+ freq_masks=freq_masks,
400
+ time_masks=time_masks,
401
+ freq_width=freq_width,
402
+ time_width=time_width,
403
+ rng=rng,
404
+ mask_value=mask_value,
405
+ use_vectorized_code=use_vectorized_spec_augment,
406
+ )
407
+ else:
408
+ self.spec_augment = lambda input_spec, length: input_spec
409
+
410
+ def forward(self, input_spec, length):
411
+ augmented_spec = self.spec_cutout(input_spec=input_spec)
412
+ augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length)
413
+ return augmented_spec
config.json CHANGED
@@ -1,11 +1,14 @@
1
  {
2
- "_attn_implementation_autoset": true,
3
  "angular": true,
4
  "angular_margin": 0.2,
5
  "angular_scale": 30,
 
 
 
6
  "attention_channels": 128,
7
  "auto_map": {
8
- "AutoConfig": "configuration_resnet.ResNetConfig"
 
9
  },
10
  "block_sizes": [
11
  3,
@@ -2575,6 +2578,7 @@
2575
  ],
2576
  "time_masks": 5,
2577
  "time_width": 0.03,
 
2578
  "transformers_version": "4.48.3",
2579
  "use_torchaudio": true,
2580
  "use_vectorized_spec_augment": true,
 
1
  {
 
2
  "angular": true,
3
  "angular_margin": 0.2,
4
  "angular_scale": 30,
5
+ "architectures": [
6
+ "ResNetForSequenceClassification"
7
+ ],
8
  "attention_channels": 128,
9
  "auto_map": {
10
+ "AutoConfig": "configuration_resnet.ResNetConfig",
11
+ "AutoModelForAudioClassification": "modeling_resnet.ResNetForSequenceClassification"
12
  },
13
  "block_sizes": [
14
  3,
 
2578
  ],
2579
  "time_masks": 5,
2580
  "time_width": 0.03,
2581
+ "torch_dtype": "float32",
2582
  "transformers_version": "4.48.3",
2583
  "use_torchaudio": true,
2584
  "use_vectorized_spec_augment": true,
conv_asr.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .module import NeuralModule
8
+ from .tdnn_attention import (
9
+ StatsPoolLayer,
10
+ AttentivePoolLayer,
11
+ ChannelDependentAttentiveStatisticsPoolLayer,
12
+ TdnnModule,
13
+ TdnnSeModule,
14
+ TdnnSeRes2NetModule,
15
+ init_weights
16
+ )
17
+
18
+
19
+ def conv3x3(in_planes, out_planes, stride=1, padding=1):
20
+ """2D convolution with kernel_size = 3"""
21
+ return nn.Conv2d(
22
+ in_planes,
23
+ out_planes,
24
+ kernel_size=3,
25
+ stride=stride,
26
+ padding=padding,
27
+ bias=False,
28
+ )
29
+
30
+
31
+ def conv1x1(in_planes, out_planes, stride=1):
32
+ """2D convolution with kernel_size = 1"""
33
+ return nn.Conv2d(
34
+ in_planes, out_planes, kernel_size=1, stride=stride, bias=False
35
+ )
36
+
37
+
38
+ class BasicBlock(nn.Module):
39
+
40
+ def __init__(
41
+ self,
42
+ in_channels,
43
+ out_channels,
44
+ stride=1,
45
+ downsample=None,
46
+ activation=nn.ReLU,
47
+ ):
48
+ super(BasicBlock, self).__init__()
49
+ self.activation = activation()
50
+
51
+ self.bn1 = nn.BatchNorm2d(in_channels)
52
+ self.conv1 = conv3x3(in_channels, out_channels, stride)
53
+
54
+ self.bn2 = nn.BatchNorm2d(out_channels)
55
+ self.conv2 = conv3x3(out_channels, out_channels)
56
+
57
+ self.bn3 = nn.BatchNorm2d(out_channels)
58
+ self.conv3 = conv1x1(out_channels, out_channels)
59
+
60
+ self.downsample = downsample
61
+ self.stride = stride
62
+
63
+ def forward(self, x):
64
+ residual = x
65
+ out = self.bn1(x)
66
+ out = self.activation(out)
67
+ out = self.conv1(out)
68
+
69
+ out = self.bn2(out)
70
+ out = self.activation(out)
71
+ out = self.conv2(out)
72
+
73
+ out = self.bn3(out)
74
+ out = self.activation(out)
75
+ out = self.conv3(out)
76
+
77
+ if self.downsample is not None:
78
+ residual = self.downsample(x)
79
+
80
+ out += residual
81
+
82
+ return out
83
+
84
+
85
+ class SEBlock(nn.Module):
86
+
87
+ def __init__(self, channels, reduction=1, activation=nn.ReLU):
88
+ super(SEBlock, self).__init__()
89
+
90
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
91
+
92
+ self.fc = nn.Sequential(
93
+ nn.Linear(channels, channels // reduction),
94
+ activation(),
95
+ nn.Linear(channels // reduction, channels),
96
+ nn.Sigmoid(),
97
+ )
98
+
99
+ def forward(self, x):
100
+ """Intermediate step. Processes the input tensor x
101
+ and returns an output tensor.
102
+ """
103
+ b, c, _, _ = x.size()
104
+ y = self.avg_pool(x).view(b, c)
105
+ y = self.fc(y).view(b, c, 1, 1)
106
+ return x * y
107
+
108
+
109
+ class SEBasicBlock(nn.Module):
110
+
111
+ def __init__(
112
+ self,
113
+ in_channels,
114
+ out_channels,
115
+ stride=1,
116
+ downsample=None,
117
+ activation=nn.ReLU,
118
+ reduction=1,
119
+ ):
120
+ super(SEBasicBlock, self).__init__()
121
+ self.activation = activation()
122
+
123
+ self.bn1 = nn.BatchNorm2d(in_channels)
124
+ self.conv1 = conv3x3(in_channels, out_channels, stride)
125
+
126
+ self.bn2 = nn.BatchNorm2d(out_channels)
127
+ self.conv2 = conv3x3(out_channels, out_channels)
128
+
129
+ self.bn3 = nn.BatchNorm2d(out_channels)
130
+ self.conv3 = conv1x1(out_channels, out_channels)
131
+
132
+ self.downsample = downsample
133
+ self.stride = stride
134
+
135
+ self.se = SEBlock(out_channels, reduction)
136
+
137
+ def forward(self, x):
138
+ residual = x
139
+
140
+ out = self.bn1(x)
141
+ out = self.activation(out)
142
+ out = self.conv1(out)
143
+
144
+ out = self.bn2(out)
145
+ out = self.activation(out)
146
+ out = self.conv2(out)
147
+
148
+ out = self.bn3(out)
149
+ out = self.activation(out)
150
+ out = self.conv3(out)
151
+
152
+ out = self.se(out)
153
+
154
+ if self.downsample is not None:
155
+ residual = self.downsample(x)
156
+
157
+ out += residual
158
+
159
+ return out
160
+
161
+
162
+ class SEBottleneck(nn.Module):
163
+
164
+ def __init__(
165
+ self,
166
+ in_channels,
167
+ out_channels,
168
+ stride=1,
169
+ downsample=None,
170
+ activation=nn.ReLU,
171
+ reduction=16, # Reduction ratio for SE block
172
+ ):
173
+ super(SEBottleneck, self).__init__()
174
+ self.activation = activation()
175
+
176
+ # 1x1 convolution to reduce channels
177
+ self.conv1 = conv1x1(in_channels, out_channels // 4, stride)
178
+ self.bn1 = nn.BatchNorm2d(out_channels // 4)
179
+
180
+ # 3x3 convolution
181
+ self.conv2 = conv3x3(out_channels // 4, out_channels // 4)
182
+ self.bn2 = nn.BatchNorm2d(out_channels // 4)
183
+
184
+ # 1x1 convolution to restore channels
185
+ self.conv3 = conv1x1(out_channels // 4, out_channels)
186
+ self.bn3 = nn.BatchNorm2d(out_channels)
187
+
188
+ # Squeeze-and-Excitation block
189
+ self.se = SEBlock(out_channels, reduction)
190
+
191
+ self.downsample = downsample
192
+ self.stride = stride
193
+
194
+ def forward(self, x):
195
+ residual = x
196
+
197
+ # First 1x1 convolution
198
+ out = self.conv1(x)
199
+ out = self.bn1(out)
200
+ out = self.activation(out)
201
+
202
+ # 3x3 convolution
203
+ out = self.conv2(out)
204
+ out = self.bn2(out)
205
+ out = self.activation(out)
206
+
207
+ # Second 1x1 convolution
208
+ out = self.conv3(out)
209
+ out = self.bn3(out)
210
+
211
+ # Apply SE block
212
+ out = self.se(out)
213
+
214
+ # Downsample residual if needed
215
+ if self.downsample is not None:
216
+ residual = self.downsample(x)
217
+
218
+ # Add residual
219
+ out += residual
220
+ out = self.activation(out)
221
+
222
+ return out
223
+
224
+
225
+ class Bottleneck(nn.Module):
226
+
227
+ def __init__(
228
+ self,
229
+ in_channels,
230
+ out_channels,
231
+ stride=1,
232
+ downsample=None,
233
+ activation=nn.ReLU,
234
+ ):
235
+ super(Bottleneck, self).__init__()
236
+ self.activation = activation()
237
+
238
+ # 1x1 convolution to reduce channels
239
+ self.conv1 = conv1x1(in_channels, out_channels // 4, stride)
240
+ self.bn1 = nn.BatchNorm2d(out_channels // 4)
241
+
242
+ # 3x3 convolution
243
+ self.conv2 = conv3x3(out_channels // 4, out_channels // 4)
244
+ self.bn2 = nn.BatchNorm2d(out_channels // 4)
245
+
246
+ # 1x1 convolution to restore channels
247
+ self.conv3 = conv1x1(out_channels // 4, out_channels)
248
+ self.bn3 = nn.BatchNorm2d(out_channels)
249
+
250
+ self.downsample = downsample
251
+ self.stride = stride
252
+
253
+ def forward(self, x):
254
+ residual = x
255
+
256
+ # First 1x1 convolution
257
+ out = self.conv1(x)
258
+ out = self.bn1(out)
259
+ out = self.activation(out)
260
+
261
+ # 3x3 convolution
262
+ out = self.conv2(out)
263
+ out = self.bn2(out)
264
+ out = self.activation(out)
265
+
266
+ # Second 1x1 convolution
267
+ out = self.conv3(out)
268
+ out = self.bn3(out)
269
+
270
+ # Downsample residual if needed
271
+ if self.downsample is not None:
272
+ residual = self.downsample(x)
273
+
274
+ # Add residual
275
+ out += residual
276
+ out = self.activation(out)
277
+
278
+ return out
279
+
280
+
281
+ class ResNetEncoder(NeuralModule):
282
+
283
+ def __init__(
284
+ self,
285
+ feat_in: int,
286
+ filters: list = [16, 32, 64, 128],
287
+ block_sizes: list = [3, 4, 6, 3],
288
+ strides: list = [1, 2, 2, 1],
289
+ block_type: str = 'basic', # basic, bottleneck
290
+ reduction: int = 8, # reduction for SE layer
291
+ init_mode: str = 'xavier_uniform',
292
+ ):
293
+ super().__init__()
294
+ if block_type == 'basic':
295
+ self.block_class = BasicBlock
296
+ self.se_block_class = SEBasicBlock
297
+ elif block_type == 'bottleneck':
298
+ self.block_class = Bottleneck
299
+ self.se_block_class = SEBottleneck
300
+
301
+ self.pre_conv = nn.Sequential(
302
+ nn.Conv2d(
303
+ in_channels=1,
304
+ out_channels=filters[0],
305
+ kernel_size=3,
306
+ stride=1,
307
+ padding=1,
308
+ bias=False
309
+ ),
310
+ nn.BatchNorm2d(filters[0]),
311
+ nn.ReLU(inplace=True)
312
+ )
313
+
314
+ self.layer1 = self._make_layer_se(
315
+ filters[0], filters[0], block_sizes[0], stride=strides[0], reduction=reduction
316
+ )
317
+ self.layer2 = self._make_layer_se(
318
+ filters[0], filters[1], block_sizes[1], stride=strides[1], reduction=reduction
319
+ )
320
+ self.layer3 = self._make_layer(
321
+ filters[1], filters[2], block_sizes[2], stride=strides[2]
322
+ )
323
+ self.layer4 = self._make_layer(
324
+ filters[2], filters[3], block_sizes[3], stride=strides[3]
325
+ )
326
+
327
+ self.apply(lambda x: init_weights(x, mode=init_mode))
328
+
329
+ def _make_layer_se(self, in_channels, out_channels, block_num, stride=1, reduction=1):
330
+ """Construct the squeeze-and-excitation block layer.
331
+
332
+ Arguments
333
+ ---------
334
+ in_channels : int
335
+ Number of input channels.
336
+ out_channels : int
337
+ The number of output channels.
338
+ block_num: int
339
+ Number of ResNet blocks for the network.
340
+ stride : int
341
+ Factor that reduce the spatial dimensionality. Default is 1
342
+
343
+ Returns
344
+ -------
345
+ se_block : nn.Sequential
346
+ Squeeze-and-excitation block
347
+ """
348
+ downsample = None
349
+ if stride != 1 or in_channels != out_channels:
350
+ downsample = nn.Sequential(
351
+ nn.Conv2d(
352
+ in_channels,
353
+ out_channels,
354
+ kernel_size=1,
355
+ stride=stride,
356
+ bias=False,
357
+ ),
358
+ nn.BatchNorm2d(out_channels),
359
+ )
360
+
361
+ layers = []
362
+ layers.append(
363
+ self.se_block_class(in_channels, out_channels, stride, downsample, reduction=reduction)
364
+ )
365
+
366
+ for i in range(1, block_num):
367
+ layers.append(self.se_block_class(out_channels, out_channels, reduction=reduction))
368
+
369
+ return nn.Sequential(*layers)
370
+
371
+ def _make_layer(self, in_channels, out_channels, block_num, stride=1):
372
+ """
373
+ Construct the ResNet block layer.
374
+
375
+ Arguments
376
+ ---------
377
+ in_channels : int
378
+ Number of input channels.
379
+ out_channels : int
380
+ The number of output channels.
381
+ block_num: int
382
+ Number of ResNet blocks for the network.
383
+ stride : int
384
+ Factor that reduce the spatial dimensionality. Default is 1
385
+
386
+ Returns
387
+ -------
388
+ block : nn.Sequential
389
+ ResNet block
390
+ """
391
+ downsample = None
392
+ if stride != 1 or in_channels != out_channels:
393
+ downsample = nn.Sequential(
394
+ nn.Conv2d(
395
+ in_channels,
396
+ out_channels,
397
+ kernel_size=1,
398
+ stride=stride,
399
+ bias=False,
400
+ ),
401
+ nn.BatchNorm2d(out_channels),
402
+ )
403
+
404
+ layers = []
405
+ layers.append(self.block_class(in_channels, out_channels, stride, downsample))
406
+
407
+ for i in range(1, block_num):
408
+ layers.append(self.block_class(out_channels, out_channels))
409
+ return nn.Sequential(*layers)
410
+
411
+ def forward(self, audio_signal: torch.Tensor, length: torch.Tensor = None):
412
+ x = audio_signal
413
+ x = x.unsqueeze(dim=1) # (B, 1, C, T)
414
+
415
+ x = self.pre_conv(x)
416
+ x = self.layer1(x)
417
+ x = self.layer2(x)
418
+ x = self.layer3(x)
419
+ x = self.layer4(x)
420
+ x = x.flatten(1, 2)
421
+
422
+ return x, length
423
+
424
+
425
+ class SpeakerDecoder(NeuralModule):
426
+ """
427
+ Speaker Decoder creates the final neural layers that maps from the outputs
428
+ of Jasper Encoder to the embedding layer followed by speaker based softmax loss.
429
+
430
+ Args:
431
+ feat_in (int): Number of channels being input to this module
432
+ num_classes (int): Number of unique speakers in dataset
433
+ emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings
434
+ from 1st of this layers). Defaults to [1024,1024]
435
+ pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention'
436
+ Defaults to 'xvector (mean and variance)'
437
+ tap (temporal average pooling: just mean)
438
+ attention (attention based pooling)
439
+ init_mode (str): Describes how neural network parameters are
440
+ initialized. Options are ['xavier_uniform', 'xavier_normal',
441
+ 'kaiming_uniform','kaiming_normal'].
442
+ Defaults to "xavier_uniform".
443
+ """
444
+
445
+ def __init__(
446
+ self,
447
+ feat_in: int,
448
+ num_classes: int,
449
+ emb_sizes: Optional[Union[int, list]] = 256,
450
+ pool_mode: str = 'xvector',
451
+ angular: bool = False,
452
+ attention_channels: int = 128,
453
+ init_mode: str = "xavier_uniform",
454
+ ):
455
+ super().__init__()
456
+ self.angular = angular
457
+ self.emb_id = 2
458
+ bias = False if self.angular else True
459
+ emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes
460
+
461
+ self._num_classes = num_classes
462
+ self.pool_mode = pool_mode.lower()
463
+ if self.pool_mode == 'xvector' or self.pool_mode == 'tap':
464
+ self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode)
465
+ affine_type = 'linear'
466
+ elif self.pool_mode == 'attention':
467
+ self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels)
468
+ affine_type = 'conv'
469
+ elif self.pool_mode == 'ecapa2':
470
+ self._pooling = ChannelDependentAttentiveStatisticsPoolLayer(
471
+ inp_filters=feat_in, attention_channels=attention_channels
472
+ )
473
+ affine_type = 'conv'
474
+
475
+ shapes = [self._pooling.feat_in]
476
+ for size in emb_sizes:
477
+ shapes.append(int(size))
478
+
479
+ emb_layers = []
480
+ for shape_in, shape_out in zip(shapes[:-1], shapes[1:]):
481
+ layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type)
482
+ emb_layers.append(layer)
483
+
484
+ self.emb_layers = nn.ModuleList(emb_layers)
485
+
486
+ self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias)
487
+
488
+ self.apply(lambda x: init_weights(x, mode=init_mode))
489
+
490
+ def affine_layer(
491
+ self,
492
+ inp_shape,
493
+ out_shape,
494
+ learn_mean=True,
495
+ affine_type='conv',
496
+ ):
497
+ if affine_type == 'conv':
498
+ layer = nn.Sequential(
499
+ nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True),
500
+ nn.Conv1d(inp_shape, out_shape, kernel_size=1),
501
+ )
502
+
503
+ else:
504
+ layer = nn.Sequential(
505
+ nn.Linear(inp_shape, out_shape),
506
+ nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True),
507
+ nn.ReLU(),
508
+ )
509
+
510
+ return layer
511
+
512
+ def forward(self, encoder_output, length=None):
513
+ pool = self._pooling(encoder_output, length)
514
+ embs = []
515
+
516
+ for layer in self.emb_layers:
517
+ pool, emb = layer(pool), layer[: self.emb_id](pool)
518
+ embs.append(emb)
519
+
520
+ pool = pool.squeeze(-1)
521
+ if self.angular:
522
+ for W in self.final.parameters():
523
+ W = F.normalize(W, p=2, dim=1)
524
+ pool = F.normalize(pool, p=2, dim=1)
525
+
526
+ out = self.final(pool)
527
+
528
+ return out, embs[-1].squeeze(-1)
features.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Optional, Union, Tuple
4
+
5
+ import librosa
6
+ import torchaudio
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import torchaudio
13
+
14
+ HAVE_TORCHAUDIO = True
15
+ except ModuleNotFoundError:
16
+ HAVE_TORCHAUDIO = False
17
+
18
+ CONSTANT = 1e-5
19
+
20
+
21
+ def normalize_batch(x, seq_len, normalize_type):
22
+ x_mean = None
23
+ x_std = None
24
+ if normalize_type == "per_feature":
25
+ batch_size = x.shape[0]
26
+ max_time = x.shape[2]
27
+
28
+ # When doing stream capture to a graph, item() is not allowed
29
+ # becuase it calls cudaStreamSynchronize(). Therefore, we are
30
+ # sacrificing some error checking when running with cuda graphs.
31
+ if (
32
+ torch.cuda.is_available()
33
+ and not torch.cuda.is_current_stream_capturing()
34
+ and torch.any(seq_len == 1).item()
35
+ ):
36
+ raise ValueError(
37
+ "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
38
+ "in torch.std() returning nan. Make sure your audio length has enough samples for a single "
39
+ "feature (ex. at least `hop_length` for Mel Spectrograms)."
40
+ )
41
+ time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time)
42
+ valid_mask = time_steps < seq_len.unsqueeze(1)
43
+ x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2)
44
+ x_mean_denominator = valid_mask.sum(axis=1)
45
+ x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1)
46
+
47
+ # Subtract 1 in the denominator to correct for the bias.
48
+ x_std = torch.sqrt(
49
+ torch.sum(torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2, axis=2)
50
+ / (x_mean_denominator.unsqueeze(1) - 1.0)
51
+ )
52
+ # make sure x_std is not zero
53
+ x_std += CONSTANT
54
+ return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
55
+ elif normalize_type == "all_features":
56
+ x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
57
+ x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
58
+ for i in range(x.shape[0]):
59
+ x_mean[i] = x[i, :, : seq_len[i].item()].mean()
60
+ x_std[i] = x[i, :, : seq_len[i].item()].std()
61
+ # make sure x_std is not zero
62
+ x_std += CONSTANT
63
+ return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
64
+ elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
65
+ x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
66
+ x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
67
+ return (
68
+ (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
69
+ x_mean,
70
+ x_std,
71
+ )
72
+ else:
73
+ return x, x_mean, x_std
74
+
75
+
76
+ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Tensor, fill_value=0.0) -> torch.Tensor:
77
+ """
78
+ Fill spectrogram values outside the length with `fill_value`
79
+
80
+ Args:
81
+ spectrogram: Tensor with shape [B, C, L] containing batched spectrograms
82
+ spectrogram_len: Tensor with shape [B] containing the sequence length of each batch element
83
+ fill_value: value to fill with, 0.0 by default
84
+
85
+ Returns:
86
+ cleaned spectrogram, tensor with shape equal to `spectrogram`
87
+ """
88
+ device = spectrogram.device
89
+ batch_size, _, max_len = spectrogram.shape
90
+ mask = torch.arange(max_len, device=device)[None, :] >= spectrogram_len[:, None]
91
+ mask = mask.unsqueeze(1).expand_as(spectrogram)
92
+ return spectrogram.masked_fill(mask, fill_value)
93
+
94
+
95
+ def splice_frames(x, frame_splicing):
96
+ """Stacks frames together across feature dim
97
+
98
+ input is batch_size, feature_dim, num_frames
99
+ output is batch_size, feature_dim*frame_splicing, num_frames
100
+
101
+ """
102
+ seq = [x]
103
+ for n in range(1, frame_splicing):
104
+ seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
105
+ return torch.cat(seq, dim=1)
106
+
107
+
108
+ @torch.jit.script_if_tracing
109
+ def make_seq_mask_like(
110
+ lengths: torch.Tensor, like: torch.Tensor, time_dim: int = -1, valid_ones: bool = True
111
+ ) -> torch.Tensor:
112
+ """
113
+
114
+ Args:
115
+ lengths: Tensor with shape [B] containing the sequence length of each batch element
116
+ like: The mask will contain the same number of dimensions as this Tensor, and will have the same max
117
+ length in the time dimension of this Tensor.
118
+ time_dim: Time dimension of the `shape_tensor` and the resulting mask. Zero-based.
119
+ valid_ones: If True, valid tokens will contain value `1` and padding will be `0`. Else, invert.
120
+
121
+ Returns:
122
+ A :class:`torch.Tensor` containing 1's and 0's for valid and invalid tokens, respectively, if `valid_ones`, else
123
+ vice-versa. Mask will have the same number of dimensions as `like`. Batch and time dimensions will match
124
+ the `like`. All other dimensions will be singletons. E.g., if `like.shape == [3, 4, 5]` and
125
+ `time_dim == -1', mask will have shape `[3, 1, 5]`.
126
+ """
127
+ # Mask with shape [B, T]
128
+ mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.view(-1, 1))
129
+ # [B, T] -> [B, *, T] where * is any number of singleton dimensions to expand to like tensor
130
+ for _ in range(like.dim() - mask.dim()):
131
+ mask = mask.unsqueeze(1)
132
+ # If needed, transpose time dim
133
+ if time_dim != -1 and time_dim != mask.dim() - 1:
134
+ mask = mask.transpose(-1, time_dim)
135
+ # Maybe invert the padded vs. valid token values
136
+ if not valid_ones:
137
+ mask = ~mask
138
+ return mask
139
+
140
+
141
+ class FilterbankFeatures(nn.Module):
142
+ """Featurizer that converts wavs to Mel Spectrograms.
143
+ See AudioToMelSpectrogramPreprocessor for args.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ sample_rate=16000,
149
+ n_window_size=320,
150
+ n_window_stride=160,
151
+ window="hann",
152
+ normalize="per_feature",
153
+ n_fft=None,
154
+ preemph=0.97,
155
+ nfilt=64,
156
+ lowfreq=0,
157
+ highfreq=None,
158
+ log=True,
159
+ log_zero_guard_type="add",
160
+ log_zero_guard_value=2**-24,
161
+ dither=CONSTANT,
162
+ pad_to=16,
163
+ max_duration=16.7,
164
+ frame_splicing=1,
165
+ exact_pad=False,
166
+ pad_value=0,
167
+ mag_power=2.0,
168
+ use_grads=False,
169
+ rng=None,
170
+ nb_augmentation_prob=0.0,
171
+ nb_max_freq=4000,
172
+ mel_norm="slaney",
173
+ stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
174
+ stft_conv=False, # Deprecated arguments; kept for config compatibility
175
+ ):
176
+ super().__init__()
177
+ if stft_conv or stft_exact_pad:
178
+ print(
179
+ "Using torch_stft is deprecated and has been removed. The values have been forcibly set to False "
180
+ "for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
181
+ "as needed."
182
+ )
183
+ if exact_pad and n_window_stride % 2 == 1:
184
+ raise NotImplementedError(
185
+ f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
186
+ "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
187
+ )
188
+ self.log_zero_guard_value = log_zero_guard_value
189
+ if (
190
+ n_window_size is None
191
+ or n_window_stride is None
192
+ or not isinstance(n_window_size, int)
193
+ or not isinstance(n_window_stride, int)
194
+ or n_window_size <= 0
195
+ or n_window_stride <= 0
196
+ ):
197
+ raise ValueError(
198
+ f"{self} got an invalid value for either n_window_size or "
199
+ f"n_window_stride. Both must be positive ints."
200
+ )
201
+
202
+ self.win_length = n_window_size
203
+ self.hop_length = n_window_stride
204
+ self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
205
+ self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
206
+ self.exact_pad = exact_pad
207
+
208
+ if exact_pad:
209
+ print("STFT using exact pad")
210
+ torch_windows = {
211
+ 'hann': torch.hann_window,
212
+ 'hamming': torch.hamming_window,
213
+ 'blackman': torch.blackman_window,
214
+ 'bartlett': torch.bartlett_window,
215
+ 'none': None,
216
+ }
217
+ window_fn = torch_windows.get(window, None)
218
+ window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
219
+ self.register_buffer("window", window_tensor)
220
+
221
+ self.normalize = normalize
222
+ self.log = log
223
+ self.dither = dither
224
+ self.frame_splicing = frame_splicing
225
+ self.nfilt = nfilt
226
+ self.preemph = preemph
227
+ self.pad_to = pad_to
228
+ highfreq = highfreq or sample_rate / 2
229
+
230
+ filterbanks = torch.tensor(
231
+ librosa.filters.mel(
232
+ sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm
233
+ ),
234
+ dtype=torch.float,
235
+ ).unsqueeze(0)
236
+ self.register_buffer("fb", filterbanks)
237
+
238
+ # Calculate maximum sequence length
239
+ max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
240
+ max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
241
+ self.max_length = max_length + max_pad
242
+ self.pad_value = pad_value
243
+ self.mag_power = mag_power
244
+
245
+ # We want to avoid taking the log of zero
246
+ # There are two options: either adding or clamping to a small value
247
+ if log_zero_guard_type not in ["add", "clamp"]:
248
+ raise ValueError(
249
+ f"{self} received {log_zero_guard_type} for the "
250
+ f"log_zero_guard_type parameter. It must be either 'add' or "
251
+ f"'clamp'."
252
+ )
253
+
254
+ self.use_grads = use_grads
255
+ if not use_grads:
256
+ self.forward = torch.no_grad()(self.forward)
257
+ self._rng = random.Random() if rng is None else rng
258
+ self.nb_augmentation_prob = nb_augmentation_prob
259
+ if self.nb_augmentation_prob > 0.0:
260
+ if nb_max_freq >= sample_rate / 2:
261
+ self.nb_augmentation_prob = 0.0
262
+ else:
263
+ self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
264
+
265
+ # log_zero_guard_value is the the small we want to use, we support
266
+ # an actual number, or "tiny", or "eps"
267
+ self.log_zero_guard_type = log_zero_guard_type
268
+
269
+ def stft(self, x):
270
+ return torch.stft(
271
+ x,
272
+ n_fft=self.n_fft,
273
+ hop_length=self.hop_length,
274
+ win_length=self.win_length,
275
+ center=False if self.exact_pad else True,
276
+ window=self.window.to(dtype=torch.float),
277
+ return_complex=True,
278
+ )
279
+
280
+ def log_zero_guard_value_fn(self, x):
281
+ if isinstance(self.log_zero_guard_value, str):
282
+ if self.log_zero_guard_value == "tiny":
283
+ return torch.finfo(x.dtype).tiny
284
+ elif self.log_zero_guard_value == "eps":
285
+ return torch.finfo(x.dtype).eps
286
+ else:
287
+ raise ValueError(
288
+ f"{self} received {self.log_zero_guard_value} for the "
289
+ f"log_zero_guard_type parameter. It must be either a "
290
+ f"number, 'tiny', or 'eps'"
291
+ )
292
+ else:
293
+ return self.log_zero_guard_value
294
+
295
+ def get_seq_len(self, seq_len):
296
+ # Assuming that center is True is stft_pad_amount = 0
297
+ pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
298
+ seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1
299
+ return seq_len.to(dtype=torch.long)
300
+
301
+ @property
302
+ def filter_banks(self):
303
+ return self.fb
304
+
305
+ def forward(self, x, seq_len, linear_spec=False):
306
+ seq_len = self.get_seq_len(seq_len)
307
+
308
+ if self.stft_pad_amount is not None:
309
+ x = torch.nn.functional.pad(
310
+ x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
311
+ ).squeeze(1)
312
+
313
+ # dither (only in training mode for eval determinism)
314
+ if self.training and self.dither > 0:
315
+ x += self.dither * torch.randn_like(x)
316
+
317
+ # do preemphasis
318
+ if self.preemph is not None:
319
+ x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
320
+
321
+ # disable autocast to get full range of stft values
322
+ with torch.amp.autocast(x.device.type, enabled=False):
323
+ x = self.stft(x)
324
+
325
+ # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
326
+ # guard is needed for sqrt if grads are passed through
327
+ guard = 0 if not self.use_grads else CONSTANT
328
+ x = torch.view_as_real(x)
329
+ x = torch.sqrt(x.pow(2).sum(-1) + guard)
330
+
331
+ if self.training and self.nb_augmentation_prob > 0.0:
332
+ for idx in range(x.shape[0]):
333
+ if self._rng.random() < self.nb_augmentation_prob:
334
+ x[idx, self._nb_max_fft_bin :, :] = 0.0
335
+
336
+ # get power spectrum
337
+ if self.mag_power != 1.0:
338
+ x = x.pow(self.mag_power)
339
+
340
+ # return plain spectrogram if required
341
+ if linear_spec:
342
+ return x, seq_len
343
+
344
+ # dot with filterbank energies
345
+ x = torch.matmul(self.fb.to(x.dtype), x)
346
+ # log features if required
347
+ if self.log:
348
+ if self.log_zero_guard_type == "add":
349
+ x = torch.log(x + self.log_zero_guard_value_fn(x))
350
+ elif self.log_zero_guard_type == "clamp":
351
+ x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
352
+ else:
353
+ raise ValueError("log_zero_guard_type was not understood")
354
+
355
+ # frame splicing if required
356
+ if self.frame_splicing > 1:
357
+ x = splice_frames(x, self.frame_splicing)
358
+
359
+ # normalize if required
360
+ if self.normalize:
361
+ x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)
362
+
363
+ # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
364
+ max_len = x.size(-1)
365
+ mask = torch.arange(max_len, device=x.device)
366
+ mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
367
+ x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
368
+ del mask
369
+ pad_to = self.pad_to
370
+ if pad_to == "max":
371
+ x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
372
+ elif pad_to > 0:
373
+ pad_amt = x.size(-1) % pad_to
374
+ if pad_amt != 0:
375
+ x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
376
+ return x, seq_len
377
+
378
+
379
+ class FilterbankFeaturesTA(nn.Module):
380
+ """
381
+ Exportable, `torchaudio`-based implementation of Mel Spectrogram extraction.
382
+
383
+ See `AudioToMelSpectrogramPreprocessor` for args.
384
+
385
+ """
386
+
387
+ def __init__(
388
+ self,
389
+ sample_rate: int = 16000,
390
+ n_window_size: int = 320,
391
+ n_window_stride: int = 160,
392
+ normalize: Optional[str] = "per_feature",
393
+ nfilt: int = 64,
394
+ n_fft: Optional[int] = None,
395
+ preemph: float = 0.97,
396
+ lowfreq: float = 0,
397
+ highfreq: Optional[float] = None,
398
+ log: bool = True,
399
+ log_zero_guard_type: str = "add",
400
+ log_zero_guard_value: Union[float, str] = 2**-24,
401
+ dither: float = 1e-5,
402
+ window: str = "hann",
403
+ pad_to: int = 0,
404
+ pad_value: float = 0.0,
405
+ mel_norm="slaney",
406
+ # Seems like no one uses these options anymore. Don't convolute the code by supporting thm.
407
+ use_grads: bool = False, # Deprecated arguments; kept for config compatibility
408
+ max_duration: float = 16.7, # Deprecated arguments; kept for config compatibility
409
+ frame_splicing: int = 1, # Deprecated arguments; kept for config compatibility
410
+ exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
411
+ nb_augmentation_prob: float = 0.0, # Deprecated arguments; kept for config compatibility
412
+ nb_max_freq: int = 4000, # Deprecated arguments; kept for config compatibility
413
+ mag_power: float = 2.0, # Deprecated arguments; kept for config compatibility
414
+ rng: Optional[random.Random] = None, # Deprecated arguments; kept for config compatibility
415
+ stft_exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
416
+ stft_conv: bool = False, # Deprecated arguments; kept for config compatibility
417
+ ):
418
+ super().__init__()
419
+ if not HAVE_TORCHAUDIO:
420
+ raise ValueError(f"Need to install torchaudio to instantiate a {self.__class__.__name__}")
421
+
422
+ # Make sure log zero guard is supported, if given as a string
423
+ supported_log_zero_guard_strings = {"eps", "tiny"}
424
+ if isinstance(log_zero_guard_value, str) and log_zero_guard_value not in supported_log_zero_guard_strings:
425
+ raise ValueError(
426
+ f"Log zero guard value must either be a float or a member of {supported_log_zero_guard_strings}"
427
+ )
428
+
429
+ # Copied from `AudioPreprocessor` due to the ad-hoc structuring of the Mel Spec extractor class
430
+ self.torch_windows = {
431
+ 'hann': torch.hann_window,
432
+ 'hamming': torch.hamming_window,
433
+ 'blackman': torch.blackman_window,
434
+ 'bartlett': torch.bartlett_window,
435
+ 'ones': torch.ones,
436
+ None: torch.ones,
437
+ }
438
+
439
+ # Ensure we can look up the window function
440
+ if window not in self.torch_windows:
441
+ raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}")
442
+
443
+ self.win_length = n_window_size
444
+ self.hop_length = n_window_stride
445
+ self._sample_rate = sample_rate
446
+ self._normalize_strategy = normalize
447
+ self._use_log = log
448
+ self._preemphasis_value = preemph
449
+ self.log_zero_guard_type = log_zero_guard_type
450
+ self.log_zero_guard_value: Union[str, float] = log_zero_guard_value
451
+ self.dither = dither
452
+ self.pad_to = pad_to
453
+ self.pad_value = pad_value
454
+ self.n_fft = n_fft
455
+ self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram(
456
+ sample_rate=self._sample_rate,
457
+ win_length=self.win_length,
458
+ hop_length=self.hop_length,
459
+ n_mels=nfilt,
460
+ window_fn=self.torch_windows[window],
461
+ mel_scale="slaney",
462
+ norm=mel_norm,
463
+ n_fft=n_fft,
464
+ f_max=highfreq,
465
+ f_min=lowfreq,
466
+ wkwargs={"periodic": False},
467
+ )
468
+
469
+ @property
470
+ def filter_banks(self):
471
+ """Matches the analogous class"""
472
+ return self._mel_spec_extractor.mel_scale.fb
473
+
474
+ def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
475
+ if isinstance(self.log_zero_guard_value, float):
476
+ return self.log_zero_guard_value
477
+ return getattr(torch.finfo(dtype), self.log_zero_guard_value)
478
+
479
+ def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor:
480
+ if self.training and self.dither > 0.0:
481
+ noise = torch.randn_like(signals) * self.dither
482
+ signals = signals + noise
483
+ return signals
484
+
485
+ def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor:
486
+ if self._preemphasis_value is not None:
487
+ padded = torch.nn.functional.pad(signals, (1, 0))
488
+ signals = signals - self._preemphasis_value * padded[:, :-1]
489
+ return signals
490
+
491
+ def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
492
+ out_lengths = input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
493
+ return out_lengths
494
+
495
+ def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor:
496
+ # Only apply during training; else need to capture dynamic shape for exported models
497
+ if not self.training or self.pad_to == 0 or features.shape[-1] % self.pad_to == 0:
498
+ return features
499
+ pad_length = self.pad_to - (features.shape[-1] % self.pad_to)
500
+ return torch.nn.functional.pad(features, pad=(0, pad_length), value=self.pad_value)
501
+
502
+ def _apply_log(self, features: torch.Tensor) -> torch.Tensor:
503
+ if self._use_log:
504
+ zero_guard = self._resolve_log_zero_guard_value(features.dtype)
505
+ if self.log_zero_guard_type == "add":
506
+ features = features + zero_guard
507
+ elif self.log_zero_guard_type == "clamp":
508
+ features = features.clamp(min=zero_guard)
509
+ else:
510
+ raise ValueError(f"Unsupported log zero guard type: '{self.log_zero_guard_type}'")
511
+ features = features.log()
512
+ return features
513
+
514
+ def _extract_spectrograms(self, signals: torch.Tensor) -> torch.Tensor:
515
+ # Complex FFT needs to be done in single precision
516
+ with torch.amp.autocast('cuda', enabled=False):
517
+ features = self._mel_spec_extractor(waveform=signals)
518
+ return features
519
+
520
+ def _apply_normalization(self, features: torch.Tensor, lengths: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
521
+ # For consistency, this function always does a masked fill even if not normalizing.
522
+ mask: torch.Tensor = make_seq_mask_like(lengths=lengths, like=features, time_dim=-1, valid_ones=False)
523
+ features = features.masked_fill(mask, 0.0)
524
+ # Maybe don't normalize
525
+ if self._normalize_strategy is None:
526
+ return features
527
+ # Use the log zero guard for the sqrt zero guard
528
+ guard_value = self._resolve_log_zero_guard_value(features.dtype)
529
+ if self._normalize_strategy == "per_feature" or self._normalize_strategy == "all_features":
530
+ # 'all_features' reduces over each sample; 'per_feature' reduces over each channel
531
+ reduce_dim = 2
532
+ if self._normalize_strategy == "all_features":
533
+ reduce_dim = [1, 2]
534
+ # [B, D, T] -> [B, D, 1] or [B, 1, 1]
535
+ means = features.sum(dim=reduce_dim, keepdim=True).div(lengths.view(-1, 1, 1))
536
+ stds = (
537
+ features.sub(means)
538
+ .masked_fill(mask, 0.0)
539
+ .pow(2.0)
540
+ .sum(dim=reduce_dim, keepdim=True) # [B, D, T] -> [B, D, 1] or [B, 1, 1]
541
+ .div(lengths.view(-1, 1, 1) - 1) # assume biased estimator
542
+ .clamp(min=guard_value) # avoid sqrt(0)
543
+ .sqrt()
544
+ )
545
+ features = (features - means) / (stds + eps)
546
+ else:
547
+ # Deprecating constant std/mean
548
+ raise ValueError(f"Unsupported norm type: '{self._normalize_strategy}")
549
+ features = features.masked_fill(mask, 0.0)
550
+ return features
551
+
552
+ def forward(self, input_signal: torch.Tensor, length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
553
+ feature_lengths = self._compute_output_lengths(input_lengths=length)
554
+ signals = self._apply_dithering(signals=input_signal)
555
+ signals = self._apply_preemphasis(signals=signals)
556
+ features = self._extract_spectrograms(signals=signals)
557
+ features = self._apply_log(features=features)
558
+ features = self._apply_normalization(features=features, lengths=feature_lengths)
559
+ features = self._apply_pad_to(features=features)
560
+ return features, feature_lengths
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:506c87f1f3d6ea60a6c5f7bb91f13d69870ddbbd881a31d8e595868779f3be5a
3
+ size 7011792
modeling_resnet.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Union, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from transformers import PreTrainedModel
8
+ from transformers.utils import ModelOutput
9
+
10
+ from .configuration_resnet import ResNetConfig
11
+ from .audio_processing import AudioToMelSpectrogramPreprocessor
12
+ from .audio_processing import SpectrogramAugmentation
13
+ from .conv_asr import ResNetEncoder, SpeakerDecoder
14
+ from .angular_loss import AdditiveMarginSoftmaxLoss, AdditiveAngularMarginSoftmaxLoss
15
+
16
+
17
+ @dataclass
18
+ class ResNetBaseModelOutput(ModelOutput):
19
+
20
+ encoder_outputs: torch.FloatTensor = None
21
+ extract_features: torch.FloatTensor = None
22
+ output_lengths: torch.FloatTensor = None
23
+
24
+
25
+ @dataclass
26
+ class ResNetSequenceClassifierOutput(ModelOutput):
27
+
28
+ loss: torch.FloatTensor = None
29
+ logits: torch.FloatTensor = None
30
+ embeddings: torch.FloatTensor = None
31
+
32
+
33
+ class ResNetPreTrainedModel(PreTrainedModel):
34
+
35
+ config_class = ResNetConfig
36
+ base_model_prefix = "resnet"
37
+ main_input_name = "input_values"
38
+
39
+ def _init_weights(self, module):
40
+ """Initialize the weights"""
41
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
42
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
43
+ if module.bias is not None:
44
+ module.bias.data.zero_()
45
+ elif isinstance(module, nn.Conv2d):
46
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
47
+ if module.bias is not None:
48
+ module.bias.data.zero_()
49
+ elif isinstance(module, nn.LayerNorm):
50
+ module.bias.data.zero_()
51
+ module.weight.data.fill_(1.0)
52
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
53
+ nn.init.constant_(module.weight, 1)
54
+ nn.init.constant_(module.bias, 0)
55
+
56
+ @property
57
+ def num_weights(self):
58
+ """
59
+ Utility property that returns the total number of parameters of NeuralModule.
60
+ """
61
+ return self._num_weights()
62
+
63
+ @torch.jit.ignore
64
+ def _num_weights(self):
65
+ num: int = 0
66
+ for p in self.parameters():
67
+ if p.requires_grad:
68
+ num += p.numel()
69
+ return num
70
+
71
+
72
+ class ResNetModel(ResNetPreTrainedModel):
73
+
74
+ def __init__(self, config: ResNetConfig):
75
+ super().__init__(config)
76
+ self.config = config
77
+
78
+ self.preprocessor = AudioToMelSpectrogramPreprocessor(**config.mel_spectrogram_config)
79
+ self.spec_augment = SpectrogramAugmentation(**config.spectrogram_augmentation_config)
80
+ self.encoder = ResNetEncoder(**config.encoder_config)
81
+
82
+ # Initialize weights and apply final processing
83
+ self.post_init()
84
+
85
+ def forward(
86
+ self,
87
+ input_values: Optional[torch.Tensor],
88
+ attention_mask: Optional[torch.Tensor] = None,
89
+ ) -> Union[Tuple, ResNetBaseModelOutput]:
90
+ if attention_mask is None:
91
+ attention_mask = torch.ones_like(input_values).to(input_values)
92
+ lengths = attention_mask.sum(dim=1).long()
93
+ extract_features, output_lengths = self.preprocessor(input_values, lengths)
94
+ if self.training:
95
+ extract_features = self.spec_augment(extract_features, output_lengths)
96
+ encoder_outputs, output_lengths = self.encoder(extract_features, output_lengths)
97
+
98
+ return ResNetBaseModelOutput(
99
+ encoder_outputs=encoder_outputs,
100
+ extract_features=extract_features,
101
+ output_lengths=output_lengths,
102
+ )
103
+
104
+
105
+ class ResNetForSequenceClassification(ResNetPreTrainedModel):
106
+
107
+ def __init__(self, config: ResNetConfig):
108
+ super().__init__(config)
109
+
110
+ self.resnet = ResNetModel(config)
111
+ self.classifier = SpeakerDecoder(**config.decoder_config)
112
+
113
+ if config.objective == 'additive_angular_margin':
114
+ self.loss_fct = AdditiveAngularMarginSoftmaxLoss(**config.objective_config)
115
+ elif config.objective == 'additive_margin':
116
+ self.loss_fct = AdditiveMarginSoftmaxLoss(**config.objective_config)
117
+ elif config.objective == 'cross_entropy':
118
+ self.loss_fct = nn.CrossEntropyLoss(**config.objective_config)
119
+
120
+ self.init_weights()
121
+
122
+ def freeze_base_model(self):
123
+ for param in self.resnet.parameters():
124
+ param.requires_grad = False
125
+
126
+ def forward(
127
+ self,
128
+ input_values: Optional[torch.Tensor],
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ labels: Optional[torch.Tensor] = None,
131
+ ) -> Union[Tuple, ResNetSequenceClassifierOutput]:
132
+ resnet_outputs = self.resnet(
133
+ input_values,
134
+ attention_mask,
135
+ )
136
+ logits, output_embeddings = self.classifier(
137
+ resnet_outputs.encoder_outputs,
138
+ resnet_outputs.output_lengths
139
+ )
140
+ logits = logits.view(-1, self.config.num_labels)
141
+
142
+ loss = None
143
+ if labels is not None:
144
+ loss = self.loss_fct(logits, labels.view(-1))
145
+
146
+ return ResNetSequenceClassifierOutput(
147
+ loss=loss,
148
+ logits=logits,
149
+ embeddings=output_embeddings,
150
+ )
module.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class NeuralModule(nn.Module):
6
+
7
+ @property
8
+ def num_weights(self):
9
+ """
10
+ Utility property that returns the total number of parameters of NeuralModule.
11
+ """
12
+ return self._num_weights()
13
+
14
+ @torch.jit.ignore
15
+ def _num_weights(self):
16
+ num: int = 0
17
+ for p in self.parameters():
18
+ if p.requires_grad:
19
+ num += p.numel()
20
+ return num
21
+
22
+ def freeze(self) -> None:
23
+ r"""
24
+ Freeze all params for inference.
25
+
26
+ This method sets `requires_grad` to False for all parameters of the module.
27
+ It also stores the original `requires_grad` state of each parameter in a dictionary,
28
+ so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`.
29
+ """
30
+ grad_map = {}
31
+
32
+ for pname, param in self.named_parameters():
33
+ # Store the original grad state
34
+ grad_map[pname] = param.requires_grad
35
+ # Freeze the parameter
36
+ param.requires_grad = False
37
+
38
+ # Store the frozen grad map
39
+ if not hasattr(self, '_frozen_grad_map'):
40
+ self._frozen_grad_map = grad_map
41
+ else:
42
+ self._frozen_grad_map.update(grad_map)
43
+
44
+ self.eval()
45
+
46
+ def unfreeze(self, partial: bool = False) -> None:
47
+ """
48
+ Unfreeze all parameters for training.
49
+
50
+ Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
51
+ The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
52
+ previously unfrozen prior `freeze()`.
53
+
54
+ Example:
55
+ Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always.
56
+
57
+ ```python
58
+ model.encoder.freeze() # Freezes all parameters in the encoder explicitly
59
+ ```
60
+
61
+ During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method.
62
+ This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called,
63
+ we should keep the encoder parameters frozen.
64
+
65
+ ```python
66
+ model.freeze() # Freezes all parameters in the model; encoder remains frozen
67
+ ```
68
+
69
+ Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling
70
+ `unfreeze(partial=True)`.
71
+
72
+ ```python
73
+ model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen
74
+ ```
75
+
76
+ Args:
77
+ partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen
78
+ when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`.
79
+ """
80
+ if partial and not hasattr(self, '_frozen_grad_map'):
81
+ raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")
82
+
83
+ for pname, param in self.named_parameters():
84
+ if not partial:
85
+ # Unfreeze all parameters
86
+ param.requires_grad = True
87
+ else:
88
+ # Unfreeze only parameters that were previously frozen
89
+
90
+ # Check if the parameter was frozen
91
+ if pname in self._frozen_grad_map:
92
+ param.requires_grad = self._frozen_grad_map[pname]
93
+ else:
94
+ # Log a warning if the parameter was not found in the frozen grad map
95
+ print(
96
+ f"Parameter {pname} not found in list of previously frozen parameters. "
97
+ f"Unfreezing this parameter."
98
+ )
99
+ param.requires_grad = True
100
+
101
+ # Clean up the frozen grad map
102
+ if hasattr(self, '_frozen_grad_map'):
103
+ delattr(self, '_frozen_grad_map')
104
+
105
+ self.train()
spectrogram_augment.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class SpecAugment(nn.Module):
11
+ """
12
+ Zeroes out(cuts) random continuous horisontal or
13
+ vertical segments of the spectrogram as described in
14
+ SpecAugment (https://arxiv.org/abs/1904.08779).
15
+
16
+ params:
17
+ freq_masks - how many frequency segments should be cut
18
+ time_masks - how many time segments should be cut
19
+ freq_width - maximum number of frequencies to be cut in one segment
20
+ time_width - maximum number of time steps to be cut in one segment.
21
+ Can be a positive integer or a float value in the range [0, 1].
22
+ If positive integer value, defines maximum number of time steps
23
+ to be cut in one segment.
24
+ If a float value, defines maximum percentage of timesteps that
25
+ are cut adaptively.
26
+ use_vectorized_code - GPU-based implementation with batched masking and GPU rng,
27
+ setting it to False reverts to the legacy implementation.
28
+ Fast implementation is inspired by torchaudio:
29
+ https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/functional/functional.py#L816
30
+ """
31
+
32
+ FREQ_AXIS = 1 # Frequency axis in the spectrogram tensor
33
+ TIME_AXIS = 2 # Time axis in the spectrogram tensor
34
+
35
+ def __init__(
36
+ self,
37
+ freq_masks: int = 0,
38
+ time_masks: int = 0,
39
+ freq_width: int = 10,
40
+ time_width: Union[int, float] = 10,
41
+ rng: random.Random = None,
42
+ mask_value: float = 0.0,
43
+ use_vectorized_code: bool = True,
44
+ ):
45
+ super().__init__()
46
+
47
+ self._rng = random.Random() if rng is None else rng
48
+
49
+ self.freq_masks = freq_masks
50
+ self.time_masks = time_masks
51
+
52
+ self.freq_width = freq_width
53
+ self.time_width = time_width
54
+
55
+ self.mask_value = mask_value
56
+ self.use_vectorized_code = use_vectorized_code
57
+
58
+ if isinstance(time_width, int):
59
+ self.adaptive_temporal_width = False
60
+ else:
61
+ if time_width > 1.0 or time_width < 0.0:
62
+ raise ValueError("If `time_width` is a float value, must be in range [0, 1]")
63
+
64
+ self.adaptive_temporal_width = True
65
+
66
+ @torch.no_grad()
67
+ def forward(self, input_spec, length):
68
+ if self.use_vectorized_code:
69
+ return self._forward_vectorized(input_spec, length)
70
+ else:
71
+ return self._forward_legacy(input_spec, length)
72
+
73
+ def _forward_legacy(self, input_spec, length):
74
+ batch_size, num_freq_bins, _ = input_spec.shape
75
+ # Move lengths to CPU before repeated indexing
76
+ lengths_cpu = length.cpu().numpy()
77
+ # Generate a numpy boolean mask. `True` elements represent where the input spec will be augmented.
78
+ fill_mask: np.array = np.full(shape=input_spec.shape, fill_value=False)
79
+ freq_start_upper_bound = num_freq_bins - self.freq_width
80
+ # Choose different mask ranges for each element of the batch
81
+ for idx in range(batch_size):
82
+ # Set freq masking
83
+ for _ in range(self.freq_masks):
84
+ start = self._rng.randint(0, freq_start_upper_bound)
85
+ width = self._rng.randint(0, self.freq_width)
86
+ fill_mask[idx, start : start + width, :] = True
87
+
88
+ # Derive time width, sometimes based percentage of input length.
89
+ if self.adaptive_temporal_width:
90
+ time_max_width = max(1, int(lengths_cpu[idx] * self.time_width))
91
+ else:
92
+ time_max_width = self.time_width
93
+ time_start_upper_bound = max(1, lengths_cpu[idx] - time_max_width)
94
+
95
+ # Set time masking
96
+ for _ in range(self.time_masks):
97
+ start = self._rng.randint(0, time_start_upper_bound)
98
+ width = self._rng.randint(0, time_max_width)
99
+ fill_mask[idx, :, start : start + width] = True
100
+ # Bring the mask to device and fill spec
101
+ fill_mask = torch.from_numpy(fill_mask).to(input_spec.device)
102
+ masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value)
103
+ return masked_spec
104
+
105
+ def _forward_vectorized(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
106
+ # time masks
107
+ input_spec = self._apply_masks(
108
+ input_spec=input_spec,
109
+ num_masks=self.time_masks,
110
+ length=length,
111
+ width=self.time_width,
112
+ axis=self.TIME_AXIS,
113
+ mask_value=self.mask_value,
114
+ )
115
+ # freq masks
116
+ input_spec = self._apply_masks(
117
+ input_spec=input_spec,
118
+ num_masks=self.freq_masks,
119
+ length=length,
120
+ width=self.freq_width,
121
+ axis=self.FREQ_AXIS,
122
+ mask_value=self.mask_value,
123
+ )
124
+ return input_spec
125
+
126
+ def _apply_masks(
127
+ self,
128
+ input_spec: torch.Tensor,
129
+ num_masks: int,
130
+ length: torch.Tensor,
131
+ width: Union[int, float],
132
+ mask_value: float,
133
+ axis: int,
134
+ ) -> torch.Tensor:
135
+
136
+ assert axis in (
137
+ self.FREQ_AXIS,
138
+ self.TIME_AXIS,
139
+ ), f"Axis can be only be equal to frequency \
140
+ ({self.FREQ_AXIS}) or time ({self.TIME_AXIS}). Received: {axis=}"
141
+ assert not (
142
+ isinstance(width, float) and axis == self.FREQ_AXIS
143
+ ), "Float width supported \
144
+ only with time axis."
145
+
146
+ batch_size = input_spec.shape[0]
147
+ axis_length = input_spec.shape[axis]
148
+
149
+ # If width is float then it is transformed into a tensor
150
+ if axis == self.TIME_AXIS and isinstance(width, float):
151
+ width = torch.clamp(width * length, max=axis_length).unsqueeze(1)
152
+
153
+ # Generate [0-1) random numbers and then scale the tensors.
154
+ # Use float32 dtype for begin/end mask markers before they are quantized to long.
155
+ mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * width
156
+ mask_width = mask_width.long()
157
+ mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32)
158
+
159
+ if axis == self.TIME_AXIS:
160
+ # length can only be used for the time axis
161
+ mask_start = mask_start * (length.unsqueeze(1) - mask_width)
162
+ else:
163
+ mask_start = mask_start * (axis_length - mask_width)
164
+
165
+ mask_start = mask_start.long()
166
+ mask_end = mask_start + mask_width
167
+
168
+ # Create mask values using vectorized indexing
169
+ indices = torch.arange(axis_length, device=input_spec.device)
170
+ # Create a mask_tensor with all the indices.
171
+ # The mask_tensor shape is (batch_size, num_masks, axis_length).
172
+ mask_tensor = (indices >= mask_start.unsqueeze(-1)) & (indices < mask_end.unsqueeze(-1))
173
+
174
+ # Reduce masks to one mask
175
+ mask_tensor = mask_tensor.any(dim=1)
176
+
177
+ # Create a final mask that aligns with the full tensor
178
+ mask = torch.zeros_like(input_spec, dtype=torch.bool)
179
+ if axis == self.TIME_AXIS:
180
+ mask_ranges = mask_tensor[:, None, :]
181
+ else: # axis == self.FREQ_AXIS
182
+ mask_ranges = mask_tensor[:, :, None]
183
+ mask[:, :, :] = mask_ranges
184
+
185
+ # Apply the mask value
186
+ return input_spec.masked_fill(mask=mask, value=mask_value)
187
+
188
+
189
+ class SpecCutout(nn.Module):
190
+ """
191
+ Zeroes out(cuts) random rectangles in the spectrogram
192
+ as described in (https://arxiv.org/abs/1708.04552).
193
+
194
+ params:
195
+ rect_masks - how many rectangular masks should be cut
196
+ rect_freq - maximum size of cut rectangles along the frequency dimension
197
+ rect_time - maximum size of cut rectangles along the time dimension
198
+ """
199
+
200
+ def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None):
201
+ super(SpecCutout, self).__init__()
202
+
203
+ self._rng = random.Random() if rng is None else rng
204
+
205
+ self.rect_masks = rect_masks
206
+ self.rect_time = rect_time
207
+ self.rect_freq = rect_freq
208
+
209
+ @torch.no_grad()
210
+ def forward(self, input_spec):
211
+ sh = input_spec.shape
212
+
213
+ for idx in range(sh[0]):
214
+ for i in range(self.rect_masks):
215
+ rect_x = self._rng.randint(0, sh[1] - self.rect_freq)
216
+ rect_y = self._rng.randint(0, sh[2] - self.rect_time)
217
+
218
+ w_x = self._rng.randint(0, self.rect_freq)
219
+ w_y = self._rng.randint(0, self.rect_time)
220
+
221
+ input_spec[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0
222
+
223
+ return input_spec
tdnn_attention.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ from numpy import inf
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.init import _calculate_correct_fan
10
+
11
+
12
+ class StatsPoolLayer(nn.Module):
13
+ """Statistics and time average pooling (TAP) layer
14
+
15
+ This computes mean and, optionally, standard deviation statistics across the time dimension.
16
+
17
+ Args:
18
+ feat_in: Input features with shape [B, D, T]
19
+ pool_mode: Type of pool mode. Supported modes are 'xvector' (mean and standard deviation) and 'tap' (time
20
+ average pooling, i.e., mean)
21
+ eps: Epsilon, minimum value before taking the square root, when using 'xvector' mode.
22
+ unbiased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default
23
+ for torch.Tensor.std() is True.
24
+
25
+ Returns:
26
+ Pooled statistics with shape [B, D].
27
+
28
+ Raises:
29
+ ValueError if an unsupported pooling mode is specified.
30
+ """
31
+
32
+ def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, unbiased: bool = True):
33
+ super().__init__()
34
+ supported_modes = {"xvector", "tap"}
35
+ if pool_mode not in supported_modes:
36
+ raise ValueError(f"Pool mode must be one of {supported_modes}; got '{pool_mode}'")
37
+ self.pool_mode = pool_mode
38
+ self.feat_in = feat_in
39
+ self.eps = eps
40
+ self.unbiased = unbiased
41
+ if self.pool_mode == 'xvector':
42
+ # Mean + std
43
+ self.feat_in *= 2
44
+
45
+ def forward(self, encoder_output, length=None):
46
+ if length is None:
47
+ mean = encoder_output.mean(dim=-1) # Time Axis
48
+ if self.pool_mode == 'xvector':
49
+ correction = 1 if self.unbiased else 0
50
+ std = encoder_output.std(dim=-1, correction=correction).clamp(min=self.eps)
51
+ pooled = torch.cat([mean, std], dim=-1)
52
+ else:
53
+ pooled = mean
54
+ else:
55
+ mask = make_seq_mask_like(like=encoder_output, lengths=length, valid_ones=False)
56
+ encoder_output = encoder_output.masked_fill(mask, 0.0)
57
+ # [B, D, T] -> [B, D]
58
+ means = encoder_output.mean(dim=-1)
59
+ # Re-scale to get padded means
60
+ means = means * (encoder_output.shape[-1] / length).unsqueeze(-1)
61
+ if self.pool_mode == "xvector":
62
+ correction = 1 if self.unbiased else 0
63
+ stds = (
64
+ encoder_output.sub(means.unsqueeze(-1))
65
+ .masked_fill(mask, 0.0)
66
+ .pow(2.0)
67
+ .sum(-1) # [B, D, T] -> [B, D]
68
+ .div(length.view(-1, 1).sub(correction))
69
+ .clamp(min=self.eps)
70
+ .sqrt()
71
+ )
72
+ pooled = torch.cat((means, stds), dim=-1)
73
+ else:
74
+ pooled = means
75
+ return pooled
76
+
77
+
78
+ class AttentivePoolLayer(nn.Module):
79
+ """
80
+ Attention pooling layer for pooling speaker embeddings
81
+ Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
82
+ inputs:
83
+ feat_in: input feature channel length from encoder
84
+ attention_channels: intermediate attention channel size
85
+ kernel_size: kernel_size for TDNN and attention conv1d layers (default: 1)
86
+ dilation: dilation size for TDNN and attention conv1d layers (default: 1)
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ inp_filters: int,
92
+ attention_channels: int = 128,
93
+ kernel_size: int = 1,
94
+ dilation: int = 1,
95
+ eps: float = 1e-10,
96
+ ):
97
+ super().__init__()
98
+
99
+ self.feat_in = 2 * inp_filters
100
+
101
+ self.attention_layer = nn.Sequential(
102
+ TdnnModule(inp_filters * 3, attention_channels, kernel_size=kernel_size, dilation=dilation),
103
+ nn.Tanh(),
104
+ nn.Conv1d(
105
+ in_channels=attention_channels,
106
+ out_channels=inp_filters,
107
+ kernel_size=kernel_size,
108
+ dilation=dilation,
109
+ ),
110
+ )
111
+ self.eps = eps
112
+
113
+ def forward(self, x, length=None):
114
+ max_len = x.size(2)
115
+
116
+ if length is None:
117
+ length = torch.ones(x.shape[0], device=x.device)
118
+
119
+ mask, num_values = lens_to_mask(length, max_len=max_len, device=x.device)
120
+
121
+ # encoder statistics
122
+ mean, std = get_statistics_with_mask(x, mask / num_values)
123
+ mean = mean.unsqueeze(2).repeat(1, 1, max_len)
124
+ std = std.unsqueeze(2).repeat(1, 1, max_len)
125
+ attn = torch.cat([x, mean, std], dim=1)
126
+
127
+ # attention statistics
128
+ attn = self.attention_layer(attn) # attention pass
129
+ attn = attn.masked_fill(mask == 0, -inf)
130
+ alpha = F.softmax(attn, dim=2) # attention values, α
131
+ mu, sg = get_statistics_with_mask(x, alpha) # µ and ∑
132
+
133
+ # gather
134
+ return torch.cat((mu, sg), dim=1).unsqueeze(2)
135
+
136
+
137
+ class ChannelDependentAttentiveStatisticsPoolLayer(nn.Module):
138
+
139
+ def __init__(
140
+ self,
141
+ inp_filters: int,
142
+ attention_channels=128,
143
+ eps: float = 1e-10,
144
+ unbiased: bool = True
145
+ ):
146
+ super().__init__()
147
+ self.feat_in = inp_filters
148
+ self.unbiased = unbiased
149
+ self.eps = eps
150
+
151
+ self.attention_channels = attention_channels
152
+
153
+ self.attention_layer = nn.Sequential(
154
+ nn.Linear(in_features=inp_filters * 3, out_features=attention_channels, bias=True),
155
+ nn.Tanh(),
156
+ nn.Linear(in_features=attention_channels, out_features=inp_filters * 3, bias=True),
157
+ nn.Softmax(dim=1)
158
+ )
159
+
160
+ def forward(self, encoder_output, length=None):
161
+ if length is None:
162
+ mean = encoder_output.mean(dim=-1) # Time Axis
163
+ correction = 1 if self.unbiased else 0
164
+ std = encoder_output.std(dim=-1, correction=correction).clamp(min=self.eps)
165
+ pooled = torch.cat([mean, std], dim=-1)
166
+ else:
167
+ mask = make_seq_mask_like(like=encoder_output, lengths=length, valid_ones=False)
168
+ encoder_output = encoder_output.masked_fill(mask, 0.0)
169
+ # [B, D, T] -> [B, D]
170
+ means = encoder_output.mean(dim=-1)
171
+ # Re-scale to get padded means
172
+ means = means * (encoder_output.shape[-1] / length).unsqueeze(-1)
173
+ correction = 1 if self.unbiased else 0
174
+ stds = (
175
+ encoder_output.sub(means.unsqueeze(-1))
176
+ .masked_fill(mask, 0.0)
177
+ .pow(2.0)
178
+ .sum(-1) # [B, D, T] -> [B, D]
179
+ .div(length.view(-1, 1).sub(correction))
180
+ .clamp(min=self.eps)
181
+ .sqrt()
182
+ )
183
+ pooled = torch.cat((means, stds), dim=-1)
184
+
185
+ ext = pooled.unsqueeze(2).expand(-1, -1, encoder_output.shape[-1])
186
+ h_ext = torch.cat((encoder_output, ext), dim=1)
187
+
188
+ alpha = self.attention_layer(h_ext.transpose(1, 2))
189
+ alpha = alpha[:, :, :self.feat_in]
190
+
191
+ mu = torch.mean(alpha * h_ext[:, :self.feat_in, :].transpose(1, 2), 1)
192
+ sg = torch.sqrt(
193
+ (torch.sum(alpha * h_ext[:, :self.feat_in, :].transpose(1, 2) ** 2, dim=1) - mu ** 2).clamp(min=self.eps)
194
+ )
195
+
196
+ return torch.cat((mu, sg), dim=1).unsqueeze(2)
197
+
198
+
199
+ class TdnnModule(nn.Module):
200
+ """
201
+ Time Delayed Neural Module (TDNN) - 1D
202
+ input:
203
+ inp_filters: input filter channels for conv layer
204
+ out_filters: output filter channels for conv layer
205
+ kernel_size: kernel weight size for conv layer
206
+ dilation: dilation for conv layer
207
+ stride: stride for conv layer
208
+ padding: padding for conv layer (default None: chooses padding value such that input and output feature shape matches)
209
+ output:
210
+ tdnn layer output
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ inp_filters: int,
216
+ out_filters: int,
217
+ kernel_size: int = 1,
218
+ dilation: int = 1,
219
+ stride: int = 1,
220
+ groups: int = 1,
221
+ padding: int = None,
222
+ ):
223
+ super().__init__()
224
+ if padding is None:
225
+ padding = get_same_padding(kernel_size, stride=stride, dilation=dilation)
226
+
227
+ self.conv_layer = nn.Conv1d(
228
+ in_channels=inp_filters,
229
+ out_channels=out_filters,
230
+ kernel_size=kernel_size,
231
+ dilation=dilation,
232
+ groups=groups,
233
+ padding=padding,
234
+ )
235
+
236
+ self.activation = nn.ReLU()
237
+ self.bn = nn.BatchNorm1d(out_filters)
238
+
239
+ def forward(self, x, length=None):
240
+ x = self.conv_layer(x)
241
+ x = self.activation(x)
242
+ return self.bn(x)
243
+
244
+
245
+ class MaskedSEModule(nn.Module):
246
+ """
247
+ Squeeze and Excite module implementation with conv1d layers
248
+ input:
249
+ inp_filters: input filter channel size
250
+ se_filters: intermediate squeeze and excite channel output and input size
251
+ out_filters: output filter channel size
252
+ kernel_size: kernel_size for both conv1d layers
253
+ dilation: dilation size for both conv1d layers
254
+
255
+ output:
256
+ squeeze and excite layer output
257
+ """
258
+
259
+ def __init__(self, inp_filters: int, se_filters: int, out_filters: int, kernel_size: int = 1, dilation: int = 1):
260
+ super().__init__()
261
+ self.se_layer = nn.Sequential(
262
+ nn.Conv1d(
263
+ inp_filters,
264
+ se_filters,
265
+ kernel_size=kernel_size,
266
+ dilation=dilation,
267
+ ),
268
+ nn.ReLU(),
269
+ nn.BatchNorm1d(se_filters),
270
+ nn.Conv1d(
271
+ se_filters,
272
+ out_filters,
273
+ kernel_size=kernel_size,
274
+ dilation=dilation,
275
+ ),
276
+ nn.Sigmoid(),
277
+ )
278
+
279
+ def forward(self, input, length=None):
280
+ if length is None:
281
+ x = torch.mean(input, dim=2, keep_dim=True)
282
+ else:
283
+ max_len = input.size(2)
284
+ mask, num_values = lens_to_mask(length, max_len=max_len, device=input.device)
285
+ x = torch.sum((input * mask), dim=2, keepdim=True) / (num_values)
286
+
287
+ out = self.se_layer(x)
288
+ return out * input
289
+
290
+
291
+ class MaskedSEModule2D(nn.Module):
292
+ """
293
+ Squeeze-and-Excitation module for 2D inputs (e.g., images or feature maps).
294
+
295
+ Args:
296
+ inp_filters (int): Number of input channels.
297
+ se_filters (int): Number of intermediate squeeze-and-excite channels.
298
+ out_filters (int): Number of output channels.
299
+ kernel_size (int, optional): Kernel size for both Conv2d layers. Default: 1.
300
+ dilation (int, optional): Dilation factor for both Conv2d layers. Default: 1.
301
+
302
+ Output:
303
+ Scaled feature map with channel-wise attention.
304
+ """
305
+
306
+ def __init__(self, inp_filters: int, se_filters: int, out_filters: int, kernel_size: int = 1, dilation: int = 1):
307
+ super().__init__()
308
+ self.se_layer = nn.Sequential(
309
+ nn.Conv2d(inp_filters, se_filters, kernel_size=kernel_size, dilation=dilation),
310
+ nn.ReLU(),
311
+ nn.BatchNorm2d(se_filters),
312
+ nn.Conv2d(se_filters, out_filters, kernel_size=kernel_size, dilation=dilation),
313
+ nn.Sigmoid(),
314
+ )
315
+
316
+ def forward(self, inputs: torch.Tensor, length: torch.Tensor = None):
317
+ """
318
+ Forward pass with optional masking.
319
+
320
+ Args:
321
+ input (torch.Tensor): Input tensor of shape (B, C, H, W).
322
+ length (torch.Tensor, optional): Sequence lengths for dynamic masking. Default: None.
323
+
324
+ Returns:
325
+ torch.Tensor: Feature map rescaled by SE attention.
326
+ """
327
+ if length is None:
328
+ # Global average pooling over spatial dimensions (H, W)
329
+ x = torch.mean(inputs, dim=(2, 3), keepdim=True)
330
+ else:
331
+ max_h, max_w = inputs.shape[2], inputs.shape[3]
332
+ mask, num_values = lens_to_mask_2d(length, max_len=inputs.size(-1), device=inputs.device)
333
+ x = torch.sum(inputs * mask, dim=(2, 3), keepdim=True) / num_values
334
+
335
+ # Apply SE layers and scale the input
336
+ out = self.se_layer(x)
337
+ return out * inputs
338
+
339
+
340
+ def lens_to_mask_2d(lens: List[int], max_len: int, device: str = None):
341
+ """
342
+ outputs masking labels for list of lengths of audio features, with max length of any
343
+ mask as max_len
344
+ input:
345
+ lens: list of lens
346
+ max_len: max length of any audio feature
347
+ output:
348
+ mask: masked labels
349
+ num_values: sum of mask values for each feature (useful for computing statistics later)
350
+ """
351
+ lens_mat = torch.arange(max_len).to(device)
352
+ mask = lens_mat[:max_len].unsqueeze(0) < lens.unsqueeze(1)
353
+ mask = mask.unsqueeze(1).unsqueeze(1)
354
+ num_values = torch.sum(mask, dim=-1, keepdim=True)
355
+ return mask, num_values
356
+
357
+
358
+ class TdnnSeModule(nn.Module):
359
+ """
360
+ Modified building SE_TDNN group module block from ECAPA implementation for faster training and inference
361
+ Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
362
+ inputs:
363
+ inp_filters: input filter channel size
364
+ out_filters: output filter channel size
365
+ group_scale: scale value to group wider conv channels (deafult:8)
366
+ se_channels: squeeze and excite output channel size (deafult: 1024/8= 128)
367
+ kernel_size: kernel_size for group conv1d layers (default: 1)
368
+ dilation: dilation size for group conv1d layers (default: 1)
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ inp_filters: int,
374
+ out_filters: int,
375
+ group_scale: int = 8,
376
+ se_channels: int = 128,
377
+ kernel_size: int = 1,
378
+ dilation: int = 1,
379
+ init_mode: str = 'xavier_uniform',
380
+ ):
381
+ super().__init__()
382
+ self.out_filters = out_filters
383
+ padding_val = get_same_padding(kernel_size=kernel_size, dilation=dilation, stride=1)
384
+
385
+ group_conv = nn.Conv1d(
386
+ out_filters,
387
+ out_filters,
388
+ kernel_size=kernel_size,
389
+ dilation=dilation,
390
+ padding=padding_val,
391
+ groups=group_scale,
392
+ )
393
+ self.group_tdnn_block = nn.Sequential(
394
+ TdnnModule(inp_filters, out_filters, kernel_size=1, dilation=1),
395
+ group_conv,
396
+ nn.ReLU(),
397
+ nn.BatchNorm1d(out_filters),
398
+ TdnnModule(out_filters, out_filters, kernel_size=1, dilation=1),
399
+ )
400
+
401
+ self.se_layer = MaskedSEModule(out_filters, se_channels, out_filters)
402
+
403
+ self.apply(lambda x: init_weights(x, mode=init_mode))
404
+
405
+ def forward(self, input, length=None):
406
+ x = self.group_tdnn_block(input)
407
+ x = self.se_layer(x, length)
408
+ return x + input
409
+
410
+
411
+ class Res2NetBlock(nn.Module):
412
+ """
413
+ Res2Net module that splits input channels into groups and processes them separately before merging.
414
+ This allows multi-scale feature extraction.
415
+ """
416
+ def __init__(self, in_channels, out_channels, scale=4, kernel_size=1, dilation=1):
417
+ super().__init__()
418
+ assert in_channels % scale == 0, "in_channels must be divisible by scale"
419
+
420
+ self.scale = scale
421
+ self.width = in_channels // scale # Number of channels per group
422
+
423
+ self.convs = nn.ModuleList([
424
+ nn.Conv1d(self.width, self.width, kernel_size=kernel_size, dilation=dilation, padding=dilation, bias=False)
425
+ for _ in range(scale - 1)
426
+ ])
427
+ self.bn = nn.BatchNorm1d(out_channels)
428
+ self.activation = nn.ReLU()
429
+
430
+ def forward(self, x):
431
+ """
432
+ x: [B, C, T]
433
+ """
434
+ splits = torch.split(x, self.width, dim=1)
435
+ outputs = [splits[0]] # First part remains unchanged
436
+
437
+ for i in range(1, self.scale):
438
+ conv_out = self.convs[i - 1](splits[i]) # Apply convolution on each group
439
+ outputs.append(conv_out + outputs[i - 1]) # Hierarchical aggregation
440
+
441
+ out = torch.cat(outputs, dim=1) # Merge groups
442
+ return self.activation(self.bn(out))
443
+
444
+
445
+ class TdnnSeRes2NetModule(nn.Module):
446
+ """
447
+ SE-TDNN module with Res2Net for ECAPA-TDNN.
448
+ """
449
+ def __init__(
450
+ self,
451
+ inp_filters: int,
452
+ out_filters: int,
453
+ group_scale: int = 1,
454
+ se_channels: int = 128,
455
+ kernel_size: int = 1,
456
+ dilation: int = 1,
457
+ res2net_scale: int = 8, # New Res2Net parameter
458
+ ):
459
+ super().__init__()
460
+
461
+ # First TDNN layer
462
+ self.tdnn1 = TdnnModule(inp_filters, out_filters, kernel_size=1, dilation=1, groups=group_scale)
463
+
464
+ # Res2Net block replaces grouped TDNN
465
+ self.res2net = Res2NetBlock(out_filters, out_filters, scale=res2net_scale, kernel_size=kernel_size, dilation=dilation)
466
+
467
+ # Squeeze-and-Excite module
468
+ self.se_layer = MaskedSEModule(out_filters, se_channels, out_filters)
469
+
470
+ def forward(self, x, length=None):
471
+ residual = x
472
+ x = self.tdnn1(x)
473
+ x = self.res2net(x) # Apply Res2Net block
474
+ x = self.se_layer(x, length)
475
+ return x + residual # Residual connection
476
+
477
+
478
+ class MaskedConv1d(nn.Module):
479
+
480
+ __constants__ = ["use_conv_mask", "real_out_channels", "heads"]
481
+
482
+ def __init__(
483
+ self,
484
+ in_channels,
485
+ out_channels,
486
+ kernel_size,
487
+ stride=1,
488
+ padding=0,
489
+ dilation=1,
490
+ groups=1,
491
+ heads=-1,
492
+ bias=False,
493
+ use_mask=True,
494
+ quantize=False,
495
+ ):
496
+ super(MaskedConv1d, self).__init__()
497
+
498
+ if not (heads == -1 or groups == in_channels):
499
+ raise ValueError("Only use heads for depthwise convolutions")
500
+
501
+ self.real_out_channels = out_channels
502
+ if heads != -1:
503
+ in_channels = heads
504
+ out_channels = heads
505
+ groups = heads
506
+
507
+ # preserve original padding
508
+ self._padding = padding
509
+
510
+ # if padding is a tuple/list, it is considered as asymmetric padding
511
+ if type(padding) in (tuple, list):
512
+ self.pad_layer = nn.ConstantPad1d(padding, value=0.0)
513
+ # reset padding for conv since pad_layer will handle this
514
+ padding = 0
515
+ else:
516
+ self.pad_layer = None
517
+
518
+ self.conv = nn.Conv1d(
519
+ in_channels,
520
+ out_channels,
521
+ kernel_size,
522
+ stride=stride,
523
+ padding=padding,
524
+ dilation=dilation,
525
+ groups=groups,
526
+ bias=bias,
527
+ )
528
+ self.use_mask = use_mask
529
+ self.heads = heads
530
+
531
+ # Calculations for "same" padding cache
532
+ self.same_padding = (self.conv.stride[0] == 1) and (
533
+ 2 * self.conv.padding[0] == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1)
534
+ )
535
+ if self.pad_layer is None:
536
+ self.same_padding_asymmetric = False
537
+ else:
538
+ self.same_padding_asymmetric = (self.conv.stride[0] == 1) and (
539
+ sum(self._padding) == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1)
540
+ )
541
+
542
+ # `self.lens` caches consecutive integers from 0 to `self.max_len` that are used to compute the mask for a
543
+ # batch. Recomputed to bigger size as needed. Stored on a device of the latest batch lens.
544
+ if self.use_mask:
545
+ self.max_len = torch.tensor(0)
546
+ self.lens = torch.tensor(0)
547
+
548
+ def get_seq_len(self, lens):
549
+ if self.same_padding or self.same_padding_asymmetric:
550
+ return lens
551
+
552
+ if self.pad_layer is None:
553
+ return (
554
+ torch.div(
555
+ lens + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1,
556
+ self.conv.stride[0],
557
+ rounding_mode='trunc',
558
+ )
559
+ + 1
560
+ )
561
+ else:
562
+ return (
563
+ torch.div(
564
+ lens + sum(self._padding) - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1,
565
+ self.conv.stride[0],
566
+ rounding_mode='trunc',
567
+ )
568
+ + 1
569
+ )
570
+
571
+ def forward(self, x, lens):
572
+ if self.use_mask:
573
+ # Generally will be called by ConvASREncoder, but kept as single gpu backup.
574
+ if x.size(2) > self.max_len:
575
+ self.update_masked_length(x.size(2), device=lens.device)
576
+ x = self.mask_input(x, lens)
577
+
578
+ # Update lengths
579
+ lens = self.get_seq_len(lens)
580
+
581
+ # asymmtric pad if necessary
582
+ if self.pad_layer is not None:
583
+ x = self.pad_layer(x)
584
+
585
+ sh = x.shape
586
+ if self.heads != -1:
587
+ x = x.view(-1, self.heads, sh[-1])
588
+
589
+ out = self.conv(x)
590
+
591
+ if self.heads != -1:
592
+ out = out.view(sh[0], self.real_out_channels, -1)
593
+
594
+ return out, lens
595
+
596
+ def update_masked_length(self, max_len, seq_range=None, device=None):
597
+ if seq_range is None:
598
+ self.lens, self.max_len = _masked_conv_init_lens(self.lens, max_len, self.max_len)
599
+ self.lens = self.lens.to(device)
600
+ else:
601
+ self.lens = seq_range
602
+ self.max_len = torch.tensor(max_len)
603
+
604
+ def mask_input(self, x, lens):
605
+ max_len = x.size(2)
606
+ mask = self.lens[:max_len].unsqueeze(0).to(lens.device) < lens.unsqueeze(1)
607
+ x = x * mask.unsqueeze(1).to(device=x.device)
608
+ return x
609
+
610
+
611
+ @torch.jit.script
612
+ def _masked_conv_init_lens(lens: torch.Tensor, current_maxlen: int, original_maxlen: torch.Tensor):
613
+ if current_maxlen > original_maxlen:
614
+ new_lens = torch.arange(current_maxlen)
615
+ new_max_lens = torch.tensor(current_maxlen)
616
+ else:
617
+ new_lens = lens
618
+ new_max_lens = original_maxlen
619
+ return new_lens, new_max_lens
620
+
621
+
622
+ def get_same_padding(kernel_size, stride, dilation) -> int:
623
+ if stride > 1 and dilation > 1:
624
+ raise ValueError("Only stride OR dilation may be greater than 1")
625
+ return (dilation * (kernel_size - 1)) // 2
626
+
627
+
628
+ def lens_to_mask(lens: List[int], max_len: int, device: str = None):
629
+ """
630
+ outputs masking labels for list of lengths of audio features, with max length of any
631
+ mask as max_len
632
+ input:
633
+ lens: list of lens
634
+ max_len: max length of any audio feature
635
+ output:
636
+ mask: masked labels
637
+ num_values: sum of mask values for each feature (useful for computing statistics later)
638
+ """
639
+ lens_mat = torch.arange(max_len).to(device)
640
+ mask = lens_mat[:max_len].unsqueeze(0) < lens.unsqueeze(1)
641
+ mask = mask.unsqueeze(1)
642
+ num_values = torch.sum(mask, dim=2, keepdim=True)
643
+ return mask, num_values
644
+
645
+
646
+ def get_statistics_with_mask(x: torch.Tensor, m: torch.Tensor, dim: int = 2, eps: float = 1e-10):
647
+ """
648
+ compute mean and standard deviation of input(x) provided with its masking labels (m)
649
+ input:
650
+ x: feature input
651
+ m: averaged mask labels
652
+ output:
653
+ mean: mean of input features
654
+ std: stadard deviation of input features
655
+ """
656
+ mean = torch.sum((m * x), dim=dim)
657
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
658
+ return mean, std
659
+
660
+
661
+ @torch.jit.script_if_tracing
662
+ def make_seq_mask_like(
663
+ like: torch.Tensor, lengths: torch.Tensor, valid_ones: bool = True, time_dim: int = -1
664
+ ) -> torch.Tensor:
665
+ mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.unsqueeze(-1))
666
+ # Match number of dims in `like` tensor
667
+ for _ in range(like.dim() - mask.dim()):
668
+ mask = mask.unsqueeze(1)
669
+ # If time dim != -1, transpose to proper dim.
670
+ if time_dim != -1:
671
+ mask = mask.transpose(time_dim, -1)
672
+ if not valid_ones:
673
+ mask = ~mask
674
+ return mask
675
+
676
+
677
+ def init_weights(m, mode: Optional[str] = 'xavier_uniform'):
678
+ if isinstance(m, MaskedConv1d):
679
+ init_weights(m.conv, mode)
680
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
681
+ if mode is not None:
682
+ if mode == 'xavier_uniform':
683
+ nn.init.xavier_uniform_(m.weight, gain=1.0)
684
+ elif mode == 'xavier_normal':
685
+ nn.init.xavier_normal_(m.weight, gain=1.0)
686
+ elif mode == 'kaiming_uniform':
687
+ nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
688
+ elif mode == 'kaiming_normal':
689
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
690
+ elif mode == 'tds_uniform':
691
+ tds_uniform_(m.weight)
692
+ elif mode == 'tds_normal':
693
+ tds_normal_(m.weight)
694
+ else:
695
+ raise ValueError("Unknown Initialization mode: {0}".format(mode))
696
+ elif isinstance(m, nn.BatchNorm1d):
697
+ if m.track_running_stats:
698
+ m.running_mean.zero_()
699
+ m.running_var.fill_(1)
700
+ m.num_batches_tracked.zero_()
701
+ if m.affine:
702
+ nn.init.ones_(m.weight)
703
+ nn.init.zeros_(m.bias)
704
+
705
+
706
+ def tds_uniform_(tensor, mode='fan_in'):
707
+ """
708
+ Uniform Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
709
+ Normalized to -
710
+
711
+ .. math::
712
+ \\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}}
713
+
714
+ Args:
715
+ tensor: an n-dimensional `torch.Tensor`
716
+ mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
717
+ preserves the magnitude of the variance of the weights in the
718
+ forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
719
+ backwards pass.
720
+ """
721
+ fan = _calculate_correct_fan(tensor, mode)
722
+ gain = 2.0 # sqrt(4.0) = 2
723
+ std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
724
+ bound = std # Calculate uniform bounds from standard deviation
725
+ with torch.no_grad():
726
+ return tensor.uniform_(-bound, bound)
727
+
728
+
729
+ def tds_normal_(tensor, mode='fan_in'):
730
+ """
731
+ Normal Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
732
+ Normalized to -
733
+
734
+ .. math::
735
+ \\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}}
736
+
737
+ Args:
738
+ tensor: an n-dimensional `torch.Tensor`
739
+ mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
740
+ preserves the magnitude of the variance of the weights in the
741
+ forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
742
+ backwards pass.
743
+ """
744
+ fan = _calculate_correct_fan(tensor, mode)
745
+ gain = 2.0
746
+ std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
747
+ bound = std # Calculate uniform bounds from standard deviation
748
+ with torch.no_grad():
749
+ return tensor.normal_(0.0, bound)