|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import math |
|
from collections import defaultdict |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
|
|
class SamePad(nn.Module): |
|
def __init__(self, kernel_size, causal=False): |
|
super().__init__() |
|
if causal: |
|
self.remove = kernel_size - 1 |
|
else: |
|
self.remove = 1 if kernel_size % 2 == 0 else 0 |
|
|
|
def forward(self, x): |
|
if self.remove > 0: |
|
x = x[:, :, : -self.remove] |
|
return x |
|
|
|
|
|
class TransposeLast(nn.Module): |
|
def __init__(self, deconstruct_idx=None, tranpose_dim=-2): |
|
super().__init__() |
|
self.deconstruct_idx = deconstruct_idx |
|
self.tranpose_dim = tranpose_dim |
|
|
|
def forward(self, x): |
|
if self.deconstruct_idx is not None: |
|
x = x[self.deconstruct_idx] |
|
return x.transpose(self.tranpose_dim, -1) |
|
|
|
|
|
class Swish(nn.Module): |
|
def __init__(self): |
|
super(Swish, self).__init__() |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
return inputs * inputs.sigmoid() |
|
|
|
|
|
class GLU(nn.Module): |
|
def __init__(self, dim: int) -> None: |
|
super(GLU, self).__init__() |
|
self.dim = dim |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
outputs, gate = inputs.chunk(2, dim=self.dim) |
|
return outputs * gate.sigmoid() |
|
|
|
|
|
class ResidualConnectionModule(nn.Module): |
|
def __init__( |
|
self, |
|
module: nn.Module, |
|
module_factor: float = 1.0, |
|
input_factor: float = 1.0, |
|
): |
|
super(ResidualConnectionModule, self).__init__() |
|
self.module = module |
|
self.module_factor = module_factor |
|
self.input_factor = input_factor |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) |
|
|
|
|
|
class Linear(nn.Module): |
|
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: |
|
super(Linear, self).__init__() |
|
self.linear = nn.Linear(in_features, out_features, bias=bias) |
|
nn.init.xavier_uniform_(self.linear.weight) |
|
if bias: |
|
nn.init.zeros_(self.linear.bias) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.linear(x) |
|
|
|
|
|
class View(nn.Module): |
|
def __init__(self, shape: tuple, contiguous: bool = False): |
|
super(View, self).__init__() |
|
self.shape = shape |
|
self.contiguous = contiguous |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.contiguous: |
|
x = x.contiguous() |
|
|
|
return x.view(*self.shape) |
|
|
|
|
|
class Transpose(nn.Module): |
|
def __init__(self, shape: tuple): |
|
super(Transpose, self).__init__() |
|
self.shape = shape |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return x.transpose(*self.shape) |
|
|
|
|
|
class FeedForwardModule(nn.Module): |
|
def __init__( |
|
self, |
|
encoder_dim: int = 512, |
|
expansion_factor: int = 4, |
|
dropout_p: float = 0.1, |
|
) -> None: |
|
super(FeedForwardModule, self).__init__() |
|
self.sequential = nn.Sequential( |
|
nn.LayerNorm(encoder_dim), |
|
Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), |
|
Swish(), |
|
nn.Dropout(p=dropout_p), |
|
Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), |
|
nn.Dropout(p=dropout_p), |
|
) |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
return self.sequential(inputs) |
|
|
|
|
|
class DepthwiseConv1d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
stride: int = 1, |
|
padding: int = 0, |
|
bias: bool = False, |
|
) -> None: |
|
super(DepthwiseConv1d, self).__init__() |
|
assert ( |
|
out_channels % in_channels == 0 |
|
), "out_channels should be constant multiple of in_channels" |
|
self.conv = nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
groups=in_channels, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
return self.conv(inputs) |
|
|
|
|
|
class PointwiseConv1d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
stride: int = 1, |
|
padding: int = 0, |
|
bias: bool = True, |
|
) -> None: |
|
super(PointwiseConv1d, self).__init__() |
|
self.conv = nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
return self.conv(inputs) |
|
|
|
|
|
class ConformerConvModule(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
kernel_size: int = 31, |
|
expansion_factor: int = 2, |
|
dropout_p: float = 0.1, |
|
) -> None: |
|
super(ConformerConvModule, self).__init__() |
|
assert ( |
|
kernel_size - 1 |
|
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" |
|
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" |
|
|
|
self.sequential = nn.Sequential( |
|
nn.LayerNorm(in_channels), |
|
Transpose(shape=(1, 2)), |
|
PointwiseConv1d( |
|
in_channels, |
|
in_channels * expansion_factor, |
|
stride=1, |
|
padding=0, |
|
bias=True, |
|
), |
|
GLU(dim=1), |
|
DepthwiseConv1d( |
|
in_channels, |
|
in_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=(kernel_size - 1) // 2, |
|
), |
|
nn.BatchNorm1d(in_channels), |
|
Swish(), |
|
PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), |
|
nn.Dropout(p=dropout_p), |
|
) |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
return self.sequential(inputs).transpose(1, 2) |
|
|
|
|
|
class FramewiseConv2dSubampling(nn.Module): |
|
def __init__(self, out_channels: int, subsample_rate: int = 2) -> None: |
|
super(FramewiseConv2dSubampling, self).__init__() |
|
assert subsample_rate in {2, 4}, "subsample_rate should be 2 or 4" |
|
self.subsample_rate = subsample_rate |
|
self.cnn = nn.Sequential( |
|
nn.Conv2d(1, out_channels, kernel_size=3, stride=2), |
|
nn.ReLU(), |
|
nn.Conv2d( |
|
out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=(2 if subsample_rate == 4 else 1, 2), |
|
padding=(0 if subsample_rate == 4 else 1, 0), |
|
), |
|
nn.ReLU(), |
|
) |
|
|
|
def forward( |
|
self, inputs: torch.Tensor, input_lengths: torch.LongTensor |
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
|
|
if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0: |
|
inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0) |
|
outputs = self.cnn(inputs.unsqueeze(1)) |
|
batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() |
|
|
|
outputs = outputs.permute(0, 2, 1, 3) |
|
outputs = outputs.contiguous().view( |
|
batch_size, subsampled_lengths, channels * sumsampled_dim |
|
) |
|
|
|
if self.subsample_rate == 4: |
|
output_lengths = (((input_lengths - 1) >> 1) - 1) >> 1 |
|
else: |
|
output_lengths = input_lengths >> 1 |
|
|
|
return outputs, output_lengths |
|
|
|
|
|
class PatchwiseConv2dSubampling(nn.Module): |
|
def __init__( |
|
self, |
|
mel_dim: int, |
|
out_channels: int, |
|
patch_size_time: int = 16, |
|
patch_size_freq: int = 16, |
|
) -> None: |
|
super(PatchwiseConv2dSubampling, self).__init__() |
|
|
|
self.mel_dim = mel_dim |
|
self.patch_size_time = patch_size_time |
|
self.patch_size_freq = patch_size_freq |
|
|
|
self.proj = nn.Conv2d( |
|
1, |
|
out_channels, |
|
kernel_size=(patch_size_time, patch_size_freq), |
|
stride=(patch_size_time, patch_size_freq), |
|
padding=0, |
|
) |
|
self.cnn = nn.Sequential( |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
) |
|
|
|
@property |
|
def subsample_rate(self) -> int: |
|
return self.patch_size_time * self.patch_size_freq // self.mel_dim |
|
|
|
def forward( |
|
self, inputs: torch.Tensor, input_lengths: torch.LongTensor |
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
assert ( |
|
inputs.shape[2] == self.mel_dim |
|
), "inputs.shape[2] should be equal to mel_dim" |
|
|
|
|
|
outputs = self.proj(inputs.unsqueeze(1)) |
|
outputs = self.cnn(outputs) |
|
|
|
outputs = outputs.flatten(2, 3).transpose(1, 2) |
|
|
|
|
|
output_lengths = ( |
|
input_lengths |
|
// self.patch_size_time |
|
* (self.mel_dim // self.patch_size_freq) |
|
) |
|
|
|
return outputs, output_lengths |
|
|
|
|
|
class RelPositionalEncoding(nn.Module): |
|
def __init__(self, d_model: int, max_len: int = 10000) -> None: |
|
super(RelPositionalEncoding, self).__init__() |
|
self.d_model = d_model |
|
self.pe = None |
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) |
|
|
|
def extend_pe(self, x: torch.Tensor) -> None: |
|
if self.pe is not None: |
|
if self.pe.size(1) >= x.size(1) * 2 - 1: |
|
if self.pe.dtype != x.dtype or self.pe.device != x.device: |
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
|
return |
|
|
|
pe_positive = torch.zeros(x.size(1), self.d_model) |
|
pe_negative = torch.zeros(x.size(1), self.d_model) |
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, self.d_model, 2, dtype=torch.float32) |
|
* -(math.log(10000.0) / self.d_model) |
|
) |
|
pe_positive[:, 0::2] = torch.sin(position * div_term) |
|
pe_positive[:, 1::2] = torch.cos(position * div_term) |
|
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
|
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
|
|
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
|
pe_negative = pe_negative[1:].unsqueeze(0) |
|
pe = torch.cat([pe_positive, pe_negative], dim=1) |
|
self.pe = pe.to(device=x.device, dtype=x.dtype) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
self.extend_pe(x) |
|
pos_emb = self.pe[ |
|
:, |
|
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), |
|
] |
|
return pos_emb |
|
|
|
|
|
class RelativeMultiHeadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int = 512, |
|
num_heads: int = 16, |
|
dropout_p: float = 0.1, |
|
): |
|
super(RelativeMultiHeadAttention, self).__init__() |
|
assert d_model % num_heads == 0, "d_model % num_heads should be zero." |
|
self.d_model = d_model |
|
self.d_head = int(d_model / num_heads) |
|
self.num_heads = num_heads |
|
self.sqrt_dim = math.sqrt(self.d_head) |
|
|
|
self.query_proj = Linear(d_model, d_model) |
|
self.key_proj = Linear(d_model, d_model) |
|
self.value_proj = Linear(d_model, d_model) |
|
self.pos_proj = Linear(d_model, d_model, bias=False) |
|
|
|
self.dropout = nn.Dropout(p=dropout_p) |
|
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
|
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
|
torch.nn.init.xavier_uniform_(self.u_bias) |
|
torch.nn.init.xavier_uniform_(self.v_bias) |
|
|
|
self.out_proj = Linear(d_model, d_model) |
|
|
|
def forward( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
pos_embedding: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
batch_size = value.size(0) |
|
|
|
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) |
|
key = ( |
|
self.key_proj(key) |
|
.view(batch_size, -1, self.num_heads, self.d_head) |
|
.permute(0, 2, 1, 3) |
|
) |
|
value = ( |
|
self.value_proj(value) |
|
.view(batch_size, -1, self.num_heads, self.d_head) |
|
.permute(0, 2, 1, 3) |
|
) |
|
pos_embedding = self.pos_proj(pos_embedding).view( |
|
batch_size, -1, self.num_heads, self.d_head |
|
) |
|
|
|
content_score = torch.matmul( |
|
(query + self.u_bias).transpose(1, 2), key.transpose(2, 3) |
|
) |
|
pos_score = torch.matmul( |
|
(query + self.v_bias).transpose(1, 2), |
|
pos_embedding.permute(0, 2, 3, 1), |
|
) |
|
pos_score = self._relative_shift(pos_score) |
|
|
|
score = (content_score + pos_score) / self.sqrt_dim |
|
|
|
if mask is not None: |
|
mask = mask.unsqueeze(1) |
|
score.masked_fill_(mask, -1e9) |
|
|
|
attn = F.softmax(score, -1) |
|
attn = self.dropout(attn) |
|
|
|
context = torch.matmul(attn, value).transpose(1, 2) |
|
context = context.contiguous().view(batch_size, -1, self.d_model) |
|
|
|
return self.out_proj(context), attn |
|
|
|
def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: |
|
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() |
|
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) |
|
padded_pos_score = torch.cat([zeros, pos_score], dim=-1) |
|
|
|
padded_pos_score = padded_pos_score.view( |
|
batch_size, num_heads, seq_length2 + 1, seq_length1 |
|
) |
|
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)[ |
|
:, :, :, : seq_length2 // 2 + 1 |
|
] |
|
|
|
return pos_score |
|
|
|
|
|
class MultiHeadedSelfAttentionModule(nn.Module): |
|
def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1): |
|
super(MultiHeadedSelfAttentionModule, self).__init__() |
|
self.positional_encoding = RelPositionalEncoding(d_model) |
|
self.layer_norm = nn.LayerNorm(d_model) |
|
self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) |
|
self.dropout = nn.Dropout(p=dropout_p) |
|
|
|
def forward( |
|
self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
batch_size = inputs.size(0) |
|
pos_embedding = self.positional_encoding(inputs) |
|
pos_embedding = pos_embedding.repeat(batch_size, 1, 1) |
|
|
|
inputs = self.layer_norm(inputs) |
|
outputs, attn = self.attention( |
|
inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask |
|
) |
|
|
|
return self.dropout(outputs), attn |
|
|
|
|
|
class ConformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
encoder_dim: int = 512, |
|
attention_type: str = "mhsa", |
|
num_attention_heads: int = 8, |
|
mamba_d_state: int = 16, |
|
mamba_d_conv: int = 4, |
|
mamba_expand: int = 2, |
|
mamba_bidirectional: bool = True, |
|
feed_forward_expansion_factor: int = 4, |
|
conv_expansion_factor: int = 2, |
|
feed_forward_dropout_p: float = 0.1, |
|
attention_dropout_p: float = 0.1, |
|
conv_dropout_p: float = 0.1, |
|
conv_kernel_size: int = 31, |
|
half_step_residual: bool = True, |
|
transformer_style: bool = False, |
|
): |
|
super(ConformerBlock, self).__init__() |
|
|
|
self.transformer_style = transformer_style |
|
self.attention_type = attention_type |
|
|
|
if half_step_residual and not transformer_style: |
|
self.feed_forward_residual_factor = 0.5 |
|
else: |
|
self.feed_forward_residual_factor = 1 |
|
|
|
assert attention_type in ["mhsa", "mamba"] |
|
if attention_type == "mhsa": |
|
attention = MultiHeadedSelfAttentionModule( |
|
d_model=encoder_dim, |
|
num_heads=num_attention_heads, |
|
dropout_p=attention_dropout_p, |
|
) |
|
|
|
self.ffn_1 = FeedForwardModule( |
|
encoder_dim=encoder_dim, |
|
expansion_factor=feed_forward_expansion_factor, |
|
dropout_p=feed_forward_dropout_p, |
|
) |
|
self.attention = attention |
|
if not transformer_style: |
|
self.conv = ConformerConvModule( |
|
in_channels=encoder_dim, |
|
kernel_size=conv_kernel_size, |
|
expansion_factor=conv_expansion_factor, |
|
dropout_p=conv_dropout_p, |
|
) |
|
self.ffn_2 = FeedForwardModule( |
|
encoder_dim=encoder_dim, |
|
expansion_factor=feed_forward_expansion_factor, |
|
dropout_p=feed_forward_dropout_p, |
|
) |
|
self.layernorm = nn.LayerNorm(encoder_dim) |
|
|
|
def forward( |
|
self, x: torch.Tensor |
|
) -> Tuple[torch.Tensor, Dict[str, Union[torch.Tensor, None]]]: |
|
|
|
ffn_1_out = self.ffn_1(x) |
|
x = ffn_1_out * self.feed_forward_residual_factor + x |
|
|
|
|
|
if not isinstance(self.attention, MultiHeadedSelfAttentionModule): |
|
|
|
attn_out = self.attention(x) |
|
attn = None |
|
else: |
|
attn_out, attn = self.attention(x) |
|
x = attn_out + x |
|
|
|
if self.transformer_style: |
|
x = self.layernorm(x) |
|
return x, { |
|
"ffn_1": ffn_1_out, |
|
"attn": attn, |
|
"conv": None, |
|
"ffn_2": None, |
|
} |
|
|
|
|
|
conv_out = self.conv(x) |
|
x = conv_out + x |
|
|
|
|
|
ffn_2_out = self.ffn_2(x) |
|
x = ffn_2_out * self.feed_forward_residual_factor + x |
|
x = self.layernorm(x) |
|
|
|
other = { |
|
"ffn_1": ffn_1_out, |
|
"attn": attn, |
|
"conv": conv_out, |
|
"ffn_2": ffn_2_out, |
|
} |
|
|
|
return x, other |
|
|
|
|
|
class ConformerEncoder(nn.Module): |
|
def __init__(self, cfg): |
|
super(ConformerEncoder, self).__init__() |
|
|
|
self.cfg = cfg |
|
self.framewise_subsample = None |
|
self.patchwise_subsample = None |
|
self.framewise_in_proj = None |
|
self.patchwise_in_proj = None |
|
assert ( |
|
cfg.use_framewise_subsample or cfg.use_patchwise_subsample |
|
), "At least one subsampling method should be used" |
|
if cfg.use_framewise_subsample: |
|
self.framewise_subsample = FramewiseConv2dSubampling( |
|
out_channels=cfg.conv_subsample_channels, |
|
subsample_rate=cfg.conv_subsample_rate, |
|
) |
|
self.framewise_in_proj = nn.Sequential( |
|
Linear( |
|
cfg.conv_subsample_channels * (((cfg.input_dim - 1) // 2 - 1) // 2), |
|
cfg.encoder_dim, |
|
), |
|
nn.Dropout(p=cfg.input_dropout_p), |
|
) |
|
if cfg.use_patchwise_subsample: |
|
self.patchwise_subsample = PatchwiseConv2dSubampling( |
|
mel_dim=cfg.input_dim, |
|
out_channels=cfg.conv_subsample_channels, |
|
patch_size_time=cfg.patch_size_time, |
|
patch_size_freq=cfg.patch_size_freq, |
|
) |
|
self.patchwise_in_proj = nn.Sequential( |
|
Linear( |
|
cfg.conv_subsample_channels, |
|
cfg.encoder_dim, |
|
), |
|
nn.Dropout(p=cfg.input_dropout_p), |
|
) |
|
assert not cfg.use_framewise_subsample or ( |
|
cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate |
|
), ( |
|
f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate" |
|
f"({self.patchwise_subsample.subsample_rate})" |
|
) |
|
|
|
self.framewise_norm, self.patchwise_norm = None, None |
|
if getattr(cfg, "subsample_normalization", False): |
|
if cfg.use_framewise_subsample: |
|
self.framewise_norm = nn.LayerNorm(cfg.encoder_dim) |
|
if cfg.use_patchwise_subsample: |
|
self.patchwise_norm = nn.LayerNorm(cfg.encoder_dim) |
|
|
|
self.conv_pos = None |
|
if getattr(cfg, "conv_pos", False): |
|
num_pos_layers = cfg.conv_pos_depth |
|
k = max(3, cfg.conv_pos_width // num_pos_layers) |
|
self.conv_pos = nn.Sequential( |
|
TransposeLast(), |
|
*[ |
|
nn.Sequential( |
|
nn.Conv1d( |
|
cfg.encoder_dim, |
|
cfg.encoder_dim, |
|
kernel_size=k, |
|
padding=k // 2, |
|
groups=cfg.conv_pos_groups, |
|
), |
|
SamePad(k), |
|
TransposeLast(), |
|
nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False), |
|
TransposeLast(), |
|
nn.GELU(), |
|
) |
|
for _ in range(num_pos_layers) |
|
], |
|
TransposeLast(), |
|
) |
|
self.conv_pos_post_ln = nn.LayerNorm(cfg.encoder_dim) |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
ConformerBlock( |
|
encoder_dim=cfg.encoder_dim, |
|
attention_type=cfg.attention_type, |
|
num_attention_heads=cfg.num_attention_heads, |
|
mamba_d_state=cfg.mamba_d_state, |
|
mamba_d_conv=cfg.mamba_d_conv, |
|
mamba_expand=cfg.mamba_expand, |
|
mamba_bidirectional=cfg.mamba_bidirectional, |
|
feed_forward_expansion_factor=cfg.feed_forward_expansion_factor, |
|
conv_expansion_factor=cfg.conv_expansion_factor, |
|
feed_forward_dropout_p=cfg.feed_forward_dropout_p, |
|
attention_dropout_p=cfg.attention_dropout_p, |
|
conv_dropout_p=cfg.conv_dropout_p, |
|
conv_kernel_size=cfg.conv_kernel_size, |
|
half_step_residual=cfg.half_step_residual, |
|
transformer_style=getattr(cfg, "transformer_style", False), |
|
) |
|
for _ in range(cfg.num_layers) |
|
] |
|
) |
|
|
|
def count_parameters(self) -> int: |
|
"""Count parameters of encoder""" |
|
return sum([p.numel() for p in self.parameters() if p.requires_grad]) |
|
|
|
def update_dropout(self, dropout_p: float) -> None: |
|
"""Update dropout probability of encoder""" |
|
for name, child in self.named_children(): |
|
if isinstance(child, nn.Dropout): |
|
child.p = dropout_p |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
input_lengths: Optional[torch.Tensor] = None, |
|
return_hidden: bool = False, |
|
freeze_input_layers: bool = False, |
|
target_layer: Optional[int] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, List[torch.Tensor]]]: |
|
if input_lengths is None: |
|
input_lengths = torch.full( |
|
(inputs.size(0),), |
|
inputs.size(1), |
|
dtype=torch.long, |
|
device=inputs.device, |
|
) |
|
|
|
with torch.no_grad() if freeze_input_layers else contextlib.ExitStack(): |
|
frame_feat, patch_feat = None, None |
|
if self.framewise_subsample is not None: |
|
frame_feat, frame_lengths = self.framewise_subsample( |
|
inputs, input_lengths |
|
) |
|
frame_feat = self.framewise_in_proj(frame_feat) |
|
if self.framewise_norm is not None: |
|
frame_feat = self.framewise_norm(frame_feat) |
|
|
|
if self.patchwise_subsample is not None: |
|
patch_feat, patch_lengths = self.patchwise_subsample( |
|
inputs, input_lengths |
|
) |
|
patch_feat = self.patchwise_in_proj(patch_feat) |
|
if self.patchwise_norm is not None: |
|
patch_feat = self.patchwise_norm(patch_feat) |
|
|
|
if frame_feat is not None and patch_feat is not None: |
|
min_len = min(frame_feat.size(1), patch_feat.size(1)) |
|
frame_feat = frame_feat[:, :min_len] |
|
patch_feat = patch_feat[:, :min_len] |
|
|
|
features = frame_feat + patch_feat |
|
output_lengths = ( |
|
frame_lengths |
|
if frame_lengths.max().item() < patch_lengths.max().item() |
|
else patch_lengths |
|
) |
|
elif frame_feat is not None: |
|
features = frame_feat |
|
output_lengths = frame_lengths |
|
else: |
|
features = patch_feat |
|
output_lengths = patch_lengths |
|
|
|
if self.conv_pos is not None: |
|
features = features + self.conv_pos(features) |
|
features = self.conv_pos_post_ln(features) |
|
|
|
layer_results = defaultdict(list) |
|
|
|
outputs = features |
|
for i, layer in enumerate(self.layers): |
|
outputs, other = layer(outputs) |
|
if return_hidden: |
|
layer_results["hidden_states"].append(outputs) |
|
for k, v in other.items(): |
|
layer_results[k].append(v) |
|
|
|
if target_layer is not None and i == target_layer: |
|
break |
|
|
|
return outputs, output_lengths, layer_results |
|
|