Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
_all_ = [ | |
"MeanPooling", | |
"AttentiveStatisticsPooling" | |
] | |
class Pooling(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def compute_length_from_mask(self, mask): | |
""" | |
mask: (batch_size, T) | |
Assuming that the sampling rate is 16kHz, the frame shift is 20ms | |
""" | |
wav_lens = torch.sum(mask, dim=1) # (batch_size, ) | |
feat_lens = torch.div(wav_lens - 1, 16000 * 0.02, rounding_mode="floor") + 1 | |
feat_lens = feat_lens.int().tolist() | |
return feat_lens | |
def forward(self, x, mask): | |
raise NotImplementedError | |
class MeanPooling(Pooling): | |
def forward(self, x, mask): | |
feat_lens = self.compute_length_from_mask(mask) | |
pooled_list = [] | |
for seq, feat_len in zip(x, feat_lens): | |
# Take only the valid frames according to the mask | |
seq = seq[:feat_len] | |
# Compute the mean along the time axis | |
pooled = torch.mean(seq, dim=0) | |
pooled_list.append(pooled) | |
# Return as a stacked tensor | |
return torch.stack(pooled_list) | |
class AttentiveStatisticsPooling(Pooling): | |
""" | |
Attentive Statistics Pooling with Multi-Head Attention. | |
Paper: Attentive Statistics Pooling for Deep Speaker Embedding | |
Link: https://arxiv.org/pdf/1803.10963.pdf | |
""" | |
def __init__(self, input_size, num_heads=4): | |
super().__init__() | |
self._indim = input_size | |
self.num_heads = num_heads | |
# Linear transformation to project input features | |
self.sap_linear = nn.Linear(input_size, input_size * num_heads) | |
# Attention parameters for each head | |
self.attention = nn.Parameter(torch.FloatTensor(num_heads, input_size, 1)) | |
torch.nn.init.normal_(self.attention, mean=0, std=1) | |
def forward(self, xs, mask): | |
""" | |
Args: | |
xs: (batch_size, T, feat_dim) - Input sequence. | |
mask: (batch_size, T) - Mask for valid frames. | |
Returns: | |
output: (batch_size, feat_dim * 2 * num_heads) - Pooled representation. | |
""" | |
feat_lens = self.compute_length_from_mask(mask) | |
pooled_list = [] | |
attention_list=[] | |
if xs.dim()==2: | |
xs=xs.unsqueeze(1) | |
for x, feat_len in zip(xs, feat_lens): | |
# Extract valid frames based on mask | |
x = x[:feat_len].unsqueeze(0) # (1, T, feat_dim) | |
print("shape of features:",x.shape) | |
# Apply linear projection and reshape for multi-head attention | |
h = torch.tanh(self.sap_linear(x)) # (1, T, feat_dim * num_heads) | |
h = h.view(x.size(0), x.size(1), self.num_heads, -1) # (1, T, num_heads, head_dim) | |
# Compute attention weights for each head | |
w = torch.einsum('bthd,hdf->bthf', h, self.attention) # (1, T, num_heads) | |
w = w.squeeze(-1) if w.dim()==4 else w | |
print("attention weights before softmax:",w.shape) | |
w = F.softmax(w, dim=1).unsqueeze(-1) # Normalize across time axis: (1, T, num_heads, 1) | |
print("attention weights after softmax:",w.shape) | |
attention_list.append(w.squeeze(-1).squeeze(0).cpu().detach()) | |
# Compute mean (mu) and standard deviation (rh) for each head | |
mu = torch.sum(x.unsqueeze(2) * w, dim=1) # (1, num_heads, feat_dim) | |
rh = torch.sqrt((torch.sum((x.unsqueeze(2) * 2) * w, dim=1) - mu * 2).clamp(min=1e-5)) | |
# Concatenate mean and standard deviation | |
pooled = torch.cat((mu, rh), 2).squeeze(0) # (num_heads, feat_dim * 2) | |
pooled_avg = pooled.mean(dim=0) | |
pooled_list.append(pooled_avg) # Flatten: (feat_dim * 2 * num_heads,) | |
# Stack pooled features for the batch | |
return torch.stack(pooled_list), torch.stack(attention_list,dim=0).unsqueeze(0) # (batch_size, feat_dim * 2 * num_heads) | |