gigaam-ctc / encoder.py
waveletdeboshir's picture
Upload encoder.py
3e65276 verified
"""Copied from https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py"""
import math
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
# try:
# from flash_attn import flash_attn_func
# IMPORT_FLASH = True
# except Exception as err:
# IMPORT_FLASH = False
# IMPORT_FLASH_ERR = err
IMPORT_FLASH = False
IMPORT_FLASH_ERR = "Flash Attention not installed."
# from .utils import apply_masked_flash_attn, apply_rotary_pos_emb
def rtt_half(x: Tensor) -> Tensor:
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat([-x2, x1], dim=x1.ndim - 1)
def apply_rotary_pos_emb(
q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, offset: int = 0
) -> Tuple[Tensor, Tensor]:
"""
Applies Rotary Position Embeddings to query and key tensors.
"""
cos, sin = (
cos[offset : q.shape[0] + offset, ...],
sin[offset : q.shape[0] + offset, ...],
)
return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin)
# def apply_masked_flash_attn(
# q: Tensor,
# k: Tensor,
# v: Tensor,
# mask: Tensor,
# h: int,
# d_k: int,
# ) -> Tensor:
# """
# Applies Flash Attention with padding masks.
# """
# from einops import rearrange
# from flash_attn import flash_attn_varlen_func
# from flash_attn.bert_padding import pad_input, unpad_input
# pad_mask = ~mask[:, 0, :]
# b, t = pad_mask.shape
# q = q.view(b, t, h * d_k)
# k = k.view(b, t, h * d_k)
# v = v.view(b, t, h * d_k)
# q_unpad, indices_q, _, max_seqlen_q = unpad_input(q, pad_mask)[:4]
# q_unpad = rearrange(q_unpad, "nnz (h d) -> nnz h d", h=h)
# k_unpad = unpad_input(k, pad_mask)[0]
# k_unpad = rearrange(k_unpad, "nnz (h d) -> nnz h d", h=h)
# v_unpad = unpad_input(v, pad_mask)[0]
# v_unpad = rearrange(v_unpad, "nnz (h d) -> nnz h d", h=h)
# lengths_q = pad_mask.sum(1).to(torch.int32).to(q.device)
# cu_seqlens_q = F.pad(lengths_q.cumsum(0), (1, 0), value=0).to(torch.int32)
# max_seqlen_q = torch.max(lengths_q)
# output_unpad = flash_attn_varlen_func(
# q_unpad,
# k_unpad,
# v_unpad,
# cu_seqlens_q,
# cu_seqlens_q,
# max_seqlen_q,
# max_seqlen_q,
# )
# scores = pad_input(
# rearrange(output_unpad, "nnz h d -> nnz (h d)"),
# indices_q,
# b,
# t,
# )
# return scores
class StridingSubsampling(nn.Module):
"""
Strided Subsampling layer used to reduce the sequence length.
"""
def __init__(
self,
subsampling_factor: int,
feat_in: int,
feat_out: int,
conv_channels: int,
):
super().__init__()
self._sampling_num = int(math.log(subsampling_factor, 2))
self._stride = 2
self._kernel_size = 3
self._padding = (self._kernel_size - 1) // 2
layers: List[nn.Module] = []
in_channels = 1
for _ in range(self._sampling_num):
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._padding,
)
)
layers.append(nn.ReLU())
in_channels = conv_channels
out_length = self.calc_output_length(torch.tensor(feat_in))
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
self.conv = torch.nn.Sequential(*layers)
def calc_output_length(self, lengths: Tensor) -> Tensor:
"""
Calculates the output length after applying the subsampling.
"""
lengths = lengths.to(torch.float)
add_pad = 2 * self._padding - self._kernel_size
for _ in range(self._sampling_num):
lengths = torch.div(lengths + add_pad, self._stride) + 1.0
lengths = torch.floor(lengths)
return lengths.to(dtype=torch.int)
def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Tensor]:
x = self.conv(x.unsqueeze(1))
b, _, t, _ = x.size()
x = self.out(x.transpose(1, 2).reshape(b, t, -1))
return x, self.calc_output_length(lengths)
class MultiHeadAttention(nn.Module, ABC):
"""
Base class of Multi-Head Attention Mechanisms.
"""
def __init__(self, n_head: int, n_feat: int, flash_attn=False):
super().__init__()
assert n_feat % n_head == 0
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.flash_attn = flash_attn
if self.flash_attn and not IMPORT_FLASH:
raise RuntimeError(
f"flash_attn_func was imported with err {IMPORT_FLASH_ERR}. "
"Please install flash_attn or use --no_flash flag. "
"If you have already done this, "
"--force-reinstall flag might be useful"
)
def forward_qkv(
self, query: Tensor, key: Tensor, value: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Projects the inputs into queries, keys, and values for multi-head attention.
"""
b = query.size(0)
q = self.linear_q(query).view(b, -1, self.h, self.d_k)
k = self.linear_k(key).view(b, -1, self.h, self.d_k)
v = self.linear_v(value).view(b, -1, self.h, self.d_k)
if self.flash_attn:
return q, k, v
return q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
def forward_attention(
self, value: Tensor, scores: Tensor, mask: Optional[Tensor]
) -> Tensor:
"""
Computes the scaled dot-product attention given the projected values and scores.
"""
b = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask, -10000.0)
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
else:
attn = torch.softmax(scores, dim=-1)
x = torch.matmul(attn, value)
x = x.transpose(1, 2).reshape(b, -1, self.h * self.d_k)
return self.linear_out(x)
class RelPositionMultiHeadAttention(MultiHeadAttention):
"""
Relative Position Multi-Head Attention module.
"""
def __init__(self, n_head: int, n_feat: int):
super().__init__(n_head, n_feat)
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
def rel_shift(self, x: Tensor) -> Tensor:
b, h, qlen, pos_len = x.size()
x = torch.nn.functional.pad(x, pad=(1, 0))
x = x.view(b, h, -1, qlen)
return x[:, :, 1:].view(b, h, qlen, pos_len)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
) -> Tensor:
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2)
p = self.linear_pos(pos_emb)
p = p.view(pos_emb.shape[0], -1, self.h, self.d_k).transpose(1, 2)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
class RotaryPositionMultiHeadAttention(MultiHeadAttention):
"""
Rotary Position Multi-Head Attention module.
"""
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: List[Tensor],
mask: Optional[Tensor] = None,
) -> Tensor:
b, t, _ = value.size()
query = query.transpose(0, 1).view(t, b, self.h, self.d_k)
key = key.transpose(0, 1).view(t, b, self.h, self.d_k)
value = value.transpose(0, 1).view(t, b, self.h, self.d_k)
cos, sin = pos_emb
query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
q, k, v = self.forward_qkv(
query.view(t, b, self.h * self.d_k).transpose(0, 1),
key.view(t, b, self.h * self.d_k).transpose(0, 1),
value.view(t, b, self.h * self.d_k).transpose(0, 1),
)
# if not self.flash_attn:
scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k))
out = self.forward_attention(v, scores, mask)
# else:
# if mask is None:
# scores = flash_attn_func(q, k, v)
# else:
# scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k)
# scores = scores.view(b, -1, self.h * self.d_k)
# out = self.linear_out(scores)
return out
class PositionalEncoding(nn.Module, ABC):
"""
Base class of Positional Encodings.
"""
def __init__(self, dim: int, base: int):
super().__init__()
self.dim = dim
self.base = base
@abstractmethod
def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
pass
def extend_pe(self, length: int, device: torch.device):
"""
Extends the positional encoding buffer to process longer sequences.
"""
pe = self.create_pe(length, device)
if pe is None:
return
if hasattr(self, "pe"):
self.pe = pe
else:
self.register_buffer("pe", pe, persistent=False)
class RelPositionalEmbedding(PositionalEncoding):
"""
Relative Positional Embedding module.
"""
def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
"""
Creates the relative positional encoding matrix.
"""
if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1:
return None
positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1)
pos_length = positions.size(0)
pe = torch.zeros(pos_length, self.dim, device=positions.device)
div_term = torch.exp(
torch.arange(0, self.dim, 2, device=pe.device)
* -(math.log(10000.0) / self.dim)
)
pe[:, 0::2] = torch.sin(positions * div_term)
pe[:, 1::2] = torch.cos(positions * div_term)
return pe.unsqueeze(0)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
input_len = x.size(1)
center_pos = self.pe.size(1) // 2 + 1
start_pos = center_pos - input_len
end_pos = center_pos + input_len - 1
return x, self.pe[:, start_pos:end_pos]
class RotaryPositionalEmbedding(PositionalEncoding):
"""
Rotary Positional Embedding module.
"""
def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
"""
Creates or extends the rotary positional encoding matrix.
"""
if hasattr(self, "pe") and self.pe.size(0) >= 2 * length:
return None
positions = torch.arange(0, length, dtype=torch.float32, device=device)
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
)
t = torch.arange(length, device=positions.device).type_as(inv_freq)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]:
cos_emb = self.pe[0 : x.shape[1]]
half_pe = self.pe.shape[0] // 2
sin_emb = self.pe[half_pe : half_pe + x.shape[1]]
return x, [cos_emb, sin_emb]
class ConformerConvolution(nn.Module):
"""
Conformer Convolution module.
"""
def __init__(
self,
d_model: int,
kernel_size: int,
):
super().__init__()
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
self.depthwise_conv = nn.Conv1d(
in_channels=d_model,
out_channels=d_model,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
groups=d_model,
bias=True,
)
self.batch_norm = nn.BatchNorm1d(d_model)
self.activation = nn.SiLU()
self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
def forward(self, x: Tensor, pad_mask: Optional[Tensor] = None) -> Tensor:
x = x.transpose(1, 2)
x = self.pointwise_conv1(x)
x = nn.functional.glu(x, dim=1)
if pad_mask is not None:
x = x.masked_fill(pad_mask.unsqueeze(1), 0.0)
x = self.depthwise_conv(x)
x = self.batch_norm(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
return x.transpose(1, 2)
class ConformerFeedForward(nn.Module):
"""
Conformer Feed Forward module.
"""
def __init__(self, d_model: int, d_ff: int, use_bias=True):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias)
self.activation = nn.SiLU()
self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias)
def forward(self, x: Tensor) -> Tensor:
return self.linear2(self.activation(self.linear1(x)))
class ConformerLayer(nn.Module):
"""
Conformer Layer module.
This module combines several submodules including feed forward networks,
depthwise separable convolution, and multi-head self-attention
to form a single Conformer block.
"""
def __init__(
self,
d_model: int,
d_ff: int,
self_attention_model: str,
n_heads: int = 16,
conv_kernel_size: int = 31,
flash_attn: bool = False,
):
super().__init__()
self.fc_factor = 0.5
self.norm_feed_forward1 = nn.LayerNorm(d_model)
self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
self.norm_conv = nn.LayerNorm(d_model)
self.conv = ConformerConvolution(
d_model=d_model,
kernel_size=conv_kernel_size,
)
self.norm_self_att = nn.LayerNorm(d_model)
if self_attention_model == "rotary":
self.self_attn: nn.Module = RotaryPositionMultiHeadAttention(
n_head=n_heads,
n_feat=d_model,
flash_attn=flash_attn,
)
else:
assert not flash_attn, "Not supported flash_attn for rel_pos"
self.self_attn = RelPositionMultiHeadAttention(
n_head=n_heads,
n_feat=d_model,
)
self.norm_feed_forward2 = nn.LayerNorm(d_model)
self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
self.norm_out = nn.LayerNorm(d_model)
def forward(
self,
x: Tensor,
pos_emb: Union[Tensor, List[Tensor]],
att_mask: Optional[Tensor] = None,
pad_mask: Optional[Tensor] = None,
) -> Tensor:
residual = x
x = self.norm_feed_forward1(x)
x = self.feed_forward1(x)
residual = residual + x * self.fc_factor
x = self.norm_self_att(residual)
x = self.self_attn(x, x, x, pos_emb, mask=att_mask)
residual = residual + x
x = self.norm_conv(residual)
x = self.conv(x, pad_mask=pad_mask)
residual = residual + x
x = self.norm_feed_forward2(residual)
x = self.feed_forward2(x)
residual = residual + x * self.fc_factor
x = self.norm_out(residual)
return x
class ConformerEncoder(nn.Module):
"""
Conformer Encoder module.
This module encapsulates the entire Conformer encoder architecture,
consisting of a StridingSubsampling layer, positional embeddings, and
a stack of Conformer Layers.
It serves as the main component responsible for processing speech features.
"""
def __init__(
self,
feat_in: int = 64,
n_layers: int = 16,
d_model: int = 768,
subsampling_factor: int = 4,
ff_expansion_factor: int = 4,
self_attention_model: str = "rotary",
n_heads: int = 16,
pos_emb_max_len: int = 5000,
conv_kernel_size: int = 31,
flash_attn: bool = False,
):
super().__init__()
self.feat_in = feat_in
assert self_attention_model in [
"rotary",
"rel_pos",
], f"Not supported attn = {self_attention_model}"
self.pre_encode = StridingSubsampling(
subsampling_factor=subsampling_factor,
feat_in=feat_in,
feat_out=d_model,
conv_channels=d_model,
)
if self_attention_model == "rotary":
self.pos_enc: nn.Module = RotaryPositionalEmbedding(
d_model // n_heads, pos_emb_max_len
)
else:
self.pos_enc = RelPositionalEmbedding(d_model, pos_emb_max_len)
self.layers = nn.ModuleList()
for _ in range(n_layers):
layer = ConformerLayer(
d_model=d_model,
d_ff=d_model * ff_expansion_factor,
self_attention_model=self_attention_model,
n_heads=n_heads,
conv_kernel_size=conv_kernel_size,
flash_attn=flash_attn,
)
self.layers.append(layer)
self.pos_enc.extend_pe(pos_emb_max_len, next(self.parameters()).device)
def input_example(
self,
batch_size: int = 1,
seqlen: int = 200,
):
device = next(self.parameters()).device
features = torch.zeros(batch_size, self.feat_in, seqlen)
feature_lengths = torch.full([batch_size], features.shape[-1])
return features.float().to(device), feature_lengths.to(device)
def input_names(self):
return ["audio_signal", "length"]
def output_names(self):
return ["encoded", "encoded_len"]
def dynamic_axes(self):
return {
"audio_signal": {0: "batch_size", 2: "seq_len"},
"length": {0: "batch_size"},
"encoded": {0: "batch_size", 1: "seq_len"},
"encoded_len": {0: "batch_size"},
}
def forward(self, audio_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]:
audio_signal, length = self.pre_encode(
x=audio_signal.transpose(1, 2), lengths=length
)
max_len = audio_signal.size(1)
audio_signal, pos_emb = self.pos_enc(x=audio_signal)
pad_mask = torch.arange(0, max_len, device=audio_signal.device).expand(
length.size(0), -1
) < length.unsqueeze(-1)
att_mask = None
if audio_signal.shape[0] > 1:
att_mask = pad_mask.unsqueeze(1).repeat([1, max_len, 1])
att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2))
att_mask = ~att_mask
pad_mask = ~pad_mask
for layer in self.layers:
audio_signal = layer(
x=audio_signal,
pos_emb=pos_emb,
att_mask=att_mask,
pad_mask=pad_mask,
)
return audio_signal.transpose(1, 2), length