SER_Naturalistic / net /pooling.py
Samara369's picture
Upload 96 files
901595e verified
raw
history blame
3.98 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
_all_ = [
"MeanPooling",
"AttentiveStatisticsPooling"
]
class Pooling(nn.Module):
def __init__(self):
super().__init__()
def compute_length_from_mask(self, mask):
"""
mask: (batch_size, T)
Assuming that the sampling rate is 16kHz, the frame shift is 20ms
"""
wav_lens = torch.sum(mask, dim=1) # (batch_size, )
feat_lens = torch.div(wav_lens - 1, 16000 * 0.02, rounding_mode="floor") + 1
feat_lens = feat_lens.int().tolist()
return feat_lens
def forward(self, x, mask):
raise NotImplementedError
class MeanPooling(Pooling):
def forward(self, x, mask):
feat_lens = self.compute_length_from_mask(mask)
pooled_list = []
for seq, feat_len in zip(x, feat_lens):
# Take only the valid frames according to the mask
seq = seq[:feat_len]
# Compute the mean along the time axis
pooled = torch.mean(seq, dim=0)
pooled_list.append(pooled)
# Return as a stacked tensor
return torch.stack(pooled_list)
class AttentiveStatisticsPooling(Pooling):
"""
Attentive Statistics Pooling with Multi-Head Attention.
Paper: Attentive Statistics Pooling for Deep Speaker Embedding
Link: https://arxiv.org/pdf/1803.10963.pdf
"""
def __init__(self, input_size, num_heads=4):
super().__init__()
self._indim = input_size
self.num_heads = num_heads
# Linear transformation to project input features
self.sap_linear = nn.Linear(input_size, input_size * num_heads)
# Attention parameters for each head
self.attention = nn.Parameter(torch.FloatTensor(num_heads, input_size, 1))
torch.nn.init.normal_(self.attention, mean=0, std=1)
def forward(self, xs, mask):
"""
Args:
xs: (batch_size, T, feat_dim) - Input sequence.
mask: (batch_size, T) - Mask for valid frames.
Returns:
output: (batch_size, feat_dim * 2 * num_heads) - Pooled representation.
"""
feat_lens = self.compute_length_from_mask(mask)
pooled_list = []
attention_list=[]
if xs.dim()==2:
xs=xs.unsqueeze(1)
for x, feat_len in zip(xs, feat_lens):
# Extract valid frames based on mask
x = x[:feat_len].unsqueeze(0) # (1, T, feat_dim)
print("shape of features:",x.shape)
# Apply linear projection and reshape for multi-head attention
h = torch.tanh(self.sap_linear(x)) # (1, T, feat_dim * num_heads)
h = h.view(x.size(0), x.size(1), self.num_heads, -1) # (1, T, num_heads, head_dim)
# Compute attention weights for each head
w = torch.einsum('bthd,hdf->bthf', h, self.attention) # (1, T, num_heads)
w = w.squeeze(-1) if w.dim()==4 else w
print("attention weights before softmax:",w.shape)
w = F.softmax(w, dim=1).unsqueeze(-1) # Normalize across time axis: (1, T, num_heads, 1)
print("attention weights after softmax:",w.shape)
attention_list.append(w.squeeze(-1).squeeze(0).cpu().detach())
# Compute mean (mu) and standard deviation (rh) for each head
mu = torch.sum(x.unsqueeze(2) * w, dim=1) # (1, num_heads, feat_dim)
rh = torch.sqrt((torch.sum((x.unsqueeze(2) * 2) * w, dim=1) - mu * 2).clamp(min=1e-5))
# Concatenate mean and standard deviation
pooled = torch.cat((mu, rh), 2).squeeze(0) # (num_heads, feat_dim * 2)
pooled_avg = pooled.mean(dim=0)
pooled_list.append(pooled_avg) # Flatten: (feat_dim * 2 * num_heads,)
# Stack pooled features for the batch
return torch.stack(pooled_list), torch.stack(attention_list,dim=0).unsqueeze(0) # (batch_size, feat_dim * 2 * num_heads)