SER_Naturalistic / net /pooling_atte.py
Samara369's picture
Upload 96 files
901595e verified
raw
history blame
6.36 kB
'''import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
"MeanPooling",
"AttentiveStatisticsPooling"
]
class Pooling(nn.Module):
def __init__(self):
super().__init__()
def compute_length_from_mask(self, mask):
"""
mask: (batch_size, T)
Assuming that the sampling rate is 16kHz, the frame shift is 20ms
"""
wav_lens = torch.sum(mask, dim=1) # (batch_size, )
feat_lens = torch.div(wav_lens-1, 16000*0.02, rounding_mode="floor") + 1
feat_lens = feat_lens.int().tolist()
return feat_lens
def forward(self, x, mask):
raise NotImplementedError
class MeanPooling(Pooling):
"""
MeanPooling
A simple pooling mechanism that computes the mean of the input sequence.
x: (batch_size, T, feat_dim)
mask: (batch_size, T)
=> output: (batch_size, feat_dim)
"""
def forward(self, x, mask):
# Compute the lengths of the sequences from the mask
feat_lens = self.compute_length_from_mask(mask)
# Perform mean pooling
pooled_list = []
for seq, feat_len in zip(x, feat_lens):
# Take only the valid frames according to the mask
seq = seq[:feat_len]
# Compute the mean along the time axis
pooled = torch.mean(seq, dim=0)
pooled_list.append(pooled)
# Return as a stacked tensor
return torch.stack(pooled_list)
class AttentiveStatisticsPooling(Pooling):
"""
AttentiveStatisticsPooling
Paper: Attentive Statistics Pooling for Deep Speaker Embedding
Link: https://arxiv.org/pdf/1803.10963.pdf
"""
def __init__(self, input_size):
super().__init__()
self._indim = input_size
self.sap_linear = nn.Linear(input_size, input_size)
self.attention = nn.Parameter(torch.FloatTensor(input_size, 1))
torch.nn.init.normal_(self.attention, mean=0, std=1)
def forward(self, xs, mask):
"""
xs: (batch_size, T, feat_dim)
mask: (batch_size, T)
=> output: (batch_size, feat_dim*2)
"""
feat_lens = self.compute_length_from_mask(mask)
pooled_list = []
for x, feat_len in zip(xs, feat_lens):
x = x[:feat_len].unsqueeze(0)
h = torch.tanh(self.sap_linear(x))
w = torch.matmul(h, self.attention).squeeze(dim=2)
w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1)
mu = torch.sum(x * w, dim=1)
rh = torch.sqrt((torch.sum((x**2) * w, dim=1) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, rh), 1).squeeze(0)
pooled_list.append(x)
return torch.stack(pooled_list)
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
_all_ = [
"MeanPooling",
"AttentiveStatisticsPooling"
]
class Pooling(nn.Module):
def __init__(self):
super().__init__()
def compute_length_from_mask(self, mask):
"""
mask: (batch_size, T)
Assuming that the sampling rate is 16kHz, the frame shift is 20ms
"""
wav_lens = torch.sum(mask, dim=1) # (batch_size, )
feat_lens = torch.div(wav_lens - 1, 16000 * 0.02, rounding_mode="floor") + 1
feat_lens = feat_lens.int().tolist()
return feat_lens
def forward(self, x, mask):
raise NotImplementedError
class MeanPooling(Pooling):
def forward(self, x, mask):
feat_lens = self.compute_length_from_mask(mask)
pooled_list = []
for seq, feat_len in zip(x, feat_lens):
# Take only the valid frames according to the mask
seq = seq[:feat_len]
# Compute the mean along the time axis
pooled = torch.mean(seq, dim=0)
pooled_list.append(pooled)
# Return as a stacked tensor
return torch.stack(pooled_list)
class AttentiveStatisticsPooling(Pooling):
"""
Attentive Statistics Pooling with Multi-Head Attention.
Paper: Attentive Statistics Pooling for Deep Speaker Embedding
Link: https://arxiv.org/pdf/1803.10963.pdf
"""
def __init__(self, input_size, num_heads=4):
super().__init__()
self._indim = input_size
self.num_heads = num_heads
# Linear transformation to project input features
self.sap_linear = nn.Linear(input_size, input_size * num_heads)
# Attention parameters for each head
self.attention = nn.Parameter(torch.FloatTensor(num_heads, input_size, 1))
torch.nn.init.normal_(self.attention, mean=0, std=1)
def forward(self, xs, mask):
"""
Args:
xs: (batch_size, T, feat_dim) - Input sequence.
mask: (batch_size, T) - Mask for valid frames.
Returns:
output: (batch_size, feat_dim * 2 * num_heads) - Pooled representation.
"""
feat_lens = self.compute_length_from_mask(mask)
pooled_list = []
for x, feat_len in zip(xs, feat_lens):
# Extract valid frames based on mask
x = x[:feat_len].unsqueeze(0) # (1, T, feat_dim)
# Apply linear projection and reshape for multi-head attention
h = torch.tanh(self.sap_linear(x)) # (1, T, feat_dim * num_heads)
h = h.view(x.size(0), x.size(1), self.num_heads, -1) # (1, T, num_heads, head_dim)
# Compute attention weights for each head
w = torch.einsum('bthd,hdf->bthf', h, self.attention).squeeze(-1) # (1, T, num_heads)
w = F.softmax(w, dim=1).unsqueeze(-1) # Normalize across time axis: (1, T, num_heads, 1)
# Compute mean (mu) and standard deviation (rh) for each head
mu = torch.sum(x.unsqueeze(2) * w, dim=1) # (1, num_heads, feat_dim)
rh = torch.sqrt((torch.sum((x.unsqueeze(2) * 2) * w, dim=1) - mu * 2).clamp(min=1e-5))
# Concatenate mean and standard deviation
pooled = torch.cat((mu, rh), 2).squeeze(0) # (num_heads, feat_dim * 2)
#modification
pooled_avg =pooled.mean(dim=0)
pooled_list.append(pooled.view(-1)) # Flatten: (feat_dim * 2 * num_heads,)
# Stack pooled features for the batch
return torch.stack(pooled_list) # (batch_size, feat_dim * 2 * num_heads)