"""Copied from https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py""" import math from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Union import torch from torch import Tensor, nn # try: # from flash_attn import flash_attn_func # IMPORT_FLASH = True # except Exception as err: # IMPORT_FLASH = False # IMPORT_FLASH_ERR = err IMPORT_FLASH = False IMPORT_FLASH_ERR = "Flash Attention not installed." # from .utils import apply_masked_flash_attn, apply_rotary_pos_emb def rtt_half(x: Tensor) -> Tensor: x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat([-x2, x1], dim=x1.ndim - 1) def apply_rotary_pos_emb( q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, offset: int = 0 ) -> Tuple[Tensor, Tensor]: """ Applies Rotary Position Embeddings to query and key tensors. """ cos, sin = ( cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...], ) return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin) # def apply_masked_flash_attn( # q: Tensor, # k: Tensor, # v: Tensor, # mask: Tensor, # h: int, # d_k: int, # ) -> Tensor: # """ # Applies Flash Attention with padding masks. # """ # from einops import rearrange # from flash_attn import flash_attn_varlen_func # from flash_attn.bert_padding import pad_input, unpad_input # pad_mask = ~mask[:, 0, :] # b, t = pad_mask.shape # q = q.view(b, t, h * d_k) # k = k.view(b, t, h * d_k) # v = v.view(b, t, h * d_k) # q_unpad, indices_q, _, max_seqlen_q = unpad_input(q, pad_mask)[:4] # q_unpad = rearrange(q_unpad, "nnz (h d) -> nnz h d", h=h) # k_unpad = unpad_input(k, pad_mask)[0] # k_unpad = rearrange(k_unpad, "nnz (h d) -> nnz h d", h=h) # v_unpad = unpad_input(v, pad_mask)[0] # v_unpad = rearrange(v_unpad, "nnz (h d) -> nnz h d", h=h) # lengths_q = pad_mask.sum(1).to(torch.int32).to(q.device) # cu_seqlens_q = F.pad(lengths_q.cumsum(0), (1, 0), value=0).to(torch.int32) # max_seqlen_q = torch.max(lengths_q) # output_unpad = flash_attn_varlen_func( # q_unpad, # k_unpad, # v_unpad, # cu_seqlens_q, # cu_seqlens_q, # max_seqlen_q, # max_seqlen_q, # ) # scores = pad_input( # rearrange(output_unpad, "nnz h d -> nnz (h d)"), # indices_q, # b, # t, # ) # return scores class StridingSubsampling(nn.Module): """ Strided Subsampling layer used to reduce the sequence length. """ def __init__( self, subsampling_factor: int, feat_in: int, feat_out: int, conv_channels: int, ): super().__init__() self._sampling_num = int(math.log(subsampling_factor, 2)) self._stride = 2 self._kernel_size = 3 self._padding = (self._kernel_size - 1) // 2 layers: List[nn.Module] = [] in_channels = 1 for _ in range(self._sampling_num): layers.append( torch.nn.Conv2d( in_channels=in_channels, out_channels=conv_channels, kernel_size=self._kernel_size, stride=self._stride, padding=self._padding, ) ) layers.append(nn.ReLU()) in_channels = conv_channels out_length = self.calc_output_length(torch.tensor(feat_in)) self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) self.conv = torch.nn.Sequential(*layers) def calc_output_length(self, lengths: Tensor) -> Tensor: """ Calculates the output length after applying the subsampling. """ lengths = lengths.to(torch.float) add_pad = 2 * self._padding - self._kernel_size for _ in range(self._sampling_num): lengths = torch.div(lengths + add_pad, self._stride) + 1.0 lengths = torch.floor(lengths) return lengths.to(dtype=torch.int) def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Tensor]: x = self.conv(x.unsqueeze(1)) b, _, t, _ = x.size() x = self.out(x.transpose(1, 2).reshape(b, t, -1)) return x, self.calc_output_length(lengths) class MultiHeadAttention(nn.Module, ABC): """ Base class of Multi-Head Attention Mechanisms. """ def __init__(self, n_head: int, n_feat: int, flash_attn=False): super().__init__() assert n_feat % n_head == 0 self.d_k = n_feat // n_head self.h = n_head self.linear_q = nn.Linear(n_feat, n_feat) self.linear_k = nn.Linear(n_feat, n_feat) self.linear_v = nn.Linear(n_feat, n_feat) self.linear_out = nn.Linear(n_feat, n_feat) self.flash_attn = flash_attn if self.flash_attn and not IMPORT_FLASH: raise RuntimeError( f"flash_attn_func was imported with err {IMPORT_FLASH_ERR}. " "Please install flash_attn or use --no_flash flag. " "If you have already done this, " "--force-reinstall flag might be useful" ) def forward_qkv( self, query: Tensor, key: Tensor, value: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: """ Projects the inputs into queries, keys, and values for multi-head attention. """ b = query.size(0) q = self.linear_q(query).view(b, -1, self.h, self.d_k) k = self.linear_k(key).view(b, -1, self.h, self.d_k) v = self.linear_v(value).view(b, -1, self.h, self.d_k) if self.flash_attn: return q, k, v return q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) def forward_attention( self, value: Tensor, scores: Tensor, mask: Optional[Tensor] ) -> Tensor: """ Computes the scaled dot-product attention given the projected values and scores. """ b = value.size(0) if mask is not None: mask = mask.unsqueeze(1) scores = scores.masked_fill(mask, -10000.0) attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) else: attn = torch.softmax(scores, dim=-1) x = torch.matmul(attn, value) x = x.transpose(1, 2).reshape(b, -1, self.h * self.d_k) return self.linear_out(x) class RelPositionMultiHeadAttention(MultiHeadAttention): """ Relative Position Multi-Head Attention module. """ def __init__(self, n_head: int, n_feat: int): super().__init__(n_head, n_feat) self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) def rel_shift(self, x: Tensor) -> Tensor: b, h, qlen, pos_len = x.size() x = torch.nn.functional.pad(x, pad=(1, 0)) x = x.view(b, h, -1, qlen) return x[:, :, 1:].view(b, h, qlen, pos_len) def forward( self, query: Tensor, key: Tensor, value: Tensor, pos_emb: Tensor, mask: Optional[Tensor] = None, ) -> Tensor: q, k, v = self.forward_qkv(query, key, value) q = q.transpose(1, 2) p = self.linear_pos(pos_emb) p = p.view(pos_emb.shape[0], -1, self.h, self.d_k).transpose(1, 2) q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) matrix_bd = self.rel_shift(matrix_bd) matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)] scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask) class RotaryPositionMultiHeadAttention(MultiHeadAttention): """ Rotary Position Multi-Head Attention module. """ def forward( self, query: Tensor, key: Tensor, value: Tensor, pos_emb: List[Tensor], mask: Optional[Tensor] = None, ) -> Tensor: b, t, _ = value.size() query = query.transpose(0, 1).view(t, b, self.h, self.d_k) key = key.transpose(0, 1).view(t, b, self.h, self.d_k) value = value.transpose(0, 1).view(t, b, self.h, self.d_k) cos, sin = pos_emb query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0) q, k, v = self.forward_qkv( query.view(t, b, self.h * self.d_k).transpose(0, 1), key.view(t, b, self.h * self.d_k).transpose(0, 1), value.view(t, b, self.h * self.d_k).transpose(0, 1), ) # if not self.flash_attn: scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k)) out = self.forward_attention(v, scores, mask) # else: # if mask is None: # scores = flash_attn_func(q, k, v) # else: # scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k) # scores = scores.view(b, -1, self.h * self.d_k) # out = self.linear_out(scores) return out class PositionalEncoding(nn.Module, ABC): """ Base class of Positional Encodings. """ def __init__(self, dim: int, base: int): super().__init__() self.dim = dim self.base = base @abstractmethod def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: pass def extend_pe(self, length: int, device: torch.device): """ Extends the positional encoding buffer to process longer sequences. """ pe = self.create_pe(length, device) if pe is None: return if hasattr(self, "pe"): self.pe = pe else: self.register_buffer("pe", pe, persistent=False) class RelPositionalEmbedding(PositionalEncoding): """ Relative Positional Embedding module. """ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: """ Creates the relative positional encoding matrix. """ if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1: return None positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1) pos_length = positions.size(0) pe = torch.zeros(pos_length, self.dim, device=positions.device) div_term = torch.exp( torch.arange(0, self.dim, 2, device=pe.device) * -(math.log(10000.0) / self.dim) ) pe[:, 0::2] = torch.sin(positions * div_term) pe[:, 1::2] = torch.cos(positions * div_term) return pe.unsqueeze(0) def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: input_len = x.size(1) center_pos = self.pe.size(1) // 2 + 1 start_pos = center_pos - input_len end_pos = center_pos + input_len - 1 return x, self.pe[:, start_pos:end_pos] class RotaryPositionalEmbedding(PositionalEncoding): """ Rotary Positional Embedding module. """ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: """ Creates or extends the rotary positional encoding matrix. """ if hasattr(self, "pe") and self.pe.size(0) >= 2 * length: return None positions = torch.arange(0, length, dtype=torch.float32, device=device) inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) ) t = torch.arange(length, device=positions.device).type_as(inv_freq) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(positions.device) return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]]) def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]: cos_emb = self.pe[0 : x.shape[1]] half_pe = self.pe.shape[0] // 2 sin_emb = self.pe[half_pe : half_pe + x.shape[1]] return x, [cos_emb, sin_emb] class ConformerConvolution(nn.Module): """ Conformer Convolution module. """ def __init__( self, d_model: int, kernel_size: int, ): super().__init__() assert (kernel_size - 1) % 2 == 0 self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1) self.depthwise_conv = nn.Conv1d( in_channels=d_model, out_channels=d_model, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=d_model, bias=True, ) self.batch_norm = nn.BatchNorm1d(d_model) self.activation = nn.SiLU() self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1) def forward(self, x: Tensor, pad_mask: Optional[Tensor] = None) -> Tensor: x = x.transpose(1, 2) x = self.pointwise_conv1(x) x = nn.functional.glu(x, dim=1) if pad_mask is not None: x = x.masked_fill(pad_mask.unsqueeze(1), 0.0) x = self.depthwise_conv(x) x = self.batch_norm(x) x = self.activation(x) x = self.pointwise_conv2(x) return x.transpose(1, 2) class ConformerFeedForward(nn.Module): """ Conformer Feed Forward module. """ def __init__(self, d_model: int, d_ff: int, use_bias=True): super().__init__() self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias) self.activation = nn.SiLU() self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias) def forward(self, x: Tensor) -> Tensor: return self.linear2(self.activation(self.linear1(x))) class ConformerLayer(nn.Module): """ Conformer Layer module. This module combines several submodules including feed forward networks, depthwise separable convolution, and multi-head self-attention to form a single Conformer block. """ def __init__( self, d_model: int, d_ff: int, self_attention_model: str, n_heads: int = 16, conv_kernel_size: int = 31, flash_attn: bool = False, ): super().__init__() self.fc_factor = 0.5 self.norm_feed_forward1 = nn.LayerNorm(d_model) self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff) self.norm_conv = nn.LayerNorm(d_model) self.conv = ConformerConvolution( d_model=d_model, kernel_size=conv_kernel_size, ) self.norm_self_att = nn.LayerNorm(d_model) if self_attention_model == "rotary": self.self_attn: nn.Module = RotaryPositionMultiHeadAttention( n_head=n_heads, n_feat=d_model, flash_attn=flash_attn, ) else: assert not flash_attn, "Not supported flash_attn for rel_pos" self.self_attn = RelPositionMultiHeadAttention( n_head=n_heads, n_feat=d_model, ) self.norm_feed_forward2 = nn.LayerNorm(d_model) self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff) self.norm_out = nn.LayerNorm(d_model) def forward( self, x: Tensor, pos_emb: Union[Tensor, List[Tensor]], att_mask: Optional[Tensor] = None, pad_mask: Optional[Tensor] = None, ) -> Tensor: residual = x x = self.norm_feed_forward1(x) x = self.feed_forward1(x) residual = residual + x * self.fc_factor x = self.norm_self_att(residual) x = self.self_attn(x, x, x, pos_emb, mask=att_mask) residual = residual + x x = self.norm_conv(residual) x = self.conv(x, pad_mask=pad_mask) residual = residual + x x = self.norm_feed_forward2(residual) x = self.feed_forward2(x) residual = residual + x * self.fc_factor x = self.norm_out(residual) return x class ConformerEncoder(nn.Module): """ Conformer Encoder module. This module encapsulates the entire Conformer encoder architecture, consisting of a StridingSubsampling layer, positional embeddings, and a stack of Conformer Layers. It serves as the main component responsible for processing speech features. """ def __init__( self, feat_in: int = 64, n_layers: int = 16, d_model: int = 768, subsampling_factor: int = 4, ff_expansion_factor: int = 4, self_attention_model: str = "rotary", n_heads: int = 16, pos_emb_max_len: int = 5000, conv_kernel_size: int = 31, flash_attn: bool = False, ): super().__init__() self.feat_in = feat_in assert self_attention_model in [ "rotary", "rel_pos", ], f"Not supported attn = {self_attention_model}" self.pre_encode = StridingSubsampling( subsampling_factor=subsampling_factor, feat_in=feat_in, feat_out=d_model, conv_channels=d_model, ) if self_attention_model == "rotary": self.pos_enc: nn.Module = RotaryPositionalEmbedding( d_model // n_heads, pos_emb_max_len ) else: self.pos_enc = RelPositionalEmbedding(d_model, pos_emb_max_len) self.layers = nn.ModuleList() for _ in range(n_layers): layer = ConformerLayer( d_model=d_model, d_ff=d_model * ff_expansion_factor, self_attention_model=self_attention_model, n_heads=n_heads, conv_kernel_size=conv_kernel_size, flash_attn=flash_attn, ) self.layers.append(layer) self.pos_enc.extend_pe(pos_emb_max_len, next(self.parameters()).device) def input_example( self, batch_size: int = 1, seqlen: int = 200, ): device = next(self.parameters()).device features = torch.zeros(batch_size, self.feat_in, seqlen) feature_lengths = torch.full([batch_size], features.shape[-1]) return features.float().to(device), feature_lengths.to(device) def input_names(self): return ["audio_signal", "length"] def output_names(self): return ["encoded", "encoded_len"] def dynamic_axes(self): return { "audio_signal": {0: "batch_size", 2: "seq_len"}, "length": {0: "batch_size"}, "encoded": {0: "batch_size", 1: "seq_len"}, "encoded_len": {0: "batch_size"}, } def forward(self, audio_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]: audio_signal, length = self.pre_encode( x=audio_signal.transpose(1, 2), lengths=length ) max_len = audio_signal.size(1) audio_signal, pos_emb = self.pos_enc(x=audio_signal) pad_mask = torch.arange(0, max_len, device=audio_signal.device).expand( length.size(0), -1 ) < length.unsqueeze(-1) att_mask = None if audio_signal.shape[0] > 1: att_mask = pad_mask.unsqueeze(1).repeat([1, max_len, 1]) att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2)) att_mask = ~att_mask pad_mask = ~pad_mask for layer in self.layers: audio_signal = layer( x=audio_signal, pos_emb=pos_emb, att_mask=att_mask, pad_mask=pad_mask, ) return audio_signal.transpose(1, 2), length