ggmbr commited on
Commit
d81b7c1
·
1 Parent(s): 7c7793f

fix link to script

Browse files
Files changed (1) hide show
  1. spk_embeddings.py +0 -54
spk_embeddings.py DELETED
@@ -1,54 +0,0 @@
1
- '''
2
- * Software Name : spk_embeddings.py
3
- * SPDX-FileCopyrightText: Copyright (c) Orange SA
4
- * SPDX-License-Identifier: CC-BY-SA-3.0
5
- *
6
- * This software is distributed under the Creative Commons Attribution Share Alike 3.0 Unported,
7
- * see the "LICENSE.txt" file for more details or https://huggingface.co/Orange/w-pro/blob/main/LICENSE.txt
8
- '''
9
-
10
- import torch, torchaudio
11
- import torch.nn as nn
12
- from transformers.models.wavlm.modeling_wavlm import WavLMPreTrainedModel, WavLMModel
13
-
14
- class TopLayers(nn.Module):
15
- def __init__(self, embd_size = 250, top_interm_size = 512):
16
- super(TopLayers, self).__init__()
17
- self.affine1 = nn.Conv1d(in_channels=2048, out_channels=top_interm_size, kernel_size=1)
18
- self.batchnorm1 = nn.BatchNorm1d(num_features=top_interm_size, affine=False, eps=1e-03)
19
- self.affine2 = nn.Conv1d(in_channels=top_interm_size, out_channels=embd_size, kernel_size=1)
20
- self.batchnorm2 = nn.BatchNorm1d(num_features=embd_size, affine=False, eps=1e-03)
21
- self.activation = nn.ReLU(inplace=True)
22
-
23
- def forward(self, x):
24
- out = self.batchnorm1(self.activation(self.affine1(x)))
25
- out = self.batchnorm2(self.activation(self.affine2(out)))
26
- return nn.functional.normalize(out[:,:,0])
27
-
28
- class EmbeddingsModel(WavLMPreTrainedModel):
29
- def __init__(self, config):
30
- super().__init__(config)
31
- self.wavlm = WavLMModel(config)
32
- self.top_layers = TopLayers(config.embd_size, config.top_interm_size)
33
-
34
- def forward(self, input_values):
35
- # MVN normalization
36
- x_norm = (input_values - input_values.mean(dim=1).unsqueeze(1)) / (input_values.std(dim=1).unsqueeze(1))
37
- # wavlm fwd
38
- base_out = self.wavlm(input_values=x_norm, output_hidden_states=False).last_hidden_state
39
- # stats pooling
40
- v = base_out.var(dim=1).clamp(min=1e-10)
41
- x_stats = torch.cat((base_out.mean(dim=1),v.pow(0.5)),dim=1).unsqueeze(dim=2)
42
- # top layers fwd
43
- return self.top_layers(x_stats)
44
-
45
-
46
- def compute_embedding(fnm, model, max_size=320000):
47
- sig, sr = torchaudio.load(fnm)
48
- assert sr == 16000, "please convert your audio file to a sampling rate of 16 kHz"
49
- sig = sig.mean(dim=0)
50
- if sig.shape[0] > max_size:
51
- print(f"truncating long signal {fnm}")
52
- sig = sig[:max_size]
53
- embd = model(sig.unsqueeze(dim=0))
54
- return embd.clone().detach()