SER_Naturalistic / net /pooling_org.py
Samara369's picture
Upload 96 files
901595e verified
raw
history blame
2.44 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
"MeanPooling",
"AttentiveStatisticsPooling"
]
class Pooling(nn.Module):
def __init__(self):
super().__init__()
def compute_length_from_mask(self, mask):
"""
mask: (batch_size, T)
Assuming that the sampling rate is 16kHz, the frame shift is 20ms
"""
wav_lens = torch.sum(mask, dim=1) # (batch_size, )
feat_lens = torch.div(wav_lens-1, 16000*0.02, rounding_mode="floor") + 1
feat_lens = feat_lens.int().tolist()
return feat_lens
def forward(self, x, mask):
raise NotImplementedError
class MeanPooling(Pooling):
def forward(self, x, mask):
feat_lens = self.compute_length_from_mask(mask)
pooled_list = []
for seq, feat_len in zip(x, feat_lens):
# Take only the valid frames according to the mask
seq = seq[:feat_len]
# Compute the mean along the time axis
pooled = torch.mean(seq, dim=0)
pooled_list.append(pooled)
# Return as a stacked tensor
return torch.stack(pooled_list)
class AttentiveStatisticsPooling(Pooling):
"""
AttentiveStatisticsPooling
Paper: Attentive Statistics Pooling for Deep Speaker Embedding
Link: https://arxiv.org/pdf/1803.10963.pdf
"""
def __init__(self, input_size):
super().__init__()
self._indim = input_size
self.sap_linear = nn.Linear(input_size, input_size)
self.attention = nn.Parameter(torch.FloatTensor(input_size, 1))
torch.nn.init.normal_(self.attention, mean=0, std=1)
def forward(self, xs, mask):
"""
xs: (batch_size, T, feat_dim)
mask: (batch_size, T)
=> output: (batch_size, feat_dim*2)
"""
feat_lens = self.compute_length_from_mask(mask)
pooled_list = []
for x, feat_len in zip(xs, feat_lens):
x = x[:feat_len].unsqueeze(0)
h = torch.tanh(self.sap_linear(x))
w = torch.matmul(h, self.attention).squeeze(dim=2)
w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1)
mu = torch.sum(x * w, dim=1)
rh = torch.sqrt((torch.sum((x**2) * w, dim=1) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, rh), 1).squeeze(0)
pooled_list.append(x)
return torch.stack(pooled_list)