# Copyright (c) 2021, Soohwan Kim. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import math from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn class SamePad(nn.Module): def __init__(self, kernel_size, causal=False): super().__init__() if causal: self.remove = kernel_size - 1 else: self.remove = 1 if kernel_size % 2 == 0 else 0 def forward(self, x): if self.remove > 0: x = x[:, :, : -self.remove] return x class TransposeLast(nn.Module): def __init__(self, deconstruct_idx=None, tranpose_dim=-2): super().__init__() self.deconstruct_idx = deconstruct_idx self.tranpose_dim = tranpose_dim def forward(self, x): if self.deconstruct_idx is not None: x = x[self.deconstruct_idx] return x.transpose(self.tranpose_dim, -1) class Swish(nn.Module): def __init__(self): super(Swish, self).__init__() def forward(self, inputs: torch.Tensor) -> torch.Tensor: return inputs * inputs.sigmoid() class GLU(nn.Module): def __init__(self, dim: int) -> None: super(GLU, self).__init__() self.dim = dim def forward(self, inputs: torch.Tensor) -> torch.Tensor: outputs, gate = inputs.chunk(2, dim=self.dim) return outputs * gate.sigmoid() class ResidualConnectionModule(nn.Module): def __init__( self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0, ): super(ResidualConnectionModule, self).__init__() self.module = module self.module_factor = module_factor self.input_factor = input_factor def forward(self, inputs: torch.Tensor) -> torch.Tensor: return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) class Linear(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: super(Linear, self).__init__() self.linear = nn.Linear(in_features, out_features, bias=bias) nn.init.xavier_uniform_(self.linear.weight) if bias: nn.init.zeros_(self.linear.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) class View(nn.Module): def __init__(self, shape: tuple, contiguous: bool = False): super(View, self).__init__() self.shape = shape self.contiguous = contiguous def forward(self, x: torch.Tensor) -> torch.Tensor: if self.contiguous: x = x.contiguous() return x.view(*self.shape) class Transpose(nn.Module): def __init__(self, shape: tuple): super(Transpose, self).__init__() self.shape = shape def forward(self, x: torch.Tensor) -> torch.Tensor: return x.transpose(*self.shape) class FeedForwardModule(nn.Module): def __init__( self, encoder_dim: int = 512, expansion_factor: int = 4, dropout_p: float = 0.1, ) -> None: super(FeedForwardModule, self).__init__() self.sequential = nn.Sequential( nn.LayerNorm(encoder_dim), Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), Swish(), nn.Dropout(p=dropout_p), Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), nn.Dropout(p=dropout_p), ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.sequential(inputs) class DepthwiseConv1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False, ) -> None: super(DepthwiseConv1d, self).__init__() assert ( out_channels % in_channels == 0 ), "out_channels should be constant multiple of in_channels" self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=in_channels, stride=stride, padding=padding, bias=bias, ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.conv(inputs) class PointwiseConv1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: int = 1, padding: int = 0, bias: bool = True, ) -> None: super(PointwiseConv1d, self).__init__() self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding, bias=bias, ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.conv(inputs) class ConformerConvModule(nn.Module): def __init__( self, in_channels: int, kernel_size: int = 31, expansion_factor: int = 2, dropout_p: float = 0.1, ) -> None: super(ConformerConvModule, self).__init__() assert ( kernel_size - 1 ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" self.sequential = nn.Sequential( nn.LayerNorm(in_channels), Transpose(shape=(1, 2)), PointwiseConv1d( in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True, ), GLU(dim=1), DepthwiseConv1d( in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ), nn.BatchNorm1d(in_channels), Swish(), PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), nn.Dropout(p=dropout_p), ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.sequential(inputs).transpose(1, 2) class FramewiseConv2dSubampling(nn.Module): def __init__(self, out_channels: int, subsample_rate: int = 2) -> None: super(FramewiseConv2dSubampling, self).__init__() assert subsample_rate in {2, 4}, "subsample_rate should be 2 or 4" self.subsample_rate = subsample_rate self.cnn = nn.Sequential( nn.Conv2d(1, out_channels, kernel_size=3, stride=2), nn.ReLU(), nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=(2 if subsample_rate == 4 else 1, 2), padding=(0 if subsample_rate == 4 else 1, 0), ), nn.ReLU(), ) def forward( self, inputs: torch.Tensor, input_lengths: torch.LongTensor ) -> Tuple[torch.Tensor, torch.LongTensor]: # inputs: (B, T, C) -> (B, 1, T, C) if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0: inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0) outputs = self.cnn(inputs.unsqueeze(1)) batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() outputs = outputs.permute(0, 2, 1, 3) outputs = outputs.contiguous().view( batch_size, subsampled_lengths, channels * sumsampled_dim ) if self.subsample_rate == 4: output_lengths = (((input_lengths - 1) >> 1) - 1) >> 1 else: output_lengths = input_lengths >> 1 return outputs, output_lengths class PatchwiseConv2dSubampling(nn.Module): def __init__( self, mel_dim: int, out_channels: int, patch_size_time: int = 16, patch_size_freq: int = 16, ) -> None: super(PatchwiseConv2dSubampling, self).__init__() self.mel_dim = mel_dim self.patch_size_time = patch_size_time self.patch_size_freq = patch_size_freq self.proj = nn.Conv2d( 1, out_channels, kernel_size=(patch_size_time, patch_size_freq), stride=(patch_size_time, patch_size_freq), padding=0, ) self.cnn = nn.Sequential( nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(), ) @property def subsample_rate(self) -> int: return self.patch_size_time * self.patch_size_freq // self.mel_dim def forward( self, inputs: torch.Tensor, input_lengths: torch.LongTensor ) -> Tuple[torch.Tensor, torch.LongTensor]: assert ( inputs.shape[2] == self.mel_dim ), "inputs.shape[2] should be equal to mel_dim" # inputs: (B, Time, Freq) -> (B, 1, Time, Freq) outputs = self.proj(inputs.unsqueeze(1)) outputs = self.cnn(outputs) # (B, channels, Time // patch_size_time, Freq // patch_size_freq) outputs = outputs.flatten(2, 3).transpose(1, 2) # (B, (Time // patch_size_time) * (Freq // patch_size_freq), channels) output_lengths = ( input_lengths // self.patch_size_time * (self.mel_dim // self.patch_size_freq) ) return outputs, output_lengths class RelPositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 10000) -> None: super(RelPositionalEncoding, self).__init__() self.d_model = d_model self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) def extend_pe(self, x: torch.Tensor) -> None: if self.pe is not None: if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe_positive = torch.zeros(x.size(1), self.d_model) pe_negative = torch.zeros(x.size(1), self.d_model) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) ) pe_positive[:, 0::2] = torch.sin(position * div_term) pe_positive[:, 1::2] = torch.cos(position * div_term) pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) pe_negative = pe_negative[1:].unsqueeze(0) pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, T, C) self.extend_pe(x) pos_emb = self.pe[ :, self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), ] return pos_emb class RelativeMultiHeadAttention(nn.Module): def __init__( self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1, ): super(RelativeMultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model % num_heads should be zero." self.d_model = d_model self.d_head = int(d_model / num_heads) self.num_heads = num_heads self.sqrt_dim = math.sqrt(self.d_head) self.query_proj = Linear(d_model, d_model) self.key_proj = Linear(d_model, d_model) self.value_proj = Linear(d_model, d_model) self.pos_proj = Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(p=dropout_p) self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) torch.nn.init.xavier_uniform_(self.u_bias) torch.nn.init.xavier_uniform_(self.v_bias) self.out_proj = Linear(d_model, d_model) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, pos_embedding: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = value.size(0) query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) key = ( self.key_proj(key) .view(batch_size, -1, self.num_heads, self.d_head) .permute(0, 2, 1, 3) ) value = ( self.value_proj(value) .view(batch_size, -1, self.num_heads, self.d_head) .permute(0, 2, 1, 3) ) pos_embedding = self.pos_proj(pos_embedding).view( batch_size, -1, self.num_heads, self.d_head ) content_score = torch.matmul( (query + self.u_bias).transpose(1, 2), key.transpose(2, 3) ) pos_score = torch.matmul( (query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1), ) pos_score = self._relative_shift(pos_score) score = (content_score + pos_score) / self.sqrt_dim if mask is not None: mask = mask.unsqueeze(1) score.masked_fill_(mask, -1e9) attn = F.softmax(score, -1) attn = self.dropout(attn) context = torch.matmul(attn, value).transpose(1, 2) context = context.contiguous().view(batch_size, -1, self.d_model) return self.out_proj(context), attn def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) padded_pos_score = torch.cat([zeros, pos_score], dim=-1) padded_pos_score = padded_pos_score.view( batch_size, num_heads, seq_length2 + 1, seq_length1 ) pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)[ :, :, :, : seq_length2 // 2 + 1 ] return pos_score class MultiHeadedSelfAttentionModule(nn.Module): def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1): super(MultiHeadedSelfAttentionModule, self).__init__() self.positional_encoding = RelPositionalEncoding(d_model) self.layer_norm = nn.LayerNorm(d_model) self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) self.dropout = nn.Dropout(p=dropout_p) def forward( self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = inputs.size(0) pos_embedding = self.positional_encoding(inputs) pos_embedding = pos_embedding.repeat(batch_size, 1, 1) inputs = self.layer_norm(inputs) outputs, attn = self.attention( inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask ) return self.dropout(outputs), attn class ConformerBlock(nn.Module): def __init__( self, encoder_dim: int = 512, attention_type: str = "mhsa", num_attention_heads: int = 8, mamba_d_state: int = 16, mamba_d_conv: int = 4, mamba_expand: int = 2, mamba_bidirectional: bool = True, feed_forward_expansion_factor: int = 4, conv_expansion_factor: int = 2, feed_forward_dropout_p: float = 0.1, attention_dropout_p: float = 0.1, conv_dropout_p: float = 0.1, conv_kernel_size: int = 31, half_step_residual: bool = True, transformer_style: bool = False, ): super(ConformerBlock, self).__init__() self.transformer_style = transformer_style self.attention_type = attention_type if half_step_residual and not transformer_style: self.feed_forward_residual_factor = 0.5 else: self.feed_forward_residual_factor = 1 assert attention_type in ["mhsa", "mamba"] if attention_type == "mhsa": attention = MultiHeadedSelfAttentionModule( d_model=encoder_dim, num_heads=num_attention_heads, dropout_p=attention_dropout_p, ) self.ffn_1 = FeedForwardModule( encoder_dim=encoder_dim, expansion_factor=feed_forward_expansion_factor, dropout_p=feed_forward_dropout_p, ) self.attention = attention if not transformer_style: self.conv = ConformerConvModule( in_channels=encoder_dim, kernel_size=conv_kernel_size, expansion_factor=conv_expansion_factor, dropout_p=conv_dropout_p, ) self.ffn_2 = FeedForwardModule( encoder_dim=encoder_dim, expansion_factor=feed_forward_expansion_factor, dropout_p=feed_forward_dropout_p, ) self.layernorm = nn.LayerNorm(encoder_dim) def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, Union[torch.Tensor, None]]]: # FFN 1 ffn_1_out = self.ffn_1(x) x = ffn_1_out * self.feed_forward_residual_factor + x # Attention if not isinstance(self.attention, MultiHeadedSelfAttentionModule): # MAMBA attn_out = self.attention(x) attn = None else: attn_out, attn = self.attention(x) x = attn_out + x if self.transformer_style: x = self.layernorm(x) return x, { "ffn_1": ffn_1_out, "attn": attn, "conv": None, "ffn_2": None, } # Convolution conv_out = self.conv(x) x = conv_out + x # FFN 2 ffn_2_out = self.ffn_2(x) x = ffn_2_out * self.feed_forward_residual_factor + x x = self.layernorm(x) other = { "ffn_1": ffn_1_out, "attn": attn, "conv": conv_out, "ffn_2": ffn_2_out, } return x, other class ConformerEncoder(nn.Module): def __init__(self, cfg): super(ConformerEncoder, self).__init__() self.cfg = cfg self.framewise_subsample = None self.patchwise_subsample = None self.framewise_in_proj = None self.patchwise_in_proj = None assert ( cfg.use_framewise_subsample or cfg.use_patchwise_subsample ), "At least one subsampling method should be used" if cfg.use_framewise_subsample: self.framewise_subsample = FramewiseConv2dSubampling( out_channels=cfg.conv_subsample_channels, subsample_rate=cfg.conv_subsample_rate, ) self.framewise_in_proj = nn.Sequential( Linear( cfg.conv_subsample_channels * (((cfg.input_dim - 1) // 2 - 1) // 2), cfg.encoder_dim, ), nn.Dropout(p=cfg.input_dropout_p), ) if cfg.use_patchwise_subsample: self.patchwise_subsample = PatchwiseConv2dSubampling( mel_dim=cfg.input_dim, out_channels=cfg.conv_subsample_channels, patch_size_time=cfg.patch_size_time, patch_size_freq=cfg.patch_size_freq, ) self.patchwise_in_proj = nn.Sequential( Linear( cfg.conv_subsample_channels, cfg.encoder_dim, ), nn.Dropout(p=cfg.input_dropout_p), ) assert not cfg.use_framewise_subsample or ( cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate ), ( f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate" f"({self.patchwise_subsample.subsample_rate})" ) self.framewise_norm, self.patchwise_norm = None, None if getattr(cfg, "subsample_normalization", False): if cfg.use_framewise_subsample: self.framewise_norm = nn.LayerNorm(cfg.encoder_dim) if cfg.use_patchwise_subsample: self.patchwise_norm = nn.LayerNorm(cfg.encoder_dim) self.conv_pos = None if getattr(cfg, "conv_pos", False): num_pos_layers = cfg.conv_pos_depth k = max(3, cfg.conv_pos_width // num_pos_layers) self.conv_pos = nn.Sequential( TransposeLast(), *[ nn.Sequential( nn.Conv1d( cfg.encoder_dim, cfg.encoder_dim, kernel_size=k, padding=k // 2, groups=cfg.conv_pos_groups, ), SamePad(k), TransposeLast(), nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False), TransposeLast(), nn.GELU(), ) for _ in range(num_pos_layers) ], TransposeLast(), ) self.conv_pos_post_ln = nn.LayerNorm(cfg.encoder_dim) self.layers = nn.ModuleList( [ ConformerBlock( encoder_dim=cfg.encoder_dim, attention_type=cfg.attention_type, num_attention_heads=cfg.num_attention_heads, mamba_d_state=cfg.mamba_d_state, mamba_d_conv=cfg.mamba_d_conv, mamba_expand=cfg.mamba_expand, mamba_bidirectional=cfg.mamba_bidirectional, feed_forward_expansion_factor=cfg.feed_forward_expansion_factor, conv_expansion_factor=cfg.conv_expansion_factor, feed_forward_dropout_p=cfg.feed_forward_dropout_p, attention_dropout_p=cfg.attention_dropout_p, conv_dropout_p=cfg.conv_dropout_p, conv_kernel_size=cfg.conv_kernel_size, half_step_residual=cfg.half_step_residual, transformer_style=getattr(cfg, "transformer_style", False), ) for _ in range(cfg.num_layers) ] ) def count_parameters(self) -> int: """Count parameters of encoder""" return sum([p.numel() for p in self.parameters() if p.requires_grad]) def update_dropout(self, dropout_p: float) -> None: """Update dropout probability of encoder""" for name, child in self.named_children(): if isinstance(child, nn.Dropout): child.p = dropout_p def forward( self, inputs: torch.Tensor, input_lengths: Optional[torch.Tensor] = None, return_hidden: bool = False, freeze_input_layers: bool = False, target_layer: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, List[torch.Tensor]]]: if input_lengths is None: input_lengths = torch.full( (inputs.size(0),), inputs.size(1), dtype=torch.long, device=inputs.device, ) with torch.no_grad() if freeze_input_layers else contextlib.ExitStack(): frame_feat, patch_feat = None, None if self.framewise_subsample is not None: frame_feat, frame_lengths = self.framewise_subsample( inputs, input_lengths ) frame_feat = self.framewise_in_proj(frame_feat) if self.framewise_norm is not None: frame_feat = self.framewise_norm(frame_feat) if self.patchwise_subsample is not None: patch_feat, patch_lengths = self.patchwise_subsample( inputs, input_lengths ) patch_feat = self.patchwise_in_proj(patch_feat) if self.patchwise_norm is not None: patch_feat = self.patchwise_norm(patch_feat) if frame_feat is not None and patch_feat is not None: min_len = min(frame_feat.size(1), patch_feat.size(1)) frame_feat = frame_feat[:, :min_len] patch_feat = patch_feat[:, :min_len] features = frame_feat + patch_feat output_lengths = ( frame_lengths if frame_lengths.max().item() < patch_lengths.max().item() else patch_lengths ) elif frame_feat is not None: features = frame_feat output_lengths = frame_lengths else: features = patch_feat output_lengths = patch_lengths if self.conv_pos is not None: features = features + self.conv_pos(features) features = self.conv_pos_post_ln(features) layer_results = defaultdict(list) outputs = features for i, layer in enumerate(self.layers): outputs, other = layer(outputs) if return_hidden: layer_results["hidden_states"].append(outputs) for k, v in other.items(): layer_results[k].append(v) if target_layer is not None and i == target_layer: break return outputs, output_lengths, layer_results