fix link to script
Browse files- spk_embeddings.py +0 -54
spk_embeddings.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
* Software Name : spk_embeddings.py
|
3 |
-
* SPDX-FileCopyrightText: Copyright (c) Orange SA
|
4 |
-
* SPDX-License-Identifier: CC-BY-SA-3.0
|
5 |
-
*
|
6 |
-
* This software is distributed under the Creative Commons Attribution Share Alike 3.0 Unported,
|
7 |
-
* see the "LICENSE.txt" file for more details or https://huggingface.co/Orange/w-pro/blob/main/LICENSE.txt
|
8 |
-
'''
|
9 |
-
|
10 |
-
import torch, torchaudio
|
11 |
-
import torch.nn as nn
|
12 |
-
from transformers.models.wavlm.modeling_wavlm import WavLMPreTrainedModel, WavLMModel
|
13 |
-
|
14 |
-
class TopLayers(nn.Module):
|
15 |
-
def __init__(self, embd_size = 250, top_interm_size = 512):
|
16 |
-
super(TopLayers, self).__init__()
|
17 |
-
self.affine1 = nn.Conv1d(in_channels=2048, out_channels=top_interm_size, kernel_size=1)
|
18 |
-
self.batchnorm1 = nn.BatchNorm1d(num_features=top_interm_size, affine=False, eps=1e-03)
|
19 |
-
self.affine2 = nn.Conv1d(in_channels=top_interm_size, out_channels=embd_size, kernel_size=1)
|
20 |
-
self.batchnorm2 = nn.BatchNorm1d(num_features=embd_size, affine=False, eps=1e-03)
|
21 |
-
self.activation = nn.ReLU(inplace=True)
|
22 |
-
|
23 |
-
def forward(self, x):
|
24 |
-
out = self.batchnorm1(self.activation(self.affine1(x)))
|
25 |
-
out = self.batchnorm2(self.activation(self.affine2(out)))
|
26 |
-
return nn.functional.normalize(out[:,:,0])
|
27 |
-
|
28 |
-
class EmbeddingsModel(WavLMPreTrainedModel):
|
29 |
-
def __init__(self, config):
|
30 |
-
super().__init__(config)
|
31 |
-
self.wavlm = WavLMModel(config)
|
32 |
-
self.top_layers = TopLayers(config.embd_size, config.top_interm_size)
|
33 |
-
|
34 |
-
def forward(self, input_values):
|
35 |
-
# MVN normalization
|
36 |
-
x_norm = (input_values - input_values.mean(dim=1).unsqueeze(1)) / (input_values.std(dim=1).unsqueeze(1))
|
37 |
-
# wavlm fwd
|
38 |
-
base_out = self.wavlm(input_values=x_norm, output_hidden_states=False).last_hidden_state
|
39 |
-
# stats pooling
|
40 |
-
v = base_out.var(dim=1).clamp(min=1e-10)
|
41 |
-
x_stats = torch.cat((base_out.mean(dim=1),v.pow(0.5)),dim=1).unsqueeze(dim=2)
|
42 |
-
# top layers fwd
|
43 |
-
return self.top_layers(x_stats)
|
44 |
-
|
45 |
-
|
46 |
-
def compute_embedding(fnm, model, max_size=320000):
|
47 |
-
sig, sr = torchaudio.load(fnm)
|
48 |
-
assert sr == 16000, "please convert your audio file to a sampling rate of 16 kHz"
|
49 |
-
sig = sig.mean(dim=0)
|
50 |
-
if sig.shape[0] > max_size:
|
51 |
-
print(f"truncating long signal {fnm}")
|
52 |
-
sig = sig[:max_size]
|
53 |
-
embd = model(sig.unsqueeze(dim=0))
|
54 |
-
return embd.clone().detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|